Spaces:
Sleeping
Sleeping
Zai
commited on
Commit
•
c6494da
1
Parent(s):
43c67e7
updated the model loader and readme
Browse files- .github/workflows/hugging-face.yaml +21 -0
- README.md +0 -12
- requirements.txt +5 -0
- sample.py +2 -2
- setup.py +11 -11
- space.py +68 -12
- tests/test_dataset.py +10 -3
- tests/test_generating.py +13 -3
- tests/test_space.py +11 -0
- tests/test_training.py +8 -3
- train.py +0 -1
- vegans/config.py +13 -0
- vegans/dataset.py +23 -13
- vegans/discriminator.py +9 -8
- vegans/generator.py +13 -10
- vegans/utils.py +23 -2
- vegans/vegans.py +70 -29
- version.py +1 -1
.github/workflows/hugging-face.yaml
CHANGED
@@ -12,6 +12,27 @@ jobs:
|
|
12 |
with:
|
13 |
fetch-depth: 0
|
14 |
lfs: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
- name: Push to hub
|
16 |
env:
|
17 |
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
|
|
12 |
with:
|
13 |
fetch-depth: 0
|
14 |
lfs: true
|
15 |
+
|
16 |
+
- name: Update README.md
|
17 |
+
run: |
|
18 |
+
tmp_file=$(mktemp)
|
19 |
+
echo "---" >> $tmp_file
|
20 |
+
echo "title: Ve Gans" >> $tmp_file
|
21 |
+
echo "emoji: 📈" >> $tmp_file
|
22 |
+
echo "colorFrom: yellow" >> $tmp_file
|
23 |
+
echo "colorTo: gray" >> $tmp_file
|
24 |
+
echo "sdk: streamlit" >> $tmp_file
|
25 |
+
echo "sdk_version: 1.29.0" >> $tmp_file
|
26 |
+
echo "app_file: space.py" >> $tmp_file
|
27 |
+
echo "pinned: false" >> $tmp_file
|
28 |
+
echo "license: openrail" >> $tmp_file
|
29 |
+
echo "---" >> $tmp_file
|
30 |
+
echo "" >> $tmp_file
|
31 |
+
cat README.md >> $tmp_file
|
32 |
+
mv $tmp_file README.md
|
33 |
+
git add README.md
|
34 |
+
git commit -m "Updated README.md"
|
35 |
+
|
36 |
- name: Push to hub
|
37 |
env:
|
38 |
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
README.md
CHANGED
@@ -1,15 +1,3 @@
|
|
1 |
-
---
|
2 |
-
title: Ve Gans
|
3 |
-
emoji: 👀
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: green
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.29.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: openrail
|
11 |
-
---
|
12 |
-
|
13 |
# ve-gans: Image Generation with GANs using PyTorch
|
14 |
|
15 |
## Overview
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# ve-gans: Image Generation with GANs using PyTorch
|
2 |
|
3 |
## Overview
|
requirements.txt
CHANGED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
Pillow
|
5 |
+
numpy
|
sample.py
CHANGED
@@ -4,6 +4,6 @@ vegans = Vegans
|
|
4 |
|
5 |
vegans.load_pretrained()
|
6 |
|
7 |
-
text =
|
8 |
|
9 |
-
vegans.generate(text)
|
|
|
4 |
|
5 |
vegans.load_pretrained()
|
6 |
|
7 |
+
text = "something"
|
8 |
|
9 |
+
vegans.generate(text)
|
setup.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
from setuptools import setup, find_packages
|
2 |
|
3 |
-
with open(
|
4 |
requirements = f.read().splitlines()
|
5 |
|
6 |
setup(
|
7 |
-
name=
|
8 |
-
version=
|
9 |
packages=find_packages(),
|
10 |
install_requires=requirements,
|
11 |
-
author=
|
12 |
-
author_email=
|
13 |
-
description=
|
14 |
-
long_description=
|
15 |
-
url=
|
16 |
classifiers=[
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
],
|
21 |
)
|
|
|
1 |
from setuptools import setup, find_packages
|
2 |
|
3 |
+
with open("requirements.txt") as f:
|
4 |
requirements = f.read().splitlines()
|
5 |
|
6 |
setup(
|
7 |
+
name="ve-gans",
|
8 |
+
version="0.1",
|
9 |
packages=find_packages(),
|
10 |
install_requires=requirements,
|
11 |
+
author="Zai",
|
12 |
+
author_email="[email protected]",
|
13 |
+
description="Floorplan generation model with pytorch",
|
14 |
+
long_description="Detailed description of your project",
|
15 |
+
url="https://github.com/zaibutcooler/ve-gans",
|
16 |
classifiers=[
|
17 |
+
"Programming Language :: Python :: 3",
|
18 |
+
"License :: OSI Approved :: MIT License",
|
19 |
+
"Operating System :: OS Independent",
|
20 |
],
|
21 |
)
|
space.py
CHANGED
@@ -1,14 +1,70 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
-
|
4 |
-
|
5 |
-
from
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
# GAN model architecture
|
9 |
+
class Generator(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super(Generator, self).__init__()
|
12 |
+
# TO DO: implement generator architecture
|
13 |
+
|
14 |
+
def forward(self, z):
|
15 |
+
# TO DO: implement generator forward pass
|
16 |
+
pass
|
17 |
+
|
18 |
+
class Discriminator(nn.Module):
|
19 |
+
def __init__(self):
|
20 |
+
super(Discriminator, self).__init__()
|
21 |
+
# TO DO: implement discriminator architecture
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
# TO DO: implement discriminator forward pass
|
25 |
+
pass
|
26 |
+
|
27 |
+
# Floorplan generator model architecture
|
28 |
+
class FloorplanGenerator(nn.Module):
|
29 |
+
def __init__(self):
|
30 |
+
super(FloorplanGenerator, self).__init__()
|
31 |
+
# TO DO: implement floorplan generator architecture
|
32 |
+
|
33 |
+
def forward(self, img):
|
34 |
+
# TO DO: implement floorplan generator forward pass
|
35 |
+
pass
|
36 |
+
|
37 |
+
# Load pre-trained models
|
38 |
+
@st.cache
|
39 |
+
def load_models():
|
40 |
+
generator = Generator()
|
41 |
+
discriminator = Discriminator()
|
42 |
+
floorplan_generator = FloorplanGenerator()
|
43 |
+
# TO DO: load pre-trained model weights
|
44 |
+
return generator, discriminator, floorplan_generator
|
45 |
+
|
46 |
+
# Streamlit app
|
47 |
+
st.title("GAN Image Generation and Floorplan App")
|
48 |
+
|
49 |
+
# Load models
|
50 |
+
generator, discriminator, floorplan_generator = load_models()
|
51 |
+
|
52 |
+
# Image generation
|
53 |
+
st.header("Image Generation")
|
54 |
+
z_dim = 100
|
55 |
+
noise = torch.randn(1, z_dim)
|
56 |
+
generated_img = generator(noise)
|
57 |
+
st.image(generated_img, caption="Generated Image")
|
58 |
+
|
59 |
+
# Floorplan generation
|
60 |
+
st.header("Floorplan Generation")
|
61 |
+
img_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
|
62 |
+
if img_file:
|
63 |
+
img = Image.open(img_file)
|
64 |
+
img = np.array(img) / 255.0
|
65 |
+
floorplan = floorplan_generator(torch.tensor(img))
|
66 |
+
st.image(floorplan, caption="Generated Floorplan")
|
67 |
+
|
68 |
+
# Run the app
|
69 |
+
if __name__ == "__main__":
|
70 |
+
st.write("App is running!")
|
tests/test_dataset.py
CHANGED
@@ -1,7 +1,14 @@
|
|
1 |
import unittest
|
|
|
|
|
2 |
|
3 |
class TestTraining(unittest.TestCase):
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
if __name__ ==
|
7 |
-
unittest.main()
|
|
|
1 |
import unittest
|
2 |
+
from vegans.dataset import TrainingSet
|
3 |
+
|
4 |
|
5 |
class TestTraining(unittest.TestCase):
|
6 |
+
def test_download_dataset(self):
|
7 |
+
data = TrainingSet()
|
8 |
+
data.download_dataset()
|
9 |
+
assert data.images is not None
|
10 |
+
assert data.labels is not None
|
11 |
+
|
12 |
|
13 |
+
if __name__ == "__main__":
|
14 |
+
unittest.main()
|
tests/test_generating.py
CHANGED
@@ -1,7 +1,17 @@
|
|
1 |
import unittest
|
|
|
|
|
2 |
|
3 |
class TestTraining(unittest.TestCase):
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
if __name__ ==
|
7 |
-
unittest.main()
|
|
|
1 |
import unittest
|
2 |
+
from vegans import Vegans
|
3 |
+
|
4 |
|
5 |
class TestTraining(unittest.TestCase):
|
6 |
+
def test_loading_pretrained(self):
|
7 |
+
model = Vegans()
|
8 |
+
model.load_pretrained()
|
9 |
+
|
10 |
+
def test_generator(self):
|
11 |
+
model = Vegans()
|
12 |
+
print("TODO, make generator")
|
13 |
+
pass
|
14 |
+
|
15 |
|
16 |
+
if __name__ == "__main__":
|
17 |
+
unittest.main()
|
tests/test_space.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import space
|
3 |
+
|
4 |
+
|
5 |
+
class TestTraining(unittest.TestCase):
|
6 |
+
def test_running(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
unittest.main()
|
tests/test_training.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
import unittest
|
|
|
|
|
2 |
|
3 |
class TestTraining(unittest.TestCase):
|
4 |
-
|
|
|
|
|
|
|
5 |
|
6 |
-
if __name__ ==
|
7 |
-
unittest.main()
|
|
|
1 |
import unittest
|
2 |
+
from vegans import Vegans
|
3 |
+
|
4 |
|
5 |
class TestTraining(unittest.TestCase):
|
6 |
+
def test_training(self):
|
7 |
+
model = Vegans()
|
8 |
+
model.train(upload=False, real=False)
|
9 |
+
|
10 |
|
11 |
+
if __name__ == "__main__":
|
12 |
+
unittest.main()
|
train.py
CHANGED
@@ -8,4 +8,3 @@ vegans.train()
|
|
8 |
|
9 |
# or you can simply just
|
10 |
# vegans.load_pretrained()
|
11 |
-
|
|
|
8 |
|
9 |
# or you can simply just
|
10 |
# vegans.load_pretrained()
|
|
vegans/config.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
|
4 |
+
class Config(BaseModel):
|
5 |
+
learning_rate: float = 0.0002 # learning rate for generator and discriminator
|
6 |
+
num_epoch: int = 100 # number of epochs to train
|
7 |
+
betas: tuple[float, float] = (0.5, 0.999) # beta1 and beta2 for Adam optimizer
|
8 |
+
batch_size: int = 32 # batch size for training
|
9 |
+
image_size: int = 64 # size of input images
|
10 |
+
channels: int = 3 # number of color channels in input images
|
11 |
+
n_critic: int = 5 # number of critic iterations per generator iteration
|
12 |
+
save_interval: int = 100 # interval for saving model checkpoints
|
13 |
+
sample_interval: int = 100 # interval for generating sample images
|
vegans/dataset.py
CHANGED
@@ -1,26 +1,36 @@
|
|
1 |
from torch.utils.data import Dataset
|
2 |
from torchvision import transforms as tran
|
3 |
from datasets import load_dataset
|
|
|
|
|
4 |
|
5 |
|
|
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
])
|
16 |
-
|
17 |
self.images = None
|
18 |
self.labels = None
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def __getitem__(self, index):
|
21 |
image = self.transforms(self.images[index])
|
22 |
label = self.labels[index]
|
23 |
-
return image,label
|
24 |
-
|
25 |
def __len__(self):
|
26 |
return len(self.images)
|
|
|
1 |
from torch.utils.data import Dataset
|
2 |
from torchvision import transforms as tran
|
3 |
from datasets import load_dataset
|
4 |
+
from .utils import log
|
5 |
+
from .config import Config
|
6 |
|
7 |
|
8 |
+
class TrainData(Dataset):
|
9 |
+
def __init__(self, config: Config, gray_scale=True):
|
10 |
+
super().__init__()
|
11 |
|
12 |
+
self.transforms = tran.Compose(
|
13 |
+
[
|
14 |
+
tran.Resize((config.image_size, config.image_size)),
|
15 |
+
tran.ToTensor(),
|
16 |
+
tran.Normalize([0.5], [0.5]),
|
17 |
+
tran.Grayscale() if gray_scale else None,
|
18 |
+
]
|
19 |
+
)
|
|
|
|
|
20 |
self.images = None
|
21 |
self.labels = None
|
22 |
+
self.download_dataset()
|
23 |
+
|
24 |
+
def download_dataset(self):
|
25 |
+
dataset = load_dataset("zaibutcooler/archi-vault")
|
26 |
+
self.images = dataset["train"]["images"]
|
27 |
+
self.labels = dataset["train"]["labels"]
|
28 |
+
log("Successfully loaded the dataset")
|
29 |
+
|
30 |
def __getitem__(self, index):
|
31 |
image = self.transforms(self.images[index])
|
32 |
label = self.labels[index]
|
33 |
+
return image, label
|
34 |
+
|
35 |
def __len__(self):
|
36 |
return len(self.images)
|
vegans/discriminator.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
import torch.nn as nn
|
2 |
import torch.nn.functional as F
|
3 |
-
import
|
|
|
4 |
|
5 |
-
class Discriminator(nn.Module):
|
6 |
-
def __init__(self):
|
7 |
-
super(Discriminator, self).__init__()
|
8 |
-
self.model = nn.Sequential(
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
13 |
-
|
|
|
1 |
import torch.nn as nn
|
2 |
import torch.nn.functional as F
|
3 |
+
from .config import Config
|
4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
5 |
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
class Discriminator(nn.Module, PyTorchModelHubMixin):
|
8 |
+
def __init__(self, config: Config):
|
9 |
+
super(Discriminator, self).__init__()
|
10 |
+
self.config = config
|
11 |
+
self.model = nn.Sequential(nn.Linear())
|
12 |
|
13 |
+
def forward(self, x, y):
|
14 |
+
return x
|
vegans/generator.py
CHANGED
@@ -2,15 +2,18 @@ import torch.nn as nn
|
|
2 |
import torch.nn.functional as F
|
3 |
import torch
|
4 |
|
5 |
-
class Generator(nn.Module):
|
6 |
-
def __init__(self):
|
7 |
-
super(Generator, self).__init__()
|
8 |
-
def block(in_feat,out_feat,norm=False):
|
9 |
-
pass
|
10 |
-
|
11 |
-
self.model = nn.Sequential(
|
12 |
|
13 |
-
|
14 |
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch.nn.functional as F
|
3 |
import torch
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
from huggingface_hub import PyTorchModelHubMixin
|
7 |
|
8 |
+
|
9 |
+
class Generator(nn.Module, PyTorchModelHubMixin):
|
10 |
+
def __init__(self):
|
11 |
+
super(Generator, self).__init__()
|
12 |
+
|
13 |
+
def block(in_feat, out_feat, norm=False):
|
14 |
+
pass
|
15 |
+
|
16 |
+
self.model = nn.Sequential()
|
17 |
+
|
18 |
+
def forward(self, x, y):
|
19 |
+
pass
|
vegans/utils.py
CHANGED
@@ -1,10 +1,31 @@
|
|
|
|
|
|
|
|
|
|
1 |
def display_image():
|
2 |
pass
|
3 |
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def log(text):
|
|
|
8 |
print("#############################################\n")
|
9 |
print(f"{text}\n")
|
10 |
print("#############################################\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
def display_image():
|
6 |
pass
|
7 |
|
8 |
+
|
9 |
+
def save_image(tensor, location="./"):
|
10 |
+
|
11 |
+
image = Image.fromarray(tensor)
|
12 |
+
|
13 |
+
if not os.path.exists(location):
|
14 |
+
os.makedirs(location)
|
15 |
+
|
16 |
+
image.save(os.path.join(location, "image.jpg"))
|
17 |
+
|
18 |
|
19 |
def log(text):
|
20 |
+
print("")
|
21 |
print("#############################################\n")
|
22 |
print(f"{text}\n")
|
23 |
print("#############################################\n")
|
24 |
+
|
25 |
+
|
26 |
+
import torch
|
27 |
+
|
28 |
+
|
29 |
+
def generate_noise(batch_size, z_dim):
|
30 |
+
|
31 |
+
return torch.randn(batch_size, z_dim)
|
vegans/vegans.py
CHANGED
@@ -1,54 +1,95 @@
|
|
1 |
# initialize vegans
|
2 |
import torch
|
3 |
-
|
4 |
-
|
5 |
-
from
|
6 |
-
from
|
7 |
-
from
|
|
|
|
|
8 |
|
9 |
|
10 |
class Vegans:
|
11 |
-
def __init__(self,
|
12 |
-
|
13 |
-
self.
|
14 |
-
self.
|
15 |
-
self.device =
|
16 |
self.generator = Generator().to(self.device)
|
17 |
self.discriminator = Discriminator().to(self.device)
|
18 |
-
self.
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
def train(self):
|
23 |
loss_fn = torch.nn.BCELoss()
|
24 |
log("Started Training Loop")
|
|
|
|
|
|
|
|
|
25 |
|
26 |
for epoch in range(self.num_epoch):
|
27 |
-
for i,(image,label) in enumerate(
|
28 |
-
|
29 |
-
self.g_optim.zero_grad()
|
30 |
|
31 |
-
|
|
|
32 |
|
|
|
33 |
|
|
|
34 |
self.d_optim.zero_grad()
|
35 |
|
36 |
-
|
37 |
|
38 |
log("Finish Training")
|
|
|
|
|
|
|
39 |
|
40 |
-
def
|
41 |
-
|
42 |
-
noise = 0
|
43 |
-
output = self.generator(noise,label)
|
44 |
|
45 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
|
|
47 |
# TODO
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
52 |
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# initialize vegans
|
2 |
import torch
|
3 |
+
import huggingface_hub
|
4 |
+
|
5 |
+
from .generator import Generator
|
6 |
+
from .discriminator import Discriminator
|
7 |
+
from .utils import save_image, display_image, log, generate_noise
|
8 |
+
from .dataset import TrainingSet
|
9 |
+
from .config import Config
|
10 |
|
11 |
|
12 |
class Vegans:
|
13 |
+
def __init__(self, config: Config, dataset):
|
14 |
+
assert config is not None
|
15 |
+
self.config = config
|
16 |
+
self.dataset = dataset
|
17 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
self.generator = Generator().to(self.device)
|
19 |
self.discriminator = Discriminator().to(self.device)
|
20 |
+
self.g_optim = torch.optim.Adam(
|
21 |
+
self.generator.parameters(),
|
22 |
+
lr=self.config.learning_rate,
|
23 |
+
betas=self.config.betas,
|
24 |
+
)
|
25 |
+
self.d_optim = torch.optim.Adam(
|
26 |
+
self.discriminator.parameters(),
|
27 |
+
lr=self.config.learning_rate,
|
28 |
+
betas=self.config.betas,
|
29 |
+
)
|
30 |
+
|
31 |
+
def train(self, upload=False, real=True):
|
32 |
|
|
|
33 |
loss_fn = torch.nn.BCELoss()
|
34 |
log("Started Training Loop")
|
35 |
+
if real:
|
36 |
+
dataset = self.dataset
|
37 |
+
else:
|
38 |
+
dataset = self._dummy_dataset()
|
39 |
|
40 |
for epoch in range(self.num_epoch):
|
41 |
+
for i, (image, label) in enumerate(dataset):
|
|
|
|
|
42 |
|
43 |
+
# Train the generator
|
44 |
+
self.g_optim.zero_grad()
|
45 |
|
46 |
+
loss = loss_fn(0, 0)
|
47 |
|
48 |
+
# Train the discriminator
|
49 |
self.d_optim.zero_grad()
|
50 |
|
51 |
+
log(f"Epoch {epoch} done. Loss is {loss.item()}")
|
52 |
|
53 |
log("Finish Training")
|
54 |
+
if upload:
|
55 |
+
self.generator.save_pretrained("ve-gans")
|
56 |
+
self.generator.push_to_hub("ve-gans")
|
57 |
|
58 |
+
def eval(self):
|
59 |
+
pass
|
|
|
|
|
60 |
|
61 |
+
def _dummy_dataset(self):
|
62 |
+
dummy_images = []
|
63 |
+
dummy_labels = []
|
64 |
+
for i in range(100):
|
65 |
+
image = torch.randn(
|
66 |
+
1, self.config.channels, self.config.image_size, self.config.image_size
|
67 |
+
)
|
68 |
+
label = torch.randint(0, 2, (1,))
|
69 |
+
dummy_images.append(image)
|
70 |
+
dummy_labels.append(label)
|
71 |
+
return [(image, label) for image, label in zip(dummy_images, dummy_labels)]
|
72 |
|
73 |
+
def generate(self, label, save=True):
|
74 |
# TODO
|
75 |
+
noise = generate_noise()
|
76 |
+
output = self.generator(noise, label)
|
77 |
+
if save:
|
78 |
+
save_image(output)
|
79 |
+
return output
|
80 |
|
81 |
+
def save_pretrained(self, name="ve-gans"):
|
82 |
+
print("Uploading model...")
|
83 |
+
self.model.save_pretrained(name)
|
84 |
+
print(f"Model saved locally as '{name}'")
|
85 |
+
self.model.push_to_hub(name)
|
86 |
+
print(f"Model '{name}' uploaded to the Hugging Face Model Hub")
|
87 |
|
88 |
+
def load_pretrained(self, model_id="zaibutcooler/ve-gans"):
|
89 |
+
print("Loading model...")
|
90 |
+
model = model.from_pretrained(model_id)
|
91 |
+
print(f"Model '{model_id}' loaded successfully")
|
92 |
+
return model
|
93 |
+
|
94 |
+
def huggingface_login(self,token):
|
95 |
+
huggingface_hub.login(token)
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
__version__ = "20231117"
|
|
|
1 |
+
__version__ = "20231117"
|