File size: 1,374 Bytes
3126b1e
cbe5ac9
3126b1e
 
cbe5ac9
3126b1e
 
 
 
 
51b21f6
3126b1e
cbe5ac9
3126b1e
 
 
 
 
 
 
 
 
 
 
f933a60
626e1ed
a64f813
 
 
 
f933a60
 
 
 
 
626e1ed
3126b1e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import tensorflow as tf
import gradio as gr
import gcvit
from gcvit.utils import get_gradcam_model, get_gradcam_prediction

def predict_fn(image, model_name):
    """A predict function that will be invoked by gradio."""
    model = getattr(gcvit, model_name)(pretrain=True)
    gradcam_model = get_gradcam_model(model)
    preds, overlay = get_gradcam_prediction(image, gradcam_model, cmap='jet', alpha=0.4, pred_index=None)
    preds = {x[1]:float(x[2]) for x in preds}
    return [preds, overlay]

demo = gr.Interface(
    fn=predict_fn,
    inputs=[
        gr.inputs.Image(label="Input Image"),
        gr.Radio(['GCViTTiny', 'GCViTSmall', 'GCViTBase'], value='GCViTTiny', label='Model Size')
        ],
    outputs=[
        gr.outputs.Label(label="Prediction"),
        gr.inputs.Image(label="GradCAM"),
    ],
    title="Global Context Vision Transformer (GCViT) Demo",
    description="Image Classification with ImageNet Pretrain Weights.",
    examples=[
        ["example/hot_air_ballon.jpg", 'GCViTTiny'],
        ["example/chelsea.png", 'GCViTTiny'],
        ["example/german_shepherd.jpg", 'GCViTTiny'],
        ["example/panda.jpg", 'GCViTTiny'],
        ["example/jellyfish.jpg", 'GCViTTiny'],
        ["example/penguin.JPG", 'GCViTTiny'],
        ["example/bus.jpg", 'GCViTTiny'],
        ["example/cat_dog.JPG", 'GCViTTiny'],
        
    ],
)
demo.launch()