Technologische_Grundlagen/digit_recognition/digit_recognition_NN.ipynb

484 lines
289 KiB
Plaintext
Raw Normal View History

2024-09-27 07:00:19 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "47d72f0e-bb8b-40f6-9e4f-bfeb1225b1ea",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting torch\n",
" Using cached torch-2.2.1-cp311-cp311-manylinux1_x86_64.whl.metadata (26 kB)\n",
"Collecting torchvision\n",
" Using cached torchvision-0.17.1-cp311-cp311-manylinux1_x86_64.whl.metadata (6.6 kB)\n",
"Collecting filelock (from torch)\n",
" Using cached filelock-3.13.1-py3-none-any.whl.metadata (2.8 kB)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.11/site-packages (from torch) (4.9.0)\n",
"Requirement already satisfied: sympy in /opt/conda/lib/python3.11/site-packages (from torch) (1.12)\n",
"Requirement already satisfied: networkx in /opt/conda/lib/python3.11/site-packages (from torch) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /opt/conda/lib/python3.11/site-packages (from torch) (3.1.3)\n",
"Requirement already satisfied: fsspec in /opt/conda/lib/python3.11/site-packages (from torch) (2023.12.2)\n",
"Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)\n",
" Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)\n",
" Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)\n",
" Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)\n",
" Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)\n",
" Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)\n",
" Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-curand-cu12==10.3.2.106 (from torch)\n",
" Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch)\n",
" Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch)\n",
" Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-nccl-cu12==2.19.3 (from torch)\n",
" Using cached nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl.metadata (1.8 kB)\n",
"Collecting nvidia-nvtx-cu12==12.1.105 (from torch)\n",
" Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)\n",
"Collecting triton==2.2.0 (from torch)\n",
" Using cached triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)\n",
"Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)\n",
" Downloading nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Requirement already satisfied: numpy in /opt/conda/lib/python3.11/site-packages (from torchvision) (1.26.3)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.11/site-packages (from torchvision) (10.2.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.11/site-packages (from sympy->torch) (1.3.0)\n",
"Using cached torch-2.2.1-cp311-cp311-manylinux1_x86_64.whl (755.6 MB)\n",
"Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
"Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
"Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
"Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
"Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n",
"Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
"Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
"Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
"Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
"Using cached nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl (166.0 MB)\n",
"Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
"Using cached triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (167.9 MB)\n",
"Using cached torchvision-0.17.1-cp311-cp311-manylinux1_x86_64.whl (6.9 MB)\n",
"Using cached filelock-3.13.1-py3-none-any.whl (11 kB)\n",
"Downloading nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m27.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
"\u001b[?25hInstalling collected packages: nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch, torchvision\n",
"Successfully installed filelock-3.13.1 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.19.3 nvidia-nvjitlink-cu12-12.4.99 nvidia-nvtx-cu12-12.1.105 torch-2.2.1 torchvision-0.17.1 triton-2.2.0\n"
]
}
],
"source": [
"#!pip3 install torch==1.2.0+cu92 torchvision==0.4.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html\n",
"!pip3 install torch torchvision\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3da9ca25-5156-4216-8ac7-b2347e60c34e",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d11c9b95-ea6b-457e-a3e7-c5c8de58b450",
"metadata": {},
"outputs": [],
"source": [
"## The usual imports\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"\n",
"## for printing image\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4aad0f0f-545e-416b-97f7-2678c782215a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.2.1+cu121\n"
]
}
],
"source": [
"print(torch.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "54c98b09-a734-4edf-96c0-20b78bed9d1e",
"metadata": {},
"outputs": [],
"source": [
"## parameter denoting the batch size\n",
"BATCH_SIZE = 32\n",
"\n",
"## transformations\n",
"transform = transforms.Compose(\n",
" [transforms.ToTensor()])\n",
"\n",
"## download and load training dataset\n",
"trainset = torchvision.datasets.MNIST(root='./data', train=True,\n",
" download=True, transform=transform)\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,\n",
" shuffle=True, num_workers=2)\n",
"\n",
"## download and load testing dataset\n",
"testset = torchvision.datasets.MNIST(root='./data', train=False,\n",
" download=True, transform=transform)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,\n",
" shuffle=False, num_workers=2)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "23880f4e-4790-4055-bbb4-1b0c47755b9b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f2b57e46c50>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGgCAYAAAAD9NhnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAADA/0lEQVR4nOy9ZXRVWdb9veNGSNCEIEmABC/c3d3d3V0LCvfC3Qt3d3cKKZzCEywhSAgSd3s/0H13fqe7q7uep/tN/k+vOQZjnMm+98je+5yTu+Zec5klJycnK4FAIBAIUgHmqX0CAoFAIPjvhbyEBAKBQJBqkJeQQCAQCFIN8hISCAQCQapBXkICgUAgSDXIS0ggEAgEqQZ5CQkEAoEg1SAvIYFAIBCkGuQlJBAIBIJUg7yEBAKBQJBq+I+9hFauXKk8PT2Vra2tKlmypPr111//U4cSCAQCwf+jsPxP7HT37t1q2LBhauXKlapixYpqzZo1qn79+urp06cqV65cf/jdpKQk9eHDB+Xo6KjMzMz+E6cnEAgEgv8gkpOTVXh4uHJzc1Pm5v/kt07yfwBlypRJ7tevH/4vf/78yWPHjv2n3w0ICEhWSsk/+Sf/5J/8+3/8X0BAwD995v/bfwnFxcWpu3fvqrFjx+L/69Spo65fv/43n4+NjVWxsbEmnvwXU+8cS8YoczsbpZRSZoYXqdlHW/A827+Ztr3XvkZb/8zXwAf27Qf+rpYNeFL2GHBb+zhwp13pTNsOJx6g7dXKguBWr+3A2zW9BP5b3az8/sgC4KMaHTZtzznfmOf1yQLcJjQZPKREPHj6h9bg+4csNm032DsUbXtaLQXfF1oC/OiWyuAxmXjsHJfZhynxphnPo0tFhmmvlWefDXrwDHzmz13BT05dDV53uh7fQ+OXoa363sHgSa48T5fMYeDpx/CXeMVtj8F/rZrRtG3hkhltEYVcwRN6fwUPjrAHj4tiv2S+TD5izC7T9k/n26DNe9Yr8OTsnFcBdZ3Bcx0KAvdvyc/v7LLEtD2yV2+0RWXjvefoHwlu8YV9+LW8G/ioH3eA/9JTz+sDu/agrczaXuDmnNIqx7YX4D6TPcCP1Fxh2n4S54K2cZdbg+fzeg/uYhsOfuMt932jwlbwCtv0vLOM4LyxDuP9cWgU76+Gq3n/RbongNt+5GP6du9fTNv1RnRDW4ItH5Yhef/4V8ixLovBewwbCB6dWR/bPojnZT00ENxsspM+j8RY9evvi5Sjo+MfHl+p/0A47suXLyoxMVG5uHDQXVxcVGBg4N98fvbs2Wrq1Kl/8//mdjbK3P77hP+bl5AtbwRLC/0isUlnhTZHR37Z0pLfNbflS0jx2aAs7DmhLK309y3NeKy/nq/puzbktoZzszTjg8bccF126fTwmNsZ982XkIU1J7q5naHdhsdK2S/G46Yz9JlNIs/beF0Wtjy25R/MKnM7nsff9gm5g6PxOnns9IZzTdluHHvjdRrH2tIhltyCY/9H52phznmUcp4opZRyYLtFkmEeJrNfLKzJ7VP0g3EuWJrzs8kWhmPZ/OP75e+1pxx/4/1ivC5Li0Tuy5x9aBwve8N4Wlro9r8ZS8N5GaM6xus29kvK67CPs/jDz1oaxsfajtdhvLeN55pyblnEc94Y703jvPyb67RLMLTzhkp57L+ZZ1bGff/xS+hvno+G/VlYW6Zo43kZ+8zMwnAuSv1LkopZcvK/t57Qhw8fVPbs2dX169dV+fLlTf8/c+ZMtXXrVvX8+XN83vhLKCwsTOXMmVO5r59gGnhra1589Df+tVxgvv4l1PPYWbT9UrIoeOIPecBfdOcAu1wij3XiICXa/P1tpZTKeY5/PUW58TxjMvBG8OjlC940ywPwtaNamraLT7mHNt/27uC1DtwHP9+wEHik4S/zz0X1A3R0131o210iL3hyPk9ws2f8y/tL++LgqyctAf8tWvf5+9gMaNt/oiJ43pX+4K/6ePDYBdnHh8vwl9BPb5uatu8+yY22koX4Kzk8znDDDSD/Voq/br7+wBsqz159LpmXvkNbk8wcjx+v8i9v8zDOMzPDg8t7DueG7QE9D6NGcix3H1wLXnv8CB4rkbf4p2q8nwpO/qD+EWK9eKwQL076K5M51vO+FgPfdKMSuKMrx29CwROm7V+6NUOb1dsv4AluGcF9+/MllHsbqLJ58Ma0/WpEPrTZffrjh2MGX/7sGrR0N7iD4WU74HpHve/nnEfxRfhrMeEr261C+Izp3OQi+J7XvL82Ft1s2u7zuBPaQn3ZR8XKvAT/PJf3RJFJv4OfP14SPDaT/iNjUHU+W8/XY+Tm0M0jpu2w8CSVNZ+/Cg0NVenTp1d/hH/7L6HMmTMrCwuLv/nVExQU9De/jpRSysbGRtnY2PzN/wsEAoHg/z7+7Uu0ra2tVcmSJdXZs3xrnj17VlWoUOHffTiBQCAQ/D+M/8gS7REjRqjOnTurUqVKqfLly6u1a9eqt2/fqn79+v3zLwsEAoHgvwb/dk3or1i5cqWaO3eu+vjxoypcuLBatGiRqlKlyj/9XlhYmHJyclI5lkw1iYfunp/xmY+3soFnuZ9k2raKTELb+6p8z5rHGVauFA0GP1h8HXireWPA+w88ZNre268u2vz7U6TN8QuF7BoLuVLvoP8P4KHPM4HnPhBl2k62NAiIodHgytcP1MyOelSly4z7X/zsbdp+/Sg72tyPMR5udZ56lKULV1MluTAO/bYBdZ+kFN3g9IrjE1Sbx5pe/hD4jTDqUx+iGV++/4p5Z3aOOlbvOYjz5lV/6oHmXPiooj14LgWnss8+NKMOZxOiryWoDG+jzlWvgp+fTm0k/WXqU6vvHgTvX7k9eMASvSrzduktaCt0qQ+4T7X14F6H+oM3KsvxvLK5NHhEeT3vEsOou7he5jwMKgOqalR4BP4smCH4Dz6cO/Y5U2hEt5zQdmLAXPC663kvWnMhnnK5FQUe7aLD/Jb9KQ/Y9+I8/FI1B3imI0/BjYuhAptSW3F+oSeTzUM/tEWW47xzH09d/MUyrqoNzs9n1JXu88EXfNE66p4H1HAsbfgMcj7L58DM8b+A9z3XHTzTHerWUS76XE715Xj8/KkWuF97/VxOSIxV518vTR1N6K8YMGCAGjBgwH9q9wKBQCD4PwDxjhMIBAJBqkFeQgKBQCBINfzHNKH/Kf6qCRXpPtOU6BbG8Ksa25Tx8xVLm5u20zVj7Df0FPWjPr2Pgi95WB3cfRVjokHDDVn183ScOes8P7R97ceciiRbQ5JnMPMFjNl3A48fA19Zs7Zp+1VPxqzzrGE+TVJmxtP9G1GXGduZ2ej7PpUybb+4wA4+0YOx3+F+LcED1zJvKNSL1zGjAxM21pcqpok1++T5ZGo+lhHclxlD9ypLyU/g1gupR5X8+a5p++CFsmhLMiQNOgRwrAs1Y6z+5SbmlmQ9xJyLxC86j8XSjfPsXRsP8FPD2afd2jBU/fkn5p2kX22Io6eQCaIzMoo+bhKz93+ewtyRr0WoMSRbsh8utaHm8DhOa5P9z9KhwsqZ5zm7JO/F3yKofzRzvgs+yof5UpZrdS6WRX+O7QIvztmRA5nN/74TNbwblVeAl90/0rSdZy/v46ASzFReMGwN+OANfcFjs3AiehzlsWNGa225WQ7m3qy8UhP8eVOeZ7OG7OO4+RHgFpM5x+Oc9T1kFWHIoRwXCl7Rhdrjna/UUF/7U6PLt4r99O4nPVfcR/H5lWzBezXdRn3s+Mg4daTOxn9JE5JfQgKBQCBINchLSCAQCASphv/Y6rj/LTr2P6Ns/+KdFprIZYarX3Gpd3QN/fM19jLDIu27XwC3UPxZnXcqlzp/LkerllKuNEm8Xlsvq/54kEsr1x+mhUn/n4eAJ9jRzC+8OH/6Dj7Nn+UbLujl4lMH9kRbchTP2yyG4YWMPlyq2djhLfiOmh76s8259DheMXzz+A7Db852bE+yYniniQOXvW/IpEODyeEMNWS6x7+DHD4xvPCuBkNmXdx/Az94heGFcwE6hJaYgfsyNyxfdTKYhN58xrBk0U4MZbzI7AUen16HEo1L/41
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"image = transforms.ToPILImage(mode='L')(torch.randn(1, 96, 96))\n",
"plt.imshow(image)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "aadba5ab-c84c-4f80-8c79-d30633343a61",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f2b55422f50>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGgCAYAAAAD9NhnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAC+s0lEQVR4nOy9dXQWWdb9f+NGBE0IkgRI8MahkYbgLo27uwZtaNzd3Rp3d3caGndJsECQECRCXH9/8L7Pzadmpmd6vjO/0G+fvRZr3c19nqeqbt2qSp199zlmKSkpKUogEAgEgjSAeVrvgEAgEAj+upCHkEAgEAjSDPIQEggEAkGaQR5CAoFAIEgzyENIIBAIBGkGeQgJBAKBIM0gDyGBQCAQpBnkISQQCASCNIM8hAQCgUCQZpCHkEAgEAjSDP+1h9CSJUuUl5eXsrW1VSVKlFAXLlz4b21KIBAIBH9SWP43fnTbtm3Kz89PLVmyRJUvX14tX75c1a5dWz18+FDlzJnzd7+bnJys3r59qxwdHZWZmdl/Y/cEAoFA8F9ESkqK+vLli3J3d1fm5v/kXSflv4DSpUun9OzZE/+XL1++lOHDh//T7wYFBaUopeSf/JN/8k/+/cn/BQUF/dN7/n/8TSg+Pl7duHFDDR8+HP9fo0YNdenSpb/5fFxcnIqLizPxlP9J6l1B1VGWykoppdSL1YXxHSvrRPDYUDtTO++CUPS1234GfH2lQuBJBb3An7XlkGS5QB7nrJ/qSdY8liQb8uxnv4BHu9mBx7lYgOfs8BS8bqa7pvaakQ3RV2TEHfCnnXOA+266C36uWT7wqHyupvbHwlboG9B6L3gZ25fgO8OLgx9Y/wN4bMYU8OznYtU/wotGHMT25Rm2/bUsx6zv7Ufgk6d1AD8yfpmpXXNiT/TtHbkQvPKOfuDJbtxP10wR4E7D+GZefuN9U/tCpQzos3DNBB5Z0A08sdsn8NBIe/D4aI5LpnOaDxq2FX0/n2oO7jPlGXhKtizgQTVdwHPuDQF/2UR/fkv7+egb3LUbeHRWW3DHl1HgFh85hp/KuoMP+Wmzqb2qS3307d66Hbz0iq7g5gmgKvvGJ+BJnz4rQdogUSWoi+qwcnR0/Kef/Y8/hD5+/KiSkpKUq6sr/t/V1VUFBwf/zeenTp2qxo8f/3d2zEpZmn29OZrbc6JbGB5C5rG639KCTwJ7R97oLc14cZtZ8rfN7TgkFtZGnurV0vDQMXJLC14lllbcVqI1983Kgftml05v2/hd63R8cBiP2zYd99vSnP2pf8/Chr9lZ/huOju+Ttsk8fMWNobzY8uHkOXvzDJzOx6zrfG4zMgdDOfTwprbdnI0/4d9jo48DnNb9is+B5SlQxy5BR9CqffVuJ8WvzPeSimlHNhvkWyYhykcFwtrzY1z2tyO37U053dTDHPDeL6Mcyd1fzrDmFkarhfjcVlaJPG3zDmGxnOS+lgsLf7xufx7+22M8hiP28xwTgT/P+J/bgH/iqRilpLyn60n9PbtW5UtWzZ16dIlVbZsWdP/T548WW3YsEE9fvwYnze+CUVERKgcOXIoX9XQdGEHLC+F73h4fQB/dzWrqZ35VjL6rKLI31TiHdE8noNkXYRvUnuKrQRvOnOYqd2rz1707ehZE/xlL16Q2Vfxoqgy51du6+V34OGPM5rauXZHoy/F0nBzCI8BVwGBoGZ2fKOocO6tqX3mgw/6nt/LBu7zM9+qUvLy7dHsEf/y/tiqGPiyMfqv6d9icqPvTVx68F2Hy4PnWcK3sGfdPbntAnzb3Fdavwn9/Ipvjzce5AIvUfA5+Jd4w8O0N/nnkny7+fSdnju5d3A/Mi14Dd4g0y3wny42AzeP4Lw0S+C89JkeYGrb7ua5jx7Mt6xte1aAVx85iNtK4iX/3pd/1BUY+1b9I8R5c1th3nyAnR/LN6eZn4qCr71cAdzRTY/bqAKH0beqYyNwq1cfwRPd+fYZ0MsQmkjQ43Sy5lx03YvncQw81Ro8f16ev6x2fKO7EMi5dO+H1ab2d2v6c7+/GO4x4Rz/4yNngVdeMBQ80ovnx+6NnisP+yxBX8Xe3cETbTlXQn1+X5853XUGeKueA03t6Mycow7vuV/WQ9/p7UbFqbP1l6rw8HDl5OT0u9v8j78JZcqUSVlYWPzNW09ISMjfvB0ppZSNjY2ysTG+UggEAoHgr4D/+BJta2trVaJECXXixAn8/4kTJ1S5cuX+05sTCAQCwZ8Y/5Ul2oMGDVLt2rVTJUuWVGXLllUrVqxQr169Uj179vznXxYIBALBXwb/lYdQixYt1KdPn9SECRPUu3fvVKFChdThw4eVh4fHv/V7GW5wN9+EZgUf3mSPqb34zY/oS9eIK3/sjvK73bsdAJ9/tzJ4z/aGFVQDtWa0s2sN9LnOfAFu2ZNx52SDYH+pSQF+35z61YRDa0ztJYuqo+9Zl+zguZe/57YMuk1gPWovLa3vmdp2llxAYR3KF+S4cvnBrU7dBLdw5eqrTDfDwNuv8NP7ZdCKnZ/xmBOqx4O3Pn0F/HIEV5W9jWG8ueZpHY+3c6QoXmAc9aX7vaiFmXPTKmYIx6XA+CDwuPR6PoflS4e+gOt5wXNXoo6Z7TAXFzido6627MYe8F6LWpnaTz9zW9f2/gJe8GxvcP8pi8G99/YCr1+UKy3PN9YabGRZapFJBu3K7RzndOEd1EOqlLsHns2Tus5bfz13JiTXQ5+q6AB6uDePo+bqYeBOlC5Vprv6/Pcb0QB9zwbx/Di/p27z9qEneEwA58LUBdvAz8Toc5KQ07Cg5TG1xchKHNPS+6jZWTlzTLtUOA++/bnWXG/EcdJGdQkDDw+gbla0dAD4hxnUtia/rwr+vrS+YOMyUuNuWfkU+Klaqe4Tyf/6UoP/ykNIKaV69+6tevfu/c8/KBAIBIK/LCR3nEAgEAjSDPIQEggEAkGa4T/uE/p/RUREhHJ2doZPqN4DenfCk+h52fdK+2siY7jc2/wWHbutWp4Gz2RJf8f+JlzB9+F7ekO+66Fj3JeO0tdjSbO4Wt2bnole0xgvT7RjHPpLMUNmgUgdj/2lFv1K4/t0Abe7wmwLKktG0IjC5Afmat9EK6+K3I8fS4BvnDUbfGBgE/DgFdSfwr35t82k1htN7dUli3I/rSkSPR6bB9wykr9lRglJZS5BLcx6jo6Bl5h2A317TpcBT7bm1HcIok5TsBE9bU/XUkfIslePedJHah2W7tQeXzf3BD86kH6Mjs0Zuv7wM3UFp2WptC+D/y8mA6PqI8ZsAJ82ri34p8L8gRRLjsPZ5tq3cj+e86bXCWaosHLhfk4tQS3rt0j6whq58JwM8dd+KcsVvNYsevHczvZmBoXBffqAv2lL3cYrlYQU42owD/eihcS+KyfWx0rUXDPufwhuZjA6BzfU2orLE+o0NncDwaO+55h4jOQ8e7KQWnFoPp6v8530+Zn9kb667bd57VraUMdxOcF75+SRq8B7nOwEnvG6viaiXbkfR3twDk97X83Ujo+MV2t8t/9LPiF5ExIIBAJBmkEeQgKBQCBIM/zXVsf9J7FhZh3wygMug1ts0SGYxO/4ypicma/Z4Yl8HX0cyWXUz8axf2bxdeCLvfWy3myVGT5zMizhHTSkL3jZ4VzabFxe7H/EGzxzZZ0+pevuHujL84FpRKIq8Luf8/PUZpt7HfzHqAGmtn36QPRVG3URPMEQ/7l/neE3F0NYMdmK4Z0GDjqc+ktGLhVP+RIJnvEm/y4ypgZ5XYUhs/Yev4HvOa9LhZwMYvgsKb0h56AhVOF8jmlfrjzi8tUibZnm50kmPeYJTgwjGtNBGcOIWyIYyq20gkvRE1J4nOcSdZjYfRxDr3d3Mnwz6Egb8Hy3mMiz5c9Xwdcvqg3erYFO/ZLozDBWq3kc7wNbmIanQFmGuX5pwOS2J34sCx6VXc8Va0PYaY33DvAuc/3Asz16A56UwNDh5q06YW2ZXYPRl3sc5+GbBkwcONtvOXi/HLz+4gz3Fc8DOgQXO4zyQa3sgeBLzjMcdzwnlzo3eshQoH1nXiNtW+owZLwLw9l5IjnHY0b
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"## dummy transformation\n",
"dummy_transform = transforms.Compose(\n",
" [transforms.RandomRotation(45)])\n",
"\n",
"dummy_result = dummy_transform(image)\n",
"\n",
"plt.imshow(dummy_result)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "11cd3a28-2caa-411e-8091-4f24dc600ba5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f2b55496790>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGgCAYAAAAD9NhnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAC2FElEQVR4nOyddXRW17r1V9yIoIFAIAESgrtbcCju7u5BCwWKuxeH4hQv7g5FirsEDQQJQeJu3x/c+6789rntOe295wvteeYYjLEm6323rL323nmfuZ75mKWkpKQogUAgEAjSAOZpfQACgUAg+M+FvIQEAoFAkGaQl5BAIBAI0gzyEhIIBAJBmkFeQgKBQCBIM8hLSCAQCARpBnkJCQQCgSDNIC8hgUAgEKQZ5CUkEAgEgjSDvIQEAoFAkGb4t72Eli5dqjw9PZWtra0qWbKk+uWXX/5duxIIBALBXxSW/46Nbtu2Tfn5+amlS5eqihUrqhUrVqh69eqpBw8eqJw5c/7ud5OTk9Xbt2+Vo6OjMjMz+3ccnkAgEAj+jUhJSVERERHKzc1NmZv/k986Kf8GlClTJqVPnz74Px8fn5RRo0b90+8GBgamKKXkn/yTf/JP/v3F/wUGBv7TZ/7/+S+h+Ph4df36dTVq1Cj8f+3atdXFixf/4fNxcXEqLi7OxFP+y9S7kvpGWSqr//XxJNQoDh7cNRbcaW868CL974Df/aEIj7dNiKkd9jw9+ryXvQN/X90NPKRoErjPjNfgSa4ZwKNz2JvajneD0BefIyN4QEMb8DKlH4PfPuwDnmvLS1N7+vE96EtnlgLeYr4feJa118E/tykBnvF2GLjXkmemdtcMF9DXa85AcO8O/uC3jvK43edz35vuXQbvXLuRJokc7+fTOGaeM+PBVUIi6KvGWcAd3nFcnJ/HmNqWHyPRF9iQ3825I5D7WpkM2tvtLPj45Z3AJ/bZYGovb16H217NeXd7Kees0/arSiD4/4lElaDOq0PK0dHxn372//wl9PHjR5WUlKRcXV3x/66uriooKOgfPj99+nQ1ceLE/+HArJSl2f/+JZRiaQtuYc8HiYUV+63TWfM4DP2J9vphb27LPktzvggsrNlvbseHoqU592Vmwe+n3rdx28mG8zK3Zb+VA7dtYfPbx5rOkT+XHQ0vIeN5GK/LP/Rb8EVvnU5/3rgv43f/6XEb9u1k2B7GyZzjbW5vPE5DuDfZ4nf3bWHNcbG01NzSIuH3j9tw/cwceGz2jr+/79T9loZ5Ypyz/+x6CQT/dvzXrfGvSCr/Fk3of9p5SkrK/3hAo0ePVkOHDjXx8PBw5e7u/n92HC+a8SHlM97w1/G4aPCl2X8FL+XMX1KuQ/TD5tezS9A3okpZ8Lev7cF3FtkA3jJuEPid5gvAfScOMbU/dOaYJPjwuLPt4gPy0f384Em52J+cydnUbrx2BPoWdlgFnugAqixy8BdeKHelPpV04rFc0b+UnszIhr6MLy+B38pYATyhCM9z6KPb4OVn+4E7lNO/ME7M/QF9J2JcwEc16QLueo2/jNxrvQTf6f0zeIVrXfW2S/LadqneETykQg7wD1d4Lww72h08hUOoBl9vbWqbt+ev94c3+AuvwEX+MnK+wF/sw9yOgn9foxX4k576GlnE8ThL1n4Abm7GX3S/ni2ofg+eoy79br/gPw//5y+hTJkyKQsLi3/41RMcHPwPv46UUsrGxkbZ2Nj8w/8LBAKB4O+P//Ml2tbW1qpkyZLq+PHj+P/jx4+rChUq/Ma3BAKBQPCfiH9LOG7o0KGqY8eOqlSpUqp8+fJq5cqV6tWrV6pPnz7/jt0JBAKB4C8Ks5T/Xo72f4ylS5eqWbNmqXfv3qlChQqp+fPnqypVqvzT74WHhytnZ2flqxr/nwiq8XVKgVecydVUjZxvgK8K9gV/04kaRoqVfm+bR1Kv6HbsDPjkRR3As9zg530W3gc3NywIeNzOw9R+XzUz+rK3fwEek8ixmp17J3jnH4aAJ6bSrjPd46ow12+fgW/x5K/aGj16gwe24/fzTeNKMf/eWrOwyMoxaJ//GnhR+1fgc79tD+7oH8pt96DeUbaMXl33aRh1tCft7MDNEql3OOXhtnt7McF644SG4C4X9Yq3B+Ook+XeRq0kdEgEuN16HvfHIgxKeOznGOZdqlc7FnLgqsq5BxrxsxOom8VUpU5jf/U5eHTp3OAfiuu5lPPAZ/SlPHhKXoKCYL6lj8APPCgM7rWYCzjM7zwxtZ+s4UrIsSUOgS956st976UWdnL8PPDhb2qa2q/KRinB/z8kpiSoM2qvCgsLU05OTr/72X/bwoR+/fqpfv36/bs2LxAIBIK/AcQ7TiAQCARpBnkJCQQCgSDN8G/ThP4s/q81IbuzXBZeJeMT8B2BzANyacGl5T1u3gVf+KKGqW01iw4HNmOZn7HFi3klrfP4gq98chK8Tz3miqTO+A+uTE3o8kTmKNV91Bj86RNqWVnP8e8Np+5aV/i4ndqJ66n34PX3MuP+UA1qDM97UlMY0nYP+I6etU3tLqv3o299N+osO7YtA6/0wzDwFMOfTdlnMu/EIo+HqW0WE4e+0Ir0Ldwwew74oIqtwWN8soLbPeT1LbBf88L2dETYUjIfDzQv9x2dk5nk70szWdXckJ9jE6rbScxFVe4/UyMqv49uGb/0KM19leO+rSL5CEg9xpcmLkZfybl0uHD7hVpXQANu2y6Yx+q2JwA8+ZPWnH5+eg59vmMH88uGJ1X6jVfAzQz6VERunU/1tgbzA62DqUQ4l/gIXjTTW/BzAXnAK+SiJjs7+xG9ryRey6Z7/MDz+jEX8e+IP6IJyS8hgUAgEKQZ5CUkEAgEgjTDv2113NeCjz94gK/Ox5/VMfnoc5auFH86TlpaFNzplf5Z/6Ip95W/C3+Gt0tuCf65DcNeSz4xXOQ/mjY/+frppdJHxq9HX8NaXcEt7OkfVuAjw4pZt3Kp7albBXT7u9noq1GIIbCnd6qB58lIe5tch8LB1z7lkuGMU7T9zdQNDHmlz86lzFaKY5jxPpf0vqnMKRu8h2Evi3166bPBOk6NH7sWvOO3w8FrHeSS7J3bcoG3ncelzUfHVzW197XhUuS8ToxDPejLMFVm9xDwXA257ZyXGHPLYas/f2Ykk75T7Og4cjrYG3zWttXgRaw5xkVWMcTWsZkOE5dYyL6sV2PAzWJ4feI9eD/ZBfM8Ch1gmOtmP31/FdrLe83OzRCS/Mx4XEyjkuCxzjyvz7X1sZytvAh9fcvTqqjlSaYKTNvdHHx3Wy7//jagGfiHJH2s3ScyHWL12BXg3ax6guefznDq+3oM3Vbry/DdrcHFTO2fNjNc2tkQ0o93peeW5UkaAH8NkF9CAoFAIEgzyEtIIBAIBGkGeQkJBAKBIM3wt1+ircwZJw7tUAY8MjvjztkuMaYd70INwjJSCw2rVy9EX6xh/XCj7dRWnOh4ouw/Ug/JNpQfyGqrtZbLC2k/lGwYmt0TqOt06Mu4tO1b2uU89tM6QonctMoZm+Mg+Oj8VcFrXaPecfg9l2xbd6BO8HCGtrRxvEGNoGsvWrNcC/MADwjnMninIbyeDwfS/iafzxtT2/8prXQal7wJfuhJAXC7SyyR0KQbC83tWcdxiCuvlyc7HuZ3f57I6+F7mmU7LN5Tx8kzjvZRj1dyTJdX3GhqD9jeA33Zz3K8jdZUySmc481dqH90vd0ZPCFBz/nc33IJ9tOuXPpfpx631TcTx6zHCM5D6wjOeZsPel5azqVuacRurwPg+XbTkSWHF+dlUIjW4bJt4ryzO05ro8crCoGnRHOe2WamFrai5EbwlUG+pvb9nzivIg1lVGY0+Ql86StfcNsuFDMDW1GbzLFBW1MlerNEyKfC1JWTbHjto3LwWI615jxtsGIk+PSu60ztqnaf0DftQ3nwqyO0RpeYGKsunJ4oS7QFAoFA8HVDXkICgUAgSDPIS0ggEAgEaYa/vSZkXozx2T0HmW+zNyoT+KxZ7Xg8TCtSucdpC5smd2jjsrcs7WuC2zHOvOU72sR0HUHNqP7Y0+Cv47Te8cJgjeP
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"## dummy transform \n",
"dummy2_transform = transforms.Compose(\n",
" [transforms.RandomRotation(45), transforms.RandomVerticalFlip()])\n",
"\n",
"dummy2_result = dummy2_transform(image)\n",
"\n",
"plt.imshow(dummy2_result)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9ad649e7-7617-4658-aa61-34348836427a",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAEqCAYAAAA/LasTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAD5u0lEQVR4nOydd5Bc13Xmv8455zw5YDDIORAgKYCkKAZRlERKXslZXllea7VV9nK1QXJpRVvldWmrpKWXVtkWV9bSsqlAURSTSDAARAYGEzB5ejp3T+ece//A3suehEAOprsH71fVRc5MY+a9fu/de+653/kOq1ar1cDAwMDAwMDA0ESwG30ADAwMDAwMDAxLYQIUBgYGBgYGhqaDCVAYGBgYGBgYmg4mQGFgYGBgYGBoOpgAhYGBgYGBgaHpYAIUBgYGBgYGhqaDCVAYGBgYGBgYmg4mQGFgYGBgYGBoOpgAhYGBgYGBgaHpYAIUBgYGBgYGhqajoQHK//pf/wvt7e0QCoXYuXMn3n333UYeDgMDAwMDA0OT0LAA5Z//+Z/x1a9+FV//+tdx6dIlHD58GA888ABcLlejDomBgYGBgYGhSWA1qlng3r17sWPHDjzzzDP0e/39/Xj00Ufx9NNPX/ffVqtV+Hw+yGQysFis232oDAwMDAwMDGtArVZDKpWC2WwGm339HAl3nY5pEcViERcuXMB//I//cdH3jx8/jlOnTi17f6FQQKFQoF97vV5s2rTpth8nAwMDAwMDw9rjdrthtVqv+56GbPGEw2FUKhUYDIZF3zcYDAgEAsve//TTT0OhUNAXE5wwMDAwMDC0LjKZ7IbvaahIdun2TK1WW3HL5qmnnkIikaAvt9u9XofIwMDAwMDAsMbcjDyjIVs8Wq0WHA5nWbYkFAoty6oAgEAggEAgWK/DY2BgYGBgYGgwDcmg8Pl87Ny5E6+//vqi77/++us4cOBAIw6JgYGBgYGBoYloSAYFAL72ta/h3/ybf4Ndu3Zh//79ePbZZ+FyufBHf/RHjTokBgYGBgYGhiahYQHKZz/7WUQiEfzFX/wF/H4/Nm/ejJdffhkOh6NRh8TAwMDAwMDQJDTMB+WjkEwmoVAoGn0YDAwMDAwMDB+CRCIBuVx+3fc0LIPCwMDAwMDQCpCKE/JfNpsNFouFWq0Gssav1WqoVqsNO8aNCBOg3KGwWCywWCzweLwVy72q1SpqtRrK5TJaMMnGwMDA8JEg4yOPx4NcLgePx4NQKIRIJEJ/fz+EQiFisRjy+TxSqRTy+TxmZmaQSCQafegbBiZAWQOW2vU2exTNYrHAZrPB4XAgEomWHX+tVkOlUqHnsdr5kNVDqwQwJCir/7oech7177vRZ8CwcSDXnTwP5N5mrv3Gh1z7+kwJm82GUCgEn8+HWq2GQCCAVCqFVCrF5s2bIZFI4PV6kclkEA6HkU6nW8Kji2R/CGQx2owwAcotQm5cFosFDocDnU6H3bt3g8PhIJPJIJlMYmRkBPl8nmYfmmWAY7FY4PP5kMvl2Lx5M/R6PQ4fPrzM0a9arSIejyOXy8HpdCKdTi/7XSSAGRoagtPpRLlcbprzXAoZePr7+9HW1kYfSJ1OB51Oh3w+T69dNBqF1WrFli1bEI1GMT09DZ/Ph/PnzwO45slTrVaRy+Wa9qFmuHWkUil0Oh1sNhv27NmDWq2GbDYLj8eD1157bVGrDYaNhVKpREdHB5RKJdrb2yEQCCCRSMDlciESiSASiWCz2cDj8WhWORAIIJFI4MKFCwiFQshmsyiVSkilUo0+nVXhcrkQCoW47777YLPZ6Lm8/vrrmJycbPThrQgToNwiLBYLXC4XbDYbfD4fGo0GO3fuBJfLRSwWQygUgtPppJMXmQybYTJjs9kQCASQy+Xo6+tDR0cHHn/8cWi12kXvIw9gKpXC0NAQotHost9VLpdRLpexsLAAn88H4FqPpWaEBJVWqxVbt25FtVpFtVpFe3s7Ojo6kEwm6bVzu90YHBzEAw88AI/Hg/fffx9jY2MYGhpCrVaDWCxGuVxGoVBAtVqlK5H6fWiG1kMoFEKv16Onpwf3338/qtUqYrEYRkZG8NZbbzEBygajPlMilUrR3t4Ok8mEXbt2QSqVQqVSgcfjQSwWQywWw+FwgMvlIpvNIpFI4Ne//jWy2Szm5ubg8XhQqVSa/tnncDgQCoXYuXMntm7dilqthmKxiJGRESZAaXXIloher8fg4CBMJhN27twJtVqNrq4usNls5PN55HI5HDp0COFwGCdOnMDCwgImJiaQzWZRqVQaeg4qlQqHDx+GzWbD/fffD71eD4lEsux9bDYbSqUSYrEYO3bsQD6fX/YekhlSKBTYs2cPrl69irm5OQSDQSwsLKzH6dw0drsder0eR48exZEjR2jAKJfLIZfLUSwWUSgUkEqlEI1GYTAYoNFoIBAIIBaLYbVakclkIJFIsGnTJhSLRXi9XgDX+knk83lMTU0hk8lgYWEB+Xwefr+/aQM2huUolUoMDAygv78fHR0dKJVKEIlE8Pv9N+y4ytAacLlcCAQC2Gw2bN26FSqVCjabDUqlEm1tbZBIJNDpdOByueBwOKhUKshms0ilUnjnnXeQTCZx+fJlxGIxTE5OIh6PIxKJtERwwuPxMDg4CIvFgi1btqC/vx+VSgWFQgGdnZ1wuVwIh8NNlwFiApSbhAimtFottm3bhk2bNuHxxx+HUChc8f0+nw/JZBIzMzNwu910xd3IG1kqlWLHjh3o7OzE/v37IZVKASxf9bNYLEgkEkgkEqhUqhV/F1mBkIGd/K5CodBUAQqLxYJer0dXVxe2bt2Kffv2rfpeEmDy+XyIxWLI5XKYzWbI5XL4fD5oNBocOXIEhUIB09PTYLPZ0Ol0SKVSOHHiBCKRCBXJhcNhJkBpIaRSKdra2mC322E2m1EoFFCpVCCXy2+qZwhD80O2bBwOB44ePQq73Y7t27dDJBJBJpPRqpxqtYpCoYBcLod8Po9sNouhoSH4fD689NJLCIfDyGQyDV9w3gpcLhcdHR3o6upCR0cH7HY7yuUy8vk8zGYzTCYTMpkME6C0CiRjwufzIRQKoVar4XA40N/fj2PHjkGv14PLXf3jk8vlOH78OFwuF906cLvdK2Yj1hsigk0mkxgeHkY6nUYikUC1WoVMJoNQKITD4YBUKoVCoQCPx0MqlUKxWIRAIKDiWh6PR8Vjhw8fRkdHB4aHh3H16lXMzs5iamoKlUoF5XK5oefq8/lQKpUQjUZXbUgJfDCAcTicRd/XaDS46667IBKJoFarUalUaPWTWCxGsVgEn89HJpOB3++H3++Hz+dDJpNZj1NcE8gKS6vVYsuWLdBoNIs+p/otrDfeeAMXL15ELpfbMFsfyWQSU1NTUCqVTaulWivYbDZUKhXEYjF6e3uhUqmg1+shFoshlUrB5/MRCoWQSCRw9uxZTE5ONrXG7GYxGo0YHBzE9u3bsXv3bigUCkilUlSrVUSjUWSzWYRCIUSjUYyMjCCZTFIRrN/vRzqdRjQaRT6fb7nPgsvloqenB4ODg9R7hMPhgMfjQSaTQaVSNWW/OyZAWQU2m033IBUKBex2O3bs2IHBwUEcPHgQfD7/uv9eKpXi8OHD8Pv9eO+991CtVhEMBpsqQMnlcrhw4QKCwSDdRzWZTPQG1uv19KaNx+PIZrOQSqUQCAS0/E6hUEAul0Oj0aBYLMJsNsNiseDEiRNwuVwoFAoNDVCAa00o4/E44vE4gNV1IhwOZ1FwQt6nUqmwe/fuRe9Vq9WLvrbb7ahUKohGo5ifn8ePf/zjNTyD2w+Px8PAwAB6e3vx6U9/Gl1dXYv0NfWaqmw2i+npabrS3AhkMhnMz8/D4XA0fbr+o0ICFLVajYMHD6KtrQ19fX3QarUwGAwQi8UYGxuD1+tFPB7H3NxcU4n9Pyw6nQ7bt2/H9u3bsWXLFrp1l81m6XbN1atXMT8/j5deegnRaBRut7vh49dawOFw0N7ejv7+fkilUlo4wOVyIZFIoFAobjinNQImQFkFuVwOg8GA/v5+HDp0CCqVCmazGQaDYdkKu1W
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"## functions to show an image\n",
"def imshow(img):\n",
" #img = img / 2 + 0.5 # unnormalize\n",
" npimg = img.numpy()\n",
" plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
"\n",
"## get some random training images\n",
"dataiter = iter(trainloader)\n",
"images, labels = next(dataiter)\n",
"\n",
"## show images\n",
"imshow(torchvision.utils.make_grid(images))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "21260ecc-202e-40a7-ad12-02d49055bdc6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Image batch dimensions: torch.Size([32, 1, 28, 28])\n",
"Image label dimensions: torch.Size([32])\n"
]
}
],
"source": [
"for images, labels in trainloader:\n",
" print(\"Image batch dimensions:\", images.shape)\n",
" print(\"Image label dimensions:\", labels.shape)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "00450688-b5a3-4df7-abf5-1428dd7a3792",
"metadata": {},
"outputs": [],
"source": [
"## the model\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
" self.d1 = nn.Linear(28 * 28, 128)\n",
" self.dropout = nn.Dropout(p=0.2)\n",
" self.d2 = nn.Linear(128, 10)\n",
" \n",
" def forward(self, x):\n",
" x = x.flatten(start_dim = 1)\n",
" x = self.d1(x)\n",
" x = F.relu(x)\n",
" x = self.dropout(x)\n",
" logits = self.d2(x)\n",
" out = F.softmax(logits, dim=1)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "20abb31a-584c-4efc-850a-308853300b14",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch size: torch.Size([32, 1, 28, 28])\n",
"torch.Size([32, 10])\n"
]
}
],
"source": [
"## test the model with 1 batch\n",
"model = MyModel()\n",
"for images, labels in trainloader:\n",
" print(\"batch size:\", images.shape)\n",
" out = model(images)\n",
" print(out.shape)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "71f3b1ca-6507-48d2-bde5-3d7a437fb108",
"metadata": {},
"outputs": [],
"source": [
"learning_rate = 0.001\n",
"num_epochs = 5\n",
"\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"model = MyModel()\n",
"model = model.to(device)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f3aedefc-1e05-4952-a390-2d7fae26bd2c",
"metadata": {},
"outputs": [],
"source": [
"\n",
"## utility function to compute accuracy\n",
"def get_accuracy(output, target, batch_size):\n",
" ''' Obtain accuracy for training round '''\n",
" corrects = (torch.max(output, 1)[1].view(target.size()).data == target.data).sum()\n",
" accuracy = 100.0 * corrects/batch_size\n",
" return accuracy.item()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "498cbc09-c3bd-4e57-9aab-ab192c49198a",
"metadata": {},
"outputs": [],
"source": [
"## train the model\n",
"for epoch in range(num_epochs):\n",
" train_running_loss = 0.0\n",
" train_acc = 0.0\n",
"\n",
" ## commence training\n",
" model = model.train()\n",
"\n",
" ## training step\n",
" for i, (images, labels) in enumerate(trainloader):\n",
" \n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" ## forward + backprop + loss\n",
" predictions = model(images)\n",
" loss = criterion(predictions, labels)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
"\n",
" ## update model params\n",
" optimizer.step()\n",
"\n",
" train_running_loss += loss.detach().item()\n",
" train_acc += get_accuracy(predictions, labels, BATCH_SIZE)\n",
" \n",
" model.eval()\n",
" print('Epoch: %d | Loss: %.4f | Train Accuracy: %.2f' \\\n",
" %(epoch, train_running_loss / i, train_acc/i)) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f4df7d5-b244-4a68-b06a-bb750c975cf3",
"metadata": {},
"outputs": [],
"source": [
"test_acc = 0.0\n",
"for i, (images, labels) in enumerate(testloader, 0):\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
" outputs = model(images)\n",
" test_acc += get_accuracy(outputs, labels, BATCH_SIZE)\n",
" \n",
"print('Test Accuracy: %.2f'%( test_acc/i))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76669cbb-29a0-402e-88ed-9b0ce913945f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}