hussamalafandi commited on
Commit
b25711e
·
verified ·
1 Parent(s): 9812f13

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. LICENSE +21 -0
  2. README.md +64 -0
  3. gan.py +41 -0
  4. gan_mnist.png +0 -0
  5. gan_mnist.pth +3 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Hussam Alafandi
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - gan
4
+ - mnist
5
+ - pytorch
6
+ - generative-model
7
+ - deep-learning
8
+ license: mit
9
+ datasets:
10
+ - mnist
11
+ library_name: pytorch
12
+ ---
13
+
14
+ # GAN for MNIST Digit Generation
15
+
16
+ This repository contains a Generative Adversarial Network (GAN) trained on the MNIST dataset to generate realistic handwritten digits. The model was trained as part of the Generative AI course.
17
+
18
+ ## Model Details
19
+
20
+ - **Model Type**: GAN
21
+ - **Dataset**: MNIST (handwritten digits)
22
+ - **Generator Input**: Latent vector of size 100
23
+ - **Output**: 28x28 grayscale images
24
+ - **Framework**: PyTorch
25
+
26
+ ## Training Details
27
+
28
+ - **Optimizer**: Adam
29
+ - **Learning Rate**: 0.0002
30
+ - **Beta1**: 0.5
31
+ - **Epochs**: 50
32
+ - **Batch Size**: 64
33
+ - **Weight Decay**: 0.0001
34
+ - **Logging**: [Weights & Biases](https://wandb.ai/hussam-alafandi/GAN_MNIST/runs/6ehnzhm0?nw=nwuserhussamalafandi)
35
+
36
+ ## Usage
37
+
38
+ ### Loading the Model
39
+
40
+ To load the trained model, use the following code snippet:
41
+ ```python
42
+ from gan import Generator
43
+ import torch
44
+
45
+ latent_dim = 100
46
+ generator = Generator(latent_dim)
47
+ generator.load_state_dict(torch.load("./gan_mnist.pth"))
48
+ generator.eval()
49
+
50
+ # Generate samples
51
+ z = torch.randn(16, latent_dim)
52
+ samples = generator(z)
53
+ ```
54
+ ## Example Results
55
+ ![generated images](./gan_mnist.png)
56
+
57
+ ## References
58
+
59
+ - [Generative AI Course Repository](https://github.com/hussamalafandi/Generative_AI)
60
+ - [Weights & Biases Training Logs](https://wandb.ai/hussam-alafandi/GAN_MNIST/runs/6ehnzhm0?nw=nwuserhussamalafandi)
61
+
62
+ ## License
63
+
64
+ This project is licensed under the MIT License.
gan.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ class Generator(nn.Module):
4
+ def __init__(self, latent_dim):
5
+ super(Generator, self).__init__()
6
+ self.latent_dim = latent_dim
7
+ self.net = nn.Sequential(
8
+ nn.Linear(self.latent_dim, 128),
9
+ nn.ReLU(True),
10
+ nn.Linear(128, 256),
11
+ nn.BatchNorm1d(256),
12
+ nn.ReLU(True),
13
+ nn.Linear(256, 512),
14
+ nn.BatchNorm1d(512),
15
+ nn.ReLU(True),
16
+ nn.Linear(512, 28*28),
17
+ nn.Tanh()
18
+ )
19
+
20
+ def forward(self, z):
21
+ img = self.net(z)
22
+ img = img.view(z.size(0), 1, 28, 28)
23
+ return img
24
+
25
+
26
+ class Discriminator(nn.Module):
27
+ def __init__(self):
28
+ super(Discriminator, self).__init__()
29
+ self.net = nn.Sequential(
30
+ nn.Linear(28*28, 512),
31
+ nn.LeakyReLU(0.2, inplace=True),
32
+ nn.Linear(512, 256),
33
+ nn.LeakyReLU(0.2, inplace=True),
34
+ nn.Linear(256, 1),
35
+ nn.Sigmoid()
36
+ )
37
+
38
+ def forward(self, img):
39
+ img_flat = img.view(img.size(0), -1)
40
+ validity = self.net(img_flat)
41
+ return validity
gan_mnist.png ADDED
gan_mnist.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21a826687c374ae29109ff7916b70ecfc530524dd0aa9a8c21621521880f8c21
3
+ size 4477638