ve-gans / space.py
Zai
updated the model loader and readme
c6494da
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import numpy as np
# GAN model architecture
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# TO DO: implement generator architecture
def forward(self, z):
# TO DO: implement generator forward pass
pass
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# TO DO: implement discriminator architecture
def forward(self, x):
# TO DO: implement discriminator forward pass
pass
# Floorplan generator model architecture
class FloorplanGenerator(nn.Module):
def __init__(self):
super(FloorplanGenerator, self).__init__()
# TO DO: implement floorplan generator architecture
def forward(self, img):
# TO DO: implement floorplan generator forward pass
pass
# Load pre-trained models
@st.cache
def load_models():
generator = Generator()
discriminator = Discriminator()
floorplan_generator = FloorplanGenerator()
# TO DO: load pre-trained model weights
return generator, discriminator, floorplan_generator
# Streamlit app
st.title("GAN Image Generation and Floorplan App")
# Load models
generator, discriminator, floorplan_generator = load_models()
# Image generation
st.header("Image Generation")
z_dim = 100
noise = torch.randn(1, z_dim)
generated_img = generator(noise)
st.image(generated_img, caption="Generated Image")
# Floorplan generation
st.header("Floorplan Generation")
img_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if img_file:
img = Image.open(img_file)
img = np.array(img) / 255.0
floorplan = floorplan_generator(torch.tensor(img))
st.image(floorplan, caption="Generated Floorplan")
# Run the app
if __name__ == "__main__":
st.write("App is running!")