from argparse import ArgumentParser from typing import Dict import torch from PIL import Image import modules.transforms as transforms from modules.primaps import PriMaPs from modules.backbone.dino.dinovit import DinoFeaturizerv2 from modules.visualization import visualize_demo import gradio as gr # set seeds torch.manual_seed(0) torch.cuda.manual_seed(0) def gradio_primaps(image_path, threshold, architecture): ''' Gradio demo to visualize PriMaPs for a single image. ''' device='cuda:0' resize_to = 320 if 'v2' not in architecture else 322 patch_size = 8 if 'v2' not in architecture else 14 # get SLL image encoder and primaps module net = DinoFeaturizerv2(architecture, patch_size) net.to(device) primaps_module = PriMaPs(threshold=threshold, ignore_id=255) # get transforms demo_transforms = transforms.Compose([transforms.ToTensor(), transforms.Resize(resize_to), transforms.CenterCrop([resize_to, resize_to]), transforms.Normalize()]) # load image and apply transforms image = Image.open(image_path) image, _ = demo_transforms(image, torch.zeros(image.size)) image.to(device) # get SSL features feats = net(image.unsqueeze(0).to(device), n=1).squeeze() # get primaps pseudo labels primaps = primaps_module._get_pseudo(image, feats, torch.zeros(image.shape[1:])) # visualize overlay return visualize_demo(image, primaps) if __name__ == '__main__': # Example image paths example_images = [ "assets/demo_examples/cityscapes_example.png", "assets/demo_examples/coco_example.jpg", "assets/demo_examples/potsdam_example.png" ] # Gradio interface interface = gr.Interface( fn=gradio_primaps, inputs=[ gr.Image(type="filepath", label="Image"), gr.Slider(0.0, 1.0, step=0.05, value=0.35, label="Threshold"), gr.Dropdown(choices=['dino_vits', 'dino_vitb', 'dinov2_vits', 'dinov2_vitb'], value='dino_vitb', label="SSL Features"), ], outputs=gr.Image(label="PriMaPs"), title="PriMaPs Demo", description="Upload an image and adjust the threshold to visualize PriMaPs.", examples=[ [example_images[0], 0.35, 'dino_vitb'], [example_images[1], 0.35, 'dino_vitb'], [example_images[2], 0.35, 'dino_vitb'] ] ) # Launch the app interface.launch(debug=True)