Upload folder using huggingface_hub
Browse files- LICENSE +21 -0
- README.md +64 -0
- gan.py +41 -0
- gan_mnist.png +0 -0
- gan_mnist.pth +3 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Hussam Alafandi
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- gan
|
4 |
+
- mnist
|
5 |
+
- pytorch
|
6 |
+
- generative-model
|
7 |
+
- deep-learning
|
8 |
+
license: mit
|
9 |
+
datasets:
|
10 |
+
- mnist
|
11 |
+
library_name: pytorch
|
12 |
+
---
|
13 |
+
|
14 |
+
# GAN for MNIST Digit Generation
|
15 |
+
|
16 |
+
This repository contains a Generative Adversarial Network (GAN) trained on the MNIST dataset to generate realistic handwritten digits. The model was trained as part of the Generative AI course.
|
17 |
+
|
18 |
+
## Model Details
|
19 |
+
|
20 |
+
- **Model Type**: GAN
|
21 |
+
- **Dataset**: MNIST (handwritten digits)
|
22 |
+
- **Generator Input**: Latent vector of size 100
|
23 |
+
- **Output**: 28x28 grayscale images
|
24 |
+
- **Framework**: PyTorch
|
25 |
+
|
26 |
+
## Training Details
|
27 |
+
|
28 |
+
- **Optimizer**: Adam
|
29 |
+
- **Learning Rate**: 0.0002
|
30 |
+
- **Beta1**: 0.5
|
31 |
+
- **Epochs**: 50
|
32 |
+
- **Batch Size**: 64
|
33 |
+
- **Weight Decay**: 0.0001
|
34 |
+
- **Logging**: [Weights & Biases](https://wandb.ai/hussam-alafandi/GAN_MNIST/runs/6ehnzhm0?nw=nwuserhussamalafandi)
|
35 |
+
|
36 |
+
## Usage
|
37 |
+
|
38 |
+
### Loading the Model
|
39 |
+
|
40 |
+
To load the trained model, use the following code snippet:
|
41 |
+
```python
|
42 |
+
from gan import Generator
|
43 |
+
import torch
|
44 |
+
|
45 |
+
latent_dim = 100
|
46 |
+
generator = Generator(latent_dim)
|
47 |
+
generator.load_state_dict(torch.load("./gan_mnist.pth"))
|
48 |
+
generator.eval()
|
49 |
+
|
50 |
+
# Generate samples
|
51 |
+
z = torch.randn(16, latent_dim)
|
52 |
+
samples = generator(z)
|
53 |
+
```
|
54 |
+
## Example Results
|
55 |
+

|
56 |
+
|
57 |
+
## References
|
58 |
+
|
59 |
+
- [Generative AI Course Repository](https://github.com/hussamalafandi/Generative_AI)
|
60 |
+
- [Weights & Biases Training Logs](https://wandb.ai/hussam-alafandi/GAN_MNIST/runs/6ehnzhm0?nw=nwuserhussamalafandi)
|
61 |
+
|
62 |
+
## License
|
63 |
+
|
64 |
+
This project is licensed under the MIT License.
|
gan.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
class Generator(nn.Module):
|
4 |
+
def __init__(self, latent_dim):
|
5 |
+
super(Generator, self).__init__()
|
6 |
+
self.latent_dim = latent_dim
|
7 |
+
self.net = nn.Sequential(
|
8 |
+
nn.Linear(self.latent_dim, 128),
|
9 |
+
nn.ReLU(True),
|
10 |
+
nn.Linear(128, 256),
|
11 |
+
nn.BatchNorm1d(256),
|
12 |
+
nn.ReLU(True),
|
13 |
+
nn.Linear(256, 512),
|
14 |
+
nn.BatchNorm1d(512),
|
15 |
+
nn.ReLU(True),
|
16 |
+
nn.Linear(512, 28*28),
|
17 |
+
nn.Tanh()
|
18 |
+
)
|
19 |
+
|
20 |
+
def forward(self, z):
|
21 |
+
img = self.net(z)
|
22 |
+
img = img.view(z.size(0), 1, 28, 28)
|
23 |
+
return img
|
24 |
+
|
25 |
+
|
26 |
+
class Discriminator(nn.Module):
|
27 |
+
def __init__(self):
|
28 |
+
super(Discriminator, self).__init__()
|
29 |
+
self.net = nn.Sequential(
|
30 |
+
nn.Linear(28*28, 512),
|
31 |
+
nn.LeakyReLU(0.2, inplace=True),
|
32 |
+
nn.Linear(512, 256),
|
33 |
+
nn.LeakyReLU(0.2, inplace=True),
|
34 |
+
nn.Linear(256, 1),
|
35 |
+
nn.Sigmoid()
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, img):
|
39 |
+
img_flat = img.view(img.size(0), -1)
|
40 |
+
validity = self.net(img_flat)
|
41 |
+
return validity
|
gan_mnist.png
ADDED
![]() |
gan_mnist.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21a826687c374ae29109ff7916b70ecfc530524dd0aa9a8c21621521880f8c21
|
3 |
+
size 4477638
|