motaer0206's picture
Update app.py
14cead4 verified
# --------------------------------------------------------
# Based on yolov10
# https://github.com/THU-MIG/yolov10/app.py
# --------------------------------------------------------'
import gradio as gr
import cv2
import tempfile
from ultralytics import YOLO
import os
from PIL import Image
from gtts import gTTS
from io import BytesIO
from pydub import AudioSegment
# 建立語音緩存資料夾
CACHE_DIR = "audio_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
def get_audio_filename(text):
"""根據偵測到的物件名稱生成對應的音檔路徑"""
return os.path.join(CACHE_DIR, f"{text}.wav")
def generate_or_get_audio(text):
"""如果語音檔已存在,則直接返回,否則生成新的語音檔"""
audio_path = get_audio_filename(text)
if not os.path.exists(audio_path):
# 生成語音檔 (使用繁體中文)
tts = gTTS(text=text, lang='zh-tw')
temp_bytes = BytesIO()
tts.write_to_fp(temp_bytes)
temp_bytes.seek(0)
# 轉換為 WAV 格式並儲存
audio = AudioSegment.from_file(temp_bytes, format="mp3")
audio.export(audio_path, format="wav")
return audio_path # 返回音檔路徑
def yolov12_inference(image, video, model_id, image_size, conf_threshold):
# 修改這裡,當 model_id 是 "best.pt" 時,使用指定的路徑
if model_id == "best.pt":
model_path = "motaer0206/YOLOv12-Audio-Assistant/best.pt" # 使用您提供的路徑
else:
model_path = model_id # 其他模型名稱,保持原來的邏輯,讓 ultralytics 自動處理
model = YOLO(model_path)
detected_objects = set() # 儲存檢測到的物件
if image:
results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
annotated_image = results[0].plot()
# 提取檢測到的物件
for result in results:
for box in result.boxes:
cls_id = int(box.cls[0])
cls_name = model.names[cls_id]
detected_objects.add(cls_name)
# 產生或獲取音訊
if detected_objects:
description = "偵測到:" + "、".join(detected_objects)
audio_path = generate_or_get_audio(description)
else:
audio_path = None
return annotated_image[:, :, ::-1], audio_path # 返回 RGB 圖像和音訊路徑
else:
video_path = tempfile.mktemp(suffix=".webm")
with open(video_path, "wb") as f:
with open(video, "rb") as g:
f.write(g.read())
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
output_video_path = tempfile.mktemp(suffix=".webm")
out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
annotated_frame = results[0].plot()
out.write(annotated_frame)
# 提取每幀的檢測結果
for result in results:
for box in result.boxes:
cls_id = int(box.cls[0])
cls_name = model.names[cls_id]
detected_objects.add(cls_name)
cap.release()
out.release()
# 產生或獲取音訊 (針對整個影片)
if detected_objects:
description = "偵測到:" + "、".join(detected_objects)
audio_path = generate_or_get_audio(description)
else:
audio_path = None
return None, output_video_path, audio_path # 影片返回音訊路徑
def yolov12_inference_for_examples(image, model_path, image_size, conf_threshold):
annotated_image, audio_path = yolov12_inference(image, None, model_path, image_size, conf_threshold)
return annotated_image, audio_path
def app():
with gr.Blocks():
with gr.Row():
with gr.Column():
image = gr.Image(type="pil", label="Image", visible=True)
video = gr.Video(label="Video", visible=False)
input_type = gr.Radio(
choices=["Image", "Video"],
value="Image",
label="Input Type",
)
model_id = gr.Dropdown(
label="Model",
choices=[
"yolov12n.pt",
"yolov12s.pt",
"yolov12m.pt",
"yolov12l.pt",
"yolov12x.pt",
],
value="yolov12n.pt", # 設定預設模型
)
image_size = gr.Slider(
label="Image Size",
minimum=320,
maximum=1280,
step=32,
value=640,
)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.25,
)
yolov12_infer = gr.Button(value="Detect Objects")
with gr.Column():
output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
output_video = gr.Video(label="Annotated Video", visible=False)
output_audio = gr.Audio(label="Audio Feedback", visible=True) # 新增音訊輸出
def update_visibility(input_type):
image_vis = input_type == "Image"
video_vis = not image_vis
return (
gr.update(visible=image_vis),
gr.update(visible=video_vis),
gr.update(visible=image_vis),
gr.update(visible=video_vis),
gr.update(visible=True), # Audio 輸出總是可見
)
input_type.change(
fn=update_visibility,
inputs=[input_type],
outputs=[image, video, output_image, output_video, output_audio],
)
def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
if input_type == "Image":
annotated_image, audio_path = yolov12_inference(image, None, model_id, image_size, conf_threshold)
return annotated_image, None, audio_path
else:
_, annotated_video, audio_path = yolov12_inference(None, video, model_id, image_size, conf_threshold)
return None, annotated_video, audio_path
yolov12_infer.click(
fn=run_inference,
inputs=[image, video, model_id, image_size, conf_threshold, input_type],
outputs=[output_image, output_video, output_audio], # 新增音訊輸出
)
gr.Examples(
examples=[
[
"ultralytics/assets/bus.jpg",
"yolov12n.pt",
640,
0.25,
],
[
"ultralytics/assets/zidane.jpg",
"yolov12n.pt",
640,
0.25,
],
],
fn=yolov12_inference_for_examples, # 更新 fn
inputs=[
image,
model_id,
image_size,
conf_threshold,
],
outputs=[output_image, output_audio], # 新增音訊輸出到範例
cache_examples='lazy',
)
gradio_app = gr.Blocks()
with gradio_app:
gr.HTML(
"""
<h1 style='text-align: center'>
YOLOv12: Attention-Centric Real-Time Object Detectors
</h1>
""")
with gr.Row():
with gr.Column():
app()
if __name__ == '__main__':
gradio_app.launch()