add_resnet breed prediction

This commit is contained in:
Cornelius Specht 2024-03-11 10:28:11 +01:00
parent e72cb31cef
commit dc60647dd9
4 changed files with 1453 additions and 0 deletions

View File

@ -0,0 +1,453 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 77,
"id": "2ecee6b6-fb14-472e-9614-d7be5cff6689",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting torch\n",
" Using cached torch-2.2.1-cp311-cp311-manylinux1_x86_64.whl.metadata (26 kB)\n",
"Collecting filelock (from torch)\n",
" Using cached filelock-3.13.1-py3-none-any.whl.metadata (2.8 kB)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.11/site-packages (from torch) (4.9.0)\n",
"Requirement already satisfied: sympy in /opt/conda/lib/python3.11/site-packages (from torch) (1.12)\n",
"Requirement already satisfied: networkx in /opt/conda/lib/python3.11/site-packages (from torch) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /opt/conda/lib/python3.11/site-packages (from torch) (3.1.3)\n",
"Requirement already satisfied: fsspec in /opt/conda/lib/python3.11/site-packages (from torch) (2023.12.2)\n",
"Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)\n",
" Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)\n",
" Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)\n",
" Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)\n",
" Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)\n",
" Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)\n",
" Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-curand-cu12==10.3.2.106 (from torch)\n",
" Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch)\n",
" Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch)\n",
" Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-nccl-cu12==2.19.3 (from torch)\n",
" Using cached nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl.metadata (1.8 kB)\n",
"Collecting nvidia-nvtx-cu12==12.1.105 (from torch)\n",
" Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)\n",
"Collecting triton==2.2.0 (from torch)\n",
" Using cached triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)\n",
"Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)\n",
" Using cached nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.11/site-packages (from sympy->torch) (1.3.0)\n",
"Using cached torch-2.2.1-cp311-cp311-manylinux1_x86_64.whl (755.6 MB)\n",
"Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
"Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
"Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
"Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
"Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n",
"Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
"Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
"Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
"Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
"Using cached nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl (166.0 MB)\n",
"Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
"Using cached triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (167.9 MB)\n",
"Using cached filelock-3.13.1-py3-none-any.whl (11 kB)\n",
"Using cached nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
"Installing collected packages: nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch\n",
"Successfully installed filelock-3.13.1 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.19.3 nvidia-nvjitlink-cu12-12.4.99 nvidia-nvtx-cu12-12.1.105 torch-2.2.1 triton-2.2.0\n",
"Collecting torchvision\n",
" Using cached torchvision-0.17.1-cp311-cp311-manylinux1_x86_64.whl.metadata (6.6 kB)\n",
"Requirement already satisfied: numpy in /opt/conda/lib/python3.11/site-packages (from torchvision) (1.26.3)\n",
"Requirement already satisfied: torch==2.2.1 in /opt/conda/lib/python3.11/site-packages (from torchvision) (2.2.1)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.11/site-packages (from torchvision) (10.2.0)\n",
"Requirement already satisfied: filelock in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (3.13.1)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (4.9.0)\n",
"Requirement already satisfied: sympy in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (1.12)\n",
"Requirement already satisfied: networkx in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (3.1.3)\n",
"Requirement already satisfied: fsspec in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (2023.12.2)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (12.1.105)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (12.1.105)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (12.1.105)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (8.9.2.26)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (12.1.3.1)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (11.0.2.54)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (10.3.2.106)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (11.4.5.107)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (12.1.0.106)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (2.19.3)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (12.1.105)\n",
"Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.11/site-packages (from torch==2.2.1->torchvision) (2.2.0)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.11/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2.1->torchvision) (12.4.99)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch==2.2.1->torchvision) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.11/site-packages (from sympy->torch==2.2.1->torchvision) (1.3.0)\n",
"Using cached torchvision-0.17.1-cp311-cp311-manylinux1_x86_64.whl (6.9 MB)\n",
"Installing collected packages: torchvision\n",
"Successfully installed torchvision-0.17.1\n",
"Requirement already satisfied: pillow in /opt/conda/lib/python3.11/site-packages (10.2.0)\n"
]
}
],
"source": [
"!pip install torch\n",
"!pip install torchvision\n",
"!pip install pillow"
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "daa858b3-4270-4c00-aad7-b761a4cb6ddc",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torchvision import models, transforms\n",
"from PIL import Image"
]
},
{
"cell_type": "code",
"execution_count": 79,
"id": "34134c75-20a7-4864-8fbc-dc0a26726cba",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.11/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/opt/conda/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n"
]
},
{
"data": {
"text/plain": [
"ResNet(\n",
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (layer1): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer2): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer3): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (4): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (5): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (layer4): Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
" (fc): Linear(in_features=2048, out_features=1000, bias=True)\n",
")"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"resnet_model = models.resnet50(pretrained=True)\n",
"resnet_model.eval()\n"
]
},
{
"cell_type": "code",
"execution_count": 80,
"id": "6806320f-56fd-460e-9546-edc1be25994b",
"metadata": {},
"outputs": [],
"source": [
"transform = transforms.Compose([\n",
" transforms.Resize((224,224)), \n",
" transforms.ToTensor(), \n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "f8148b14-8239-4088-9781-d6e5f7705485",
"metadata": {},
"outputs": [],
"source": [
"image_path = \"hund123.jpg\"\n",
"image = Image.open(image_path)\n",
"input_tensor = transform(image)\n",
"input_batch = input_tensor.unsqueeze(0)"
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "ddf5ebf4-ee64-4fc9-8bc6-1c2fc495ecfa",
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad(): \n",
" output = resnet_model(input_batch)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"id": "d0f2416b-bd1d-465d-b106-10067ad80bec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted class index: 250\n"
]
}
],
"source": [
"_, predicted_class = torch.max(output,1)\n",
"print(f\"Predicted class index: {predicted_class.item()}\")"
]
},
{
"cell_type": "code",
"execution_count": 84,
"id": "346cf68b-dcfd-4a5d-a256-35b74c1e5641",
"metadata": {},
"outputs": [],
"source": [
"## check predicted class on list\n",
"## https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
]
},
{
"cell_type": "code",
"execution_count": 90,
"id": "6bf1a1f6-b4c5-42b8-a390-79416427cdf2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"250 Siberian husky\n"
]
}
],
"source": [
"import requests\n",
"import re\n",
"\n",
"## read list of labels\n",
"url = \"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt\"\n",
"content = requests.get(url)\n",
"rawd = content.text\n",
"\n",
"##using regex to convert key from int to string value\n",
"t = re.sub(r'(\\d+)', '\\'\\\\1\\'',rawd)\n",
"\n",
"##convert into dict\n",
"data = eval(t)\n",
"\n",
"##iterate through dict and get key of predicted class\n",
"for k,v in data.items(): \n",
" if k == str(predicted_class.item()): \n",
" print(k,v)\n",
"\n",
"\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 180 KiB