|
import torch |
|
import torchvision.transforms as T |
|
from timm import create_model |
|
from safetensors.torch import load_model |
|
import numpy as np |
|
from pathlib import Path |
|
import gradio as gr |
|
|
|
examples = Path('./examples').glob('*') |
|
examples = list(map(str,examples)) |
|
|
|
valid_tfms = T.Compose([ |
|
T.Resize((224,224)), |
|
T.ToTensor(), |
|
T.Normalize( |
|
mean = (0.5,0.5,0.5), |
|
std = (0.5,0.5,0.5) |
|
) |
|
]) |
|
|
|
|
|
model_path = 'model/swin_s3_base_224-pascal/model.safetensors' |
|
model = create_model( |
|
'swin_s3_base_224', |
|
pretrained = False, |
|
num_classes = 20 |
|
) |
|
load_model(model,model_path) |
|
model.eval() |
|
|
|
class_names = [ |
|
"Aeroplane","Bicycle","Bird","Boat","Bottle", |
|
"Bus","Car","Cat","Chair","Cow","Diningtable", |
|
"Dog","Horse","Motorbike","Person", |
|
"Potted plant","Sheep","Sofa","Train","Tv/monitor" |
|
] |
|
|
|
label2id = {c:idx for idx,c in enumerate(class_names)} |
|
id2label = {idx:c for idx,c in enumerate(class_names)} |
|
|
|
|
|
def predict(im): |
|
im = valid_tfms(im).unsqueeze(0) |
|
with torch.no_grad(): |
|
logits = model(im) |
|
|
|
confidences = logits.sigmoid().flatten() |
|
predictions = confidences > 0.5 |
|
predictions = predictions.float().numpy() |
|
pred_labels = np.where(predictions==1)[0] |
|
confidences = confidences[pred_labels].numpy() |
|
pred_labels = [id2label[label] for label in pred_labels] |
|
outputs = {l:c for l,c in zip(pred_labels, confidences)} |
|
return outputs |
|
|
|
gr.Interface(fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Label(label='the image contains:'), |
|
examples=examples).queue().launch() |