File size: 1,374 Bytes
3126b1e cbe5ac9 3126b1e cbe5ac9 3126b1e 51b21f6 3126b1e cbe5ac9 3126b1e f933a60 626e1ed a64f813 f933a60 626e1ed 3126b1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import tensorflow as tf
import gradio as gr
import gcvit
from gcvit.utils import get_gradcam_model, get_gradcam_prediction
def predict_fn(image, model_name):
"""A predict function that will be invoked by gradio."""
model = getattr(gcvit, model_name)(pretrain=True)
gradcam_model = get_gradcam_model(model)
preds, overlay = get_gradcam_prediction(image, gradcam_model, cmap='jet', alpha=0.4, pred_index=None)
preds = {x[1]:float(x[2]) for x in preds}
return [preds, overlay]
demo = gr.Interface(
fn=predict_fn,
inputs=[
gr.inputs.Image(label="Input Image"),
gr.Radio(['GCViTTiny', 'GCViTSmall', 'GCViTBase'], value='GCViTTiny', label='Model Size')
],
outputs=[
gr.outputs.Label(label="Prediction"),
gr.inputs.Image(label="GradCAM"),
],
title="Global Context Vision Transformer (GCViT) Demo",
description="Image Classification with ImageNet Pretrain Weights.",
examples=[
["example/hot_air_ballon.jpg", 'GCViTTiny'],
["example/chelsea.png", 'GCViTTiny'],
["example/german_shepherd.jpg", 'GCViTTiny'],
["example/panda.jpg", 'GCViTTiny'],
["example/jellyfish.jpg", 'GCViTTiny'],
["example/penguin.JPG", 'GCViTTiny'],
["example/bus.jpg", 'GCViTTiny'],
["example/cat_dog.JPG", 'GCViTTiny'],
],
)
demo.launch() |