Technologische_Grundlagen/penguins/05a - Deep Neural Networks ...

820 lines
2.4 MiB
Plaintext
Raw Normal View History

2024-09-27 06:50:26 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Deep Learning with PyTorch\n",
"\n",
"Classical machine learning relies on using statistics to determine relationships between features and labels, and can be very effective for creating predictive models. However, a massive growth in the availability of data coupled with advances in the computing technology required to process it has led to the emergence of new machine learning techniques that mimic the way the brain processes information in a structure called an artificial neural network.\n",
"\n",
"PyTorch is a framework for creating machine learning models, including deep neural networks (DNNs). In this example, we'll use PyTorch to create a simple neural network that classifies penguins into species based on the length and depth of their culmen (bill), their flipper length, and their body mass.\n",
"\n",
"> **Citation**: The penguins dataset used in the this exercise is a subset of data collected and made available by [Dr. Kristen\n",
"Gorman](https://www.uaf.edu/cfos/people/faculty/detail/kristen-gorman.php)\n",
"and the [Palmer Station, Antarctica LTER](https://pal.lternet.edu/), a\n",
"member of the [Long Term Ecological Research\n",
"Network](https://lternet.edu/).\n",
"\n",
"## Explore the Dataset\n",
"\n",
"Before we start using PyTorch to create a model, let's load the data we need from the Palmer Islands penguins dataset, which contains observations of three different species of penguin.\n",
"\n",
"> **Note**: In reality, you can solve the penguin classification problem easily using classical machine learning techniques without the need for a deep learning model; but it's a useful, easy to understand dataset with which to demonstrate the principles of neural networks in this notebook."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>CulmenLength</th>\n",
" <th>CulmenDepth</th>\n",
" <th>FlipperLength</th>\n",
" <th>BodyMass</th>\n",
" <th>Species</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>56</th>\n",
" <td>39.0</td>\n",
" <td>17.5</td>\n",
" <td>18.6</td>\n",
" <td>35.50</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>107</th>\n",
" <td>38.2</td>\n",
" <td>20.0</td>\n",
" <td>19.0</td>\n",
" <td>39.00</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>302</th>\n",
" <td>50.5</td>\n",
" <td>18.4</td>\n",
" <td>20.0</td>\n",
" <td>34.00</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>43.2</td>\n",
" <td>19.0</td>\n",
" <td>19.7</td>\n",
" <td>47.75</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>284</th>\n",
" <td>46.0</td>\n",
" <td>18.9</td>\n",
" <td>19.5</td>\n",
" <td>41.50</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <td>37.9</td>\n",
" <td>18.6</td>\n",
" <td>17.2</td>\n",
" <td>31.50</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>73</th>\n",
" <td>45.8</td>\n",
" <td>18.9</td>\n",
" <td>19.7</td>\n",
" <td>41.50</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>34.1</td>\n",
" <td>18.1</td>\n",
" <td>19.3</td>\n",
" <td>34.75</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>70</th>\n",
" <td>33.5</td>\n",
" <td>19.0</td>\n",
" <td>19.0</td>\n",
" <td>36.00</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>97</th>\n",
" <td>40.3</td>\n",
" <td>18.5</td>\n",
" <td>19.6</td>\n",
" <td>43.50</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" CulmenLength CulmenDepth FlipperLength BodyMass Species\n",
"56 39.0 17.5 18.6 35.50 0\n",
"107 38.2 20.0 19.0 39.00 0\n",
"302 50.5 18.4 20.0 34.00 2\n",
"109 43.2 19.0 19.7 47.75 0\n",
"284 46.0 18.9 19.5 41.50 2\n",
"28 37.9 18.6 17.2 31.50 0\n",
"73 45.8 18.9 19.7 41.50 0\n",
"8 34.1 18.1 19.3 34.75 0\n",
"70 33.5 19.0 19.0 36.00 0\n",
"97 40.3 18.5 19.6 43.50 0"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"# load the training dataset (excluding rows with null values)\n",
"penguins = pd.read_csv('penguins.csv').dropna()\n",
"\n",
"# Deep Learning models work best when features are on similar scales\n",
"# In a real solution, we'd implement some custom normalization for each feature, but to keep things simple\n",
"# we'll just rescale the FlipperLength and BodyMass so they're on a similar scale to the bill measurements\n",
"penguins['FlipperLength'] = penguins['FlipperLength']/10\n",
"penguins['BodyMass'] = penguins['BodyMass']/100\n",
"#penguins['CulmenLength'] = penguins['CulmenLength']/10\n",
"#penguins['CulmenDepth'] = penguins['CulmenDepth']/10\n",
"# The dataset is too small to be useful for deep learning\n",
"# So we'll oversample it to increase its size\n",
"for i in range(1,3):\n",
" penguins = penguins._append(penguins)\n",
"\n",
"# Display a random sample of 10 observations\n",
"sample = penguins.sample(10)\n",
"sample"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The **Species** column is the label our model will predict. Each label value represents a class of penguin species, encoded as 0, 1, or 2. The following code shows the actual species to which these class labels corrrespond."
]
},
{
"attachments": {
"4d9e5ba8-736d-4535-bdf3-6f6f41dd502a.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAawAAAJ2CAYAAAADhbZ4AAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImV\nVwdYU8kWnluSkEBoAQSkhN4EkRpASggtgPQiiEpIAoQSYyCI2NFFBdcuFrChqyKKHRALithZFBv2\nRREFZV0s2JU3KaDrvvK9831z73//OfOfM+fOLQOA2imOSJSDqgOQK8wXxwT708cnJdNJPQABGKAA\nczCKw80TMaOiwgG0ofPf7d0t6A3tur1U65/9/9U0ePw8LgBIFMRpvDxuLsSHAcAruSJxPgBEKW82\nLV8kxbABLTFMEOJFUpwhx5VSnCbH+2U+cTEsiFsAUFLhcMQZAKhehTy9gJsBNVT7IXYU8gRCANTo\nEPvk5k7hQZwKsTX0EUEs1Wek/aCT8TfNtGFNDidjGMvnIjOlAEGeKIcz/f8sx/+23BzJUAxL2FQy\nxSEx0jnDut3OnhImxSoQ9wnTIiIh1oT4g4An84cYpWRKQuLl/qgBN48FawZ0IHbkcQLCIDaAOEiY\nExGu4NPSBUFsiOEKQQsF+ew4iHUhXsTPC4xV+GwRT4lRxELr0sUspoK/wBHL4kpjPZRkxzMV+q8z\n+WyFPqZalBmXCDEFYvMCQUIExKoQO+Rlx4YpfMYWZbIihnzEkhhp/uYQx/CFwf5yfawgXRwUo/Av\nzc0bmi+2JVPAjlDgg/mZcSHy+mAtXI4sfzgX7CpfyIwf0uHnjQ8fmguPHxAonzvWwxfGxyp0Pojy\n/WPkY3GKKCdK4Y+b8nOCpbwpxC55BbGKsXhCPlyQcn08XZQfFSfPEy/K4oRGyfPBl4NwwAIBgA4k\nsKWBKSALCNr66vvglbwnCHCAGGQAPrBXMEMjEmU9QniMBUXgT4j4IG94nL+slw8KIP91mJUf7UG6\nrLdANiIbPIU4F4SBHHgtkY0SDkdLAE8gI/hHdA5sXJhvDmzS/n/PD7HfGSZkwhWMZCgiXW3IkxhI\nDCCGEIOINrg+7oN74eHw6AebE87APYbm8d2f8JTQTnhMuEnoJNyZLCgW/5TlONAJ9YMUtUj7sRa4\nJdR0xf1xb6gOlXEdXB/Y4y4wDhP3hZFdIctS5C2tCv0n7b/N4Ie7ofAjO5JR8giyH9n655Gqtqqu\nwyrSWv9YH3muacP1Zg33/Byf9UP1efAc9rMntgg7hJ3HTmMXseNYPaBjTVgD1oqdkOLh1fVEtrqG\nosXI8smGOoJ/xBu6s9JK5jnWOPY6fpH35fMLpe9owJoimi4WZGTm05nwi8Cns4Vch1F0J0cnFwCk\n3xf56+tNtOy7gei0fufm/wGAd9Pg4OCx71xoEwAH3OHjf/Q7Z82Anw5lAC4c5UrEBXIOlx4I8C2h\nBp80PWAEzIA1nI8TcANewA8EglAQCeJAEpgEs8+E61wMpoGZYB4oAWVgOVgDNoDNYBvYBfaCg6Ae\nHAenwTlwGVwFN8E9uHq6wQvQD96BzwiCkBAqQkP0EGPEArFDnBAG4oMEIuFIDJKEpCIZiBCRIDOR\n+UgZshLZgGxFqpEDyFHkNHIRaUfuII+QXuQ18gnFUBVUCzVELdHRKANlomFoHDoRzUCnokXoAnQp\nug6tQvegdehp9DJ6E+1EX6ADGMCUMR3MBLPHGBgLi8SSsXRMjM3GSrFyrAqrxRrhfb6OdWJ92Eec\niNNwOm4PV3AIHo9z8an4bHwJvgHfhdfhLfh1/BHej38jUAkGBDuCJ4FNGE/IIEwjlBDKCTsIRwhn\n4bPUTXhHJBJ1iFZEd/gsJhGziDOIS4gbifuIp4jtxC7iAIlE0iPZkbxJkSQOKZ9UQlpP2kNqIl0j\ndZM+KCkrGSs5KQUpJSsJlYqVypV2K51Uuqb0TOkzWZ1sQfYkR5J55OnkZeTt5EbyFXI3+TNFg2JF\n8abEUbIo8yjrKLWUs5T7lDfKysqmyh7K0coC5bnK65T3K19QfqT8UUVTxVaFpZKiIlFZqrJT5ZTK\nHZU3VCrVkupHTabmU5dSq6lnqA+pH1Rpqg6qbFWe6hzVCtU61WuqL9XIahZqTLVJakVq5WqH1K6o\n9amT1S3VWeoc9dnqFepH1TvUBzRoGmM0IjVyNZZo7Na4qNGjSdK01AzU5Gku0NymeUazi4bRzGgs\nGpc2n7addpbWrUXUstJia2VplWnt1WrT6tfW1HbRTtAu1K7QPqHdqYPpWOqwdXJ0lukc1Lml82mE\n4QjmCP6IxSNqR1wb8V53pK6fLl+3VHef7k3dT3p0vUC9bL0VevV6D/RxfVv9aP1p+pv0z+r3jdQa\n6TWSO7J05MGRdw1QA1uDGIMZBtsMWg0GDI0Mgw1FhusNzxj2GekY+RllGa02OmnUa0wz9jEWGK82\nbjJ+TtemM+k59HX0Fnq/iYFJiInEZKtJm8lnUyvTeNNi032mD8woZgyzdLPVZs1m/ebG5uPMZ5rX\nmN+1IFswLDIt1lqct3hvaWWZaLnQst6yx0rXim1VZFVjdd+aau1rPdW6yvqGDdGGYZNts9Hmqi1q\n62qbaVthe8UOtXOzE9httGsfRRjlMUo4qmpUh72KPdO+wL7G/pGDjkO4Q7FDvcPL0eajk0evGH1+\n9DdHV8ccx+2O98ZojgkdUzymccxrJ1snrlOF0w1nqnOQ8xznBudXLnYufJdNLrddaa7jXBe6Nrt+\ndXN3E7vVuvW6m7unule6dzC0GFGMJYwLHgQPf485Hsc9Pnq6eeZ7HvT8y8veK9trt1fPWKux/LHb\nx3Z5m3pzvLd6d/rQfVJ9tvh0+pr4cnyrfB/7mfnx/Hb4PWPaMLOYe5gv/R39xf5H/N+zPFmzWKcC\nsIDggNKAtkDNwPjADYEPg0yDMoJqgvqDXYNnBJ8KIYSEhawI6WAbsrnsanZ/qHvorNCWMJWw2LAN\nYY/DbcPF4Y3j0HGh41aNux9hESGMqI8EkezIVZEPoqyipkYdiyZGR0VXRD+NGRMzM+Z8LC12cuzu\n2Hdx/nHL4u7FW8dL4psT1BJSEqoT3icGJK5M7Bw/evys8ZeT9JMESQ3JpOSE5B3JAxMCJ6yZ0J3i\nmlKScmui1cTCiRcn6U/KmXRistpkzuRDqYTUxNTdqV84kZwqzkAaO60yrZ/L4q7lvuD58Vbzevne\n/JX8Z+ne6SvTezK8M1Zl9Gb6ZpZn9glYgg2CV1khWZuz3mdHZu/MHsxJzNmXq5SbmntUqCnMFrZM\nMZpSOKVdZCcqEXVO9Zy6Zmq/OEy8Iw/Jm5jXkK8Ff+RbJdaSXySPCnwKKgo+TEuYdqhQo1BY2Drd\ndvri6c+Kgop+m4HP4M5onmkyc97MR7OYs7bORmanzW6eYzZnwZzuucFzd82jzMue93uxY/HK4rfz\nE+c3LjBcMHdB1y/Bv9SUqJaISzoWei3cvAhfJFjUtth58frF30p5pZfKHMvKy74s4S659OuYX9f9\nOrg0fWnbMrdlm5YTlwuX31rhu2LXSo2VRSu7Vo1bVbeavrp09ds1k9dcLHcp37yWslaytnNd+LqG\n9ebrl6//siFzw80K/4p9lQaViyvfb+RtvLbJb1PtZsPNZZs/bRFsub01eGtdlWVV+TbitoJtT7cn\nbD//G+O36h36O8p2fN0p3Nm5K2ZXS7V7dfVug93LatAaSU3vnpQ9V/cG7G2ota/duk9nX9l+sF+y\n//mB1AO3DoYdbD7EOFR72OJw5RHakdI6pG56XX99Zn1nQ1JD+9HQo82NXo1Hjjkc23nc5HjFCe0T\ny05STi44OdhU1DRwSnSq73TG6a7myc33zow/c6MluqXtbNjZC+eCzp05zzzfdMH7wvGLnhePXmJc\nqr/sdrmu1bX1yO+uvx9pc2uru+J+peGqx9XG9rHtJ6/5Xjt9PeD6uRvsG5dvRtxsvxV/63ZHSkfn\nbd7tnjs5d17dLbj7+d7c+4T7pQ/UH5Q/NHhY9YfNH/s63TpPPAp41Po49vG9Lm7Xiyd5T750L3hK\nfVr+zPhZdY9Tz/HeoN6r
},
"848c47a7-07b4-45b0-bdb5-b397a08b2145.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAzgAAAMYCAYAAAAHIEK0AAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImV\nVwdYU8kWnluSkEBoAQSkhN4EkRpASggtgPQiiEpIAoQSYyCI2NFFBdcuFrChqyKKHRALithZFBv2\nRREFZV0s2JU3KaDrvvK9831z73//OfOfM+fOLQOA2imOSJSDqgOQK8wXxwT708cnJdNJPQABGKAA\nczCKw80TMaOiwgG0ofPf7d0t6A3tur1U65/9/9U0ePw8LgBIFMRpvDxuLsSHAcAruSJxPgBEKW82\nLV8kxbABLTFMEOJFUpwhx5VSnCbH+2U+cTEsiFsAUFLhcMQZAKhehTy9gJsBNVT7IXYU8gRCANTo\nEPvk5k7hQZwKsTX0EUEs1Wek/aCT8TfNtGFNDidjGMvnIjOlAEGeKIcz/f8sx/+23BzJUAxL2FQy\nxSEx0jnDut3OnhImxSoQ9wnTIiIh1oT4g4An84cYpWRKQuLl/qgBN48FawZ0IHbkcQLCIDaAOEiY\nExGu4NPSBUFsiOEKQQsF+ew4iHUhXsTPC4xV+GwRT4lRxELr0sUspoK/wBHL4kpjPZRkxzMV+q8z\n+WyFPqZalBmXCDEFYvMCQUIExKoQO+Rlx4YpfMYWZbIihnzEkhhp/uYQx/CFwf5yfawgXRwUo/Av\nzc0bmi+2JVPAjlDgg/mZcSHy+mAtXI4sfzgX7CpfyIwf0uHnjQ8fmguPHxAonzvWwxfGxyp0Pojy\n/WPkY3GKKCdK4Y+b8nOCpbwpxC55BbGKsXhCPlyQcn08XZQfFSfPEy/K4oRGyfPBl4NwwAIBgA4k\nsKWBKSALCNr66vvglbwnCHCAGGQAPrBXMEMjEmU9QniMBUXgT4j4IG94nL+slw8KIP91mJUf7UG6\nrLdANiIbPIU4F4SBHHgtkY0SDkdLAE8gI/hHdA5sXJhvDmzS/n/PD7HfGSZkwhWMZCgiXW3IkxhI\nDCCGEIOINrg+7oN74eHw6AebE87APYbm8d2f8JTQTnhMuEnoJNyZLCgW/5TlONAJ9YMUtUj7sRa4\nJdR0xf1xb6gOlXEdXB/Y4y4wDhP3hZFdIctS5C2tCv0n7b/N4Ie7ofAjO5JR8giyH9n655Gqtqqu\nwyrSWv9YH3muacP1Zg33/Byf9UP1efAc9rMntgg7hJ3HTmMXseNYPaBjTVgD1oqdkOLh1fVEtrqG\nosXI8smGOoJ/xBu6s9JK5jnWOPY6fpH35fMLpe9owJoimi4WZGTm05nwi8Cns4Vch1F0J0cnFwCk\n3xf56+tNtOy7gei0fufm/wGAd9Pg4OCx71xoEwAH3OHjf/Q7Z82Anw5lAC4c5UrEBXIOlx4I8C2h\nBp80PWAEzIA1nI8TcANewA8EglAQCeJAEpgEs8+E61wMpoGZYB4oAWVgOVgDNoDNYBvYBfaCg6Ae\nHAenwTlwGVwFN8E9uHq6wQvQD96BzwiCkBAqQkP0EGPEArFDnBAG4oMEIuFIDJKEpCIZiBCRIDOR\n+UgZshLZgGxFqpEDyFHkNHIRaUfuII+QXuQ18gnFUBVUCzVELdHRKANlomFoHDoRzUCnokXoAnQp\nug6tQvegdehp9DJ6E+1EX6ADGMCUMR3MBLPHGBgLi8SSsXRMjM3GSrFyrAqrxRrhfb6OdWJ92Eec\niNNwOm4PV3AIHo9z8an4bHwJvgHfhdfhLfh1/BHej38jUAkGBDuCJ4FNGE/IIEwjlBDKCTsIRwhn\n4bPUTXhHJBJ1iFZEd/gsJhGziDOIS4gbifuIp4jtxC7iAIlE0iPZkbxJkSQOKZ9UQlpP2kNqIl0j\ndZM+KCkrGSs5KQUpJSsJlYqVypV2K51Uuqb0TOkzWZ1sQfYkR5J55OnkZeTt5EbyFXI3+TNFg2JF\n8abEUbIo8yjrKLWUs5T7lDfKysqmyh7K0coC5bnK65T3K19QfqT8UUVTxVaFpZKiIlFZqrJT5ZTK\nHZU3VCrVkupHTabmU5dSq6lnqA+pH1Rpqg6qbFWe6hzVCtU61WuqL9XIahZqTLVJakVq5WqH1K6o\n9amT1S3VWeoc9dnqFepH1TvUBzRoGmM0IjVyNZZo7Na4qNGjSdK01AzU5Gku0NymeUazi4bRzGgs\nGpc2n7addpbWrUXUstJia2VplWnt1WrT6tfW1HbRTtAu1K7QPqHdqYPpWOqwdXJ0lukc1Lml82mE\n4QjmCP6IxSNqR1wb8V53pK6fLl+3VHef7k3dT3p0vUC9bL0VevV6D/RxfVv9aP1p+pv0z+r3jdQa\n6TWSO7J05MGRdw1QA1uDGIMZBtsMWg0GDI0Mgw1FhusNzxj2GekY+RllGa02OmnUa0wz9jEWGK82\nbjJ+TtemM+k59HX0Fnq/iYFJiInEZKtJm8lnUyvTeNNi032mD8woZgyzdLPVZs1m/ebG5uPMZ5rX\nmN+1IFswLDIt1lqct3hvaWWZaLnQst6yx0rXim1VZFVjdd+aau1rPdW6yvqGDdGGYZNts9Hmqi1q\n62qbaVthe8UOtXOzE9httGsfRRjlMUo4qmpUh72KPdO+wL7G/pGDjkO4Q7FDvcPL0eajk0evGH1+\n9DdHV8ccx+2O98ZojgkdUzymccxrJ1snrlOF0w1nqnOQ8xznBudXLnYufJdNLrddaa7jXBe6Nrt+\ndXN3E7vVuvW6m7unule6dzC0GFGMJYwLHgQPf485Hsc9Pnq6eeZ7HvT8y8veK9trt1fPWKux/LHb\nx3Z5m3pzvLd6d/rQfVJ9tvh0+pr4cnyrfB/7mfnx/Hb4PWPaMLOYe5gv/R39xf5H/N+zPFmzWKcC\nsIDggNKAtkDNwPjADYEPg0yDMoJqgvqDXYNnBJ8KIYSEhawI6WAbsrnsanZ/qHvorNCWMJWw2LAN\nYY/DbcPF4Y3j0HGh41aNux9hESGMqI8EkezIVZEPoqyipkYdiyZGR0VXRD+NGRMzM+Z8LC12cuzu\n2Hdx/nHL4u7FW8dL4psT1BJSEqoT3icGJK5M7Bw/evys8ZeT9JMESQ3JpOSE5B3JAxMCJ6yZ0J3i\nmlKScmui1cTCiRcn6U/KmXRistpkzuRDqYTUxNTdqV84kZwqzkAaO60yrZ/L4q7lvuD58Vbzevne\n/JX8Z+ne6SvTezK8M1Zl9Gb6ZpZn9glYgg2CV1khWZuz3mdHZu/MHsxJzNmXq5SbmntUqCnMFrZM\nMZpSOKVdZCcqEXVO9Zy6Zmq/OEy8Iw/Jm5jXkK8Ff+RbJdaSXySPCnwKKgo+TEuYdqhQo1BY2Drd\ndvri6c+Kgop+m4HP4M5onmkyc97MR7OYs7bORmanzW6eYzZnwZzuucFzd82jzMue93uxY/HK4rfz\nE+c3LjBcMHdB1y/Bv9SUqJaISzoWei3cvAhfJFjUtth58frF30p5pZfKHMvKy74s4S659OuYX9f9\nOrg0fWnbMrdlm5YTlwuX31rhu2LXSo2VRSu7Vo1bVbeavrp09ds1k9dcLHcp37yWslaytnNd+LqG\n9ebrl6//siFzw80K/4p9lQaViyvfb+RtvLbJb1PtZsPNZZs/bRFsub01eGtdlWVV+TbitoJtT7cn\nbD//G+O36h36O8p2fN0p3Nm5K2ZXS7V7dfVug93LatAaSU3vnpQ9V/cG7G2ota/duk9nX9l+sF+y\n//mB1AO3DoYdbD7EOFR72OJw5RHakdI6pG56XX99Zn1nQ1JD+9HQo82NXo1Hjjkc23nc5HjFCe0T\ny05STi44OdhU1DRwSnSq73TG6a7myc33zow/c6MluqXtbNjZC+eCzp05zzzfdMH7wvGLnhePXmJc\nqr/sdrmu1bX1yO+uvx9pc2uru+J+peGqx9XG9rHtJ6/5Xjt9PeD6uRvsG5dvRtxsvxV/63ZHSkfn\nbd7tnjs5d17dLbj7+d7c+4T7pQ/UH5Q/NHhY9YfNH/s63TpPPAp41Po49vG9Lm7Xiyd5T750L3hK\nfVr+zPhZdY9Tz/HeoN6r
},
"da25e3d5-20c3-48f3-bc9d-6ac38633299d.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAygAAALYCAYAAACXEdZcAAAMQGlDQ1BJQ0MgUHJvZmlsZQAASImV\nVwdYU8kWnluSkEBoAQSkhN4EkRpASggtgPQiiEpIAoQSYyCI2NFFBdcuFrChqyKKHRALithZFBv2\nRREFZV0s2JU3KaDrvvK9831z73//OfOfM+fOLQOA2imOSJSDqgOQK8wXxwT708cnJdNJPQABGKAA\nczCKw80TMaOiwgG0ofPf7d0t6A3tur1U65/9/9U0ePw8LgBIFMRpvDxuLsSHAcAruSJxPgBEKW82\nLV8kxbABLTFMEOJFUpwhx5VSnCbH+2U+cTEsiFsAUFLhcMQZAKhehTy9gJsBNVT7IXYU8gRCANTo\nEPvk5k7hQZwKsTX0EUEs1Wek/aCT8TfNtGFNDidjGMvnIjOlAEGeKIcz/f8sx/+23BzJUAxL2FQy\nxSEx0jnDut3OnhImxSoQ9wnTIiIh1oT4g4An84cYpWRKQuLl/qgBN48FawZ0IHbkcQLCIDaAOEiY\nExGu4NPSBUFsiOEKQQsF+ew4iHUhXsTPC4xV+GwRT4lRxELr0sUspoK/wBHL4kpjPZRkxzMV+q8z\n+WyFPqZalBmXCDEFYvMCQUIExKoQO+Rlx4YpfMYWZbIihnzEkhhp/uYQx/CFwf5yfawgXRwUo/Av\nzc0bmi+2JVPAjlDgg/mZcSHy+mAtXI4sfzgX7CpfyIwf0uHnjQ8fmguPHxAonzvWwxfGxyp0Pojy\n/WPkY3GKKCdK4Y+b8nOCpbwpxC55BbGKsXhCPlyQcn08XZQfFSfPEy/K4oRGyfPBl4NwwAIBgA4k\nsKWBKSALCNr66vvglbwnCHCAGGQAPrBXMEMjEmU9QniMBUXgT4j4IG94nL+slw8KIP91mJUf7UG6\nrLdANiIbPIU4F4SBHHgtkY0SDkdLAE8gI/hHdA5sXJhvDmzS/n/PD7HfGSZkwhWMZCgiXW3IkxhI\nDCCGEIOINrg+7oN74eHw6AebE87APYbm8d2f8JTQTnhMuEnoJNyZLCgW/5TlONAJ9YMUtUj7sRa4\nJdR0xf1xb6gOlXEdXB/Y4y4wDhP3hZFdIctS5C2tCv0n7b/N4Ie7ofAjO5JR8giyH9n655Gqtqqu\nwyrSWv9YH3muacP1Zg33/Byf9UP1efAc9rMntgg7hJ3HTmMXseNYPaBjTVgD1oqdkOLh1fVEtrqG\nosXI8smGOoJ/xBu6s9JK5jnWOPY6fpH35fMLpe9owJoimi4WZGTm05nwi8Cns4Vch1F0J0cnFwCk\n3xf56+tNtOy7gei0fufm/wGAd9Pg4OCx71xoEwAH3OHjf/Q7Z82Anw5lAC4c5UrEBXIOlx4I8C2h\nBp80PWAEzIA1nI8TcANewA8EglAQCeJAEpgEs8+E61wMpoGZYB4oAWVgOVgDNoDNYBvYBfaCg6Ae\nHAenwTlwGVwFN8E9uHq6wQvQD96BzwiCkBAqQkP0EGPEArFDnBAG4oMEIuFIDJKEpCIZiBCRIDOR\n+UgZshLZgGxFqpEDyFHkNHIRaUfuII+QXuQ18gnFUBVUCzVELdHRKANlomFoHDoRzUCnokXoAnQp\nug6tQvegdehp9DJ6E+1EX6ADGMCUMR3MBLPHGBgLi8SSsXRMjM3GSrFyrAqrxRrhfb6OdWJ92Eec\niNNwOm4PV3AIHo9z8an4bHwJvgHfhdfhLfh1/BHej38jUAkGBDuCJ4FNGE/IIEwjlBDKCTsIRwhn\n4bPUTXhHJBJ1iFZEd/gsJhGziDOIS4gbifuIp4jtxC7iAIlE0iPZkbxJkSQOKZ9UQlpP2kNqIl0j\ndZM+KCkrGSs5KQUpJSsJlYqVypV2K51Uuqb0TOkzWZ1sQfYkR5J55OnkZeTt5EbyFXI3+TNFg2JF\n8abEUbIo8yjrKLWUs5T7lDfKysqmyh7K0coC5bnK65T3K19QfqT8UUVTxVaFpZKiIlFZqrJT5ZTK\nHZU3VCrVkupHTabmU5dSq6lnqA+pH1Rpqg6qbFWe6hzVCtU61WuqL9XIahZqTLVJakVq5WqH1K6o\n9amT1S3VWeoc9dnqFepH1TvUBzRoGmM0IjVyNZZo7Na4qNGjSdK01AzU5Gku0NymeUazi4bRzGgs\nGpc2n7addpbWrUXUstJia2VplWnt1WrT6tfW1HbRTtAu1K7QPqHdqYPpWOqwdXJ0lukc1Lml82mE\n4QjmCP6IxSNqR1wb8V53pK6fLl+3VHef7k3dT3p0vUC9bL0VevV6D/RxfVv9aP1p+pv0z+r3jdQa\n6TWSO7J05MGRdw1QA1uDGIMZBtsMWg0GDI0Mgw1FhusNzxj2GekY+RllGa02OmnUa0wz9jEWGK82\nbjJ+TtemM+k59HX0Fnq/iYFJiInEZKtJm8lnUyvTeNNi032mD8woZgyzdLPVZs1m/ebG5uPMZ5rX\nmN+1IFswLDIt1lqct3hvaWWZaLnQst6yx0rXim1VZFVjdd+aau1rPdW6yvqGDdGGYZNts9Hmqi1q\n62qbaVthe8UOtXOzE9httGsfRRjlMUo4qmpUh72KPdO+wL7G/pGDjkO4Q7FDvcPL0eajk0evGH1+\n9DdHV8ccx+2O98ZojgkdUzymccxrJ1snrlOF0w1nqnOQ8xznBudXLnYufJdNLrddaa7jXBe6Nrt+\ndXN3E7vVuvW6m7unule6dzC0GFGMJYwLHgQPf485Hsc9Pnq6eeZ7HvT8y8veK9trt1fPWKux/LHb\nx3Z5m3pzvLd6d/rQfVJ9tvh0+pr4cnyrfB/7mfnx/Hb4PWPaMLOYe5gv/R39xf5H/N+zPFmzWKcC\nsIDggNKAtkDNwPjADYEPg0yDMoJqgvqDXYNnBJ8KIYSEhawI6WAbsrnsanZ/qHvorNCWMJWw2LAN\nYY/DbcPF4Y3j0HGh41aNux9hESGMqI8EkezIVZEPoqyipkYdiyZGR0VXRD+NGRMzM+Z8LC12cuzu\n2Hdx/nHL4u7FW8dL4psT1BJSEqoT3icGJK5M7Bw/evys8ZeT9JMESQ3JpOSE5B3JAxMCJ6yZ0J3i\nmlKScmui1cTCiRcn6U/KmXRistpkzuRDqYTUxNTdqV84kZwqzkAaO60yrZ/L4q7lvuD58Vbzevne\n/JX8Z+ne6SvTezK8M1Zl9Gb6ZpZn9glYgg2CV1khWZuz3mdHZu/MHsxJzNmXq5SbmntUqCnMFrZM\nMZpSOKVdZCcqEXVO9Zy6Zmq/OEy8Iw/Jm5jXkK8Ff+RbJdaSXySPCnwKKgo+TEuYdqhQo1BY2Drd\ndvri6c+Kgop+m4HP4M5onmkyc97MR7OYs7bORmanzW6eYzZnwZzuucFzd82jzMue93uxY/HK4rfz\nE+c3LjBcMHdB1y/Bv9SUqJaISzoWei3cvAhfJFjUtth58frF30p5pZfKHMvKy74s4S659OuYX9f9\nOrg0fWnbMrdlm5YTlwuX31rhu2LXSo2VRSu7Vo1bVbeavrp09ds1k9dcLHcp37yWslaytnNd+LqG\n9ebrl6//siFzw80K/4p9lQaViyvfb+RtvLbJb1PtZsPNZZs/bRFsub01eGtdlWVV+TbitoJtT7cn\nbD//G+O36h36O8p2fN0p3Nm5K2ZXS7V7dfVug93LatAaSU3vnpQ9V/cG7G2ota/duk9nX9l+sF+y\n//mB1AO3DoYdbD7EOFR72OJw5RHakdI6pG56XX99Zn1nQ1JD+9HQo82NXo1Hjjkc23nc5HjFCe0T\ny05STi44OdhU1DRwSnSq73TG6a7myc33zow/c6MluqXtbNjZC+eCzp05zzzfdMH7wvGLnhePXmJc\nqr/sdrmu1bX1yO+uvx9pc2uru+J+peGqx9XG9rHtJ6/5Xjt9PeD6uRvsG5dvRtxsvxV/63ZHSkfn\nbd7tnjs5d17dLbj7+d7c+4T7pQ/UH5Q/NHhY9YfNH/s63TpPPAp41Po49vG9Lm7Xiyd5T750L3hK\nfVr+zPhZdY9Tz/HeoN6r
}
},
"cell_type": "markdown",
"metadata": {},
"source": [
"![adelie.png](attachment:da25e3d5-20c3-48f3-bc9d-6ac38633299d.png)\n",
"![gentoo.png](attachment:848c47a7-07b4-45b0-bdb5-b397a08b2145.png)\n",
"![chinstrap.png](attachment:4d9e5ba8-736d-4535-bdf3-6f6f41dd502a.png)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['CulmenLength' 'CulmenDepth' 'FlipperLength' 'BodyMass' 'Species'] SpeciesName\n",
"[ 46.5 14.4 21.7 49.0 1 ] Gentoo\n",
"[ 35.5 16.2 19.5 33.5 0 ] Adelie\n",
"[ 44.1 18.0 21.0 40.0 0 ] Adelie\n",
"[ 50.0 16.3 23.0 57.0 1 ] Gentoo\n",
"[ 50.4 15.7 22.2 57.5 1 ] Gentoo\n",
"[ 41.5 18.5 20.1 40.0 0 ] Adelie\n",
"[ 36.9 18.6 18.9 35.0 0 ] Adelie\n",
"[ 47.6 18.3 19.5 38.5 2 ] Chinstrap\n",
"[ 36.4 17.1 18.4 28.5 0 ] Adelie\n",
"[ 45.3 13.8 20.8 42.0 1 ] Gentoo\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_131/3285247502.py:4: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`\n",
" print('[',row[0], row[1], row[2],row[3], int(row[4]), ']',penguin_classes[int(row[-1])])\n"
]
}
],
"source": [
"penguin_classes = ['Adelie', 'Gentoo', 'Chinstrap']\n",
"print(sample.columns[0:5].values, 'SpeciesName')\n",
"for index, row in penguins.sample(10).iterrows():\n",
" print('[',row[0], row[1], row[2],row[3], int(row[4]), ']',penguin_classes[int(row[-1])])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As is common in a supervised learning problem, we'll split the dataset into a set of records with which to train the model, and a smaller set with which to validate the trained model."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training Set: 957, Test Set: 411 \n",
"\n",
"Sample of features and labels:\n",
"[51.1 16.5 22.5 52.5] 1 (Gentoo)\n",
"[50.7 19.7 20.3 40.5] 2 (Chinstrap)\n",
"[49.5 16.2 22.9 58. ] 1 (Gentoo)\n",
"[39.3 20.6 19. 36.5] 0 (Adelie)\n",
"[42.5 20.7 19.7 45. ] 0 (Adelie)\n",
"[50. 15.3 22. 55.5] 1 (Gentoo)\n",
"[50.2 18.7 19.8 37.75] 2 (Chinstrap)\n",
"[50.7 19.7 20.3 40.5] 2 (Chinstrap)\n",
"[49.1 14.5 21.2 46.25] 1 (Gentoo)\n",
"[43.2 16.6 18.7 29. ] 2 (Chinstrap)\n",
"[38.8 17.6 19.1 32.75] 0 (Adelie)\n",
"[37.8 17.1 18.6 33. ] 0 (Adelie)\n",
"[45.8 14.2 21.9 47. ] 1 (Gentoo)\n",
"[43.8 13.9 20.8 43. ] 1 (Gentoo)\n",
"[36. 17.1 18.7 37. ] 0 (Adelie)\n",
"[43.3 13.4 20.9 44. ] 1 (Gentoo)\n",
"[36. 18.5 18.6 31. ] 0 (Adelie)\n",
"[41.1 19. 18.2 34.25] 0 (Adelie)\n",
"[33.1 16.1 17.8 29. ] 0 (Adelie)\n",
"[40.9 13.7 21.4 46.5] 1 (Gentoo)\n",
"[45.2 17.8 19.8 39.5] 2 (Chinstrap)\n",
"[48.4 14.6 21.3 58.5] 1 (Gentoo)\n",
"[43.6 13.9 21.7 49. ] 1 (Gentoo)\n",
"[38.5 17.9 19. 33.25] 0 (Adelie)\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"features = ['CulmenLength','CulmenDepth','FlipperLength','BodyMass']\n",
"label = 'Species'\n",
" \n",
"# Split data 70%-30% into training set and test set\n",
"x_train, x_test, y_train, y_test = train_test_split(penguins[features].values,\n",
" penguins[label].values,\n",
" test_size=0.30,\n",
" random_state=0)\n",
"\n",
"print ('Training Set: %d, Test Set: %d \\n' % (len(x_train), len(x_test)))\n",
"print(\"Sample of features and labels:\")\n",
"\n",
"# Take a look at the first 25 training features and corresponding labels\n",
"for n in range(0,24):\n",
" print(x_train[n], y_train[n], '(' + penguin_classes[y_train[n]] + ')')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The *features* are the measurements for each penguin observation, and the *label* is a numeric value that indicates the species of penguin that the observation represents (Adelie, Gentoo, or Chinstrap).\n",
"\n",
"## Install and import the PyTorch libraries\n",
"\n",
"Since we plan to use PyTorch to create our penguin classifier, we'll need to run the following two cells to install and import the PyTorch libraries we intend to use. The specific installation of of PyTorch depends on your operating system and whether your computer has graphics processing units (GPUs) that can be used for high-performance processing via *cuda*. You can find detailed instructions at https://pytorch.org/get-started/locally/."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"#!pip install torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting torch\n",
" Downloading torch-2.4.1-cp311-cp311-manylinux1_x86_64.whl.metadata (26 kB)\n",
"Collecting filelock (from torch)\n",
" Downloading filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.11/site-packages (from torch) (4.9.0)\n",
"Requirement already satisfied: sympy in /opt/conda/lib/python3.11/site-packages (from torch) (1.12)\n",
"Requirement already satisfied: networkx in /opt/conda/lib/python3.11/site-packages (from torch) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /opt/conda/lib/python3.11/site-packages (from torch) (3.1.3)\n",
"Requirement already satisfied: fsspec in /opt/conda/lib/python3.11/site-packages (from torch) (2023.12.2)\n",
"Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)\n",
" Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)\n",
" Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)\n",
" Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n",
" Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)\n",
" Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)\n",
" Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-curand-cu12==10.3.2.106 (from torch)\n",
" Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch)\n",
" Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch)\n",
" Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-nccl-cu12==2.20.5 (from torch)\n",
" Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)\n",
"Collecting nvidia-nvtx-cu12==12.1.105 (from torch)\n",
" Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)\n",
"Collecting triton==3.0.0 (from torch)\n",
" Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)\n",
"Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)\n",
" Downloading nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.11/site-packages (from sympy->torch) (1.3.0)\n",
"Downloading torch-2.4.1-cp311-cp311-manylinux1_x86_64.whl (797.1 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m797.1/797.1 MB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n",
"\u001b[?25hUsing cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
"Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
"Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
"Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
"Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n",
"\u001b[?25hUsing cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
"Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
"Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
"Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
"Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)\n",
"Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
"Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.4/209.4 MB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
"\u001b[?25hDownloading filelock-3.16.1-py3-none-any.whl (16 kB)\n",
"Downloading nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl (19.7 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m19.7/19.7 MB\u001b[0m \u001b[31m41.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n",
"\u001b[?25hInstalling collected packages: nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch\n",
"Successfully installed filelock-3.16.1 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.6.68 nvidia-nvtx-cu12-12.1.105 torch-2.4.1 triton-3.0.0\n",
"Libraries imported - ready to use PyTorch 2.4.1+cu121\n"
]
}
],
"source": [
"!pip install torch\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.utils.data as td\n",
"\n",
"# Set random seed for reproducability\n",
"torch.manual_seed(0)\n",
"\n",
"print(\"Libraries imported - ready to use PyTorch\", torch.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare the data for PyTorch\n",
"\n",
"PyTorch makes use of *data loaders* to load training and validation data in batches. We've already loaded the data into numpy arrays, but we need to wrap those in PyTorch datasets (in which the data is converted to PyTorch *tensor* objects) and create loaders to read batches from those datasets."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ready to load data\n"
]
}
],
"source": [
"# Create a dataset and loader for the training data and labels\n",
"train_x = torch.Tensor(x_train).float()\n",
"train_y = torch.Tensor(y_train).long()\n",
"train_ds = td.TensorDataset(train_x,train_y)\n",
"train_loader = td.DataLoader(train_ds, batch_size=20,\n",
" shuffle=False, num_workers=1)\n",
"\n",
"# Create a dataset and loader for the test data and labels\n",
"test_x = torch.Tensor(x_test).float()\n",
"test_y = torch.Tensor(y_test).long()\n",
"test_ds = td.TensorDataset(test_x,test_y)\n",
"test_loader = td.DataLoader(test_ds, batch_size=20,\n",
" shuffle=False, num_workers=1)\n",
"print('Ready to load data')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define a neural network\n",
"\n",
"Now we're ready to define our neural network. In this case, we'll create a network that consists of 3 fully-connected layers:\n",
"* An input layer that receives an input value for each feature (in this case, the four penguin measurements) and applies a *ReLU* activation function.\n",
"* A hidden layer that receives ten inputs and applies a *ReLU* activation function.\n",
"* An output layer that generates a non-negative numeric output for each penguin species (which a loss function will translate into classification probabilities for each of the three possible penguin species)."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PenguinNet(\n",
" (fc1): Linear(in_features=4, out_features=10, bias=True)\n",
" (fc2): Linear(in_features=10, out_features=10, bias=True)\n",
" (fc3): Linear(in_features=10, out_features=3, bias=True)\n",
")\n"
]
}
],
"source": [
"# Number of hidden layer nodes\n",
"hl = 10\n",
"\n",
"# Define the neural network\n",
"class PenguinNet(nn.Module):\n",
" def __init__(self):\n",
" super(PenguinNet, self).__init__()\n",
" self.fc1 = nn.Linear(len(features), hl)\n",
" self.fc2 = nn.Linear(hl, hl)\n",
" self.fc3 = nn.Linear(hl, len(penguin_classes))\n",
"\n",
" def forward(self, x):\n",
" x = torch.relu(self.fc1(x))\n",
" x = torch.relu(self.fc2(x))\n",
" x = torch.relu(self.fc3(x))\n",
" return x\n",
"\n",
"# Create a model instance from the network\n",
"model = PenguinNet()\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the model\n",
"\n",
"To train the model, we need to repeatedly feed the training values forward through the network, use a loss function to calculate the loss, use an optimizer to backpropagate the weight and bias value adjustments, and validate the model using the test data we withheld.\n",
"\n",
"To do this, we'll create a function to train and optimize the model, and function to test the model. Then we'll call these functions iteratively over 50 epochs, logging the loss and accuracy statistics for each epoch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def train(model, data_loader, optimizer):\n",
" # Set the model to training mode\n",
" model.train()\n",
" train_loss = 0\n",
" \n",
" for batch, tensor in enumerate(data_loader):\n",
" data, target = tensor\n",
" #feedforward\n",
" optimizer.zero_grad()\n",
" out = model(data)\n",
" loss = loss_criteria(out, target)\n",
" train_loss += loss.item()\n",
"\n",
" # backpropagate\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" #Return average loss\n",
" avg_loss = train_loss / (batch+1)\n",
" print('Training set: Average loss: {:.6f}'.format(avg_loss))\n",
" return avg_loss\n",
" \n",
" \n",
"def test(model, data_loader):\n",
" # Switch the model to evaluation mode (so we don't backpropagate)\n",
" model.eval()\n",
" test_loss = 0\n",
" correct = 0\n",
"\n",
" with torch.no_grad():\n",
" batch_count = 0\n",
" for batch, tensor in enumerate(data_loader):\n",
" batch_count += 1\n",
" data, target = tensor\n",
" # Get the predictions\n",
" out = model(data)\n",
"\n",
" # calculate the loss\n",
" test_loss += loss_criteria(out, target).item()\n",
"\n",
" # Calculate the accuracy\n",
" _, predicted = torch.max(out.data, 1)\n",
" correct += torch.sum(target==predicted).item()\n",
" \n",
" # Calculate the average loss and total accuracy for this epoch\n",
" avg_loss = test_loss/batch_count\n",
" print('Validation set: Average loss: {:.6f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
" avg_loss, correct, len(data_loader.dataset),\n",
" 100. * correct / len(data_loader.dataset)))\n",
" \n",
" # return average loss for the epoch\n",
" return avg_loss\n",
"\n",
"# Specify the loss criteria (we'll use CrossEntropyLoss for multi-class classification)\n",
"loss_criteria = nn.CrossEntropyLoss()\n",
"\n",
"# Use an \"Adam\" optimizer to adjust weights\n",
"# (see https://pytorch.org/docs/stable/optim.html#algorithms for details of supported algorithms)\n",
"learning_rate = 0.001\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
"optimizer.zero_grad()\n",
"\n",
"# We'll track metrics for each epoch in these arrays\n",
"epoch_nums = []\n",
"training_loss = []\n",
"validation_loss = []\n",
"\n",
"# Train over 50 epochs\n",
"epochs = 10\n",
"for epoch in range(1, epochs + 1):\n",
"\n",
" # print the epoch number\n",
" print('Epoch: {}'.format(epoch))\n",
" \n",
" # Feed training data into the model to optimize the weights\n",
" train_loss = train(model, train_loader, optimizer)\n",
" \n",
" # Feed the test data into the model to check its performance\n",
" test_loss = test(model, test_loader)\n",
" \n",
" # Log the metrics for this epoch\n",
" epoch_nums.append(epoch)\n",
" training_loss.append(train_loss)\n",
" validation_loss.append(test_loss)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"While the training process is running, let's try to understand what's happening:\n",
"\n",
"1. In each *epoch*, the full set of training data is passed forward through the network. There are four features for each observation, and four corresponding nodes in the input layer - so the features for each observation are passed as a vector of four values to that layer. However, for efficiency, the feature vectors are grouped into batches; so actually a matrix of multiple feature vectors is fed in each time.\n",
"2. The matrix of feature values is processed by a function that performs a weighted sum using initialized weights and bias values. The result of this function is then processed by the activation function for the input layer to constrain the values passed to the nodes in the next layer.\n",
"3. The weighted sum and activation functions are repeated in each layer. Note that the functions operate on vectors and matrices rather than individual scalar values. In other words, the forward pass is essentially a series of nested linear algebra functions. This is the reason data scientists prefer to use computers with graphical processing units (GPUs), since these are optimized for matrix and vector calculations.\n",
"4. In the final layer of the network, the output vectors contain a calculated value for each possible class (in this case, classes 0, 1, and 2). This vector is processed by a *loss function* that converts these values to probabilities and determines how far they are from the expected values based on the actual classes - so for example, suppose the output for a Gentoo penguin (class 1) observation is \\[0.3, 0.4, 0.3\\]. The correct prediction would be \\[0.0, 1.0, 0.0\\], so the variance between the predicted and actual values (how far away each predicted value is from what it should be) is \\[0.3, 0.6, 0.3\\]. This variance is aggregated for each batch and maintained as a running aggregate to calculate the overall level of error (*loss*) incurred by the training data for the epoch. \n",
"5. At the end of each epoch, the validation data is passed through the network, and its loss and accuracy (proportion of correct predictions based on the highest probability value in the output vector) are also calculated. It's important to do this because it enables us to compare the performance of the model using data on which it was not trained, helping us determine if it will generalize well for new data or if it's *overfitted* to the training data.\n",
"6. After all the data has been passed forward through the network, the output of the loss function for the *training* data (but <u>not</u> the *validation* data) is passed to the opimizer. The precise details of how the optimizer processes the loss vary depending on the specific optimization algorithm being used; but fundamentally you can think of the entire network, from the input layer to the loss function as being one big nested (*composite*) function. The optimizer applies some differential calculus to calculate *partial derivatives* for the function with respect to each weight and bias value that was used in the network. It's possible to do this efficiently for a nested function due to something called the *chain rule*, which enables you to determine the derivative of a composite function from the derivatives of its inner function and outer functions. You don't really need to worry about the details of the math here (the optimizer does it for you), but the end result is that the partial derivatives tell us about the slope (or *gradient*) of the loss function with respect to each weight and bias value - in other words, we can determine whether to increase or decrease the weight and bias values in order to decrease the loss.\n",
"7. Having determined in which direction to adjust the weights and biases, the optimizer uses the *learning rate* to determine by how much to adjust them; and then works backwards through the network in a process called *backpropagation* to assign new values to the weights and biases in each layer.\n",
"8. Now the next epoch repeats the whole training, validation, and backpropagation process starting with the revised weights and biases from the previous epoch - which hopefully will result in a lower level of loss.\n",
"9. The process continues like this for 50 epochs.\n",
"\n",
"## Review training and validation loss\n",
"\n",
"After training is complete, we can examine the loss metrics we recorded while training and validating the model. We're really looking for two things:\n",
"* The loss should reduce with each epoch, showing that the model is learning the right weights and biases to predict the correct labels.\n",
"* The training loss and validation loss should follow a similar trend, showing that the model is not overfitting to the training data.\n",
"\n",
"Let's plot the loss metrics and see:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"from matplotlib import pyplot as plt\n",
"\n",
"plt.plot(epoch_nums, training_loss)\n",
"plt.plot(epoch_nums, validation_loss)\n",
"plt.xlabel('epoch')\n",
"plt.ylabel('loss')\n",
"plt.legend(['training', 'validation'], loc='upper right')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## View the learned weights and biases\n",
"\n",
"The trained model consists of the final weights and biases that were determined by the optimizer during training. Based on our network model we should expect the following values for each layer:\n",
"* Layer 1: There are four input values going to ten output nodes, so there should be 10 x 4 weights and 10 bias values.\n",
"* Layer 2: There are ten input values going to ten output nodes, so there should be 10 x 10 weights and 10 bias values.\n",
"* Layer 3: There are ten input values going to three output nodes, so there should be 3 x 10 weights and 3 bias values."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"for param_tensor in model.state_dict():\n",
" print(param_tensor, \"\\n\", model.state_dict()[param_tensor].numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate model performance\n",
"\n",
"So, is the model any good? The raw accuracy reported from the validation data would seem to indicate that it predicts pretty well; but it's typically useful to dig a little deeper and compare the predictions for each possible class. A common way to visualize the performance of a classification model is to create a *confusion matrix* that shows a crosstab of correct and incorrect predictions for each class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Pytorch doesn't have a built-in confusion matrix metric, so we'll use SciKit-Learn\n",
"from sklearn.metrics import confusion_matrix\n",
"import numpy as np\n",
"\n",
"# Set the model to evaluate mode\n",
"model.eval()\n",
"\n",
"# Get predictions for the test data\n",
"x = torch.Tensor(x_test).float()\n",
"_, predicted = torch.max(model(x).data, 1)\n",
"\n",
"# Plot the confusion matrix\n",
"cm = confusion_matrix(y_test, predicted.numpy())\n",
"plt.imshow(cm, interpolation=\"nearest\") ##, cmap=plt.cm.Blues\n",
"plt.colorbar()\n",
"tick_marks = np.arange(len(penguin_classes))\n",
"plt.xticks(tick_marks, penguin_classes, rotation=45)\n",
"plt.yticks(tick_marks, penguin_classes)\n",
"plt.xlabel(\"Predicted Species\")\n",
"plt.ylabel(\"Actual Species\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The confusion matrix should show a strong diagonal line indicating that there are more correct than incorrect predictions for each class.\n",
"\n",
"## Save the trained model\n",
"Now that we have a model we believe is reasonably accurate, we can save its trained weights for use later."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Save the model weights\n",
"model_file = 'models/penguin_classifier.pt'\n",
"torch.save(model.state_dict(), model_file)\n",
"del model\n",
"print('model saved as', model_file)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use the trained model\n",
"\n",
"When we have a new penguin observation, we can use the model to predict the species."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# New penguin features\n",
"x_new = [[50.4,15.3,20,50]]\n",
"print ('New sample: {}'.format(x_new))\n",
"\n",
"# Create a new model class and load weights\n",
"model = PenguinNet()\n",
"model.load_state_dict(torch.load(model_file))\n",
"\n",
"# Set model to evaluation mode\n",
"model.eval()\n",
"\n",
"# Get a prediction for the new data sample\n",
"x = torch.Tensor(x_new).float()\n",
"_, predicted = torch.max(model(x).data, 1)\n",
"\n",
"print('Prediction:',penguin_classes[predicted.item()])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learn more\n",
"\n",
"This notebook was designed to help you understand the basic concepts and principles involved in deep neural networks, using a simple PyTorch example. To learn more about PyTorch, take a look at the [tutorials on the PyTorch web site](https://pytorch.org/tutorials/)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}