Spaces:
Sleeping
Sleeping
Zai
commited on
Commit
•
34253f7
1
Parent(s):
47fbacb
setup.py added
Browse files- .github/workflows/python-package-conda.yml +34 -0
- .github/workflows/test.yaml +0 -0
- .github/workflows/test.yml +0 -25
- CONTRIBUTING.md +56 -0
- README.md +72 -0
- main.py +0 -0
- notebooks/dcgan.ipynb +277 -0
- notebooks/prototype.ipynb +0 -0
- notebooks/sam/sam_1.ipynb +0 -20
- notebooks/vanilla-gans.ipynb +187 -0
- prototype/.gitattributes +0 -35
- prototype/README.md +0 -13
- prototype/app.py +0 -40
- prototype/inpainting.py +0 -227
- prototype/requirements.txt +0 -7
- prototype/test.py +0 -12
- prototype/utils.py +0 -13
- setup.py +25 -42
- space/app.py +12 -0
- tests/{demo.py → test_demo.py} +0 -0
- vegans/__init__.py +0 -0
- vegans/discriminator.py +0 -0
- vegans/generator.py +0 -0
- vegans/utils.py +0 -0
.github/workflows/python-package-conda.yml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Python Package using Conda
|
2 |
+
|
3 |
+
on: [push]
|
4 |
+
|
5 |
+
jobs:
|
6 |
+
build-linux:
|
7 |
+
runs-on: ubuntu-latest
|
8 |
+
strategy:
|
9 |
+
max-parallel: 5
|
10 |
+
|
11 |
+
steps:
|
12 |
+
- uses: actions/checkout@v3
|
13 |
+
- name: Set up Python 3.10
|
14 |
+
uses: actions/setup-python@v3
|
15 |
+
with:
|
16 |
+
python-version: '3.10'
|
17 |
+
- name: Add conda to system path
|
18 |
+
run: |
|
19 |
+
# $CONDA is an environment variable pointing to the root of the miniconda directory
|
20 |
+
echo $CONDA/bin >> $GITHUB_PATH
|
21 |
+
- name: Install dependencies
|
22 |
+
run: |
|
23 |
+
conda env update --file environment.yml --name base
|
24 |
+
- name: Lint with flake8
|
25 |
+
run: |
|
26 |
+
conda install flake8
|
27 |
+
# stop the build if there are Python syntax errors or undefined names
|
28 |
+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
29 |
+
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
30 |
+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
31 |
+
- name: Test with pytest
|
32 |
+
run: |
|
33 |
+
conda install pytest
|
34 |
+
pytest
|
.github/workflows/test.yaml
ADDED
File without changes
|
.github/workflows/test.yml
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
name: Run Python Tests
|
2 |
-
|
3 |
-
on:
|
4 |
-
push:
|
5 |
-
branches:
|
6 |
-
- main
|
7 |
-
- master
|
8 |
-
jobs:
|
9 |
-
test:
|
10 |
-
runs-on: ubuntu-latest
|
11 |
-
|
12 |
-
steps:
|
13 |
-
- name: Checkout code
|
14 |
-
uses: actions/checkout@v2
|
15 |
-
|
16 |
-
- name: Set up Python
|
17 |
-
uses: actions/setup-python@v2
|
18 |
-
with:
|
19 |
-
python-version: 3.8
|
20 |
-
|
21 |
-
- name: Install dependencies
|
22 |
-
run: pip install -r requirements.txt # Adjust this based on your project structure
|
23 |
-
|
24 |
-
- name: Run tests
|
25 |
-
run: python -m unittest discover tests # Adjust this based on your test discovery method
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to ve-gans
|
2 |
+
|
3 |
+
Thank you for considering contributing to ve-gans! Please take a moment to review the following guidelines.
|
4 |
+
|
5 |
+
## Code of Conduct
|
6 |
+
|
7 |
+
This project and everyone participating in it are governed by the [Code of Conduct](CODE_OF_CONDUCT.md). By participating, you agree to uphold this code. Please report unacceptable behavior to [your email or a dedicated email for issues].
|
8 |
+
|
9 |
+
## How to Contribute
|
10 |
+
|
11 |
+
1. Fork the repository.
|
12 |
+
|
13 |
+
2. Clone the forked repository to your local machine:
|
14 |
+
|
15 |
+
```bash
|
16 |
+
git clone https://github.com/zaibutcooler/ve-gans.git
|
17 |
+
```
|
18 |
+
|
19 |
+
3. Create a new branch for your feature or bug fix:
|
20 |
+
|
21 |
+
```bash
|
22 |
+
git checkout -b feature-name
|
23 |
+
```
|
24 |
+
|
25 |
+
4. Make your changes and commit them with a descriptive commit message:
|
26 |
+
|
27 |
+
```bash
|
28 |
+
git add .
|
29 |
+
git commit -m "Add your descriptive message here"
|
30 |
+
```
|
31 |
+
|
32 |
+
5. Push the changes to your fork:
|
33 |
+
|
34 |
+
```bash
|
35 |
+
git push origin feature-name
|
36 |
+
```
|
37 |
+
|
38 |
+
6. Create a pull request (PR) from your fork to the main repository.
|
39 |
+
|
40 |
+
7. Ensure your PR title and description are clear and concise.
|
41 |
+
|
42 |
+
## Reporting Issues
|
43 |
+
|
44 |
+
If you find any issues or have suggestions, please open an issue on the [Issue Tracker](https://github.com/zaibutcooler/ve-gans/issues).
|
45 |
+
|
46 |
+
## Style Guide
|
47 |
+
|
48 |
+
- Follow the existing coding style.
|
49 |
+
- Use meaningful variable and function names.
|
50 |
+
- Write clear and concise documentation.
|
51 |
+
|
52 |
+
## License
|
53 |
+
|
54 |
+
By contributing, you agree that your contributions will be licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
|
55 |
+
|
56 |
+
Thank you for contributing to ve-gans!
|
README.md
CHANGED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ve-gans: Image Generation with GANs using PyTorch
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
ve-gans is a project for image generation using Generative Adversarial Networks (GANs) implemented in PyTorch.
|
6 |
+
|
7 |
+
## Features
|
8 |
+
|
9 |
+
- GAN model for image generation.
|
10 |
+
- Separate scripts for training and generating images.
|
11 |
+
- Easy-to-use command-line interface.
|
12 |
+
|
13 |
+
## Installation
|
14 |
+
|
15 |
+
1. **Clone the repository:**
|
16 |
+
|
17 |
+
```bash
|
18 |
+
git clone https://github.com/zaibutcooler/ve-gans.git
|
19 |
+
cd ve-gans
|
20 |
+
```
|
21 |
+
|
22 |
+
2. **Install dependencies:**
|
23 |
+
|
24 |
+
```bash
|
25 |
+
pip install -r requirements.txt
|
26 |
+
```
|
27 |
+
|
28 |
+
## Usage
|
29 |
+
|
30 |
+
### Training
|
31 |
+
|
32 |
+
To train the GAN model, use the following command:
|
33 |
+
|
34 |
+
```bash
|
35 |
+
ve-gans-train
|
36 |
+
```
|
37 |
+
|
38 |
+
## Generating Images
|
39 |
+
|
40 |
+
To generate images with the trained model, use the following command:
|
41 |
+
|
42 |
+
```bash
|
43 |
+
ve-gans-generate
|
44 |
+
```
|
45 |
+
|
46 |
+
## Project Structure
|
47 |
+
|
48 |
+
- `ve_gans/`: Python package containing GAN implementation and utilities.
|
49 |
+
- `generator.py`: Implementation of the GAN generator.
|
50 |
+
- `discriminator.py`: Implementation of the GAN discriminator.
|
51 |
+
- `utils.py`: Utility functions.
|
52 |
+
- `requirements.txt`: List of project dependencies.
|
53 |
+
- `setup.py`: Setup script for installing the package.
|
54 |
+
- `main.py`: Example script for using the ve-gans package.
|
55 |
+
|
56 |
+
## Contributing
|
57 |
+
|
58 |
+
Contributions are welcome! Please follow the [Contribution Guidelines](CONTRIBUTING.md).
|
59 |
+
|
60 |
+
## License
|
61 |
+
|
62 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
63 |
+
|
64 |
+
## Acknowledgments
|
65 |
+
|
66 |
+
Mention any contributors or libraries that you used or were inspired by.
|
67 |
+
|
68 |
+
## Contact
|
69 |
+
|
70 |
+
- Zai
|
71 | |
72 |
+
- Project Link: [https://github.com/zaibutcooler/ve-gans](https://github.com/zaibutcooler/ve-gans)
|
main.py
ADDED
File without changes
|
notebooks/dcgan.ipynb
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"provenance": [],
|
7 |
+
"gpuType": "T4"
|
8 |
+
},
|
9 |
+
"kernelspec": {
|
10 |
+
"name": "python3",
|
11 |
+
"display_name": "Python 3"
|
12 |
+
},
|
13 |
+
"language_info": {
|
14 |
+
"name": "python"
|
15 |
+
},
|
16 |
+
"accelerator": "GPU"
|
17 |
+
},
|
18 |
+
"cells": [
|
19 |
+
{
|
20 |
+
"cell_type": "code",
|
21 |
+
"execution_count": 74,
|
22 |
+
"metadata": {
|
23 |
+
"id": "xNiydKOa0oFk"
|
24 |
+
},
|
25 |
+
"outputs": [],
|
26 |
+
"source": [
|
27 |
+
"#project gans"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "code",
|
32 |
+
"source": [
|
33 |
+
"import torch\n",
|
34 |
+
"import torchvision\n",
|
35 |
+
"import torch.nn as nn\n",
|
36 |
+
"import torch.nn.functional as F\n",
|
37 |
+
"from torch.utils.data import DataLoader\n",
|
38 |
+
"from torchvision import datasets, transforms\n",
|
39 |
+
"from torchvision.utils import save_image\n",
|
40 |
+
"import numpy as np\n",
|
41 |
+
"\n",
|
42 |
+
"# Check if GPU is available and set the device accordingly\n",
|
43 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
44 |
+
],
|
45 |
+
"metadata": {
|
46 |
+
"id": "SCS7gRJQ0tyS"
|
47 |
+
},
|
48 |
+
"execution_count": 75,
|
49 |
+
"outputs": []
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"source": [
|
54 |
+
"def get_sample_image(generator, noise_dim):\n",
|
55 |
+
" \"\"\"\n",
|
56 |
+
" Save sample 100 images\n",
|
57 |
+
" \"\"\"\n",
|
58 |
+
" noise = torch.randn(100, noise_dim).to(device)\n",
|
59 |
+
" generated_images = generator(noise).view(100, 28, 28) # (100, 28, 28)\n",
|
60 |
+
" result = generated_images.cpu().data.numpy()\n",
|
61 |
+
" img = np.zeros([280, 280])\n",
|
62 |
+
" for j in range(10):\n",
|
63 |
+
" img[j * 28:(j + 1) * 28] = np.concatenate([x for x in result[j * 10:(j + 1) * 10]], axis=-1)\n",
|
64 |
+
" return img"
|
65 |
+
],
|
66 |
+
"metadata": {
|
67 |
+
"id": "sacBbf_LwZx-"
|
68 |
+
},
|
69 |
+
"execution_count": 76,
|
70 |
+
"outputs": []
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"source": [
|
75 |
+
"class Discriminator(nn.Module):\n",
|
76 |
+
" def __init__(self, in_channels=1, num_classes=1):\n",
|
77 |
+
" super(Discriminator, self).__init__()\n",
|
78 |
+
" self.conv = nn.Sequential(\n",
|
79 |
+
" nn.Conv2d(in_channels, 512, 3, stride=2, padding=1, bias=False),\n",
|
80 |
+
" nn.BatchNorm2d(512),\n",
|
81 |
+
" nn.LeakyReLU(0.2),\n",
|
82 |
+
"\n",
|
83 |
+
" nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),\n",
|
84 |
+
" nn.BatchNorm2d(256),\n",
|
85 |
+
" nn.LeakyReLU(0.2),\n",
|
86 |
+
"\n",
|
87 |
+
" nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),\n",
|
88 |
+
" nn.BatchNorm2d(128),\n",
|
89 |
+
" nn.LeakyReLU(0.2),\n",
|
90 |
+
" nn.AvgPool2d(4),\n",
|
91 |
+
" )\n",
|
92 |
+
" self.fc = nn.Sequential(\n",
|
93 |
+
" nn.Linear(128, 1),\n",
|
94 |
+
" nn.Sigmoid(),\n",
|
95 |
+
" )\n",
|
96 |
+
"\n",
|
97 |
+
" def forward(self, x, y=False):\n",
|
98 |
+
" features = self.conv(x)\n",
|
99 |
+
" features = features.view(features.size(0), -1)\n",
|
100 |
+
" output = self.fc(features)\n",
|
101 |
+
" return output\n"
|
102 |
+
],
|
103 |
+
"metadata": {
|
104 |
+
"id": "e9n-wD7dwZ7n"
|
105 |
+
},
|
106 |
+
"execution_count": 77,
|
107 |
+
"outputs": []
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"cell_type": "code",
|
111 |
+
"source": [
|
112 |
+
"class Generator(nn.Module):\n",
|
113 |
+
" def __init__(self, input_size=100, num_classes=784):\n",
|
114 |
+
" super(Generator, self).__init__()\n",
|
115 |
+
" self.fc = nn.Sequential(\n",
|
116 |
+
" nn.Linear(input_size, 4 * 4 * 512),\n",
|
117 |
+
" nn.ReLU(),\n",
|
118 |
+
" )\n",
|
119 |
+
" self.conv = nn.Sequential(\n",
|
120 |
+
" nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),\n",
|
121 |
+
" nn.BatchNorm2d(256),\n",
|
122 |
+
" nn.ReLU(),\n",
|
123 |
+
"\n",
|
124 |
+
" nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),\n",
|
125 |
+
" nn.BatchNorm2d(128),\n",
|
126 |
+
" nn.ReLU(),\n",
|
127 |
+
"\n",
|
128 |
+
" nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),\n",
|
129 |
+
" nn.Tanh(),\n",
|
130 |
+
" )\n",
|
131 |
+
"\n",
|
132 |
+
" def forward(self, x, y=None):\n",
|
133 |
+
" x = x.view(x.size(0), -1)\n",
|
134 |
+
" features = self.fc(x)\n",
|
135 |
+
" features = features.view(features.size(0), 512, 4, 4)\n",
|
136 |
+
" output = self.conv(features)\n",
|
137 |
+
" return output\n"
|
138 |
+
],
|
139 |
+
"metadata": {
|
140 |
+
"id": "_8-E4605wZ-e"
|
141 |
+
},
|
142 |
+
"execution_count": 78,
|
143 |
+
"outputs": []
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"cell_type": "code",
|
147 |
+
"source": [
|
148 |
+
"# Instantiate the Generator and Discriminator\n",
|
149 |
+
"generator = Generator().to(device)\n",
|
150 |
+
"discriminator = Discriminator().to(device)"
|
151 |
+
],
|
152 |
+
"metadata": {
|
153 |
+
"id": "OSDpsaYBypVA"
|
154 |
+
},
|
155 |
+
"execution_count": 79,
|
156 |
+
"outputs": []
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"cell_type": "code",
|
160 |
+
"source": [
|
161 |
+
"transform = transforms.Compose([transforms.ToTensor(),\n",
|
162 |
+
" transforms.Normalize(mean=[0.5],\n",
|
163 |
+
" std=[0.5])]\n",
|
164 |
+
")"
|
165 |
+
],
|
166 |
+
"metadata": {
|
167 |
+
"id": "yQ8QdKuCz2_a"
|
168 |
+
},
|
169 |
+
"execution_count": 79,
|
170 |
+
"outputs": []
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "code",
|
174 |
+
"source": [
|
175 |
+
"batch_size = 64\n",
|
176 |
+
"\n",
|
177 |
+
"data = torchvision.datasets.FashionMNIST(root='./data/', train=True, transform=transform, download=True)\n",
|
178 |
+
"data_loader = DataLoader(dataset=data, batch_size=batch_size, shuffle=True, drop_last=True)\n",
|
179 |
+
"\n",
|
180 |
+
"loss_fn = nn.BCELoss()\n",
|
181 |
+
"d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001, betas=(0.5, 0.999))\n",
|
182 |
+
"g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001, betas=(0.5, 0.999))\n"
|
183 |
+
],
|
184 |
+
"metadata": {
|
185 |
+
"id": "8mOTuoih-3ep"
|
186 |
+
},
|
187 |
+
"execution_count": null,
|
188 |
+
"outputs": []
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"cell_type": "code",
|
192 |
+
"source": [
|
193 |
+
"max_epochs = 50\n",
|
194 |
+
"step = 0\n",
|
195 |
+
"n_critic = 1\n",
|
196 |
+
"n_noise = 100\n",
|
197 |
+
"\n",
|
198 |
+
"d_labels = torch.ones([batch_size, 1]).to(device)\n",
|
199 |
+
"d_fakes = torch.zeros([batch_size, 1]).to(device)"
|
200 |
+
],
|
201 |
+
"metadata": {
|
202 |
+
"id": "kHJ0B3mk-4Bt"
|
203 |
+
},
|
204 |
+
"execution_count": null,
|
205 |
+
"outputs": []
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "code",
|
209 |
+
"source": [
|
210 |
+
"# Training loop\n",
|
211 |
+
"for epoch in range(max_epochs):\n",
|
212 |
+
" for idx, (images, labels) in enumerate(data_loader):\n",
|
213 |
+
" real_images = images.to(device)\n",
|
214 |
+
"\n",
|
215 |
+
" # Discriminator training\n",
|
216 |
+
" real_outputs = discriminator(real_images)\n",
|
217 |
+
" d_real_loss = loss_fn(real_outputs, d_labels)\n",
|
218 |
+
"\n",
|
219 |
+
" fake_noise = torch.randn(batch_size, n_noise).to(device)\n",
|
220 |
+
" fake_images = generator(fake_noise)\n",
|
221 |
+
" fake_outputs = discriminator(fake_images.detach())\n",
|
222 |
+
" d_fake_loss = loss_fn(fake_outputs, d_fakes)\n",
|
223 |
+
"\n",
|
224 |
+
" d_loss = d_real_loss + d_fake_loss\n",
|
225 |
+
"\n",
|
226 |
+
" discriminator.zero_grad()\n",
|
227 |
+
" d_loss.backward()\n",
|
228 |
+
" d_optimizer.step()\n",
|
229 |
+
"\n",
|
230 |
+
" # Generator training (every n_critic iterations)\n",
|
231 |
+
" if step % n_critic == 0:\n",
|
232 |
+
" fake_outputs = discriminator(fake_images)\n",
|
233 |
+
" g_loss = loss_fn(fake_outputs, d_labels)\n",
|
234 |
+
"\n",
|
235 |
+
" generator.zero_grad()\n",
|
236 |
+
" discriminator.zero_grad()\n",
|
237 |
+
" g_loss.backward()\n",
|
238 |
+
" g_optimizer.step()\n",
|
239 |
+
"\n",
|
240 |
+
" if step % 500 == 0:\n",
|
241 |
+
" print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epochs, step, d_loss.item(), g_loss.item()))\n",
|
242 |
+
"\n",
|
243 |
+
" if step % 1000 == 0:\n",
|
244 |
+
" generator.eval()\n",
|
245 |
+
" img = get_sample_image(generator, n_noise)\n",
|
246 |
+
" # imsave('samples/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')\n",
|
247 |
+
" generator.train()\n",
|
248 |
+
" step += 1"
|
249 |
+
],
|
250 |
+
"metadata": {
|
251 |
+
"id": "1V9EfSBD-8E9"
|
252 |
+
},
|
253 |
+
"execution_count": null,
|
254 |
+
"outputs": []
|
255 |
+
},
|
256 |
+
{
|
257 |
+
"cell_type": "code",
|
258 |
+
"source": [
|
259 |
+
"# neeed to test"
|
260 |
+
],
|
261 |
+
"metadata": {
|
262 |
+
"id": "1g4ATYOD-9LY"
|
263 |
+
},
|
264 |
+
"execution_count": null,
|
265 |
+
"outputs": []
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "code",
|
269 |
+
"source": [],
|
270 |
+
"metadata": {
|
271 |
+
"id": "UPye6Ktu--Ph"
|
272 |
+
},
|
273 |
+
"execution_count": null,
|
274 |
+
"outputs": []
|
275 |
+
}
|
276 |
+
]
|
277 |
+
}
|
notebooks/prototype.ipynb
ADDED
File without changes
|
notebooks/sam/sam_1.ipynb
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"# notebook using sam model"
|
10 |
-
]
|
11 |
-
}
|
12 |
-
],
|
13 |
-
"metadata": {
|
14 |
-
"language_info": {
|
15 |
-
"name": "python"
|
16 |
-
}
|
17 |
-
},
|
18 |
-
"nbformat": 4,
|
19 |
-
"nbformat_minor": 2
|
20 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/vanilla-gans.ipynb
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import torch\n",
|
10 |
+
"import torchvision\n",
|
11 |
+
"import torch.nn as nn\n",
|
12 |
+
"import torch.nn.functional as F\n",
|
13 |
+
"\n",
|
14 |
+
"from torch.utils.data import DataLoader\n",
|
15 |
+
"from torchvision import datasets\n",
|
16 |
+
"from torchvision import transforms\n",
|
17 |
+
"from torchvision.utils import save_image\n",
|
18 |
+
"\n",
|
19 |
+
"import numpy as np\n",
|
20 |
+
"import datetime\n",
|
21 |
+
"\n",
|
22 |
+
"from matplotlib.pyplot import imshow, imsave\n",
|
23 |
+
"# %matplotlib inline\n",
|
24 |
+
"\n",
|
25 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": null,
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"def get_sample_image(generator, noise_dim):\n",
|
35 |
+
" z = torch.randn(100, noise_dim).to(device)\n",
|
36 |
+
" generated_images = generator(z).view(100, 28, 28)\n",
|
37 |
+
" result = generated_images.cpu().data.numpy()\n",
|
38 |
+
" img = np.zeros([280, 280])\n",
|
39 |
+
" for j in range(10):\n",
|
40 |
+
" img[j * 28:(j + 1) * 28] = np.concatenate([x for x in result[j * 10:(j + 1) * 10]], axis=-1)\n",
|
41 |
+
" return img\n",
|
42 |
+
"\n",
|
43 |
+
"class Discriminator(nn.Module):\n",
|
44 |
+
" def __init__(self, input_size=784, num_classes=1):\n",
|
45 |
+
" super(Discriminator, self).__init__()\n",
|
46 |
+
" self.layers = nn.Sequential(\n",
|
47 |
+
" nn.Linear(input_size, 512),\n",
|
48 |
+
" nn.LeakyReLU(0.2),\n",
|
49 |
+
" nn.Linear(512, 256),\n",
|
50 |
+
" nn.LeakyReLU(0.2),\n",
|
51 |
+
" nn.Linear(256, num_classes),\n",
|
52 |
+
" nn.Sigmoid(),\n",
|
53 |
+
" )\n",
|
54 |
+
"\n",
|
55 |
+
" def forward(self, x):\n",
|
56 |
+
" x = x.view(x.size(0), -1)\n",
|
57 |
+
" x = self.layers(x)\n",
|
58 |
+
" return x\n",
|
59 |
+
"\n",
|
60 |
+
"class Generator(nn.Module):\n",
|
61 |
+
" def __init__(self, input_size=100, num_classes=784):\n",
|
62 |
+
" super(Generator, self).__init__()\n",
|
63 |
+
" self.layers = nn.Sequential(\n",
|
64 |
+
" nn.Linear(input_size, 128),\n",
|
65 |
+
" nn.LeakyReLU(0.2),\n",
|
66 |
+
" nn.Linear(128, 256),\n",
|
67 |
+
" nn.BatchNorm1d(256),\n",
|
68 |
+
" nn.LeakyReLU(0.2),\n",
|
69 |
+
" nn.Linear(256, 512),\n",
|
70 |
+
" nn.BatchNorm1d(512),\n",
|
71 |
+
" nn.LeakyReLU(0.2),\n",
|
72 |
+
" nn.Linear(512, 1024),\n",
|
73 |
+
" nn.BatchNorm1d(1024),\n",
|
74 |
+
" nn.LeakyReLU(0.2),\n",
|
75 |
+
" nn.Linear(1024, num_classes),\n",
|
76 |
+
" nn.Tanh()\n",
|
77 |
+
" )\n",
|
78 |
+
"\n",
|
79 |
+
" def forward(self, x):\n",
|
80 |
+
" x = self.layers(x)\n",
|
81 |
+
" x = x.view(x.size(0), 1, 28, 28)\n",
|
82 |
+
" return x\n"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"cell_type": "code",
|
87 |
+
"execution_count": null,
|
88 |
+
"metadata": {},
|
89 |
+
"outputs": [],
|
90 |
+
"source": []
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "code",
|
94 |
+
"execution_count": null,
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [],
|
97 |
+
"source": [
|
98 |
+
"n_noise = 100\n",
|
99 |
+
"\n",
|
100 |
+
"discriminator = Discriminator().to(device)\n",
|
101 |
+
"generator = Generator().to(device)\n",
|
102 |
+
"\n",
|
103 |
+
"transform = transforms.Compose([transforms.ToTensor(),\n",
|
104 |
+
" transforms.Normalize(mean=[0.5],\n",
|
105 |
+
" std=[0.5])]\n",
|
106 |
+
")\n",
|
107 |
+
"\n",
|
108 |
+
"mnist = datasets.MNIST(root='../data/', train=True, transform=transform, download=True)\n",
|
109 |
+
"\n",
|
110 |
+
"batch_size = 64\n",
|
111 |
+
"\n",
|
112 |
+
"data_loader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True, drop_last=True)\n",
|
113 |
+
"\n",
|
114 |
+
"loss_fn = nn.BCELoss()\n",
|
115 |
+
"d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
|
116 |
+
"g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
|
117 |
+
"\n",
|
118 |
+
"max_epoch = 50\n",
|
119 |
+
"step = 0\n",
|
120 |
+
"n_critic = 1\n"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": null,
|
126 |
+
"metadata": {},
|
127 |
+
"outputs": [],
|
128 |
+
"source": [
|
129 |
+
"d_labels = torch.ones(batch_size, 1).to(device)\n",
|
130 |
+
"d_fakes = torch.zeros(batch_size, 1).to(device)\n",
|
131 |
+
"\n",
|
132 |
+
"# Training loop\n",
|
133 |
+
"for epoch in range(max_epoch):\n",
|
134 |
+
" for idx, (images, _) in enumerate(data_loader):\n",
|
135 |
+
" real_images = images.to(device)\n",
|
136 |
+
" real_outputs = discriminator(real_images)\n",
|
137 |
+
" d_real_loss = loss_fn(real_outputs, d_labels)\n",
|
138 |
+
"\n",
|
139 |
+
" fake_noise = torch.randn(batch_size, n_noise).to(device)\n",
|
140 |
+
" fake_images = generator(fake_noise)\n",
|
141 |
+
" fake_outputs = discriminator(fake_images.detach())\n",
|
142 |
+
" d_fake_loss = loss_fn(fake_outputs, d_fakes)\n",
|
143 |
+
"\n",
|
144 |
+
" d_loss = d_real_loss + d_fake_loss\n",
|
145 |
+
"\n",
|
146 |
+
" discriminator.zero_grad()\n",
|
147 |
+
" d_loss.backward()\n",
|
148 |
+
" d_optimizer.step()\n",
|
149 |
+
"\n",
|
150 |
+
" if step % n_critic == 0:\n",
|
151 |
+
" fake_outputs = discriminator(generator(fake_noise))\n",
|
152 |
+
" g_loss = loss_fn(fake_outputs, d_labels)\n",
|
153 |
+
"\n",
|
154 |
+
" generator.zero_grad()\n",
|
155 |
+
" g_loss.backward()\n",
|
156 |
+
" g_optimizer.step()\n",
|
157 |
+
"\n",
|
158 |
+
" if step % 1000 == 0:\n",
|
159 |
+
" generator.eval()\n",
|
160 |
+
" img = get_sample_image(generator, n_noise)\n",
|
161 |
+
" # imsave('samples/{}_step{}.jpg'.format('gans', str(step).zfill(3)), img, cmap='gray')\n",
|
162 |
+
" generator.train()\n",
|
163 |
+
" step += 1\n"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "code",
|
168 |
+
"execution_count": null,
|
169 |
+
"metadata": {},
|
170 |
+
"outputs": [],
|
171 |
+
"source": [
|
172 |
+
"generator.eval()\n",
|
173 |
+
"imshow(get_sample_image(generator, n_noise), cmap='gray')\n",
|
174 |
+
"\n",
|
175 |
+
"torch.save(discriminator.state_dict(), 'discriminator.pth')\n",
|
176 |
+
"torch.save(generator.state_dict(), 'generator.pth')\n"
|
177 |
+
]
|
178 |
+
}
|
179 |
+
],
|
180 |
+
"metadata": {
|
181 |
+
"language_info": {
|
182 |
+
"name": "python"
|
183 |
+
}
|
184 |
+
},
|
185 |
+
"nbformat": 4,
|
186 |
+
"nbformat_minor": 2
|
187 |
+
}
|
prototype/.gitattributes
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prototype/README.md
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: Pearl Prototype
|
3 |
-
emoji: 💻
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: red
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.29.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: openrail
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prototype/app.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from io import BytesIO
|
3 |
-
|
4 |
-
from torch import autocast
|
5 |
-
import requests
|
6 |
-
import PIL
|
7 |
-
import torch
|
8 |
-
from diffusers import StableDiffusionInpaintPipeline as StableDiffusionInpaintPipeline
|
9 |
-
|
10 |
-
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
11 |
-
"CompVis/stable-diffusion-v1-4",
|
12 |
-
revision="fp16",
|
13 |
-
torch_dtype=torch.float16,
|
14 |
-
use_auth_token=True,
|
15 |
-
)
|
16 |
-
|
17 |
-
|
18 |
-
def process_image(dict, prompt):
|
19 |
-
init_img = dict["image"].convert("RGB").resize((512, 512))
|
20 |
-
mask_img = dict["mask"].convert("RGB").resize((512, 512))
|
21 |
-
images = pipe(
|
22 |
-
prompt=prompt, init_image=init_img, mask_image=mask_img, strength=0.75
|
23 |
-
)["sample"]
|
24 |
-
return images[0]
|
25 |
-
|
26 |
-
|
27 |
-
iface = gr.Interface(
|
28 |
-
fn=process_image,
|
29 |
-
title="Stable Diffusion In-Painting Tool on Colab with Gradio",
|
30 |
-
inputs=[
|
31 |
-
gr.Image(source="upload", tool="sketch", type="pil"),
|
32 |
-
gr.Textbox(label="prompt"),
|
33 |
-
],
|
34 |
-
outputs=[gr.Image()],
|
35 |
-
description="Choose a feature and upload an image to see the processed result.",
|
36 |
-
article="<p style='text-align: center;'>Built with Gradio</p>",
|
37 |
-
)
|
38 |
-
|
39 |
-
|
40 |
-
iface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prototype/inpainting.py
DELETED
@@ -1,227 +0,0 @@
|
|
1 |
-
# credit : Hugging Face Team
|
2 |
-
import inspect
|
3 |
-
from typing import List, Optional, Union
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
|
8 |
-
import PIL
|
9 |
-
from diffusers import (
|
10 |
-
AutoencoderKL,
|
11 |
-
DDIMScheduler,
|
12 |
-
DiffusionPipeline,
|
13 |
-
PNDMScheduler,
|
14 |
-
UNet2DConditionModel,
|
15 |
-
)
|
16 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
17 |
-
from tqdm.auto import tqdm
|
18 |
-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
19 |
-
|
20 |
-
|
21 |
-
def preprocess_image(image):
|
22 |
-
w, h = image.size
|
23 |
-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
24 |
-
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
25 |
-
image = np.array(image).astype(np.float32) / 255.0
|
26 |
-
image = image[None].transpose(0, 3, 1, 2)
|
27 |
-
image = torch.from_numpy(image)
|
28 |
-
return 2.0 * image - 1.0
|
29 |
-
|
30 |
-
|
31 |
-
def preprocess_mask(mask):
|
32 |
-
mask = mask.convert("L")
|
33 |
-
w, h = mask.size
|
34 |
-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
35 |
-
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
|
36 |
-
mask = np.array(mask).astype(np.float32) / 255.0
|
37 |
-
mask = np.tile(mask, (4, 1, 1))
|
38 |
-
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
39 |
-
mask = 1 - mask # repaint white, keep black
|
40 |
-
mask = torch.from_numpy(mask)
|
41 |
-
return mask
|
42 |
-
|
43 |
-
|
44 |
-
class StableDiffusionInpaintingPipeline(DiffusionPipeline):
|
45 |
-
def __init__(
|
46 |
-
self,
|
47 |
-
vae: AutoencoderKL,
|
48 |
-
text_encoder: CLIPTextModel,
|
49 |
-
tokenizer: CLIPTokenizer,
|
50 |
-
unet: UNet2DConditionModel,
|
51 |
-
scheduler: Union[DDIMScheduler, PNDMScheduler],
|
52 |
-
safety_checker: StableDiffusionSafetyChecker,
|
53 |
-
feature_extractor: CLIPFeatureExtractor,
|
54 |
-
):
|
55 |
-
super().__init__()
|
56 |
-
scheduler = scheduler.set_format("pt")
|
57 |
-
self.register_modules(
|
58 |
-
vae=vae,
|
59 |
-
text_encoder=text_encoder,
|
60 |
-
tokenizer=tokenizer,
|
61 |
-
unet=unet,
|
62 |
-
scheduler=scheduler,
|
63 |
-
safety_checker=safety_checker,
|
64 |
-
feature_extractor=feature_extractor,
|
65 |
-
)
|
66 |
-
|
67 |
-
@torch.no_grad()
|
68 |
-
def __call__(
|
69 |
-
self,
|
70 |
-
prompt: Union[str, List[str]],
|
71 |
-
init_image: torch.FloatTensor,
|
72 |
-
mask_image: torch.FloatTensor,
|
73 |
-
strength: float = 0.8,
|
74 |
-
num_inference_steps: Optional[int] = 50,
|
75 |
-
guidance_scale: Optional[float] = 7.5,
|
76 |
-
eta: Optional[float] = 0.0,
|
77 |
-
generator: Optional[torch.Generator] = None,
|
78 |
-
output_type: Optional[str] = "pil",
|
79 |
-
):
|
80 |
-
if isinstance(prompt, str):
|
81 |
-
batch_size = 1
|
82 |
-
elif isinstance(prompt, list):
|
83 |
-
batch_size = len(prompt)
|
84 |
-
else:
|
85 |
-
raise ValueError(
|
86 |
-
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
87 |
-
)
|
88 |
-
|
89 |
-
if strength < 0 or strength > 1:
|
90 |
-
raise ValueError(
|
91 |
-
f"The value of strength should in [0.0, 1.0] but is {strength}"
|
92 |
-
)
|
93 |
-
|
94 |
-
# set timesteps
|
95 |
-
accepts_offset = "offset" in set(
|
96 |
-
inspect.signature(self.scheduler.set_timesteps).parameters.keys()
|
97 |
-
)
|
98 |
-
extra_set_kwargs = {}
|
99 |
-
offset = 0
|
100 |
-
if accepts_offset:
|
101 |
-
offset = 1
|
102 |
-
extra_set_kwargs["offset"] = 1
|
103 |
-
|
104 |
-
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
105 |
-
|
106 |
-
# preprocess image
|
107 |
-
init_image = preprocess_image(init_image).to(self.device)
|
108 |
-
|
109 |
-
# encode the init image into latents and scale the latents
|
110 |
-
init_latents = self.vae.encode(init_image).sample()
|
111 |
-
init_latents = 0.18215 * init_latents
|
112 |
-
|
113 |
-
# prepare init_latents noise to latents
|
114 |
-
init_latents = torch.cat([init_latents] * batch_size)
|
115 |
-
init_latents_orig = init_latents
|
116 |
-
|
117 |
-
# preprocess mask
|
118 |
-
mask = preprocess_mask(mask_image).to(self.device)
|
119 |
-
mask = torch.cat([mask] * batch_size)
|
120 |
-
|
121 |
-
# check sizes
|
122 |
-
if not mask.shape == init_latents.shape:
|
123 |
-
raise ValueError(f"The mask and init_image should be the same size!")
|
124 |
-
|
125 |
-
# get the original timestep using init_timestep
|
126 |
-
init_timestep = int(num_inference_steps * strength) + offset
|
127 |
-
init_timestep = min(init_timestep, num_inference_steps)
|
128 |
-
timesteps = self.scheduler.timesteps[-init_timestep]
|
129 |
-
timesteps = torch.tensor(
|
130 |
-
[timesteps] * batch_size, dtype=torch.long, device=self.device
|
131 |
-
)
|
132 |
-
|
133 |
-
# add noise to latents using the timesteps
|
134 |
-
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
135 |
-
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
136 |
-
|
137 |
-
# get prompt text embeddings
|
138 |
-
text_input = self.tokenizer(
|
139 |
-
prompt,
|
140 |
-
padding="max_length",
|
141 |
-
max_length=self.tokenizer.model_max_length,
|
142 |
-
truncation=True,
|
143 |
-
return_tensors="pt",
|
144 |
-
)
|
145 |
-
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
146 |
-
|
147 |
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
148 |
-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
149 |
-
# corresponds to doing no classifier free guidance.
|
150 |
-
do_classifier_free_guidance = guidance_scale > 1.0
|
151 |
-
# get unconditional embeddings for classifier free guidance
|
152 |
-
if do_classifier_free_guidance:
|
153 |
-
max_length = text_input.input_ids.shape[-1]
|
154 |
-
uncond_input = self.tokenizer(
|
155 |
-
[""] * batch_size,
|
156 |
-
padding="max_length",
|
157 |
-
max_length=max_length,
|
158 |
-
return_tensors="pt",
|
159 |
-
)
|
160 |
-
uncond_embeddings = self.text_encoder(
|
161 |
-
uncond_input.input_ids.to(self.device)
|
162 |
-
)[0]
|
163 |
-
|
164 |
-
# For classifier free guidance, we need to do two forward passes.
|
165 |
-
# Here we concatenate the unconditional and text embeddings into a single batch
|
166 |
-
# to avoid doing two forward passes
|
167 |
-
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
168 |
-
|
169 |
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
170 |
-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
171 |
-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
172 |
-
# and should be between [0, 1]
|
173 |
-
accepts_eta = "eta" in set(
|
174 |
-
inspect.signature(self.scheduler.step).parameters.keys()
|
175 |
-
)
|
176 |
-
extra_step_kwargs = {}
|
177 |
-
if accepts_eta:
|
178 |
-
extra_step_kwargs["eta"] = eta
|
179 |
-
|
180 |
-
latents = init_latents
|
181 |
-
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
182 |
-
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
183 |
-
# expand the latents if we are doing classifier free guidance
|
184 |
-
latent_model_input = (
|
185 |
-
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
186 |
-
)
|
187 |
-
|
188 |
-
# predict the noise residual
|
189 |
-
noise_pred = self.unet(
|
190 |
-
latent_model_input, t, encoder_hidden_states=text_embeddings
|
191 |
-
)["sample"]
|
192 |
-
|
193 |
-
# perform guidance
|
194 |
-
if do_classifier_free_guidance:
|
195 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
196 |
-
noise_pred = noise_pred_uncond + guidance_scale * (
|
197 |
-
noise_pred_text - noise_pred_uncond
|
198 |
-
)
|
199 |
-
|
200 |
-
# compute the previous noisy sample x_t -> x_t-1
|
201 |
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[
|
202 |
-
"prev_sample"
|
203 |
-
]
|
204 |
-
|
205 |
-
# masking
|
206 |
-
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
|
207 |
-
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
208 |
-
|
209 |
-
# scale and decode the image latents with vae
|
210 |
-
latents = 1 / 0.18215 * latents
|
211 |
-
image = self.vae.decode(latents)
|
212 |
-
|
213 |
-
image = (image / 2 + 0.5).clamp(0, 1)
|
214 |
-
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
215 |
-
|
216 |
-
# run safety checker
|
217 |
-
safety_cheker_input = self.feature_extractor(
|
218 |
-
self.numpy_to_pil(image), return_tensors="pt"
|
219 |
-
).to(self.device)
|
220 |
-
image, has_nsfw_concept = self.safety_checker(
|
221 |
-
images=image, clip_input=safety_cheker_input.pixel_values
|
222 |
-
)
|
223 |
-
|
224 |
-
if output_type == "pil":
|
225 |
-
image = self.numpy_to_pil(image)
|
226 |
-
|
227 |
-
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prototype/requirements.txt
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
torch
|
2 |
-
requests
|
3 |
-
pillow
|
4 |
-
diffusers
|
5 |
-
gradio
|
6 |
-
numpy
|
7 |
-
tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prototype/test.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
|
3 |
-
def greet(name, intensity):
|
4 |
-
return "Hello " * intensity + name + "!"
|
5 |
-
|
6 |
-
demo = gr.Interface(
|
7 |
-
fn=greet,
|
8 |
-
inputs=["text", "slider"],
|
9 |
-
outputs=["text"],
|
10 |
-
)
|
11 |
-
|
12 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prototype/utils.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
def add_feature(image):
|
2 |
-
# inpainting features
|
3 |
-
pass
|
4 |
-
|
5 |
-
|
6 |
-
def remove_feature(image):
|
7 |
-
# inpainting features
|
8 |
-
pass
|
9 |
-
|
10 |
-
|
11 |
-
def enhance_feature(image):
|
12 |
-
# inpainting features
|
13 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setup.py
CHANGED
@@ -1,44 +1,27 @@
|
|
1 |
-
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
# from pathlib import Path
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
# author="OpenAI",
|
30 |
-
# url="https://github.com/openai/whisper",
|
31 |
-
# license="MIT",
|
32 |
-
# packages=find_packages(exclude=["tests*"]),
|
33 |
-
# install_requires=[
|
34 |
-
# str(r)
|
35 |
-
# for r in pkg_resources.parse_requirements(
|
36 |
-
# Path(__file__).with_name("requirements.txt").open()
|
37 |
-
# )
|
38 |
-
# ],
|
39 |
-
# entry_points={
|
40 |
-
# "console_scripts": ["whisper=whisper.transcribe:cli"],
|
41 |
-
# },
|
42 |
-
# include_package_data=True,
|
43 |
-
# extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
|
44 |
-
# )
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
|
3 |
+
with open('requirements.txt') as f:
|
4 |
+
requirements = f.read().splitlines()
|
|
|
5 |
|
6 |
+
setup(
|
7 |
+
name='ve-gans',
|
8 |
+
version='0.1',
|
9 |
+
packages=find_packages(),
|
10 |
+
install_requires=requirements,
|
11 |
+
entry_points={
|
12 |
+
'console_scripts': [
|
13 |
+
've-gans-train=ve_gans.train:main',
|
14 |
+
've-gans-generate=ve_gans.generate:main',
|
15 |
+
],
|
16 |
+
},
|
17 |
+
author='Zai',
|
18 |
+
author_email='[email protected]',
|
19 |
+
description='Image generation with GANs using PyTorch',
|
20 |
+
long_description='Detailed description of your project',
|
21 |
+
url='https://github.com/zaibutcooler/ve-gans',
|
22 |
+
classifiers=[
|
23 |
+
'Programming Language :: Python :: 3',
|
24 |
+
'License :: OSI Approved :: MIT License',
|
25 |
+
'Operating System :: OS Independent',
|
26 |
+
],
|
27 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
space/app.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from torchvision.utils import make_grid
|
4 |
+
from torchvision.transforms import ToPILImage
|
5 |
+
|
6 |
+
def main():
|
7 |
+
st.title("Image Generation")
|
8 |
+
st.write("Made with GANS from scratch")
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
main()
|
tests/{demo.py → test_demo.py}
RENAMED
File without changes
|
vegans/__init__.py
ADDED
File without changes
|
vegans/discriminator.py
ADDED
File without changes
|
vegans/generator.py
ADDED
File without changes
|
vegans/utils.py
ADDED
File without changes
|