Spaces:
Sleeping
Sleeping
File size: 2,993 Bytes
83bc5a3 aade841 83bc5a3 78b651b 83bc5a3 777b90a 83bc5a3 777b90a 83bc5a3 777b90a 83bc5a3 d87e64d 78b651b 83bc5a3 16e22fb 83bc5a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
import torch
from PIL import Image
import matplotlib.pyplot as plt
import io
from model import load_model, get_val_transform # Import functions from model.py
import numpy as np
# Load the model on GPU if available
model = load_model(device=0 if torch.cuda.is_available() else -1)
val_transform = get_val_transform()
# Define colors for bounding boxes
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
def get_output_figure(pil_img, results, threshold):
plt.figure(figsize=(12, 8))
plt.imshow(pil_img)
ax = plt.gca()
for result in results:
score = result['score']
label = result['label']
box = list(result['box'].values())
if score > threshold:
color = COLORS[hash(label) % len(COLORS)]
ax.add_patch(
plt.Rectangle(
(box[0], box[1]), box[2] - box[0], box[3] - box[1],
fill=False, color=color, linewidth=2
)
)
text = f'{label}: {score:.2f}'
ax.text(
box[0], box[1] - 5, text, fontsize=10,
bbox=dict(facecolor='yellow', alpha=0.5, edgecolor='none')
)
plt.axis('off')
buf = io.BytesIO()
plt.savefig(buf, bbox_inches='tight', dpi=100)
buf.seek(0)
plt.close()
return Image.open(buf)
def detect(image, threshold=0.5):
results = model(image)
output_image = get_output_figure(image, results, threshold)
return output_image
# Build the Gradio app
with gr.Blocks() as demo:
gr.Markdown(
"""
# Fashion Object Detection
Detect fashion-related objects in images using a fine-tuned DETR model.
You can load or select an image then adjust the detection threshold using the slider for better results.
"""
)
with gr.Row():
image_input = gr.Image(label="Input Image", type="pil")
threshold_slider = gr.Slider(
minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Detection Threshold"
)
output_image = gr.Image(label="Output Prediction", type="pil")
detect_button = gr.Button("Run Detection")
detect_button.click(detect, inputs=[image_input, threshold_slider], outputs=output_image)
gr.Markdown(
"""
### About the Model
This app uses the DETR model fine-tuned on the Fashionpedia dataset, which includes diverse fashion-related objects.
"""
)
gr.Markdown(
"""
### Created by Kelechi Osuji.
"""
)
# Add example images
example_images = [
"examples/fashion_image_223.jpg",
"examples/fashion_image_1094.jpg",
"examples/fashion_image_1113.jpg",
"examples/fashion_image_508.jpg"
]
gr.Examples(
examples=example_images,
inputs=[image_input]
)
demo.launch(show_error=True)
|