import gradio as gr import torch import spaces import json import base64 from io import BytesIO from transformers import SamHQModel, SamHQProcessor, SamModel, SamProcessor import os import pandas as pd from utils import * from PIL import Image # Carga de modelos sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge") sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge") sam_model = SamModel.from_pretrained("facebook/sam-vit-huge") sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") @spaces.GPU def predict_masks_and_scores(model, processor, raw_image, input_points=None, input_boxes=None): if input_boxes is not None: input_boxes = [input_boxes] inputs = processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) scores = outputs.iou_scores return masks, scores def encode_pil_to_base64(pil_image): buffer = BytesIO() pil_image.save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode("utf-8") def compare_images_points_and_masks(user_image, input_boxes, input_points): for example_path, example_data in example_data_map.items(): if example_data["size"] == list(user_image.size): user_image = Image.open(example_data['original_image_path']) input_boxes = input_boxes.values.tolist() input_points = input_points.values.tolist() input_boxes = [[[int(coord) for coord in box] for box in input_boxes if any(box)]] input_points = [[[int(coord) for coord in point] for point in input_points if any(point)]] input_boxes = input_boxes if input_boxes[0] else None input_points = input_points if input_points[0] else None sam_masks, sam_scores = predict_masks_and_scores(sam_model, sam_processor, user_image, input_boxes=input_boxes, input_points=input_points) sam_hq_masks, sam_hq_scores = predict_masks_and_scores(sam_hq_model, sam_hq_processor, user_image, input_boxes=input_boxes, input_points=input_points) if input_boxes and input_points: img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM') img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM_HQ') elif input_boxes: img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], None, model_name='SAM') img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], input_boxes[0], None, model_name='SAM_HQ') elif input_points: img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], None, input_points[0], model_name='SAM') img2_b64 = show_all_annotations_on_image_base64(user_image, sam_hq_masks[0][0], sam_hq_scores[:, 0, :], None, input_points[0], model_name='SAM_HQ') print('user_image', user_image) print("img1_b64", img1_b64) print("img2_b64", img2_b64) html_code = f"""
""" return html_code def load_examples(json_file="examples.json"): with open(json_file, "r") as f: examples = json.load(f) return examples examples = load_examples() example_paths = [example["image_path"] for example in examples] example_data_map = { example["image_path"]: { "original_image_path": example["original_image_path"], "points": example["points"], "boxes": example["boxes"], "size": example["size"] } for example in examples } theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="emerald") with gr.Blocks(theme=theme, title="🔍 Compare SAM vs SAM-HQ") as demo: image_path_box = gr.Textbox(visible=False) gr.Markdown("## 🔍 Compare SAM vs SAM-HQ") gr.Markdown("Compare the performance of SAM and SAM-HQ on various images. Click on an example to load it") gr.Markdown("[SAM-HQ](https://huggingface.co/syscv-community/sam-hq-vit-huge) - [SAM](https://huggingface.co/facebook/sam-vit-huge)") with gr.Row(): image_input = gr.Image( type="pil", label="Example image (click below to load)", interactive=False, height=500, show_label=True ) gr.Examples( examples=example_paths, inputs=[image_input], label="Click an example to try 👇", ) result_html = gr.HTML(elem_id="result-html") with gr.Row(): points_input = gr.Dataframe( headers=["x", "y"], label="Points", datatype=["number", "number"], col_count=(2, "fixed") ) boxes_input = gr.Dataframe( headers=["x0", "y0", "x1", "y1"], label="Boxes", datatype=["number", "number", "number", "number"], col_count=(4, "fixed") ) def on_image_change(image): for example_path, example_data in example_data_map.items(): print(image.size) if example_data["size"] == list(image.size): return example_data["points"], example_data["boxes"] return [], [] image_input.change( fn=on_image_change, inputs=[image_input], outputs=[points_input, boxes_input] ) compare_button = gr.Button("Compare points and masks") compare_button.click(fn=compare_images_points_and_masks, inputs=[image_input, boxes_input, points_input], outputs=result_html) gr.HTML(""" """) demo.launch()