Spaces:
Running
Running
File size: 3,595 Bytes
139d7b2 827021c 2e786fb 827021c 2e786fb 827021c 57fc91e fdbc146 827021c fdbc146 e49c48c fdbc146 827021c fdbc146 2e786fb 827021c fdbc146 d288725 827021c d288725 fdbc146 2e786fb fdbc146 d288725 2e786fb fdbc146 2e786fb fdbc146 d288725 2e786fb d288725 827021c 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 8088244 fdbc146 8088244 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e fdbc146 57fc91e cf8fd7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
import json
from omegaconf import OmegaConf
import sys
import os
from PIL import Image
import torchvision.transforms as transforms
photos_folder = "Photos"
# Download model and config
repo_id = "Kiwinicki/sat2map-generator"
generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
# Add path to model
sys.path.append(os.path.dirname(model_path))
from model import Generator
# Load configuration
with open(config_path, "r") as f:
config_dict = json.load(f)
cfg = OmegaConf.create(config_dict)
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(cfg).to(device)
generator.load_state_dict(torch.load(generator_path, map_location=device))
generator.eval()
# Transformations
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def process_image(image):
if image is None:
return None
# Convert to tensor
image_tensor = transform(image).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
output_tensor = generator(image_tensor)
# Prepare output
output_image = output_tensor.squeeze(0).cpu()
output_image = output_image * 0.5 + 0.5 # Denormalization
output_image = transforms.ToPILImage()(output_image)
return output_image
def load_images_from_folder(folder):
images = []
if not os.path.exists(folder):
os.makedirs(folder)
return images
for filename in os.listdir(folder):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(folder, filename)
try:
img = Image.open(img_path)
images.append((img, filename))
except Exception as e:
print(f"Error loading {filename}: {e}")
return images
def app():
images = load_images_from_folder(photos_folder)
gallery_images = [img[0] for img in images] if images else []
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
clear_button = gr.Button("Clear")
with gr.Column():
gallery = gr.Gallery(
label="Image Gallery",
value=gallery_images,
columns=3, # Set number of columns directly in the constructor
rows=2,
height="auto"
)
with gr.Column():
output_image = gr.Image(label="Result Image", type="pil")
# Handle gallery selection
def on_select(evt: gr.SelectData):
if 0 <= evt.index < len(images):
return images[evt.index][0]
return None
gallery.select(
fn=on_select,
outputs=input_image
)
# Process image when input changes
input_image.change(
fn=process_image,
inputs=input_image,
outputs=output_image
)
# Clear button functionality
clear_button.click(
fn=lambda: None,
outputs=input_image
)
demo.launch()
if __name__ == "__main__":
app() |