Zai commited on
Commit
34253f7
1 Parent(s): 47fbacb

setup.py added

Browse files
.github/workflows/python-package-conda.yml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Python Package using Conda
2
+
3
+ on: [push]
4
+
5
+ jobs:
6
+ build-linux:
7
+ runs-on: ubuntu-latest
8
+ strategy:
9
+ max-parallel: 5
10
+
11
+ steps:
12
+ - uses: actions/checkout@v3
13
+ - name: Set up Python 3.10
14
+ uses: actions/setup-python@v3
15
+ with:
16
+ python-version: '3.10'
17
+ - name: Add conda to system path
18
+ run: |
19
+ # $CONDA is an environment variable pointing to the root of the miniconda directory
20
+ echo $CONDA/bin >> $GITHUB_PATH
21
+ - name: Install dependencies
22
+ run: |
23
+ conda env update --file environment.yml --name base
24
+ - name: Lint with flake8
25
+ run: |
26
+ conda install flake8
27
+ # stop the build if there are Python syntax errors or undefined names
28
+ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
29
+ # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
30
+ flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
31
+ - name: Test with pytest
32
+ run: |
33
+ conda install pytest
34
+ pytest
.github/workflows/test.yaml ADDED
File without changes
.github/workflows/test.yml DELETED
@@ -1,25 +0,0 @@
1
- name: Run Python Tests
2
-
3
- on:
4
- push:
5
- branches:
6
- - main
7
- - master
8
- jobs:
9
- test:
10
- runs-on: ubuntu-latest
11
-
12
- steps:
13
- - name: Checkout code
14
- uses: actions/checkout@v2
15
-
16
- - name: Set up Python
17
- uses: actions/setup-python@v2
18
- with:
19
- python-version: 3.8
20
-
21
- - name: Install dependencies
22
- run: pip install -r requirements.txt # Adjust this based on your project structure
23
-
24
- - name: Run tests
25
- run: python -m unittest discover tests # Adjust this based on your test discovery method
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
CONTRIBUTING.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to ve-gans
2
+
3
+ Thank you for considering contributing to ve-gans! Please take a moment to review the following guidelines.
4
+
5
+ ## Code of Conduct
6
+
7
+ This project and everyone participating in it are governed by the [Code of Conduct](CODE_OF_CONDUCT.md). By participating, you agree to uphold this code. Please report unacceptable behavior to [your email or a dedicated email for issues].
8
+
9
+ ## How to Contribute
10
+
11
+ 1. Fork the repository.
12
+
13
+ 2. Clone the forked repository to your local machine:
14
+
15
+ ```bash
16
+ git clone https://github.com/zaibutcooler/ve-gans.git
17
+ ```
18
+
19
+ 3. Create a new branch for your feature or bug fix:
20
+
21
+ ```bash
22
+ git checkout -b feature-name
23
+ ```
24
+
25
+ 4. Make your changes and commit them with a descriptive commit message:
26
+
27
+ ```bash
28
+ git add .
29
+ git commit -m "Add your descriptive message here"
30
+ ```
31
+
32
+ 5. Push the changes to your fork:
33
+
34
+ ```bash
35
+ git push origin feature-name
36
+ ```
37
+
38
+ 6. Create a pull request (PR) from your fork to the main repository.
39
+
40
+ 7. Ensure your PR title and description are clear and concise.
41
+
42
+ ## Reporting Issues
43
+
44
+ If you find any issues or have suggestions, please open an issue on the [Issue Tracker](https://github.com/zaibutcooler/ve-gans/issues).
45
+
46
+ ## Style Guide
47
+
48
+ - Follow the existing coding style.
49
+ - Use meaningful variable and function names.
50
+ - Write clear and concise documentation.
51
+
52
+ ## License
53
+
54
+ By contributing, you agree that your contributions will be licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
55
+
56
+ Thank you for contributing to ve-gans!
README.md CHANGED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ve-gans: Image Generation with GANs using PyTorch
2
+
3
+ ## Overview
4
+
5
+ ve-gans is a project for image generation using Generative Adversarial Networks (GANs) implemented in PyTorch.
6
+
7
+ ## Features
8
+
9
+ - GAN model for image generation.
10
+ - Separate scripts for training and generating images.
11
+ - Easy-to-use command-line interface.
12
+
13
+ ## Installation
14
+
15
+ 1. **Clone the repository:**
16
+
17
+ ```bash
18
+ git clone https://github.com/zaibutcooler/ve-gans.git
19
+ cd ve-gans
20
+ ```
21
+
22
+ 2. **Install dependencies:**
23
+
24
+ ```bash
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+ ## Usage
29
+
30
+ ### Training
31
+
32
+ To train the GAN model, use the following command:
33
+
34
+ ```bash
35
+ ve-gans-train
36
+ ```
37
+
38
+ ## Generating Images
39
+
40
+ To generate images with the trained model, use the following command:
41
+
42
+ ```bash
43
+ ve-gans-generate
44
+ ```
45
+
46
+ ## Project Structure
47
+
48
+ - `ve_gans/`: Python package containing GAN implementation and utilities.
49
+ - `generator.py`: Implementation of the GAN generator.
50
+ - `discriminator.py`: Implementation of the GAN discriminator.
51
+ - `utils.py`: Utility functions.
52
+ - `requirements.txt`: List of project dependencies.
53
+ - `setup.py`: Setup script for installing the package.
54
+ - `main.py`: Example script for using the ve-gans package.
55
+
56
+ ## Contributing
57
+
58
+ Contributions are welcome! Please follow the [Contribution Guidelines](CONTRIBUTING.md).
59
+
60
+ ## License
61
+
62
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
63
+
64
+ ## Acknowledgments
65
+
66
+ Mention any contributors or libraries that you used or were inspired by.
67
+
68
+ ## Contact
69
+
70
+ - Zai
71
72
+ - Project Link: [https://github.com/zaibutcooler/ve-gans](https://github.com/zaibutcooler/ve-gans)
main.py ADDED
File without changes
notebooks/dcgan.ipynb ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 74,
22
+ "metadata": {
23
+ "id": "xNiydKOa0oFk"
24
+ },
25
+ "outputs": [],
26
+ "source": [
27
+ "#project gans"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "source": [
33
+ "import torch\n",
34
+ "import torchvision\n",
35
+ "import torch.nn as nn\n",
36
+ "import torch.nn.functional as F\n",
37
+ "from torch.utils.data import DataLoader\n",
38
+ "from torchvision import datasets, transforms\n",
39
+ "from torchvision.utils import save_image\n",
40
+ "import numpy as np\n",
41
+ "\n",
42
+ "# Check if GPU is available and set the device accordingly\n",
43
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
44
+ ],
45
+ "metadata": {
46
+ "id": "SCS7gRJQ0tyS"
47
+ },
48
+ "execution_count": 75,
49
+ "outputs": []
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "source": [
54
+ "def get_sample_image(generator, noise_dim):\n",
55
+ " \"\"\"\n",
56
+ " Save sample 100 images\n",
57
+ " \"\"\"\n",
58
+ " noise = torch.randn(100, noise_dim).to(device)\n",
59
+ " generated_images = generator(noise).view(100, 28, 28) # (100, 28, 28)\n",
60
+ " result = generated_images.cpu().data.numpy()\n",
61
+ " img = np.zeros([280, 280])\n",
62
+ " for j in range(10):\n",
63
+ " img[j * 28:(j + 1) * 28] = np.concatenate([x for x in result[j * 10:(j + 1) * 10]], axis=-1)\n",
64
+ " return img"
65
+ ],
66
+ "metadata": {
67
+ "id": "sacBbf_LwZx-"
68
+ },
69
+ "execution_count": 76,
70
+ "outputs": []
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "source": [
75
+ "class Discriminator(nn.Module):\n",
76
+ " def __init__(self, in_channels=1, num_classes=1):\n",
77
+ " super(Discriminator, self).__init__()\n",
78
+ " self.conv = nn.Sequential(\n",
79
+ " nn.Conv2d(in_channels, 512, 3, stride=2, padding=1, bias=False),\n",
80
+ " nn.BatchNorm2d(512),\n",
81
+ " nn.LeakyReLU(0.2),\n",
82
+ "\n",
83
+ " nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),\n",
84
+ " nn.BatchNorm2d(256),\n",
85
+ " nn.LeakyReLU(0.2),\n",
86
+ "\n",
87
+ " nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),\n",
88
+ " nn.BatchNorm2d(128),\n",
89
+ " nn.LeakyReLU(0.2),\n",
90
+ " nn.AvgPool2d(4),\n",
91
+ " )\n",
92
+ " self.fc = nn.Sequential(\n",
93
+ " nn.Linear(128, 1),\n",
94
+ " nn.Sigmoid(),\n",
95
+ " )\n",
96
+ "\n",
97
+ " def forward(self, x, y=False):\n",
98
+ " features = self.conv(x)\n",
99
+ " features = features.view(features.size(0), -1)\n",
100
+ " output = self.fc(features)\n",
101
+ " return output\n"
102
+ ],
103
+ "metadata": {
104
+ "id": "e9n-wD7dwZ7n"
105
+ },
106
+ "execution_count": 77,
107
+ "outputs": []
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "source": [
112
+ "class Generator(nn.Module):\n",
113
+ " def __init__(self, input_size=100, num_classes=784):\n",
114
+ " super(Generator, self).__init__()\n",
115
+ " self.fc = nn.Sequential(\n",
116
+ " nn.Linear(input_size, 4 * 4 * 512),\n",
117
+ " nn.ReLU(),\n",
118
+ " )\n",
119
+ " self.conv = nn.Sequential(\n",
120
+ " nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),\n",
121
+ " nn.BatchNorm2d(256),\n",
122
+ " nn.ReLU(),\n",
123
+ "\n",
124
+ " nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),\n",
125
+ " nn.BatchNorm2d(128),\n",
126
+ " nn.ReLU(),\n",
127
+ "\n",
128
+ " nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),\n",
129
+ " nn.Tanh(),\n",
130
+ " )\n",
131
+ "\n",
132
+ " def forward(self, x, y=None):\n",
133
+ " x = x.view(x.size(0), -1)\n",
134
+ " features = self.fc(x)\n",
135
+ " features = features.view(features.size(0), 512, 4, 4)\n",
136
+ " output = self.conv(features)\n",
137
+ " return output\n"
138
+ ],
139
+ "metadata": {
140
+ "id": "_8-E4605wZ-e"
141
+ },
142
+ "execution_count": 78,
143
+ "outputs": []
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "source": [
148
+ "# Instantiate the Generator and Discriminator\n",
149
+ "generator = Generator().to(device)\n",
150
+ "discriminator = Discriminator().to(device)"
151
+ ],
152
+ "metadata": {
153
+ "id": "OSDpsaYBypVA"
154
+ },
155
+ "execution_count": 79,
156
+ "outputs": []
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "source": [
161
+ "transform = transforms.Compose([transforms.ToTensor(),\n",
162
+ " transforms.Normalize(mean=[0.5],\n",
163
+ " std=[0.5])]\n",
164
+ ")"
165
+ ],
166
+ "metadata": {
167
+ "id": "yQ8QdKuCz2_a"
168
+ },
169
+ "execution_count": 79,
170
+ "outputs": []
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "source": [
175
+ "batch_size = 64\n",
176
+ "\n",
177
+ "data = torchvision.datasets.FashionMNIST(root='./data/', train=True, transform=transform, download=True)\n",
178
+ "data_loader = DataLoader(dataset=data, batch_size=batch_size, shuffle=True, drop_last=True)\n",
179
+ "\n",
180
+ "loss_fn = nn.BCELoss()\n",
181
+ "d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001, betas=(0.5, 0.999))\n",
182
+ "g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001, betas=(0.5, 0.999))\n"
183
+ ],
184
+ "metadata": {
185
+ "id": "8mOTuoih-3ep"
186
+ },
187
+ "execution_count": null,
188
+ "outputs": []
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "source": [
193
+ "max_epochs = 50\n",
194
+ "step = 0\n",
195
+ "n_critic = 1\n",
196
+ "n_noise = 100\n",
197
+ "\n",
198
+ "d_labels = torch.ones([batch_size, 1]).to(device)\n",
199
+ "d_fakes = torch.zeros([batch_size, 1]).to(device)"
200
+ ],
201
+ "metadata": {
202
+ "id": "kHJ0B3mk-4Bt"
203
+ },
204
+ "execution_count": null,
205
+ "outputs": []
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "source": [
210
+ "# Training loop\n",
211
+ "for epoch in range(max_epochs):\n",
212
+ " for idx, (images, labels) in enumerate(data_loader):\n",
213
+ " real_images = images.to(device)\n",
214
+ "\n",
215
+ " # Discriminator training\n",
216
+ " real_outputs = discriminator(real_images)\n",
217
+ " d_real_loss = loss_fn(real_outputs, d_labels)\n",
218
+ "\n",
219
+ " fake_noise = torch.randn(batch_size, n_noise).to(device)\n",
220
+ " fake_images = generator(fake_noise)\n",
221
+ " fake_outputs = discriminator(fake_images.detach())\n",
222
+ " d_fake_loss = loss_fn(fake_outputs, d_fakes)\n",
223
+ "\n",
224
+ " d_loss = d_real_loss + d_fake_loss\n",
225
+ "\n",
226
+ " discriminator.zero_grad()\n",
227
+ " d_loss.backward()\n",
228
+ " d_optimizer.step()\n",
229
+ "\n",
230
+ " # Generator training (every n_critic iterations)\n",
231
+ " if step % n_critic == 0:\n",
232
+ " fake_outputs = discriminator(fake_images)\n",
233
+ " g_loss = loss_fn(fake_outputs, d_labels)\n",
234
+ "\n",
235
+ " generator.zero_grad()\n",
236
+ " discriminator.zero_grad()\n",
237
+ " g_loss.backward()\n",
238
+ " g_optimizer.step()\n",
239
+ "\n",
240
+ " if step % 500 == 0:\n",
241
+ " print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epochs, step, d_loss.item(), g_loss.item()))\n",
242
+ "\n",
243
+ " if step % 1000 == 0:\n",
244
+ " generator.eval()\n",
245
+ " img = get_sample_image(generator, n_noise)\n",
246
+ " # imsave('samples/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')\n",
247
+ " generator.train()\n",
248
+ " step += 1"
249
+ ],
250
+ "metadata": {
251
+ "id": "1V9EfSBD-8E9"
252
+ },
253
+ "execution_count": null,
254
+ "outputs": []
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "source": [
259
+ "# neeed to test"
260
+ ],
261
+ "metadata": {
262
+ "id": "1g4ATYOD-9LY"
263
+ },
264
+ "execution_count": null,
265
+ "outputs": []
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "source": [],
270
+ "metadata": {
271
+ "id": "UPye6Ktu--Ph"
272
+ },
273
+ "execution_count": null,
274
+ "outputs": []
275
+ }
276
+ ]
277
+ }
notebooks/prototype.ipynb ADDED
File without changes
notebooks/sam/sam_1.ipynb DELETED
@@ -1,20 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "# notebook using sam model"
10
- ]
11
- }
12
- ],
13
- "metadata": {
14
- "language_info": {
15
- "name": "python"
16
- }
17
- },
18
- "nbformat": 4,
19
- "nbformat_minor": 2
20
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/vanilla-gans.ipynb ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import torchvision\n",
11
+ "import torch.nn as nn\n",
12
+ "import torch.nn.functional as F\n",
13
+ "\n",
14
+ "from torch.utils.data import DataLoader\n",
15
+ "from torchvision import datasets\n",
16
+ "from torchvision import transforms\n",
17
+ "from torchvision.utils import save_image\n",
18
+ "\n",
19
+ "import numpy as np\n",
20
+ "import datetime\n",
21
+ "\n",
22
+ "from matplotlib.pyplot import imshow, imsave\n",
23
+ "# %matplotlib inline\n",
24
+ "\n",
25
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "def get_sample_image(generator, noise_dim):\n",
35
+ " z = torch.randn(100, noise_dim).to(device)\n",
36
+ " generated_images = generator(z).view(100, 28, 28)\n",
37
+ " result = generated_images.cpu().data.numpy()\n",
38
+ " img = np.zeros([280, 280])\n",
39
+ " for j in range(10):\n",
40
+ " img[j * 28:(j + 1) * 28] = np.concatenate([x for x in result[j * 10:(j + 1) * 10]], axis=-1)\n",
41
+ " return img\n",
42
+ "\n",
43
+ "class Discriminator(nn.Module):\n",
44
+ " def __init__(self, input_size=784, num_classes=1):\n",
45
+ " super(Discriminator, self).__init__()\n",
46
+ " self.layers = nn.Sequential(\n",
47
+ " nn.Linear(input_size, 512),\n",
48
+ " nn.LeakyReLU(0.2),\n",
49
+ " nn.Linear(512, 256),\n",
50
+ " nn.LeakyReLU(0.2),\n",
51
+ " nn.Linear(256, num_classes),\n",
52
+ " nn.Sigmoid(),\n",
53
+ " )\n",
54
+ "\n",
55
+ " def forward(self, x):\n",
56
+ " x = x.view(x.size(0), -1)\n",
57
+ " x = self.layers(x)\n",
58
+ " return x\n",
59
+ "\n",
60
+ "class Generator(nn.Module):\n",
61
+ " def __init__(self, input_size=100, num_classes=784):\n",
62
+ " super(Generator, self).__init__()\n",
63
+ " self.layers = nn.Sequential(\n",
64
+ " nn.Linear(input_size, 128),\n",
65
+ " nn.LeakyReLU(0.2),\n",
66
+ " nn.Linear(128, 256),\n",
67
+ " nn.BatchNorm1d(256),\n",
68
+ " nn.LeakyReLU(0.2),\n",
69
+ " nn.Linear(256, 512),\n",
70
+ " nn.BatchNorm1d(512),\n",
71
+ " nn.LeakyReLU(0.2),\n",
72
+ " nn.Linear(512, 1024),\n",
73
+ " nn.BatchNorm1d(1024),\n",
74
+ " nn.LeakyReLU(0.2),\n",
75
+ " nn.Linear(1024, num_classes),\n",
76
+ " nn.Tanh()\n",
77
+ " )\n",
78
+ "\n",
79
+ " def forward(self, x):\n",
80
+ " x = self.layers(x)\n",
81
+ " x = x.view(x.size(0), 1, 28, 28)\n",
82
+ " return x\n"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": []
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "n_noise = 100\n",
99
+ "\n",
100
+ "discriminator = Discriminator().to(device)\n",
101
+ "generator = Generator().to(device)\n",
102
+ "\n",
103
+ "transform = transforms.Compose([transforms.ToTensor(),\n",
104
+ " transforms.Normalize(mean=[0.5],\n",
105
+ " std=[0.5])]\n",
106
+ ")\n",
107
+ "\n",
108
+ "mnist = datasets.MNIST(root='../data/', train=True, transform=transform, download=True)\n",
109
+ "\n",
110
+ "batch_size = 64\n",
111
+ "\n",
112
+ "data_loader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True, drop_last=True)\n",
113
+ "\n",
114
+ "loss_fn = nn.BCELoss()\n",
115
+ "d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
116
+ "g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
117
+ "\n",
118
+ "max_epoch = 50\n",
119
+ "step = 0\n",
120
+ "n_critic = 1\n"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "d_labels = torch.ones(batch_size, 1).to(device)\n",
130
+ "d_fakes = torch.zeros(batch_size, 1).to(device)\n",
131
+ "\n",
132
+ "# Training loop\n",
133
+ "for epoch in range(max_epoch):\n",
134
+ " for idx, (images, _) in enumerate(data_loader):\n",
135
+ " real_images = images.to(device)\n",
136
+ " real_outputs = discriminator(real_images)\n",
137
+ " d_real_loss = loss_fn(real_outputs, d_labels)\n",
138
+ "\n",
139
+ " fake_noise = torch.randn(batch_size, n_noise).to(device)\n",
140
+ " fake_images = generator(fake_noise)\n",
141
+ " fake_outputs = discriminator(fake_images.detach())\n",
142
+ " d_fake_loss = loss_fn(fake_outputs, d_fakes)\n",
143
+ "\n",
144
+ " d_loss = d_real_loss + d_fake_loss\n",
145
+ "\n",
146
+ " discriminator.zero_grad()\n",
147
+ " d_loss.backward()\n",
148
+ " d_optimizer.step()\n",
149
+ "\n",
150
+ " if step % n_critic == 0:\n",
151
+ " fake_outputs = discriminator(generator(fake_noise))\n",
152
+ " g_loss = loss_fn(fake_outputs, d_labels)\n",
153
+ "\n",
154
+ " generator.zero_grad()\n",
155
+ " g_loss.backward()\n",
156
+ " g_optimizer.step()\n",
157
+ "\n",
158
+ " if step % 1000 == 0:\n",
159
+ " generator.eval()\n",
160
+ " img = get_sample_image(generator, n_noise)\n",
161
+ " # imsave('samples/{}_step{}.jpg'.format('gans', str(step).zfill(3)), img, cmap='gray')\n",
162
+ " generator.train()\n",
163
+ " step += 1\n"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "metadata": {},
170
+ "outputs": [],
171
+ "source": [
172
+ "generator.eval()\n",
173
+ "imshow(get_sample_image(generator, n_noise), cmap='gray')\n",
174
+ "\n",
175
+ "torch.save(discriminator.state_dict(), 'discriminator.pth')\n",
176
+ "torch.save(generator.state_dict(), 'generator.pth')\n"
177
+ ]
178
+ }
179
+ ],
180
+ "metadata": {
181
+ "language_info": {
182
+ "name": "python"
183
+ }
184
+ },
185
+ "nbformat": 4,
186
+ "nbformat_minor": 2
187
+ }
prototype/.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prototype/README.md DELETED
@@ -1,13 +0,0 @@
1
- ---
2
- title: Pearl Prototype
3
- emoji: 💻
4
- colorFrom: blue
5
- colorTo: red
6
- sdk: streamlit
7
- sdk_version: 1.29.0
8
- app_file: app.py
9
- pinned: false
10
- license: openrail
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prototype/app.py DELETED
@@ -1,40 +0,0 @@
1
- import gradio as gr
2
- from io import BytesIO
3
-
4
- from torch import autocast
5
- import requests
6
- import PIL
7
- import torch
8
- from diffusers import StableDiffusionInpaintPipeline as StableDiffusionInpaintPipeline
9
-
10
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
11
- "CompVis/stable-diffusion-v1-4",
12
- revision="fp16",
13
- torch_dtype=torch.float16,
14
- use_auth_token=True,
15
- )
16
-
17
-
18
- def process_image(dict, prompt):
19
- init_img = dict["image"].convert("RGB").resize((512, 512))
20
- mask_img = dict["mask"].convert("RGB").resize((512, 512))
21
- images = pipe(
22
- prompt=prompt, init_image=init_img, mask_image=mask_img, strength=0.75
23
- )["sample"]
24
- return images[0]
25
-
26
-
27
- iface = gr.Interface(
28
- fn=process_image,
29
- title="Stable Diffusion In-Painting Tool on Colab with Gradio",
30
- inputs=[
31
- gr.Image(source="upload", tool="sketch", type="pil"),
32
- gr.Textbox(label="prompt"),
33
- ],
34
- outputs=[gr.Image()],
35
- description="Choose a feature and upload an image to see the processed result.",
36
- article="<p style='text-align: center;'>Built with Gradio</p>",
37
- )
38
-
39
-
40
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prototype/inpainting.py DELETED
@@ -1,227 +0,0 @@
1
- # credit : Hugging Face Team
2
- import inspect
3
- from typing import List, Optional, Union
4
-
5
- import numpy as np
6
- import torch
7
-
8
- import PIL
9
- from diffusers import (
10
- AutoencoderKL,
11
- DDIMScheduler,
12
- DiffusionPipeline,
13
- PNDMScheduler,
14
- UNet2DConditionModel,
15
- )
16
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
17
- from tqdm.auto import tqdm
18
- from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
19
-
20
-
21
- def preprocess_image(image):
22
- w, h = image.size
23
- w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
24
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
25
- image = np.array(image).astype(np.float32) / 255.0
26
- image = image[None].transpose(0, 3, 1, 2)
27
- image = torch.from_numpy(image)
28
- return 2.0 * image - 1.0
29
-
30
-
31
- def preprocess_mask(mask):
32
- mask = mask.convert("L")
33
- w, h = mask.size
34
- w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
35
- mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
36
- mask = np.array(mask).astype(np.float32) / 255.0
37
- mask = np.tile(mask, (4, 1, 1))
38
- mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
39
- mask = 1 - mask # repaint white, keep black
40
- mask = torch.from_numpy(mask)
41
- return mask
42
-
43
-
44
- class StableDiffusionInpaintingPipeline(DiffusionPipeline):
45
- def __init__(
46
- self,
47
- vae: AutoencoderKL,
48
- text_encoder: CLIPTextModel,
49
- tokenizer: CLIPTokenizer,
50
- unet: UNet2DConditionModel,
51
- scheduler: Union[DDIMScheduler, PNDMScheduler],
52
- safety_checker: StableDiffusionSafetyChecker,
53
- feature_extractor: CLIPFeatureExtractor,
54
- ):
55
- super().__init__()
56
- scheduler = scheduler.set_format("pt")
57
- self.register_modules(
58
- vae=vae,
59
- text_encoder=text_encoder,
60
- tokenizer=tokenizer,
61
- unet=unet,
62
- scheduler=scheduler,
63
- safety_checker=safety_checker,
64
- feature_extractor=feature_extractor,
65
- )
66
-
67
- @torch.no_grad()
68
- def __call__(
69
- self,
70
- prompt: Union[str, List[str]],
71
- init_image: torch.FloatTensor,
72
- mask_image: torch.FloatTensor,
73
- strength: float = 0.8,
74
- num_inference_steps: Optional[int] = 50,
75
- guidance_scale: Optional[float] = 7.5,
76
- eta: Optional[float] = 0.0,
77
- generator: Optional[torch.Generator] = None,
78
- output_type: Optional[str] = "pil",
79
- ):
80
- if isinstance(prompt, str):
81
- batch_size = 1
82
- elif isinstance(prompt, list):
83
- batch_size = len(prompt)
84
- else:
85
- raise ValueError(
86
- f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
87
- )
88
-
89
- if strength < 0 or strength > 1:
90
- raise ValueError(
91
- f"The value of strength should in [0.0, 1.0] but is {strength}"
92
- )
93
-
94
- # set timesteps
95
- accepts_offset = "offset" in set(
96
- inspect.signature(self.scheduler.set_timesteps).parameters.keys()
97
- )
98
- extra_set_kwargs = {}
99
- offset = 0
100
- if accepts_offset:
101
- offset = 1
102
- extra_set_kwargs["offset"] = 1
103
-
104
- self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
105
-
106
- # preprocess image
107
- init_image = preprocess_image(init_image).to(self.device)
108
-
109
- # encode the init image into latents and scale the latents
110
- init_latents = self.vae.encode(init_image).sample()
111
- init_latents = 0.18215 * init_latents
112
-
113
- # prepare init_latents noise to latents
114
- init_latents = torch.cat([init_latents] * batch_size)
115
- init_latents_orig = init_latents
116
-
117
- # preprocess mask
118
- mask = preprocess_mask(mask_image).to(self.device)
119
- mask = torch.cat([mask] * batch_size)
120
-
121
- # check sizes
122
- if not mask.shape == init_latents.shape:
123
- raise ValueError(f"The mask and init_image should be the same size!")
124
-
125
- # get the original timestep using init_timestep
126
- init_timestep = int(num_inference_steps * strength) + offset
127
- init_timestep = min(init_timestep, num_inference_steps)
128
- timesteps = self.scheduler.timesteps[-init_timestep]
129
- timesteps = torch.tensor(
130
- [timesteps] * batch_size, dtype=torch.long, device=self.device
131
- )
132
-
133
- # add noise to latents using the timesteps
134
- noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
135
- init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
136
-
137
- # get prompt text embeddings
138
- text_input = self.tokenizer(
139
- prompt,
140
- padding="max_length",
141
- max_length=self.tokenizer.model_max_length,
142
- truncation=True,
143
- return_tensors="pt",
144
- )
145
- text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
146
-
147
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
148
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
149
- # corresponds to doing no classifier free guidance.
150
- do_classifier_free_guidance = guidance_scale > 1.0
151
- # get unconditional embeddings for classifier free guidance
152
- if do_classifier_free_guidance:
153
- max_length = text_input.input_ids.shape[-1]
154
- uncond_input = self.tokenizer(
155
- [""] * batch_size,
156
- padding="max_length",
157
- max_length=max_length,
158
- return_tensors="pt",
159
- )
160
- uncond_embeddings = self.text_encoder(
161
- uncond_input.input_ids.to(self.device)
162
- )[0]
163
-
164
- # For classifier free guidance, we need to do two forward passes.
165
- # Here we concatenate the unconditional and text embeddings into a single batch
166
- # to avoid doing two forward passes
167
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
168
-
169
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
170
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
171
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
172
- # and should be between [0, 1]
173
- accepts_eta = "eta" in set(
174
- inspect.signature(self.scheduler.step).parameters.keys()
175
- )
176
- extra_step_kwargs = {}
177
- if accepts_eta:
178
- extra_step_kwargs["eta"] = eta
179
-
180
- latents = init_latents
181
- t_start = max(num_inference_steps - init_timestep + offset, 0)
182
- for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
183
- # expand the latents if we are doing classifier free guidance
184
- latent_model_input = (
185
- torch.cat([latents] * 2) if do_classifier_free_guidance else latents
186
- )
187
-
188
- # predict the noise residual
189
- noise_pred = self.unet(
190
- latent_model_input, t, encoder_hidden_states=text_embeddings
191
- )["sample"]
192
-
193
- # perform guidance
194
- if do_classifier_free_guidance:
195
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
196
- noise_pred = noise_pred_uncond + guidance_scale * (
197
- noise_pred_text - noise_pred_uncond
198
- )
199
-
200
- # compute the previous noisy sample x_t -> x_t-1
201
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[
202
- "prev_sample"
203
- ]
204
-
205
- # masking
206
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
207
- latents = (init_latents_proper * mask) + (latents * (1 - mask))
208
-
209
- # scale and decode the image latents with vae
210
- latents = 1 / 0.18215 * latents
211
- image = self.vae.decode(latents)
212
-
213
- image = (image / 2 + 0.5).clamp(0, 1)
214
- image = image.cpu().permute(0, 2, 3, 1).numpy()
215
-
216
- # run safety checker
217
- safety_cheker_input = self.feature_extractor(
218
- self.numpy_to_pil(image), return_tensors="pt"
219
- ).to(self.device)
220
- image, has_nsfw_concept = self.safety_checker(
221
- images=image, clip_input=safety_cheker_input.pixel_values
222
- )
223
-
224
- if output_type == "pil":
225
- image = self.numpy_to_pil(image)
226
-
227
- return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prototype/requirements.txt DELETED
@@ -1,7 +0,0 @@
1
- torch
2
- requests
3
- pillow
4
- diffusers
5
- gradio
6
- numpy
7
- tqdm
 
 
 
 
 
 
 
 
prototype/test.py DELETED
@@ -1,12 +0,0 @@
1
- import gradio as gr
2
-
3
- def greet(name, intensity):
4
- return "Hello " * intensity + name + "!"
5
-
6
- demo = gr.Interface(
7
- fn=greet,
8
- inputs=["text", "slider"],
9
- outputs=["text"],
10
- )
11
-
12
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
prototype/utils.py DELETED
@@ -1,13 +0,0 @@
1
- def add_feature(image):
2
- # inpainting features
3
- pass
4
-
5
-
6
- def remove_feature(image):
7
- # inpainting features
8
- pass
9
-
10
-
11
- def enhance_feature(image):
12
- # inpainting features
13
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
setup.py CHANGED
@@ -1,44 +1,27 @@
1
- # need to change
2
 
3
- # import platform
4
- # import sys
5
- # from pathlib import Path
6
 
7
- # import pkg_resources
8
- # from setuptools import find_packages, setup
9
-
10
-
11
- # def read_version(fname="version.py"):
12
- # exec(compile(open(fname, encoding="utf-8").read(), fname, "exec"))
13
- # return locals()["__version__"]
14
-
15
-
16
- # requirements = []
17
- # if sys.platform.startswith("linux") and platform.machine() == "x86_64":
18
- # requirements.append("triton>=2.0.0,<3")
19
-
20
- # setup(
21
- # name="openai-whisper",
22
- # py_modules=["whisper"],
23
- # version=read_version(),
24
- # description="Robust Speech Recognition via Large-Scale Weak Supervision",
25
- # long_description=open("README.md", encoding="utf-8").read(),
26
- # long_description_content_type="text/markdown",
27
- # readme="README.md",
28
- # python_requires=">=3.8",
29
- # author="OpenAI",
30
- # url="https://github.com/openai/whisper",
31
- # license="MIT",
32
- # packages=find_packages(exclude=["tests*"]),
33
- # install_requires=[
34
- # str(r)
35
- # for r in pkg_resources.parse_requirements(
36
- # Path(__file__).with_name("requirements.txt").open()
37
- # )
38
- # ],
39
- # entry_points={
40
- # "console_scripts": ["whisper=whisper.transcribe:cli"],
41
- # },
42
- # include_package_data=True,
43
- # extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
44
- # )
 
1
+ from setuptools import setup, find_packages
2
 
3
+ with open('requirements.txt') as f:
4
+ requirements = f.read().splitlines()
 
5
 
6
+ setup(
7
+ name='ve-gans',
8
+ version='0.1',
9
+ packages=find_packages(),
10
+ install_requires=requirements,
11
+ entry_points={
12
+ 'console_scripts': [
13
+ 've-gans-train=ve_gans.train:main',
14
+ 've-gans-generate=ve_gans.generate:main',
15
+ ],
16
+ },
17
+ author='Zai',
18
+ author_email='[email protected]',
19
+ description='Image generation with GANs using PyTorch',
20
+ long_description='Detailed description of your project',
21
+ url='https://github.com/zaibutcooler/ve-gans',
22
+ classifiers=[
23
+ 'Programming Language :: Python :: 3',
24
+ 'License :: OSI Approved :: MIT License',
25
+ 'Operating System :: OS Independent',
26
+ ],
27
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
space/app.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from torchvision.utils import make_grid
4
+ from torchvision.transforms import ToPILImage
5
+
6
+ def main():
7
+ st.title("Image Generation")
8
+ st.write("Made with GANS from scratch")
9
+
10
+
11
+ if __name__ == '__main__':
12
+ main()
tests/{demo.py → test_demo.py} RENAMED
File without changes
vegans/__init__.py ADDED
File without changes
vegans/discriminator.py ADDED
File without changes
vegans/generator.py ADDED
File without changes
vegans/utils.py ADDED
File without changes