Rahuletto commited on
Commit
6fdc829
·
1 Parent(s): ffb1866

refactor: changing structure

Browse files
model/cnn.py → cnn/__init__.py RENAMED
@@ -1,4 +1,3 @@
1
-
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
 
 
1
  import torch.nn as nn
2
  import torch.nn.functional as F
3
 
{model → cnn}/model-old.pt RENAMED
File without changes
{model → cnn}/model.pt RENAMED
File without changes
testbench.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": null,
6
  "id": "c831c34c",
7
  "metadata": {},
8
  "outputs": [
@@ -20,10 +20,10 @@
20
  "import torch\n",
21
  "from torch.utils.data import DataLoader\n",
22
  "from torchvision import datasets, transforms\n",
23
- "from model.cnn import CNN\n",
24
  "\n",
25
  "model = CNN()\n",
26
- "model.load_state_dict(torch.load(\"model.pt\"))\n",
27
  "\n",
28
  "check_gpu = torch.cuda.is_available()\n",
29
  "device = torch.device(\"cpu\")\n",
@@ -41,7 +41,7 @@
41
  },
42
  {
43
  "cell_type": "code",
44
- "execution_count": 10,
45
  "id": "cd2d6928",
46
  "metadata": {},
47
  "outputs": [],
@@ -57,7 +57,7 @@
57
  },
58
  {
59
  "cell_type": "code",
60
- "execution_count": 11,
61
  "id": "f7bb207f",
62
  "metadata": {},
63
  "outputs": [],
@@ -72,7 +72,7 @@
72
  },
73
  {
74
  "cell_type": "code",
75
- "execution_count": 12,
76
  "id": "9ca78681",
77
  "metadata": {},
78
  "outputs": [],
@@ -82,7 +82,7 @@
82
  },
83
  {
84
  "cell_type": "code",
85
- "execution_count": 13,
86
  "id": "9c5c7fae",
87
  "metadata": {},
88
  "outputs": [
@@ -137,7 +137,7 @@
137
  },
138
  {
139
  "cell_type": "code",
140
- "execution_count": 33,
141
  "id": "1e171b86",
142
  "metadata": {},
143
  "outputs": [
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 42,
6
  "id": "c831c34c",
7
  "metadata": {},
8
  "outputs": [
 
20
  "import torch\n",
21
  "from torch.utils.data import DataLoader\n",
22
  "from torchvision import datasets, transforms\n",
23
+ "from cnn import CNN\n",
24
  "\n",
25
  "model = CNN()\n",
26
+ "model.load_state_dict(torch.load(\"cnn/model.pt\"))\n",
27
  "\n",
28
  "check_gpu = torch.cuda.is_available()\n",
29
  "device = torch.device(\"cpu\")\n",
 
41
  },
42
  {
43
  "cell_type": "code",
44
+ "execution_count": 43,
45
  "id": "cd2d6928",
46
  "metadata": {},
47
  "outputs": [],
 
57
  },
58
  {
59
  "cell_type": "code",
60
+ "execution_count": 44,
61
  "id": "f7bb207f",
62
  "metadata": {},
63
  "outputs": [],
 
72
  },
73
  {
74
  "cell_type": "code",
75
+ "execution_count": 45,
76
  "id": "9ca78681",
77
  "metadata": {},
78
  "outputs": [],
 
82
  },
83
  {
84
  "cell_type": "code",
85
+ "execution_count": 46,
86
  "id": "9c5c7fae",
87
  "metadata": {},
88
  "outputs": [
 
137
  },
138
  {
139
  "cell_type": "code",
140
+ "execution_count": 47,
141
  "id": "1e171b86",
142
  "metadata": {},
143
  "outputs": [
train.ipynb CHANGED
@@ -9,7 +9,7 @@
9
  },
10
  {
11
  "cell_type": "code",
12
- "execution_count": null,
13
  "metadata": {},
14
  "outputs": [
15
  {
@@ -28,7 +28,7 @@
28
  "import numpy as np\n",
29
  "import matplotlib.pyplot as plt\n",
30
  "import torch.nn as nn\n",
31
- "from model.cnn import CNN\n",
32
  "from tabulate import tabulate\n",
33
  "\n",
34
  "\n",
@@ -465,7 +465,7 @@
465
  " if valid_loss < min_valid_loss:\n",
466
  " saved = \"*\"\n",
467
  " min_valid_loss = valid_loss\n",
468
- " torch.save(model.state_dict(), \"model.pt\")\n",
469
  "\n",
470
  " row = [\n",
471
  " epoch + 1,\n",
 
9
  },
10
  {
11
  "cell_type": "code",
12
+ "execution_count": 19,
13
  "metadata": {},
14
  "outputs": [
15
  {
 
28
  "import numpy as np\n",
29
  "import matplotlib.pyplot as plt\n",
30
  "import torch.nn as nn\n",
31
+ "from cnn import CNN\n",
32
  "from tabulate import tabulate\n",
33
  "\n",
34
  "\n",
 
465
  " if valid_loss < min_valid_loss:\n",
466
  " saved = \"*\"\n",
467
  " min_valid_loss = valid_loss\n",
468
+ " torch.save(model.state_dict(), \"cnn/model.pt\")\n",
469
  "\n",
470
  " row = [\n",
471
  " epoch + 1,\n",