from argparse import ArgumentParser from typing import Dict import torch from PIL import Image import modules.transforms as transforms from modules.primaps import PriMaPs from modules.backbone.dino.dinovit import DinoFeaturizerv2 from modules.visualization import visualize_demo import gradio as gr # set seeds torch.manual_seed(0) torch.cuda.manual_seed(0) # np.random.seed(0) # random.seed(0) def gradio_primaps(image_path, threshold, architecture): ''' Gradio demo to visualize PriMaPs for a single image. ''' device='cpu' resize_to = 320 if 'v2' not in architecture else 322 patch_size = 8 if 'v2' not in architecture else 14 # get SLL image encoder and primaps module net = DinoFeaturizerv2(architecture, patch_size) net.to(device) primaps_module = PriMaPs(threshold=threshold, ignore_id=255) # get transforms demo_transforms = transforms.Compose([transforms.ToTensor(), transforms.Resize(resize_to), transforms.CenterCrop([resize_to, resize_to]), transforms.Normalize()]) # load image and apply transforms image = Image.open(image_path) image, _ = demo_transforms(image, torch.zeros(image.size)) image.to(device) # get SSL features feats = net(image.unsqueeze(0).to(device), n=1).squeeze() # get primaps pseudo labels primaps = primaps_module._get_pseudo(image, feats, torch.zeros(image.shape[1:])) # visualize overlay return visualize_demo(image, primaps) if __name__ == '__main__': # Gradio interface interface = gr.Interface( fn=gradio_primaps, inputs=[ gr.Image(type="filepath", label="Image"), gr.Slider(0.0, 1.0, step=0.05, value=0.4, label="Threshold"), gr.Dropdown(choices=['dino_vits', 'dino_vitb', 'dinov2_vits', 'dinov2_vitb'], value='dino_vitb', label="SSL Features"), ], outputs=gr.Image(label="PriMaPs"), title="PriMaPs Demo", description="Upload an image and adjust the threshold to visualize PriMaPs.", examples=[ ["example.jpg", 0.4, 'dino_vitb'], ] ) # Launch the app interface.launch(debug=True)