|
from argparse import ArgumentParser |
|
from typing import Dict |
|
import torch |
|
from PIL import Image |
|
import modules.transforms as transforms |
|
from modules.primaps import PriMaPs |
|
from modules.backbone.dino.dinovit import DinoFeaturizerv2 |
|
from modules.visualization import visualize_demo |
|
import gradio as gr |
|
|
|
torch.manual_seed(0) |
|
torch.cuda.manual_seed(0) |
|
|
|
|
|
|
|
def gradio_primaps(image_path, threshold, architecture): |
|
''' |
|
Gradio demo to visualize PriMaPs for a single image. |
|
''' |
|
|
|
device='cuda:0' |
|
resize_to = 320 if 'v2' not in architecture else 322 |
|
patch_size = 8 if 'v2' not in architecture else 14 |
|
|
|
|
|
net = DinoFeaturizerv2(architecture, patch_size) |
|
net.to(device) |
|
primaps_module = PriMaPs(threshold=threshold, |
|
ignore_id=255) |
|
|
|
|
|
demo_transforms = transforms.Compose([transforms.ToTensor(), |
|
transforms.Resize(resize_to), |
|
transforms.CenterCrop([resize_to, resize_to]), |
|
transforms.Normalize()]) |
|
|
|
|
|
image = Image.open(image_path) |
|
image, _ = demo_transforms(image, torch.zeros(image.size)) |
|
image.to(device) |
|
|
|
feats = net(image.unsqueeze(0).to(device), n=1).squeeze() |
|
|
|
primaps = primaps_module._get_pseudo(image, feats, torch.zeros(image.shape[1:])) |
|
|
|
return visualize_demo(image, primaps) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
example_images = [ |
|
"assets/demo_examples/cityscapes_example.png", |
|
"assets/demo_examples/coco_example.jpg", |
|
"assets/demo_examples/potsdam_example.png" |
|
] |
|
|
|
|
|
interface = gr.Interface( |
|
fn=gradio_primaps, |
|
inputs=[ |
|
gr.Image(type="filepath", label="Image"), |
|
gr.Slider(0.0, 1.0, step=0.05, value=0.35, label="Threshold"), |
|
gr.Dropdown(choices=['dino_vits', 'dino_vitb', 'dinov2_vits', 'dinov2_vitb'], value='dino_vitb', label="SSL Features"), |
|
], |
|
outputs=gr.Image(label="PriMaPs"), |
|
title="PriMaPs Demo", |
|
description="Upload an image and adjust the threshold to visualize PriMaPs.", |
|
examples=[ |
|
[example_images[0], 0.35, 'dino_vitb'], |
|
[example_images[1], 0.35, 'dino_vitb'], |
|
[example_images[2], 0.35, 'dino_vitb'] |
|
] |
|
) |
|
|
|
|
|
interface.launch(debug=True) |
|
|