import torch import cv2 import gradio as gr import numpy as np from transformers import OwlViTProcessor, OwlViTForObjectDetection import requests # 如果GPU可用,就使用GPU,否则使用CPU if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") # 从预训练模型"google/owlvit-large-patch14"加载OWL-ViT模型,并将其放置到适当的设备上 model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14").to(device) model.eval() # 从同一预训练模型中加载处理器 processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14") # 定义一个函数来处理图像URL,文本查询和分数阈值 def query_image(img_url, text_queries, score_threshold): # 使用requests库从URL中获取图像 response = requests.get(img_url) response.raise_for_status() arr = np.asarray(bytearray(response.content), dtype=np.uint8) img = cv2.imdecode(arr, -1) # 使用-1来加载原始图像 text_queries = text_queries.split(",") # 将文本查询分割成独立的查询 target_sizes = torch.Tensor([img.shape[:2]]) inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device) # 使用处理器创建模型的输入 with torch.no_grad(): outputs = model(**inputs) # 获取模型的输出 # 将输出转移到CPU上 outputs.logits = outputs.logits.cpu() outputs.pred_boxes = outputs.pred_boxes.cpu() # 使用处理器进行后处理 results = processor.post_process(outputs=outputs, target_sizes=target_sizes) boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] font = cv2.FONT_HERSHEY_SIMPLEX # 在图像上绘制边界框并添加标签 for box, score, label in zip(boxes, scores, labels): box = [int(i) for i in box.tolist()] if score >= score_threshold: img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5) y = box[3] - 10 if box[3] + 25 > 768 else box[3] + 25 img = cv2.putText( img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA ) return img description = """ Gradio demo for OWL-ViT. You can use OWL-ViT to query images with text descriptions of any object. To use it, simply provide an image URL and enter comma separated text descriptions of objects you want to query the image for. You can also use the score threshold slider to set a threshold to filter out low probability predictions. """ # 创建一个Gradio界面 demo = gr.Interface( query_image, inputs=["text", "text", gr.Slider(0, 1, value=0.1)], # 修改输入,使其接受URL而不是图像 outputs="image", title="Zero-Shot Object Detection with OWL-ViT", description=description, examples=[], # 设置为一个空列表 ) demo.launch()