IBBI / app.py
ChristopherMarais's picture
Update app.py
929de5c verified
import gradio as gr
import ibbi
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import io
# --- Model Management ---
MODEL_REGISTRY = {
"Single-Class Detection": {
"yolov10": "yolov10x_bb_detect_model",
"yolov11": "yolov11x_bb_detect_model",
"yolov9": "yolov9e_bb_detect_model",
"yolov8": "yolov8x_bb_detect_model",
"rtdetr": "rtdetrx_bb_detect_model",
},
"Multi-Class Detection": {
"yolov10": "yolov10x_bb_multi_class_detect_model",
"yolov11": "yolov11x_bb_multi_class_detect_model",
"yolov9": "yolov9e_bb_multi_class_detect_model",
"yolov8": "yolov8x_bb_multi_class_detect_model",
"rtdetr": "rtdetrx_bb_multi_class_detect_model",
},
"Zero-Shot Detection": {
"grounding_dino": "grounding_dino_detect_model"
}
}
# --- CORRECTED MODEL MANAGEMENT ---
# Caching is removed to prevent errors from stateful models.
# This function now loads a fresh model for each analysis request.
def get_model(task, architecture):
"""
Loads a fresh model instance based on user selection.
This prevents stateful changes from one run affecting the next.
"""
try:
# For Zero-Shot, the architecture is always 'grounding_dino'
if task == "Zero-Shot Detection":
architecture = "grounding_dino"
model_name = MODEL_REGISTRY[task][architecture]
print(f"Loading a fresh model instance: {model_name}")
model = ibbi.create_model(model_name, pretrained=True)
print("Model loaded successfully!")
return model
except KeyError as e:
raise gr.Error(f"Model lookup failed. Task: '{task}', Arch: '{architecture}'. Error: {e}")
except Exception as e:
raise gr.Error(f"Failed to load model. Please check the model name and your connection. Error: {e}")
# --- Visualization and Drawing Functions ---
def draw_yolo_predictions(image, results, font, color="red"):
"""Draws YOLO predictions on an image with a dynamically sized font."""
img_copy = image.copy()
draw = ImageDraw.Draw(img_copy)
if not results or not results[0].boxes:
return img_copy
res_for_img = results[0]
class_names = res_for_img.names
for box in res_for_img.boxes:
if box.cls.numel() == 0 or box.conf.numel() == 0: continue
coords = box.xyxy[0].tolist()
score = box.conf[0].item()
class_id = int(box.cls[0].item())
label_text = f"{class_names.get(class_id, f'Unknown-{class_id}')}: {score:.2f}"
draw.rectangle(coords, outline=color, width=3)
text_bbox = draw.textbbox((coords[0], coords[1]), label_text, font=font)
text_bg_y1 = coords[1] - (text_bbox[3] - text_bbox[1]) if coords[1] > (text_bbox[3] - text_bbox[1]) else 0
text_bg_coords = (coords[0], text_bg_y1, coords[0] + (text_bbox[2] - text_bbox[0]), coords[1])
draw.rectangle(text_bg_coords, fill=color)
draw.text((coords[0], text_bg_y1), label_text, fill="white", font=font)
return img_copy
def draw_dino_predictions(image, results, font, color="green"):
"""Draws Grounding DINO predictions on an image with a dynamically sized font."""
img_copy = image.copy()
draw = ImageDraw.Draw(img_copy)
if not results: return img_copy
for box, score, label in zip(results.get("boxes", []), results.get("scores", []), results.get("text_labels", [])):
coords = box.tolist()
label_text = f"{label}: {score:.2f}"
draw.rectangle(coords, outline=color, width=3)
text_bbox = draw.textbbox((coords[0], coords[1]), label_text, font=font)
text_bg_y1 = coords[1] - (text_bbox[3] - text_bbox[1]) if coords[1] > (text_bbox[3] - text_bbox[1]) else 0
text_bg_coords = (coords[0], text_bg_y1, coords[0] + (text_bbox[2] - text_bbox[0]), coords[1])
draw.rectangle(text_bg_coords, fill=color)
draw.text((coords[0], text_bg_y1), label_text, fill="white", font=font)
return img_copy
def visualize_embedding(embedding):
"""Visualizes a feature embedding as an image."""
if embedding is None: return None
if not hasattr(embedding, 'cpu'): return None
if len(embedding.shape) == 1:
embedding = embedding.unsqueeze(0)
fig, ax = plt.subplots(figsize=(10, 2))
ax.imshow(embedding.cpu().detach().numpy(), cmap='viridis', aspect='auto')
ax.set_title("Feature Embedding Visualization")
ax.set_xlabel("Feature Dimension")
ax.set_yticks([])
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
return Image.open(buf)
# --- CORRECTED Main Processing Function ---
def comprehensive_analysis(image, task, architecture, text_prompt, box_threshold, text_threshold):
"""Performs the main analysis with corrected logic."""
if image is None:
raise gr.Error("Please upload an image first!")
# Calculate a dynamic font size based on image width.
dynamic_font_size = max(15, int(image.width * 0.04))
try:
font = ImageFont.truetype("arial.ttf", dynamic_font_size)
except IOError:
font = ImageFont.load_default(size=dynamic_font_size)
# Get a fresh model instance to avoid stateful errors
model = get_model(task, architecture)
outputs = {"annotated_image": None, "model_info": "", "classes_info": "", "embedding_plot": None}
if task in ["Single-Class Detection", "Multi-Class Detection"]:
results = model.predict(image)
outputs["annotated_image"] = draw_yolo_predictions(image, results, font=font)
features = model.extract_features(image)
outputs["model_info"] = f"Architecture: {architecture.upper()}\nTask: {task}\nDevice: {model.device}"
outputs["classes_info"] = f"Classes: {model.get_classes()}"
else: # Zero-Shot Detection
if not text_prompt:
raise gr.Error("Please provide a text prompt for Zero-Shot Detection.")
results = model.predict(
image,
text_prompt=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold
)
outputs["annotated_image"] = draw_dino_predictions(image, results, font=font)
features = model.extract_features(image, text_prompt=text_prompt)
outputs["model_info"] = f"Architecture: GROUNDING_DINO\nTask: {task}\nDevice: {model.device}\nHF Model ID: {model.model.config._name_or_path}"
outputs["classes_info"] = f"Prompt: '{text_prompt}'"
# Process features for visualization
if isinstance(features, dict):
outputs["embedding_plot"] = visualize_embedding(features.get('last_hidden_state'))
else:
outputs["embedding_plot"] = visualize_embedding(features)
# Correctly placed return statement ensures all outputs are always returned
return outputs["annotated_image"], outputs["model_info"], outputs["classes_info"], outputs["embedding_plot"]
# --- Gradio UI ---
def update_ui_for_task(task):
"""Updates the UI components based on the selected task."""
if task in ["Single-Class Detection", "Multi-Class Detection"]:
arch_choices = list(MODEL_REGISTRY[task].keys())
return {
arch_dropdown: gr.update(choices=arch_choices, value=arch_choices[0], visible=True, interactive=True),
prompt_textbox: gr.update(visible=False, value=""),
box_threshold_slider: gr.update(visible=False),
text_threshold_slider: gr.update(visible=False)
}
else: # Zero-Shot Detection
arch_choices = list(MODEL_REGISTRY[task].keys())
return {
arch_dropdown: gr.update(choices=arch_choices, value=arch_choices[0], visible=False),
prompt_textbox: gr.update(visible=True),
box_threshold_slider: gr.update(visible=True),
text_threshold_slider: gr.update(visible=True)
}
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# IBBI - Intelligent Bark Beetle Identifier")
gr.Markdown("An all-in-one interface to analyze images using the `ibbi` library. Upload an image, select a task and model, and view the complete analysis.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Inputs")
image_input = gr.Image(type="pil", label="Upload Image")
task_selector = gr.Radio(
choices=["Single-Class Detection", "Multi-Class Detection", "Zero-Shot Detection"],
value="Single-Class Detection",
label="Choose Task"
)
arch_dropdown = gr.Dropdown(
choices=list(MODEL_REGISTRY["Single-Class Detection"].keys()),
value="yolov10",
label="Choose Model Architecture"
)
prompt_textbox = gr.Textbox(
label="Enter Text Prompt (for Zero-Shot)",
placeholder="e.g., insect . circle . metal ball",
visible=False
)
box_threshold_slider = gr.Slider(
minimum=0.05, maximum=1.0, value=0.25, step=0.05,
label="Box Threshold (Zero-Shot)",
info="Lower to detect more objects, even with low confidence.",
visible=False
)
text_threshold_slider = gr.Slider(
minimum=0.05, maximum=1.0, value=0.25, step=0.05,
label="Text Threshold (Zero-Shot)",
info="Lower to allow more labels to match detected objects.",
visible=False
)
analyze_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### 2. Analysis Results")
output_image = gr.Image(label="Annotated Image")
with gr.Accordion("Details", open=True):
model_details_output = gr.Textbox(label="Model Details", lines=4)
classes_output = gr.Textbox(label="Classes / Prompt")
embedding_output = gr.Image(label="Feature Embedding Visualization")
# --- Event Handlers ---
task_selector.change(
fn=update_ui_for_task,
inputs=task_selector,
outputs=[arch_dropdown, prompt_textbox, box_threshold_slider, text_threshold_slider]
)
analyze_btn.click(
fn=comprehensive_analysis,
inputs=[image_input, task_selector, arch_dropdown, prompt_textbox, box_threshold_slider, text_threshold_slider],
outputs=[output_image, model_details_output, classes_output, embedding_output]
)
gr.Markdown("---")
gr.Markdown("### 3. Or Start with an Example Image")
example_list = [
["example_images/example1.jpg"],
["example_images/example2.jpg"],
["example_images/example3.jpg"],
["example_images/example4.jpg"],
["example_images/example5.jpg"],
]
gr.Examples(
examples=example_list,
inputs=image_input,
label="Select an image to load it"
)
if __name__ == "__main__":
demo.launch(share=True, inline=True, debug=True, show_error=True)