import gradio as gr import torch from huggingface_hub import hf_hub_download import json from omegaconf import OmegaConf import sys import os from PIL import Image import torchvision.transforms as transforms photos_folder = "Photos" # Download model and config repo_id = "Kiwinicki/sat2map-generator" generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth") config_path = hf_hub_download(repo_id=repo_id, filename="config.json") model_path = hf_hub_download(repo_id=repo_id, filename="model.py") # Add path to model sys.path.append(os.path.dirname(model_path)) from model import Generator # Load configuration with open(config_path, "r") as f: config_dict = json.load(f) cfg = OmegaConf.create(config_dict) # Initialize model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") generator = Generator(cfg).to(device) generator.load_state_dict(torch.load(generator_path, map_location=device)) generator.eval() # Transformations transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def process_image(image): if image is None: return None # Convert to tensor image_tensor = transform(image).unsqueeze(0).to(device) # Inference with torch.no_grad(): output_tensor = generator(image_tensor) # Prepare output output_image = output_tensor.squeeze(0).cpu() output_image = output_image * 0.5 + 0.5 # Denormalization output_image = transforms.ToPILImage()(output_image) return output_image def load_images_from_folder(folder): images = [] if not os.path.exists(folder): os.makedirs(folder) return images for filename in os.listdir(folder): if filename.lower().endswith(('.png', '.jpg', '.jpeg')): img_path = os.path.join(folder, filename) try: img = Image.open(img_path) images.append((img, filename)) except Exception as e: print(f"Error loading {filename}: {e}") return images def app(): images = load_images_from_folder(photos_folder) gallery_images = [img[0] for img in images] if images else [] with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil") clear_button = gr.Button("Clear") with gr.Column(): gallery = gr.Gallery( label="Image Gallery", value=gallery_images, columns=3, # Set number of columns directly in the constructor rows=2, height="auto" ) with gr.Column(): output_image = gr.Image(label="Result Image", type="pil") # Handle gallery selection def on_select(evt: gr.SelectData): if 0 <= evt.index < len(images): return images[evt.index][0] return None gallery.select( fn=on_select, outputs=input_image ) # Process image when input changes input_image.change( fn=process_image, inputs=input_image, outputs=output_image ) # Clear button functionality clear_button.click( fn=lambda: None, outputs=input_image ) demo.launch() if __name__ == "__main__": app()