Spaces:
Sleeping
Sleeping
import gradio as gr | |
import ibbi | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageFont | |
import matplotlib.pyplot as plt | |
import io | |
# --- Model Management --- | |
MODEL_REGISTRY = { | |
"Single-Class Detection": { | |
"yolov10": "yolov10x_bb_detect_model", | |
"yolov11": "yolov11x_bb_detect_model", | |
"yolov9": "yolov9e_bb_detect_model", | |
"yolov8": "yolov8x_bb_detect_model", | |
"rtdetr": "rtdetrx_bb_detect_model", | |
}, | |
"Multi-Class Detection": { | |
"yolov10": "yolov10x_bb_multi_class_detect_model", | |
"yolov11": "yolov11x_bb_multi_class_detect_model", | |
"yolov9": "yolov9e_bb_multi_class_detect_model", | |
"yolov8": "yolov8x_bb_multi_class_detect_model", | |
"rtdetr": "rtdetrx_bb_multi_class_detect_model", | |
}, | |
"Zero-Shot Detection": { | |
"grounding_dino": "grounding_dino_detect_model" | |
} | |
} | |
# --- CORRECTED MODEL MANAGEMENT --- | |
# Caching is removed to prevent errors from stateful models. | |
# This function now loads a fresh model for each analysis request. | |
def get_model(task, architecture): | |
""" | |
Loads a fresh model instance based on user selection. | |
This prevents stateful changes from one run affecting the next. | |
""" | |
try: | |
# For Zero-Shot, the architecture is always 'grounding_dino' | |
if task == "Zero-Shot Detection": | |
architecture = "grounding_dino" | |
model_name = MODEL_REGISTRY[task][architecture] | |
print(f"Loading a fresh model instance: {model_name}") | |
model = ibbi.create_model(model_name, pretrained=True) | |
print("Model loaded successfully!") | |
return model | |
except KeyError as e: | |
raise gr.Error(f"Model lookup failed. Task: '{task}', Arch: '{architecture}'. Error: {e}") | |
except Exception as e: | |
raise gr.Error(f"Failed to load model. Please check the model name and your connection. Error: {e}") | |
# --- Visualization and Drawing Functions --- | |
def draw_yolo_predictions(image, results, font, color="red"): | |
"""Draws YOLO predictions on an image with a dynamically sized font.""" | |
img_copy = image.copy() | |
draw = ImageDraw.Draw(img_copy) | |
if not results or not results[0].boxes: | |
return img_copy | |
res_for_img = results[0] | |
class_names = res_for_img.names | |
for box in res_for_img.boxes: | |
if box.cls.numel() == 0 or box.conf.numel() == 0: continue | |
coords = box.xyxy[0].tolist() | |
score = box.conf[0].item() | |
class_id = int(box.cls[0].item()) | |
label_text = f"{class_names.get(class_id, f'Unknown-{class_id}')}: {score:.2f}" | |
draw.rectangle(coords, outline=color, width=3) | |
text_bbox = draw.textbbox((coords[0], coords[1]), label_text, font=font) | |
text_bg_y1 = coords[1] - (text_bbox[3] - text_bbox[1]) if coords[1] > (text_bbox[3] - text_bbox[1]) else 0 | |
text_bg_coords = (coords[0], text_bg_y1, coords[0] + (text_bbox[2] - text_bbox[0]), coords[1]) | |
draw.rectangle(text_bg_coords, fill=color) | |
draw.text((coords[0], text_bg_y1), label_text, fill="white", font=font) | |
return img_copy | |
def draw_dino_predictions(image, results, font, color="green"): | |
"""Draws Grounding DINO predictions on an image with a dynamically sized font.""" | |
img_copy = image.copy() | |
draw = ImageDraw.Draw(img_copy) | |
if not results: return img_copy | |
for box, score, label in zip(results.get("boxes", []), results.get("scores", []), results.get("text_labels", [])): | |
coords = box.tolist() | |
label_text = f"{label}: {score:.2f}" | |
draw.rectangle(coords, outline=color, width=3) | |
text_bbox = draw.textbbox((coords[0], coords[1]), label_text, font=font) | |
text_bg_y1 = coords[1] - (text_bbox[3] - text_bbox[1]) if coords[1] > (text_bbox[3] - text_bbox[1]) else 0 | |
text_bg_coords = (coords[0], text_bg_y1, coords[0] + (text_bbox[2] - text_bbox[0]), coords[1]) | |
draw.rectangle(text_bg_coords, fill=color) | |
draw.text((coords[0], text_bg_y1), label_text, fill="white", font=font) | |
return img_copy | |
def visualize_embedding(embedding): | |
"""Visualizes a feature embedding as an image.""" | |
if embedding is None: return None | |
if not hasattr(embedding, 'cpu'): return None | |
if len(embedding.shape) == 1: | |
embedding = embedding.unsqueeze(0) | |
fig, ax = plt.subplots(figsize=(10, 2)) | |
ax.imshow(embedding.cpu().detach().numpy(), cmap='viridis', aspect='auto') | |
ax.set_title("Feature Embedding Visualization") | |
ax.set_xlabel("Feature Dimension") | |
ax.set_yticks([]) | |
fig.tight_layout() | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png') | |
plt.close(fig) | |
buf.seek(0) | |
return Image.open(buf) | |
# --- CORRECTED Main Processing Function --- | |
def comprehensive_analysis(image, task, architecture, text_prompt, box_threshold, text_threshold): | |
"""Performs the main analysis with corrected logic.""" | |
if image is None: | |
raise gr.Error("Please upload an image first!") | |
# Calculate a dynamic font size based on image width. | |
dynamic_font_size = max(15, int(image.width * 0.04)) | |
try: | |
font = ImageFont.truetype("arial.ttf", dynamic_font_size) | |
except IOError: | |
font = ImageFont.load_default(size=dynamic_font_size) | |
# Get a fresh model instance to avoid stateful errors | |
model = get_model(task, architecture) | |
outputs = {"annotated_image": None, "model_info": "", "classes_info": "", "embedding_plot": None} | |
if task in ["Single-Class Detection", "Multi-Class Detection"]: | |
results = model.predict(image) | |
outputs["annotated_image"] = draw_yolo_predictions(image, results, font=font) | |
features = model.extract_features(image) | |
outputs["model_info"] = f"Architecture: {architecture.upper()}\nTask: {task}\nDevice: {model.device}" | |
outputs["classes_info"] = f"Classes: {model.get_classes()}" | |
else: # Zero-Shot Detection | |
if not text_prompt: | |
raise gr.Error("Please provide a text prompt for Zero-Shot Detection.") | |
results = model.predict( | |
image, | |
text_prompt=text_prompt, | |
box_threshold=box_threshold, | |
text_threshold=text_threshold | |
) | |
outputs["annotated_image"] = draw_dino_predictions(image, results, font=font) | |
features = model.extract_features(image, text_prompt=text_prompt) | |
outputs["model_info"] = f"Architecture: GROUNDING_DINO\nTask: {task}\nDevice: {model.device}\nHF Model ID: {model.model.config._name_or_path}" | |
outputs["classes_info"] = f"Prompt: '{text_prompt}'" | |
# Process features for visualization | |
if isinstance(features, dict): | |
outputs["embedding_plot"] = visualize_embedding(features.get('last_hidden_state')) | |
else: | |
outputs["embedding_plot"] = visualize_embedding(features) | |
# Correctly placed return statement ensures all outputs are always returned | |
return outputs["annotated_image"], outputs["model_info"], outputs["classes_info"], outputs["embedding_plot"] | |
# --- Gradio UI --- | |
def update_ui_for_task(task): | |
"""Updates the UI components based on the selected task.""" | |
if task in ["Single-Class Detection", "Multi-Class Detection"]: | |
arch_choices = list(MODEL_REGISTRY[task].keys()) | |
return { | |
arch_dropdown: gr.update(choices=arch_choices, value=arch_choices[0], visible=True, interactive=True), | |
prompt_textbox: gr.update(visible=False, value=""), | |
box_threshold_slider: gr.update(visible=False), | |
text_threshold_slider: gr.update(visible=False) | |
} | |
else: # Zero-Shot Detection | |
arch_choices = list(MODEL_REGISTRY[task].keys()) | |
return { | |
arch_dropdown: gr.update(choices=arch_choices, value=arch_choices[0], visible=False), | |
prompt_textbox: gr.update(visible=True), | |
box_threshold_slider: gr.update(visible=True), | |
text_threshold_slider: gr.update(visible=True) | |
} | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# IBBI - Intelligent Bark Beetle Identifier") | |
gr.Markdown("An all-in-one interface to analyze images using the `ibbi` library. Upload an image, select a task and model, and view the complete analysis.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### 1. Inputs") | |
image_input = gr.Image(type="pil", label="Upload Image") | |
task_selector = gr.Radio( | |
choices=["Single-Class Detection", "Multi-Class Detection", "Zero-Shot Detection"], | |
value="Single-Class Detection", | |
label="Choose Task" | |
) | |
arch_dropdown = gr.Dropdown( | |
choices=list(MODEL_REGISTRY["Single-Class Detection"].keys()), | |
value="yolov10", | |
label="Choose Model Architecture" | |
) | |
prompt_textbox = gr.Textbox( | |
label="Enter Text Prompt (for Zero-Shot)", | |
placeholder="e.g., insect . circle . metal ball", | |
visible=False | |
) | |
box_threshold_slider = gr.Slider( | |
minimum=0.05, maximum=1.0, value=0.25, step=0.05, | |
label="Box Threshold (Zero-Shot)", | |
info="Lower to detect more objects, even with low confidence.", | |
visible=False | |
) | |
text_threshold_slider = gr.Slider( | |
minimum=0.05, maximum=1.0, value=0.25, step=0.05, | |
label="Text Threshold (Zero-Shot)", | |
info="Lower to allow more labels to match detected objects.", | |
visible=False | |
) | |
analyze_btn = gr.Button("Analyze Image", variant="primary") | |
with gr.Column(scale=2): | |
gr.Markdown("### 2. Analysis Results") | |
output_image = gr.Image(label="Annotated Image") | |
with gr.Accordion("Details", open=True): | |
model_details_output = gr.Textbox(label="Model Details", lines=4) | |
classes_output = gr.Textbox(label="Classes / Prompt") | |
embedding_output = gr.Image(label="Feature Embedding Visualization") | |
# --- Event Handlers --- | |
task_selector.change( | |
fn=update_ui_for_task, | |
inputs=task_selector, | |
outputs=[arch_dropdown, prompt_textbox, box_threshold_slider, text_threshold_slider] | |
) | |
analyze_btn.click( | |
fn=comprehensive_analysis, | |
inputs=[image_input, task_selector, arch_dropdown, prompt_textbox, box_threshold_slider, text_threshold_slider], | |
outputs=[output_image, model_details_output, classes_output, embedding_output] | |
) | |
gr.Markdown("---") | |
gr.Markdown("### 3. Or Start with an Example Image") | |
example_list = [ | |
["example_images/example1.jpg"], | |
["example_images/example2.jpg"], | |
["example_images/example3.jpg"], | |
["example_images/example4.jpg"], | |
["example_images/example5.jpg"], | |
] | |
gr.Examples( | |
examples=example_list, | |
inputs=image_input, | |
label="Select an image to load it" | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True, inline=True, debug=True, show_error=True) |