PriMaPs / app.py
Oliver Hahn
add demo
80165f3
raw
history blame
2.61 kB
from argparse import ArgumentParser
from typing import Dict
import torch
from PIL import Image
import modules.transforms as transforms
from modules.primaps import PriMaPs
from modules.backbone.dino.dinovit import DinoFeaturizerv2
from modules.visualization import visualize_demo
import gradio as gr
# set seeds
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# np.random.seed(0)
# random.seed(0)
def gradio_primaps(image_path, threshold, architecture):
'''
Gradio demo to visualize PriMaPs for a single image.
'''
device='cpu'
resize_to = 320 if 'v2' not in architecture else 322
patch_size = 8 if 'v2' not in architecture else 14
# get SLL image encoder and primaps module
net = DinoFeaturizerv2(architecture, patch_size)
net.to(device)
primaps_module = PriMaPs(threshold=threshold,
ignore_id=255)
# get transforms
demo_transforms = transforms.Compose([transforms.ToTensor(),
transforms.Resize(resize_to),
transforms.CenterCrop([resize_to, resize_to]),
transforms.Normalize()])
# load image and apply transforms
image = Image.open(image_path)
image, _ = demo_transforms(image, torch.zeros(image.size))
image.to(device)
# get SSL features
feats = net(image.unsqueeze(0).to(device), n=1).squeeze()
# get primaps pseudo labels
primaps = primaps_module._get_pseudo(image, feats, torch.zeros(image.shape[1:]))
# visualize overlay
return visualize_demo(image, primaps)
if __name__ == '__main__':
# Example image paths
example_images = [
"assets/demo_examples/cityscapes_example.png",
"assets/demo_examples/coco_example.jpg",
"assets/demo_examples/potsdam_example.png"
]
# Gradio interface
interface = gr.Interface(
fn=gradio_primaps,
inputs=[
gr.Image(type="filepath", label="Image"),
gr.Slider(0.0, 1.0, step=0.05, value=0.35, label="Threshold"),
gr.Dropdown(choices=['dino_vits', 'dino_vitb', 'dinov2_vits', 'dinov2_vitb'], value='dino_vitb', label="SSL Features"),
],
outputs=gr.Image(label="PriMaPs"),
title="PriMaPs Demo",
description="Upload an image and adjust the threshold to visualize PriMaPs.",
examples=[
[example_images[0], 0.4, 'dino_vitb'],
[example_images[1], 0.4, 'dino_vitb'],
[example_images[2], 0.4, 'dino_vitb']
]
)
# Launch the app
interface.launch(debug=True)