refactor: changing structure
Browse files- model/cnn.py → cnn/__init__.py +0 -1
- {model → cnn}/model-old.pt +0 -0
- {model → cnn}/model.pt +0 -0
- testbench.ipynb +8 -8
- train.ipynb +3 -3
model/cnn.py → cnn/__init__.py
RENAMED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
|
|
|
|
|
1 |
import torch.nn as nn
|
2 |
import torch.nn.functional as F
|
3 |
|
{model → cnn}/model-old.pt
RENAMED
File without changes
|
{model → cnn}/model.pt
RENAMED
File without changes
|
testbench.ipynb
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"id": "c831c34c",
|
7 |
"metadata": {},
|
8 |
"outputs": [
|
@@ -20,10 +20,10 @@
|
|
20 |
"import torch\n",
|
21 |
"from torch.utils.data import DataLoader\n",
|
22 |
"from torchvision import datasets, transforms\n",
|
23 |
-
"from
|
24 |
"\n",
|
25 |
"model = CNN()\n",
|
26 |
-
"model.load_state_dict(torch.load(\"model.pt\"))\n",
|
27 |
"\n",
|
28 |
"check_gpu = torch.cuda.is_available()\n",
|
29 |
"device = torch.device(\"cpu\")\n",
|
@@ -41,7 +41,7 @@
|
|
41 |
},
|
42 |
{
|
43 |
"cell_type": "code",
|
44 |
-
"execution_count":
|
45 |
"id": "cd2d6928",
|
46 |
"metadata": {},
|
47 |
"outputs": [],
|
@@ -57,7 +57,7 @@
|
|
57 |
},
|
58 |
{
|
59 |
"cell_type": "code",
|
60 |
-
"execution_count":
|
61 |
"id": "f7bb207f",
|
62 |
"metadata": {},
|
63 |
"outputs": [],
|
@@ -72,7 +72,7 @@
|
|
72 |
},
|
73 |
{
|
74 |
"cell_type": "code",
|
75 |
-
"execution_count":
|
76 |
"id": "9ca78681",
|
77 |
"metadata": {},
|
78 |
"outputs": [],
|
@@ -82,7 +82,7 @@
|
|
82 |
},
|
83 |
{
|
84 |
"cell_type": "code",
|
85 |
-
"execution_count":
|
86 |
"id": "9c5c7fae",
|
87 |
"metadata": {},
|
88 |
"outputs": [
|
@@ -137,7 +137,7 @@
|
|
137 |
},
|
138 |
{
|
139 |
"cell_type": "code",
|
140 |
-
"execution_count":
|
141 |
"id": "1e171b86",
|
142 |
"metadata": {},
|
143 |
"outputs": [
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 42,
|
6 |
"id": "c831c34c",
|
7 |
"metadata": {},
|
8 |
"outputs": [
|
|
|
20 |
"import torch\n",
|
21 |
"from torch.utils.data import DataLoader\n",
|
22 |
"from torchvision import datasets, transforms\n",
|
23 |
+
"from cnn import CNN\n",
|
24 |
"\n",
|
25 |
"model = CNN()\n",
|
26 |
+
"model.load_state_dict(torch.load(\"cnn/model.pt\"))\n",
|
27 |
"\n",
|
28 |
"check_gpu = torch.cuda.is_available()\n",
|
29 |
"device = torch.device(\"cpu\")\n",
|
|
|
41 |
},
|
42 |
{
|
43 |
"cell_type": "code",
|
44 |
+
"execution_count": 43,
|
45 |
"id": "cd2d6928",
|
46 |
"metadata": {},
|
47 |
"outputs": [],
|
|
|
57 |
},
|
58 |
{
|
59 |
"cell_type": "code",
|
60 |
+
"execution_count": 44,
|
61 |
"id": "f7bb207f",
|
62 |
"metadata": {},
|
63 |
"outputs": [],
|
|
|
72 |
},
|
73 |
{
|
74 |
"cell_type": "code",
|
75 |
+
"execution_count": 45,
|
76 |
"id": "9ca78681",
|
77 |
"metadata": {},
|
78 |
"outputs": [],
|
|
|
82 |
},
|
83 |
{
|
84 |
"cell_type": "code",
|
85 |
+
"execution_count": 46,
|
86 |
"id": "9c5c7fae",
|
87 |
"metadata": {},
|
88 |
"outputs": [
|
|
|
137 |
},
|
138 |
{
|
139 |
"cell_type": "code",
|
140 |
+
"execution_count": 47,
|
141 |
"id": "1e171b86",
|
142 |
"metadata": {},
|
143 |
"outputs": [
|
train.ipynb
CHANGED
@@ -9,7 +9,7 @@
|
|
9 |
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
-
"execution_count":
|
13 |
"metadata": {},
|
14 |
"outputs": [
|
15 |
{
|
@@ -28,7 +28,7 @@
|
|
28 |
"import numpy as np\n",
|
29 |
"import matplotlib.pyplot as plt\n",
|
30 |
"import torch.nn as nn\n",
|
31 |
-
"from
|
32 |
"from tabulate import tabulate\n",
|
33 |
"\n",
|
34 |
"\n",
|
@@ -465,7 +465,7 @@
|
|
465 |
" if valid_loss < min_valid_loss:\n",
|
466 |
" saved = \"*\"\n",
|
467 |
" min_valid_loss = valid_loss\n",
|
468 |
-
" torch.save(model.state_dict(), \"model.pt\")\n",
|
469 |
"\n",
|
470 |
" row = [\n",
|
471 |
" epoch + 1,\n",
|
|
|
9 |
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
+
"execution_count": 19,
|
13 |
"metadata": {},
|
14 |
"outputs": [
|
15 |
{
|
|
|
28 |
"import numpy as np\n",
|
29 |
"import matplotlib.pyplot as plt\n",
|
30 |
"import torch.nn as nn\n",
|
31 |
+
"from cnn import CNN\n",
|
32 |
"from tabulate import tabulate\n",
|
33 |
"\n",
|
34 |
"\n",
|
|
|
465 |
" if valid_loss < min_valid_loss:\n",
|
466 |
" saved = \"*\"\n",
|
467 |
" min_valid_loss = valid_loss\n",
|
468 |
+
" torch.save(model.state_dict(), \"cnn/model.pt\")\n",
|
469 |
"\n",
|
470 |
" row = [\n",
|
471 |
" epoch + 1,\n",
|