{"cells":[{"cell_type":"code","execution_count":1,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:03:57.784108Z","iopub.status.busy":"2023-11-27T08:03:57.783742Z","iopub.status.idle":"2023-11-27T08:04:02.510978Z","shell.execute_reply":"2023-11-27T08:04:02.509977Z","shell.execute_reply.started":"2023-11-27T08:03:57.784070Z"},"trusted":true},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","from torchvision import datasets, transforms, models\n","from torch.utils.data import DataLoader, Dataset, random_split\n","import pandas as pd\n","from PIL import Image\n","import torch.nn.functional as F\n","from tqdm import tqdm\n","from sklearn.metrics import f1_score,accuracy_score,classification_report\n","import numpy as np\n","import shutil\n","import os"]},{"cell_type":"code","execution_count":2,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:02.513433Z","iopub.status.busy":"2023-11-27T08:04:02.513036Z","iopub.status.idle":"2023-11-27T08:04:02.517597Z","shell.execute_reply":"2023-11-27T08:04:02.516631Z","shell.execute_reply.started":"2023-11-27T08:04:02.513406Z"},"trusted":true},"outputs":[],"source":["TRAIN_FOLDER = \"icon_data/train/\"\n","# TEST_FOLDER = \"/kaggle/input/sadat-icons/icon_data/test/\""]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:02.520149Z","iopub.status.busy":"2023-11-27T08:04:02.519847Z","iopub.status.idle":"2023-11-27T08:04:03.479542Z","shell.execute_reply":"2023-11-27T08:04:03.478422Z","shell.execute_reply.started":"2023-11-27T08:04:02.520125Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["A subdirectory or file images already exists.\n"]}],"source":["!mkdir images"]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:03.481576Z","iopub.status.busy":"2023-11-27T08:04:03.481188Z","iopub.status.idle":"2023-11-27T08:04:03.495143Z","shell.execute_reply":"2023-11-27T08:04:03.494255Z","shell.execute_reply.started":"2023-11-27T08:04:03.481534Z"},"trusted":true},"outputs":[],"source":["class_2_id = {}\n","id_2_class = {}\n","i=0\n","for folder in os.listdir(TRAIN_FOLDER):\n"," class_2_id[folder]=i\n"," id_2_class[i]=folder\n"," i+=1"]},{"cell_type":"code","execution_count":5,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:03.498319Z","iopub.status.busy":"2023-11-27T08:04:03.498024Z","iopub.status.idle":"2023-11-27T08:04:03.506912Z","shell.execute_reply":"2023-11-27T08:04:03.506063Z","shell.execute_reply.started":"2023-11-27T08:04:03.498293Z"},"trusted":true},"outputs":[{"data":{"text/plain":["{0: 'back',\n"," 1: 'Briefcase',\n"," 2: 'Call',\n"," 3: 'Camera',\n"," 4: 'Circle',\n"," 5: 'Cloud',\n"," 6: 'delete',\n"," 7: 'Down',\n"," 8: 'edit',\n"," 9: 'Export',\n"," 10: 'Face',\n"," 11: 'Folder',\n"," 12: 'Globe',\n"," 13: 'Google',\n"," 14: 'Heart',\n"," 15: 'Home',\n"," 16: 'Image',\n"," 17: 'Import',\n"," 18: 'Info',\n"," 19: 'Link',\n"," 20: 'Location',\n"," 21: 'Mail',\n"," 22: 'menu',\n"," 23: 'Merge',\n"," 24: 'Message',\n"," 25: 'Microphone',\n"," 26: 'more',\n"," 27: 'Music',\n"," 28: 'Mute',\n"," 29: 'Person',\n"," 30: 'Phone',\n"," 31: 'plus',\n"," 32: 'QRCODE',\n"," 33: 'Refresh',\n"," 34: 'search',\n"," 35: 'settings',\n"," 36: 'share',\n"," 37: 'Star',\n"," 38: 'Tick',\n"," 39: 'Up',\n"," 40: 'vidCam',\n"," 41: 'Video',\n"," 42: 'Volume'}"]},"execution_count":5,"metadata":{},"output_type":"execute_result"}],"source":["id_2_class"]},{"cell_type":"code","execution_count":6,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:03.508112Z","iopub.status.busy":"2023-11-27T08:04:03.507853Z","iopub.status.idle":"2023-11-27T08:04:03.519107Z","shell.execute_reply":"2023-11-27T08:04:03.518124Z","shell.execute_reply.started":"2023-11-27T08:04:03.508089Z"},"trusted":true},"outputs":[{"data":{"text/plain":["{0: 'back',\n"," 1: 'Briefcase',\n"," 2: 'Call',\n"," 3: 'Camera',\n"," 4: 'Circle',\n"," 5: 'Cloud',\n"," 6: 'delete',\n"," 7: 'Down',\n"," 8: 'edit',\n"," 9: 'Export',\n"," 10: 'Face',\n"," 11: 'Folder',\n"," 12: 'Globe',\n"," 13: 'Google',\n"," 14: 'Heart',\n"," 15: 'Home',\n"," 16: 'Image',\n"," 17: 'Import',\n"," 18: 'Info',\n"," 19: 'Link',\n"," 20: 'Location',\n"," 21: 'Mail',\n"," 22: 'menu',\n"," 23: 'Merge',\n"," 24: 'Message',\n"," 25: 'Microphone',\n"," 26: 'more',\n"," 27: 'Music',\n"," 28: 'Mute',\n"," 29: 'Person',\n"," 30: 'Phone',\n"," 31: 'plus',\n"," 32: 'QRCODE',\n"," 33: 'Refresh',\n"," 34: 'search',\n"," 35: 'settings',\n"," 36: 'share',\n"," 37: 'Star',\n"," 38: 'Tick',\n"," 39: 'Up',\n"," 40: 'vidCam',\n"," 41: 'Video',\n"," 42: 'Volume'}"]},"execution_count":6,"metadata":{},"output_type":"execute_result"}],"source":["id_2_class"]},{"cell_type":"code","execution_count":25,"metadata":{},"outputs":[],"source":["#save to json\n","\n","import json\n","with open('id_2_class.json', 'w') as fp:\n"," json.dump(id_2_class, fp)"]},{"cell_type":"code","execution_count":7,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:03.521040Z","iopub.status.busy":"2023-11-27T08:04:03.520311Z","iopub.status.idle":"2023-11-27T08:04:37.185102Z","shell.execute_reply":"2023-11-27T08:04:37.184220Z","shell.execute_reply.started":"2023-11-27T08:04:03.521004Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["c:\\Users\\User\\miniconda3\\envs\\textgen\\lib\\site-packages\\PIL\\Image.py:970: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n"," warnings.warn(\n"]}],"source":["for_csv = {\n"," \"id\":[],\n"," \"label\":[],\n"," \"text_label\":[]\n","}\n","IMG_SIZE = 224\n","i = 0\n","\n","FOLDER = TRAIN_FOLDER\n","for folder in os.listdir(FOLDER):\n"," folder_path = os.path.join(FOLDER, folder)\n"," for image in os.listdir(folder_path):\n"," image_path = os.path.join(folder_path,image)\n"," try:\n"," img = Image.open(image_path).resize((IMG_SIZE,IMG_SIZE)).convert(\"L\")\n"," img.save(\"images/\"+str(i)+\".jpg\")\n"," for_csv[\"id\"].append(i)\n"," for_csv[\"label\"].append(class_2_id[folder])\n"," for_csv[\"text_label\"].append(folder)\n"," i+=1\n"," except:\n"," continue\n","\n","# FOLDER = TEST_FOLDER\n","# for folder in os.listdir(FOLDER):\n","# folder_path = os.path.join(FOLDER, folder)\n","# for image in os.listdir(folder_path):\n","# image_path = os.path.join(folder_path,image)\n","# try:\n","# img = Image.open(image_path).resize((IMG_SIZE,IMG_SIZE)).convert(\"RGB\")\n","# img.save(\"images/\"+str(i)+\".jpg\")\n","# for_csv[\"id\"].append(i)\n","# for_csv[\"label\"].append(class_2_id[folder])\n","# for_csv[\"text_label\"].append(folder)\n","# i+=1\n","# except:\n","# continue"]},{"cell_type":"code","execution_count":8,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:37.187018Z","iopub.status.busy":"2023-11-27T08:04:37.186379Z","iopub.status.idle":"2023-11-27T08:04:37.203381Z","shell.execute_reply":"2023-11-27T08:04:37.202604Z","shell.execute_reply.started":"2023-11-27T08:04:37.186980Z"},"trusted":true},"outputs":[],"source":["\n","# Load the CSV file\n","labels_df = pd.DataFrame(for_csv, index=None)"]},{"cell_type":"code","execution_count":9,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:37.204627Z","iopub.status.busy":"2023-11-27T08:04:37.204384Z","iopub.status.idle":"2023-11-27T08:04:37.223011Z","shell.execute_reply":"2023-11-27T08:04:37.222243Z","shell.execute_reply.started":"2023-11-27T08:04:37.204605Z"},"trusted":true},"outputs":[],"source":["labels_df = labels_df.sample(frac=1).reset_index(drop=True)"]},{"cell_type":"code","execution_count":10,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:37.224272Z","iopub.status.busy":"2023-11-27T08:04:37.223964Z","iopub.status.idle":"2023-11-27T08:04:37.244182Z","shell.execute_reply":"2023-11-27T08:04:37.243499Z","shell.execute_reply.started":"2023-11-27T08:04:37.224243Z"},"trusted":true},"outputs":[],"source":["import pandas as pd\n","from sklearn.model_selection import train_test_split\n","\n","def split_dataframe_by_label(df, frac):\n","\n"," # Ensure that frac is between 0 and 1\n"," if not (0 <= frac <= 1):\n"," raise ValueError(\"The fraction must be between 0 and 1.\")\n","\n"," # Initialize empty dataframes for the split\n"," df_first = pd.DataFrame(columns=df.columns)\n"," df_second = pd.DataFrame(columns=df.columns)\n","\n"," # Iterate over each label and split the data\n"," for label in df['label'].unique():\n"," label_df = df[df['label'] == label]\n"," df1, df2 = train_test_split(label_df, train_size=frac, random_state=42)\n"," df_first = pd.concat([df_first, df1])\n"," df_second = pd.concat([df_second, df2])\n","\n"," return df_first, df_second\n","\n","# Example usage\n","# df_first, df_second = split_dataframe_by_label(your_dataframe, 0.7)\n"]},{"cell_type":"code","execution_count":11,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:37.245394Z","iopub.status.busy":"2023-11-27T08:04:37.245130Z","iopub.status.idle":"2023-11-27T08:04:37.343110Z","shell.execute_reply":"2023-11-27T08:04:37.342176Z","shell.execute_reply.started":"2023-11-27T08:04:37.245370Z"},"trusted":true},"outputs":[],"source":["train_df, test_df = split_dataframe_by_label(labels_df, 0.8)"]},{"cell_type":"code","execution_count":12,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:37.344433Z","iopub.status.busy":"2023-11-27T08:04:37.344161Z","iopub.status.idle":"2023-11-27T08:04:37.350672Z","shell.execute_reply":"2023-11-27T08:04:37.349742Z","shell.execute_reply.started":"2023-11-27T08:04:37.344408Z"},"trusted":true},"outputs":[],"source":["def undersample_dataframe(df, max_samples):\n"," # Check if max_samples is a positive integer\n"," if not isinstance(max_samples, int) or max_samples <= 0:\n"," raise ValueError(\"max_samples must be a positive integer.\")\n","\n"," # Group by label and undersample each group\n"," grouped = df.groupby('label')\n"," undersampled_df = pd.DataFrame(grouped.apply(lambda x: x.sample(min(len(x), max_samples))))\n","\n"," # Reset index after groupby and sampling\n"," undersampled_df.reset_index(drop=True, inplace=True)\n","\n"," return undersampled_df\n"]},{"cell_type":"code","execution_count":13,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:37.352120Z","iopub.status.busy":"2023-11-27T08:04:37.351814Z","iopub.status.idle":"2023-11-27T08:04:37.360944Z","shell.execute_reply":"2023-11-27T08:04:37.360110Z","shell.execute_reply.started":"2023-11-27T08:04:37.352096Z"},"trusted":true},"outputs":[],"source":["def upsample_dataframe(df, min_samples):\n","\n"," # Check if min_samples is a positive integer\n"," if not isinstance(min_samples, int) or min_samples <= 0:\n"," raise ValueError(\"min_samples must be a positive integer.\")\n","\n"," # Group by label and upsample each group\n"," grouped = df.groupby('label')\n"," upsampled_df = pd.DataFrame(grouped.apply(lambda x: x.sample(min_samples, replace=True) if len(x) < min_samples else x))\n","\n"," # Reset index after groupby and sampling\n"," upsampled_df.reset_index(drop=True, inplace=True)\n","\n"," return upsampled_df\n","\n"]},{"cell_type":"code","execution_count":14,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:37.364825Z","iopub.status.busy":"2023-11-27T08:04:37.364488Z","iopub.status.idle":"2023-11-27T08:04:37.390356Z","shell.execute_reply":"2023-11-27T08:04:37.389532Z","shell.execute_reply.started":"2023-11-27T08:04:37.364795Z"},"trusted":true},"outputs":[],"source":["train_df = upsample_dataframe(train_df,100)\n","# test_df = undersample_dataframe(test_df,100)\n","# train_df = undersample_dataframe(train_df,150)"]},{"cell_type":"code","execution_count":15,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:37.391599Z","iopub.status.busy":"2023-11-27T08:04:37.391334Z","iopub.status.idle":"2023-11-27T08:04:37.837887Z","shell.execute_reply":"2023-11-27T08:04:37.836930Z","shell.execute_reply.started":"2023-11-27T08:04:37.391575Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["0 127 32\n","1 174 44\n","2 236 60\n","3 100 22\n","4 120 31\n","5 198 50\n","6 130 33\n","7 122 31\n","8 103 26\n","9 182 46\n","10 100 22\n","11 198 50\n","12 141 36\n","13 198 50\n","14 198 50\n","15 198 50\n","16 160 41\n","17 157 40\n","18 100 22\n","19 173 44\n","20 113 29\n","21 152 39\n","22 109 28\n","23 197 50\n","24 100 17\n","25 180 45\n","26 101 26\n","27 116 29\n","28 199 50\n","29 151 38\n","30 100 25\n","31 100 24\n","32 198 50\n","33 199 50\n","34 100 25\n","35 100 23\n","36 100 17\n","37 197 50\n","38 205 52\n","39 204 51\n","40 100 14\n","41 104 27\n","42 192 48\n"]}],"source":["count_dict = {}\n","for i, row in train_df.iterrows():\n"," l = row[\"label\"]\n"," if l in count_dict:\n"," count_dict[l]+=1\n"," else:\n"," count_dict[l]=1\n","\n","count_dict2 = {}\n","\n","for i, row in test_df.iterrows():\n"," l = row[\"label\"]\n"," if l in count_dict2:\n"," count_dict2[l]+=1\n"," else:\n"," count_dict2[l]=1\n","\n","for i in range(len(class_2_id)):\n"," print(i, count_dict[i], count_dict2[i])"]},{"cell_type":"code","execution_count":16,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:37.839384Z","iopub.status.busy":"2023-11-27T08:04:37.839091Z","iopub.status.idle":"2023-11-27T08:04:37.843670Z","shell.execute_reply":"2023-11-27T08:04:37.842728Z","shell.execute_reply.started":"2023-11-27T08:04:37.839357Z"},"trusted":true},"outputs":[],"source":["FOLDER = \"images/\""]},{"cell_type":"code","execution_count":17,"metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","execution":{"iopub.execute_input":"2023-11-27T08:04:37.845227Z","iopub.status.busy":"2023-11-27T08:04:37.844946Z","iopub.status.idle":"2023-11-27T08:04:42.565747Z","shell.execute_reply":"2023-11-27T08:04:42.564920Z","shell.execute_reply.started":"2023-11-27T08:04:37.845187Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["c:\\Users\\User\\miniconda3\\envs\\textgen\\lib\\site-packages\\torch\\functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ..\\aten\\src\\ATen\\native\\TensorShape.cpp:3484.)\n"," return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n"]}],"source":["# Define the transformation\n","test_transform = transforms.Compose([\n","# transforms.Resize((224, 224)),\n"," transforms.ToTensor(),\n"," transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])\n","])\n","\n","train_transform = transforms.Compose([\n","# transforms.Resize((224, 224)),\n"," transforms.RandomApply([transforms.ColorJitter(0.3, 0.3, 0.0, 0.0)], p=0.5), #brightness, contrast, saturation, hue\n"," transforms.ToTensor(),\n"," transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])\n","])\n","\n","\n","# Define a custom dataset\n","class IconDataset(Dataset):\n"," def __init__(self, labels_frame, root_dir, transform=None):\n"," self.labels_frame = labels_frame\n"," self.root_dir = root_dir\n"," self.transform = transform\n","\n"," def __len__(self):\n"," return len(self.labels_frame)\n","\n"," def __getitem__(self, idx):\n"," img_name = f\"{self.root_dir}/{self.labels_frame.iloc[idx, 0]}.jpg\"\n"," image = Image.open(img_name).convert(\"RGB\")\n"," label = self.labels_frame.iloc[idx, 1]\n","\n"," if self.transform:\n"," image = self.transform(image)\n","\n"," return image, label\n","\n","\n","# Create dataset\n","train_dataset = IconDataset(labels_frame=train_df, root_dir=FOLDER, transform=train_transform)\n","test_dataset = IconDataset(labels_frame=test_df, root_dir=FOLDER, transform=test_transform)\n","\n","train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n","test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n","\n","# CNN Model\n","class SimpleCNN(nn.Module):\n"," def __init__(self):\n"," super(SimpleCNN, self).__init__()\n"," self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)\n"," self.bn1 = nn.BatchNorm2d(16)\n"," self.conv2 = nn.Conv2d(16, 32, 3, 1, 1)\n"," self.bn2 = nn.BatchNorm2d(32)\n"," self.fc1 = nn.Linear(32 * 224 * 224, 37) # Assuming images are 224x224\n","\n"," def forward(self, x):\n"," x = F.relu(self.bn1(self.conv1(x)))\n"," x = F.relu(self.bn2(self.conv2(x)))\n"," x = torch.flatten(x, 1)\n"," x = F.softmax(self.fc1(x),dim=0)\n"," return x\n","\n","class MaxViT(nn.Module):\n"," def __init__(self):\n"," super(MaxViT, self).__init__()\n"," model = models.maxvit_t(weights=\"DEFAULT\")\n"," num_ftrs = model.classifier[5].in_features\n"," model.classifier[5] = nn.Linear(num_ftrs, len(class_2_id))\n"," self.model = model\n"," def forward(self, x):\n"," return self.model(x)\n","\n","# Instantiate the model\n","model = MaxViT().to(\"cuda\")\n","\n","# Loss function and optimizer\n","criterion = nn.CrossEntropyLoss()\n","optimizer = optim.AdamW(model.parameters(), lr=0.001)\n","scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 1 - epoch / 100)\n"]},{"cell_type":"code","execution_count":18,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T08:04:42.567156Z","iopub.status.busy":"2023-11-27T08:04:42.566879Z","iopub.status.idle":"2023-11-27T08:36:43.560704Z","shell.execute_reply":"2023-11-27T08:36:43.559618Z","shell.execute_reply.started":"2023-11-27T08:04:42.567132Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:48<00:00, 4.17it/s]\n","100%|██████████| 50/50 [00:07<00:00, 6.99it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 1/20, Loss: 0.6631521582603455, Accuracy: 80.97038437303088% F1: 0.8143231843539931\n","Model weights saved.\n"," precision recall f1-score support\n","\n"," 0 0.93 0.88 0.90 32\n"," 1 0.95 0.93 0.94 44\n"," 2 0.98 0.90 0.94 60\n"," 3 0.81 1.00 0.90 22\n"," 4 0.73 0.71 0.72 31\n"," 5 0.64 0.92 0.75 50\n"," 6 0.89 1.00 0.94 33\n"," 7 0.78 0.81 0.79 31\n"," 8 0.93 1.00 0.96 26\n"," 9 0.76 0.76 0.76 46\n"," 10 0.88 0.68 0.77 22\n"," 11 0.78 0.86 0.82 50\n"," 12 0.85 0.94 0.89 36\n"," 13 1.00 0.50 0.67 50\n"," 14 1.00 0.50 0.67 50\n"," 15 0.75 0.92 0.83 50\n"," 16 0.80 0.90 0.85 41\n"," 17 0.66 0.78 0.71 40\n"," 18 0.90 0.82 0.86 22\n"," 19 0.82 0.64 0.72 44\n"," 20 0.88 1.00 0.94 29\n"," 21 0.84 0.92 0.88 39\n"," 22 0.83 0.89 0.86 28\n"," 23 0.83 0.40 0.54 50\n"," 24 0.81 0.76 0.79 17\n"," 25 0.84 0.84 0.84 45\n"," 26 0.95 0.73 0.83 26\n"," 27 0.92 0.76 0.83 29\n"," 28 0.95 0.74 0.83 50\n"," 29 0.56 0.97 0.71 38\n"," 30 0.89 0.64 0.74 25\n"," 31 0.82 0.96 0.88 24\n"," 32 0.75 0.94 0.83 50\n"," 33 0.69 0.74 0.71 50\n"," 34 0.96 0.96 0.96 25\n"," 35 0.86 0.83 0.84 23\n"," 36 0.94 0.88 0.91 17\n"," 37 0.95 0.84 0.89 50\n"," 38 0.84 0.81 0.82 52\n"," 39 0.63 0.63 0.63 51\n"," 40 0.64 1.00 0.78 14\n"," 41 0.62 0.74 0.68 27\n"," 42 0.85 0.92 0.88 48\n","\n"," accuracy 0.81 1587\n"," macro avg 0.83 0.82 0.81 1587\n","weighted avg 0.83 0.81 0.81 1587\n","\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.34it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.72it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 2/20, Loss: 0.5732803344726562, Accuracy: 84.7511027095148% F1: 0.8471072856598753\n","Model weights saved.\n"," precision recall f1-score support\n","\n"," 0 0.79 0.97 0.87 32\n"," 1 1.00 0.86 0.93 44\n"," 2 0.97 0.97 0.97 60\n"," 3 0.91 0.91 0.91 22\n"," 4 0.82 0.74 0.78 31\n"," 5 0.84 0.86 0.85 50\n"," 6 0.92 1.00 0.96 33\n"," 7 0.54 0.90 0.67 31\n"," 8 0.84 1.00 0.91 26\n"," 9 0.71 0.80 0.76 46\n"," 10 0.82 0.64 0.72 22\n"," 11 0.90 0.94 0.92 50\n"," 12 0.78 0.97 0.86 36\n"," 13 0.83 0.70 0.76 50\n"," 14 0.88 0.72 0.79 50\n"," 15 0.83 0.86 0.84 50\n"," 16 0.90 0.93 0.92 41\n"," 17 0.80 0.40 0.53 40\n"," 18 0.89 0.77 0.83 22\n"," 19 0.81 0.77 0.79 44\n"," 20 0.90 0.97 0.93 29\n"," 21 0.88 0.97 0.93 39\n"," 22 0.83 0.89 0.86 28\n"," 23 0.62 0.70 0.66 50\n"," 24 1.00 0.71 0.83 17\n"," 25 0.89 0.87 0.88 45\n"," 26 0.96 0.88 0.92 26\n"," 27 0.71 0.76 0.73 29\n"," 28 0.95 0.82 0.88 50\n"," 29 0.83 0.92 0.88 38\n"," 30 0.94 0.68 0.79 25\n"," 31 0.88 0.96 0.92 24\n"," 32 0.82 0.90 0.86 50\n"," 33 0.85 0.80 0.82 50\n"," 34 1.00 1.00 1.00 25\n"," 35 0.95 0.83 0.88 23\n"," 36 0.84 0.94 0.89 17\n"," 37 0.88 0.92 0.90 50\n"," 38 1.00 0.92 0.96 52\n"," 39 0.93 0.73 0.81 51\n"," 40 0.74 1.00 0.85 14\n"," 41 0.95 0.74 0.83 27\n"," 42 0.75 0.94 0.83 48\n","\n"," accuracy 0.85 1587\n"," macro avg 0.86 0.85 0.85 1587\n","weighted avg 0.86 0.85 0.85 1587\n","\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.37it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.99it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 3/20, Loss: 0.7223578095436096, Accuracy: 85.82230623818525% F1: 0.8563987625960884\n","Model weights saved.\n"," precision recall f1-score support\n","\n"," 0 0.84 0.97 0.90 32\n"," 1 1.00 0.91 0.95 44\n"," 2 0.97 0.97 0.97 60\n"," 3 0.95 0.91 0.93 22\n"," 4 0.79 0.74 0.77 31\n"," 5 0.65 0.94 0.77 50\n"," 6 0.97 1.00 0.99 33\n"," 7 0.86 0.77 0.81 31\n"," 8 1.00 0.92 0.96 26\n"," 9 0.81 0.76 0.79 46\n"," 10 0.67 0.82 0.73 22\n"," 11 0.88 0.92 0.90 50\n"," 12 0.92 0.97 0.95 36\n"," 13 0.90 0.86 0.88 50\n"," 14 0.92 0.68 0.78 50\n"," 15 0.88 0.84 0.86 50\n"," 16 0.95 0.90 0.92 41\n"," 17 0.73 0.80 0.76 40\n"," 18 0.85 0.77 0.81 22\n"," 19 0.94 0.66 0.77 44\n"," 20 0.88 1.00 0.94 29\n"," 21 0.91 1.00 0.95 39\n"," 22 0.81 0.89 0.85 28\n"," 23 0.74 0.68 0.71 50\n"," 24 0.81 0.76 0.79 17\n"," 25 0.92 0.80 0.86 45\n"," 26 0.92 0.85 0.88 26\n"," 27 0.86 0.62 0.72 29\n"," 28 0.89 0.94 0.91 50\n"," 29 0.97 0.84 0.90 38\n"," 30 0.95 0.72 0.82 25\n"," 31 0.92 0.96 0.94 24\n"," 32 0.82 0.98 0.89 50\n"," 33 0.80 0.80 0.80 50\n"," 34 0.73 0.96 0.83 25\n"," 35 0.71 0.87 0.78 23\n"," 36 0.84 0.94 0.89 17\n"," 37 0.88 0.92 0.90 50\n"," 38 0.92 0.88 0.90 52\n"," 39 0.84 0.73 0.78 51\n"," 40 0.82 1.00 0.90 14\n"," 41 0.84 0.78 0.81 27\n"," 42 0.83 0.94 0.88 48\n","\n"," accuracy 0.86 1587\n"," macro avg 0.86 0.86 0.86 1587\n","weighted avg 0.87 0.86 0.86 1587\n","\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.38it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.83it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 4/20, Loss: 0.5389899015426636, Accuracy: 83.30182734719597% F1: 0.8397030664461608\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.37it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.84it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 5/20, Loss: 0.26058411598205566, Accuracy: 86.7044738500315% F1: 0.8652079255655145\n","Model weights saved.\n"," precision recall f1-score support\n","\n"," 0 0.94 0.97 0.95 32\n"," 1 0.93 0.95 0.94 44\n"," 2 0.98 0.95 0.97 60\n"," 3 0.95 0.91 0.93 22\n"," 4 0.69 0.77 0.73 31\n"," 5 0.83 0.90 0.87 50\n"," 6 1.00 0.94 0.97 33\n"," 7 0.69 0.94 0.79 31\n"," 8 0.96 0.96 0.96 26\n"," 9 0.65 0.85 0.74 46\n"," 10 0.94 0.77 0.85 22\n"," 11 0.98 0.84 0.90 50\n"," 12 0.94 0.92 0.93 36\n"," 13 0.91 0.84 0.87 50\n"," 14 0.92 0.66 0.77 50\n"," 15 0.91 0.84 0.87 50\n"," 16 0.93 0.90 0.91 41\n"," 17 0.88 0.70 0.78 40\n"," 18 0.68 0.77 0.72 22\n"," 19 0.89 0.77 0.83 44\n"," 20 0.94 1.00 0.97 29\n"," 21 0.93 0.95 0.94 39\n"," 22 0.83 0.89 0.86 28\n"," 23 0.71 0.74 0.73 50\n"," 24 0.68 0.76 0.72 17\n"," 25 0.86 0.98 0.92 45\n"," 26 0.81 0.81 0.81 26\n"," 27 0.69 0.83 0.75 29\n"," 28 0.96 0.88 0.92 50\n"," 29 0.94 0.84 0.89 38\n"," 30 1.00 0.68 0.81 25\n"," 31 0.96 0.96 0.96 24\n"," 32 0.87 0.92 0.89 50\n"," 33 0.86 0.84 0.85 50\n"," 34 1.00 1.00 1.00 25\n"," 35 0.95 0.87 0.91 23\n"," 36 0.80 0.94 0.86 17\n"," 37 0.96 0.90 0.93 50\n"," 38 0.91 0.92 0.91 52\n"," 39 0.83 0.75 0.78 51\n"," 40 0.67 1.00 0.80 14\n"," 41 0.82 0.85 0.84 27\n"," 42 0.82 0.94 0.87 48\n","\n"," accuracy 0.87 1587\n"," macro avg 0.87 0.87 0.87 1587\n","weighted avg 0.88 0.87 0.87 1587\n","\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.37it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.78it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 6/20, Loss: 0.11516380310058594, Accuracy: 86.32640201638311% F1: 0.8651042180451826\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.37it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.75it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 7/20, Loss: 0.2884931266307831, Accuracy: 84.05797101449275% F1: 0.8479193552665635\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.33it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.63it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 8/20, Loss: 0.4003789722919464, Accuracy: 85.1291745431632% F1: 0.8538440496150018\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.38it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.92it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 9/20, Loss: 0.17860029637813568, Accuracy: 87.77567737870196% F1: 0.8782015659184909\n","Model weights saved.\n"," precision recall f1-score support\n","\n"," 0 0.90 0.88 0.89 32\n"," 1 1.00 0.93 0.96 44\n"," 2 0.95 0.97 0.96 60\n"," 3 0.84 0.95 0.89 22\n"," 4 0.79 0.74 0.77 31\n"," 5 0.77 0.92 0.84 50\n"," 6 0.97 1.00 0.99 33\n"," 7 0.83 0.81 0.82 31\n"," 8 1.00 0.96 0.98 26\n"," 9 0.85 0.87 0.86 46\n"," 10 0.95 0.82 0.88 22\n"," 11 0.94 0.88 0.91 50\n"," 12 0.90 1.00 0.95 36\n"," 13 1.00 0.66 0.80 50\n"," 14 0.80 0.82 0.81 50\n"," 15 0.75 0.84 0.79 50\n"," 16 0.81 0.93 0.86 41\n"," 17 0.82 0.82 0.82 40\n"," 18 0.86 0.82 0.84 22\n"," 19 0.80 0.80 0.80 44\n"," 20 0.90 0.97 0.93 29\n"," 21 0.97 0.97 0.97 39\n"," 22 0.77 0.96 0.86 28\n"," 23 0.93 0.74 0.82 50\n"," 24 0.92 0.65 0.76 17\n"," 25 0.91 0.93 0.92 45\n"," 26 0.92 0.88 0.90 26\n"," 27 0.95 0.72 0.82 29\n"," 28 0.92 0.90 0.91 50\n"," 29 0.92 0.92 0.92 38\n"," 30 0.76 0.88 0.81 25\n"," 31 0.96 0.96 0.96 24\n"," 32 0.81 0.96 0.88 50\n"," 33 0.90 0.74 0.81 50\n"," 34 0.96 1.00 0.98 25\n"," 35 0.88 0.91 0.89 23\n"," 36 1.00 0.94 0.97 17\n"," 37 0.94 0.90 0.92 50\n"," 38 0.87 0.92 0.90 52\n"," 39 0.84 0.84 0.84 51\n"," 40 0.81 0.93 0.87 14\n"," 41 0.81 0.78 0.79 27\n"," 42 0.85 0.96 0.90 48\n","\n"," accuracy 0.88 1587\n"," macro avg 0.88 0.88 0.88 1587\n","weighted avg 0.88 0.88 0.88 1587\n","\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.40it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.80it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 10/20, Loss: 0.2623906135559082, Accuracy: 85.94833018273472% F1: 0.8568683109835779\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.39it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.82it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 11/20, Loss: 0.03642905130982399, Accuracy: 87.46061751732829% F1: 0.8752476652110971\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.38it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.74it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 12/20, Loss: 0.04971063882112503, Accuracy: 88.84688090737241% F1: 0.8874316680196108\n","Model weights saved.\n"," precision recall f1-score support\n","\n"," 0 0.91 0.97 0.94 32\n"," 1 0.98 1.00 0.99 44\n"," 2 0.98 0.95 0.97 60\n"," 3 1.00 0.95 0.98 22\n"," 4 0.70 0.84 0.76 31\n"," 5 0.80 0.94 0.86 50\n"," 6 1.00 1.00 1.00 33\n"," 7 0.85 0.90 0.88 31\n"," 8 0.96 1.00 0.98 26\n"," 9 0.72 0.89 0.80 46\n"," 10 1.00 0.64 0.78 22\n"," 11 0.87 0.90 0.88 50\n"," 12 0.97 1.00 0.99 36\n"," 13 1.00 0.84 0.91 50\n"," 14 0.95 0.74 0.83 50\n"," 15 0.90 0.88 0.89 50\n"," 16 0.97 0.90 0.94 41\n"," 17 0.88 0.75 0.81 40\n"," 18 0.84 0.73 0.78 22\n"," 19 0.84 0.82 0.83 44\n"," 20 0.93 0.97 0.95 29\n"," 21 0.95 0.95 0.95 39\n"," 22 0.64 0.96 0.77 28\n"," 23 0.97 0.74 0.84 50\n"," 24 0.92 0.65 0.76 17\n"," 25 0.95 0.89 0.92 45\n"," 26 0.92 0.88 0.90 26\n"," 27 0.91 0.69 0.78 29\n"," 28 0.90 0.92 0.91 50\n"," 29 0.78 0.95 0.86 38\n"," 30 0.81 0.84 0.82 25\n"," 31 1.00 0.96 0.98 24\n"," 32 0.91 0.98 0.94 50\n"," 33 0.81 0.88 0.85 50\n"," 34 1.00 0.96 0.98 25\n"," 35 0.95 0.87 0.91 23\n"," 36 1.00 0.94 0.97 17\n"," 37 0.92 0.96 0.94 50\n"," 38 0.94 0.90 0.92 52\n"," 39 0.77 0.78 0.78 51\n"," 40 0.82 1.00 0.90 14\n"," 41 0.88 0.81 0.85 27\n"," 42 0.84 0.96 0.89 48\n","\n"," accuracy 0.89 1587\n"," macro avg 0.90 0.89 0.89 1587\n","weighted avg 0.90 0.89 0.89 1587\n","\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.33it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.69it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 13/20, Loss: 0.24502374231815338, Accuracy: 88.78386893509767% F1: 0.8914100815186707\n","Model weights saved.\n"," precision recall f1-score support\n","\n"," 0 0.86 0.97 0.91 32\n"," 1 0.98 0.93 0.95 44\n"," 2 0.97 0.95 0.96 60\n"," 3 0.95 0.86 0.90 22\n"," 4 0.64 0.81 0.71 31\n"," 5 0.85 0.92 0.88 50\n"," 6 0.92 1.00 0.96 33\n"," 7 0.83 0.94 0.88 31\n"," 8 1.00 0.88 0.94 26\n"," 9 0.67 0.87 0.75 46\n"," 10 1.00 0.77 0.87 22\n"," 11 0.91 0.82 0.86 50\n"," 12 0.95 0.97 0.96 36\n"," 13 0.95 0.80 0.87 50\n"," 14 0.93 0.80 0.86 50\n"," 15 0.81 0.88 0.85 50\n"," 16 0.97 0.88 0.92 41\n"," 17 0.75 0.82 0.79 40\n"," 18 0.90 0.86 0.88 22\n"," 19 0.97 0.82 0.89 44\n"," 20 1.00 1.00 1.00 29\n"," 21 0.95 1.00 0.97 39\n"," 22 0.83 0.89 0.86 28\n"," 23 0.77 0.86 0.81 50\n"," 24 0.92 0.71 0.80 17\n"," 25 0.94 0.98 0.96 45\n"," 26 0.96 0.88 0.92 26\n"," 27 0.78 0.86 0.82 29\n"," 28 0.98 0.90 0.94 50\n"," 29 0.92 0.87 0.89 38\n"," 30 0.95 0.80 0.87 25\n"," 31 1.00 0.96 0.98 24\n"," 32 0.89 0.94 0.91 50\n"," 33 0.86 0.84 0.85 50\n"," 34 1.00 0.96 0.98 25\n"," 35 0.95 0.87 0.91 23\n"," 36 1.00 0.88 0.94 17\n"," 37 0.76 1.00 0.86 50\n"," 38 1.00 0.92 0.96 52\n"," 39 0.86 0.71 0.77 51\n"," 40 0.88 1.00 0.93 14\n"," 41 0.88 0.85 0.87 27\n"," 42 0.92 0.92 0.92 48\n","\n"," accuracy 0.89 1587\n"," macro avg 0.90 0.89 0.89 1587\n","weighted avg 0.90 0.89 0.89 1587\n","\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.35it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.94it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 14/20, Loss: 0.12690715491771698, Accuracy: 87.0825456836799% F1: 0.8691777714430017\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.37it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.89it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 15/20, Loss: 0.039787568151950836, Accuracy: 87.52362948960302% F1: 0.8785891393604705\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.40it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.79it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 16/20, Loss: 0.09257231652736664, Accuracy: 87.96471329552615% F1: 0.8818484762362288\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.37it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.83it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 17/20, Loss: 0.14385418593883514, Accuracy: 87.0825456836799% F1: 0.870063776256168\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.39it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.77it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 18/20, Loss: 0.1985616385936737, Accuracy: 87.58664146187776% F1: 0.8737545801811116\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.36it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.60it/s]\n"]},{"name":"stdout","output_type":"stream","text":["Epoch 19/20, Loss: 0.007958111353218555, Accuracy: 88.21676118462508% F1: 0.8797355118295431\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 201/201 [00:37<00:00, 5.36it/s]\n","100%|██████████| 50/50 [00:04<00:00, 11.94it/s]"]},{"name":"stdout","output_type":"stream","text":["Epoch 20/20, Loss: 0.3914642035961151, Accuracy: 86.7044738500315% F1: 0.8674478996558217\n"]},{"name":"stderr","output_type":"stream","text":["\n"]}],"source":["best_f1 = 0\n","# Training loop\n","num_epochs = 20\n","for epoch in range(num_epochs):\n"," model.train()\n"," for inputs, labels in tqdm(train_loader):\n"," optimizer.zero_grad()\n"," outputs = model(inputs.to(\"cuda\"))\n"," loss = criterion(outputs, labels.to(\"cuda\"))\n"," loss.backward()\n"," optimizer.step()\n"," scheduler.step()\n","\n"," # Validation loop\n"," model.eval()\n"," total = 0\n"," all_labels = []\n"," all_predicted = []\n"," with torch.no_grad():\n"," for inputs, labels in tqdm(test_loader):\n"," outputs = model(inputs.to(\"cuda\"))\n"," _, predicted = torch.max(outputs.data, 1)\n"," total += labels.size(0)\n"," all_labels.append(labels.numpy())\n"," all_predicted.append(predicted.detach().to(\"cpu\").numpy())\n"," \n"," all_labels = np.concatenate(all_labels)\n"," all_predicted = np.concatenate(all_predicted)\n"," f1 = f1_score(all_labels,all_predicted, average='macro')\n"," acc = accuracy_score(all_labels,all_predicted)\n"," print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}, Accuracy: {acc*100}% F1: {f1}\")\n","\n"," if f1>best_f1:\n"," best_f1 = f1\n"," torch.save(model.state_dict(), 'best_model.pth')\n"," print('Model weights saved.')\n"," print(classification_report(all_labels,all_predicted))"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[{"data":{"text/plain":[""]},"execution_count":20,"metadata":{},"output_type":"execute_result"}],"source":["#load best model\n","model.load_state_dict(torch.load('best_model.pth'))"]},{"cell_type":"code","execution_count":21,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T09:14:06.192390Z","iopub.status.busy":"2023-11-27T09:14:06.192039Z","iopub.status.idle":"2023-11-27T09:14:15.612115Z","shell.execute_reply":"2023-11-27T09:14:15.611319Z","shell.execute_reply.started":"2023-11-27T09:14:06.192362Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["C:\\Users\\User\\AppData\\Local\\Temp\\ipykernel_33268\\1528050369.py:8: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n"," res = F.softmax(model(image_tensor.unsqueeze(0)))\n","c:\\Users\\User\\miniconda3\\envs\\textgen\\lib\\site-packages\\PIL\\Image.py:970: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n"," warnings.warn(\n"]}],"source":["list_ = []\n","PATH = \"icon_data/UNKNOWN/\"\n","for image_name in os.listdir(PATH):\n"," image = Image.open(PATH+image_name).resize((224,224)).convert(\"RGB\")\n"," image_tensor = test_transform(image).to(\"cuda\")\n","\n"," with torch.no_grad():\n"," res = F.softmax(model(image_tensor.unsqueeze(0)))\n"," \n"," list_.append(torch.max(res[0]).to(\"cpu\").tolist())\n"," \n","# print(tensor_entropy(res[0]))\n","# print([id_2_class[i] for i in torch.topk(res,5,dim=1).indices[0].to(\"cpu\").tolist()])"]},{"cell_type":"code","execution_count":22,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T09:14:15.613800Z","iopub.status.busy":"2023-11-27T09:14:15.613493Z","iopub.status.idle":"2023-11-27T09:14:16.293542Z","shell.execute_reply":"2023-11-27T09:14:16.292537Z","shell.execute_reply.started":"2023-11-27T09:14:15.613773Z"},"trusted":true},"outputs":[{"data":{"text/plain":[""]},"execution_count":22,"metadata":{},"output_type":"execute_result"},{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["from matplotlib import pyplot as plt\n","list_.sort()\n","\n","plt.bar( range(len(list_)) ,list_)"]},{"cell_type":"code","execution_count":23,"metadata":{"execution":{"iopub.execute_input":"2023-11-27T09:16:28.779463Z","iopub.status.busy":"2023-11-27T09:16:28.778543Z","iopub.status.idle":"2023-11-27T09:16:37.212765Z","shell.execute_reply":"2023-11-27T09:16:37.211798Z","shell.execute_reply.started":"2023-11-27T09:16:28.779423Z"},"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":[" 0%| | 0/50 [00:00"]},"execution_count":23,"metadata":{},"output_type":"execute_result"},{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["model.eval()\n","total = 0\n","all_labels = []\n","all_predicted = []\n","\n","all_vals = []\n","with torch.no_grad():\n"," for inputs, labels in tqdm(test_loader):\n"," outputs = F.softmax(model(inputs.to(\"cuda\")))\n"," val, predicted = torch.max(outputs.data, 1)\n"," total += labels.size(0)\n"," all_predicted+=predicted.detach().to(\"cpu\").tolist()\n"," all_vals+=val.detach().to(\"cpu\").tolist()\n"," all_labels+=labels\n"," \n","all_predicted.sort()\n","\n","plt.bar( range(len(test_dataset)) ,all_vals)"]}],"metadata":{"kaggle":{"accelerator":"gpu","dataSources":[{"datasetId":4060149,"sourceId":7063206,"sourceType":"datasetVersion"}],"dockerImageVersionId":30588,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.9"}},"nbformat":4,"nbformat_minor":4}