Zai commited on
Commit
58893eb
1 Parent(s): bbddfb0

dcgans training notebook

Browse files
notebooks/dcgan.ipynb DELETED
@@ -1,277 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": [],
7
- "gpuType": "T4"
8
- },
9
- "kernelspec": {
10
- "name": "python3",
11
- "display_name": "Python 3"
12
- },
13
- "language_info": {
14
- "name": "python"
15
- },
16
- "accelerator": "GPU"
17
- },
18
- "cells": [
19
- {
20
- "cell_type": "code",
21
- "execution_count": 74,
22
- "metadata": {
23
- "id": "xNiydKOa0oFk"
24
- },
25
- "outputs": [],
26
- "source": [
27
- "#project gans"
28
- ]
29
- },
30
- {
31
- "cell_type": "code",
32
- "source": [
33
- "import torch\n",
34
- "import torchvision\n",
35
- "import torch.nn as nn\n",
36
- "import torch.nn.functional as F\n",
37
- "from torch.utils.data import DataLoader\n",
38
- "from torchvision import datasets, transforms\n",
39
- "from torchvision.utils import save_image\n",
40
- "import numpy as np\n",
41
- "\n",
42
- "# Check if GPU is available and set the device accordingly\n",
43
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
44
- ],
45
- "metadata": {
46
- "id": "SCS7gRJQ0tyS"
47
- },
48
- "execution_count": 75,
49
- "outputs": []
50
- },
51
- {
52
- "cell_type": "code",
53
- "source": [
54
- "def get_sample_image(generator, noise_dim):\n",
55
- " \"\"\"\n",
56
- " Save sample 100 images\n",
57
- " \"\"\"\n",
58
- " noise = torch.randn(100, noise_dim).to(device)\n",
59
- " generated_images = generator(noise).view(100, 28, 28) # (100, 28, 28)\n",
60
- " result = generated_images.cpu().data.numpy()\n",
61
- " img = np.zeros([280, 280])\n",
62
- " for j in range(10):\n",
63
- " img[j * 28:(j + 1) * 28] = np.concatenate([x for x in result[j * 10:(j + 1) * 10]], axis=-1)\n",
64
- " return img"
65
- ],
66
- "metadata": {
67
- "id": "sacBbf_LwZx-"
68
- },
69
- "execution_count": 76,
70
- "outputs": []
71
- },
72
- {
73
- "cell_type": "code",
74
- "source": [
75
- "class Discriminator(nn.Module):\n",
76
- " def __init__(self, in_channels=1, num_classes=1):\n",
77
- " super(Discriminator, self).__init__()\n",
78
- " self.conv = nn.Sequential(\n",
79
- " nn.Conv2d(in_channels, 512, 3, stride=2, padding=1, bias=False),\n",
80
- " nn.BatchNorm2d(512),\n",
81
- " nn.LeakyReLU(0.2),\n",
82
- "\n",
83
- " nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),\n",
84
- " nn.BatchNorm2d(256),\n",
85
- " nn.LeakyReLU(0.2),\n",
86
- "\n",
87
- " nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),\n",
88
- " nn.BatchNorm2d(128),\n",
89
- " nn.LeakyReLU(0.2),\n",
90
- " nn.AvgPool2d(4),\n",
91
- " )\n",
92
- " self.fc = nn.Sequential(\n",
93
- " nn.Linear(128, 1),\n",
94
- " nn.Sigmoid(),\n",
95
- " )\n",
96
- "\n",
97
- " def forward(self, x, y=False):\n",
98
- " features = self.conv(x)\n",
99
- " features = features.view(features.size(0), -1)\n",
100
- " output = self.fc(features)\n",
101
- " return output\n"
102
- ],
103
- "metadata": {
104
- "id": "e9n-wD7dwZ7n"
105
- },
106
- "execution_count": 77,
107
- "outputs": []
108
- },
109
- {
110
- "cell_type": "code",
111
- "source": [
112
- "class Generator(nn.Module):\n",
113
- " def __init__(self, input_size=100, num_classes=784):\n",
114
- " super(Generator, self).__init__()\n",
115
- " self.fc = nn.Sequential(\n",
116
- " nn.Linear(input_size, 4 * 4 * 512),\n",
117
- " nn.ReLU(),\n",
118
- " )\n",
119
- " self.conv = nn.Sequential(\n",
120
- " nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),\n",
121
- " nn.BatchNorm2d(256),\n",
122
- " nn.ReLU(),\n",
123
- "\n",
124
- " nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),\n",
125
- " nn.BatchNorm2d(128),\n",
126
- " nn.ReLU(),\n",
127
- "\n",
128
- " nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),\n",
129
- " nn.Tanh(),\n",
130
- " )\n",
131
- "\n",
132
- " def forward(self, x, y=None):\n",
133
- " x = x.view(x.size(0), -1)\n",
134
- " features = self.fc(x)\n",
135
- " features = features.view(features.size(0), 512, 4, 4)\n",
136
- " output = self.conv(features)\n",
137
- " return output\n"
138
- ],
139
- "metadata": {
140
- "id": "_8-E4605wZ-e"
141
- },
142
- "execution_count": 78,
143
- "outputs": []
144
- },
145
- {
146
- "cell_type": "code",
147
- "source": [
148
- "# Instantiate the Generator and Discriminator\n",
149
- "generator = Generator().to(device)\n",
150
- "discriminator = Discriminator().to(device)"
151
- ],
152
- "metadata": {
153
- "id": "OSDpsaYBypVA"
154
- },
155
- "execution_count": 79,
156
- "outputs": []
157
- },
158
- {
159
- "cell_type": "code",
160
- "source": [
161
- "transform = transforms.Compose([transforms.ToTensor(),\n",
162
- " transforms.Normalize(mean=[0.5],\n",
163
- " std=[0.5])]\n",
164
- ")"
165
- ],
166
- "metadata": {
167
- "id": "yQ8QdKuCz2_a"
168
- },
169
- "execution_count": 79,
170
- "outputs": []
171
- },
172
- {
173
- "cell_type": "code",
174
- "source": [
175
- "batch_size = 64\n",
176
- "\n",
177
- "data = torchvision.datasets.FashionMNIST(root='./data/', train=True, transform=transform, download=True)\n",
178
- "data_loader = DataLoader(dataset=data, batch_size=batch_size, shuffle=True, drop_last=True)\n",
179
- "\n",
180
- "loss_fn = nn.BCELoss()\n",
181
- "d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001, betas=(0.5, 0.999))\n",
182
- "g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001, betas=(0.5, 0.999))\n"
183
- ],
184
- "metadata": {
185
- "id": "8mOTuoih-3ep"
186
- },
187
- "execution_count": null,
188
- "outputs": []
189
- },
190
- {
191
- "cell_type": "code",
192
- "source": [
193
- "max_epochs = 50\n",
194
- "step = 0\n",
195
- "n_critic = 1\n",
196
- "n_noise = 100\n",
197
- "\n",
198
- "d_labels = torch.ones([batch_size, 1]).to(device)\n",
199
- "d_fakes = torch.zeros([batch_size, 1]).to(device)"
200
- ],
201
- "metadata": {
202
- "id": "kHJ0B3mk-4Bt"
203
- },
204
- "execution_count": null,
205
- "outputs": []
206
- },
207
- {
208
- "cell_type": "code",
209
- "source": [
210
- "# Training loop\n",
211
- "for epoch in range(max_epochs):\n",
212
- " for idx, (images, labels) in enumerate(data_loader):\n",
213
- " real_images = images.to(device)\n",
214
- "\n",
215
- " # Discriminator training\n",
216
- " real_outputs = discriminator(real_images)\n",
217
- " d_real_loss = loss_fn(real_outputs, d_labels)\n",
218
- "\n",
219
- " fake_noise = torch.randn(batch_size, n_noise).to(device)\n",
220
- " fake_images = generator(fake_noise)\n",
221
- " fake_outputs = discriminator(fake_images.detach())\n",
222
- " d_fake_loss = loss_fn(fake_outputs, d_fakes)\n",
223
- "\n",
224
- " d_loss = d_real_loss + d_fake_loss\n",
225
- "\n",
226
- " discriminator.zero_grad()\n",
227
- " d_loss.backward()\n",
228
- " d_optimizer.step()\n",
229
- "\n",
230
- " # Generator training (every n_critic iterations)\n",
231
- " if step % n_critic == 0:\n",
232
- " fake_outputs = discriminator(fake_images)\n",
233
- " g_loss = loss_fn(fake_outputs, d_labels)\n",
234
- "\n",
235
- " generator.zero_grad()\n",
236
- " discriminator.zero_grad()\n",
237
- " g_loss.backward()\n",
238
- " g_optimizer.step()\n",
239
- "\n",
240
- " if step % 500 == 0:\n",
241
- " print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epochs, step, d_loss.item(), g_loss.item()))\n",
242
- "\n",
243
- " if step % 1000 == 0:\n",
244
- " generator.eval()\n",
245
- " img = get_sample_image(generator, n_noise)\n",
246
- " # imsave('samples/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')\n",
247
- " generator.train()\n",
248
- " step += 1"
249
- ],
250
- "metadata": {
251
- "id": "1V9EfSBD-8E9"
252
- },
253
- "execution_count": null,
254
- "outputs": []
255
- },
256
- {
257
- "cell_type": "code",
258
- "source": [
259
- "# neeed to test"
260
- ],
261
- "metadata": {
262
- "id": "1g4ATYOD-9LY"
263
- },
264
- "execution_count": null,
265
- "outputs": []
266
- },
267
- {
268
- "cell_type": "code",
269
- "source": [],
270
- "metadata": {
271
- "id": "UPye6Ktu--Ph"
272
- },
273
- "execution_count": null,
274
- "outputs": []
275
- }
276
- ]
277
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/simple_dcgans.ipynb ADDED
The diff for this file is too large to render. See raw diff