File size: 2,277 Bytes
1dc26c7 4f7ac5e 1dc26c7 5f99e47 1dc26c7 04513cf 3398a0c 1dc26c7 3398a0c ccbf06e 51f04f9 1dc26c7 6dcd6be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from argparse import ArgumentParser
from typing import Dict
import torch
from PIL import Image
import modules.transforms as transforms
from modules.primaps import PriMaPs
from modules.backbone.dino.dinovit import DinoFeaturizerv2
from modules.visualization import visualize_demo
import gradio as gr
def gradio_primaps(image_path, threshold, architecture):
'''
Gradio demo to visualize PriMaPs for a single image.
'''
device='cpu'
resize_to = 320 if 'v2' not in architecture else 322
patch_size = 8 if 'v2' not in architecture else 14
# get SLL image encoder and primaps module
net = DinoFeaturizerv2(architecture, patch_size)
net.to(device)
primaps_module = PriMaPs(threshold=threshold,
ignore_id=255)
# get transforms
demo_transforms = transforms.Compose([transforms.ToTensor(),
transforms.Resize(resize_to),
transforms.CenterCrop([resize_to, resize_to]),
transforms.Normalize()])
# load image and apply transforms
image = Image.open(image_path)
image, _ = demo_transforms(image, torch.zeros(image.size))
image.to(device)
# get SSL features
feats = net(image.unsqueeze(0).to(device), n=1).squeeze()
# get primaps pseudo labels
primaps = primaps_module._get_pseudo(image, feats, torch.zeros(image.shape[1:]))
# visualize overlay
return visualize_demo(image, primaps)
if __name__ == '__main__':
# Gradio interface
interface = gr.Interface(
fn=gradio_primaps,
inputs=[
gr.Image(type="filepath", label="Image"),
gr.Slider(0.0, 1.0, step=0.05, value=0.4, label="Threshold"),
gr.Dropdown(choices=['dino_vits', 'dino_vitb', 'dinov2_vits', 'dinov2_vitb'], value='dino_vits', label="SSL Features"),
],
outputs=gr.Image(label="PriMaPs"),
title="PriMaPs Demo",
description="Upload an image and adjust the threshold to visualize PriMaPs.",
examples=[
["example.jpg", 0.4, 'dino_vits'],
],
article="For more details, visit the [project page](https://visinf.github.io/primaps)."
)
# Launch the app
interface.launch()
|