#!/usr/bin/env python3
import os
import re
import tempfile
import gc
from collections.abc import Iterator
from threading import Thread, Lock
import json
import requests
import cv2
import gradio as gr
import spaces
import torch
import numpy as np
from loguru import logger
from PIL import Image
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer, pipeline
import time
import warnings
from typing import Dict, List, Optional, Union
import librosa
import scipy.signal as sps
import queue
# CSV/TXT 분석
import pandas as pd
# PDF 텍스트 추출
import PyPDF2
warnings.filterwarnings('ignore')
# 로깅 설정
logger.remove()
logger.add(lambda msg: print(msg, flush=True), level="INFO")
print("🎮 로봇 시각 시스템 초기화 (Gemma3-R1984-4B + Whisper)...")
##############################################################################
# 상수 정의
##############################################################################
MAX_CONTENT_CHARS = 2000
MAX_INPUT_LENGTH = 2096
MAX_NUM_IMAGES = 5
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
##############################################################################
# 전역 변수
##############################################################################
model = None
processor = None
whisper_model = None
model_loaded = False
whisper_loaded = False
model_name = "Gemma3-R1984-4B"
# 오디오 관련 전역 변수
audio_lock = Lock()
last_audio_data = None
last_transcription = ""
##############################################################################
# 메모리 관리
##############################################################################
def clear_cuda_cache():
"""CUDA 캐시를 명시적으로 비웁니다."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
##############################################################################
# 키워드 추출 함수
##############################################################################
def extract_keywords(text: str, top_k: int = 5) -> str:
"""키워드 추출"""
text = re.sub(r"[^a-zA-Z0-9가-힣\s]", "", text)
tokens = text.split()
seen = set()
unique_tokens = []
for token in tokens:
if token not in seen and len(token) > 1:
seen.add(token)
unique_tokens.append(token)
key_tokens = unique_tokens[:top_k]
return " ".join(key_tokens)
##############################################################################
# Whisper 모델 로드
##############################################################################
@spaces.GPU(duration=60)
def load_whisper():
global whisper_model, whisper_loaded
if whisper_loaded:
logger.info("Whisper 모델이 이미 로드되어 있습니다.")
return True
try:
logger.info("Whisper 모델 로딩 시작...")
# 파이프라인 방식으로 로드
device = 0 if torch.cuda.is_available() else "cpu"
whisper_model = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-base",
chunk_length_s=30,
device=device,
)
whisper_loaded = True
logger.info("✅ Whisper 모델 로딩 완료!")
return True
except Exception as e:
logger.error(f"Whisper 모델 로딩 실패: {e}")
return False
##############################################################################
# 오디오 처리 함수 (간소화)
##############################################################################
def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int = 16000) -> np.ndarray:
"""오디오 리샘플링"""
if orig_sr == target_sr:
return audio.astype(np.float32)
# scipy를 사용한 리샘플링
number_of_samples = round(len(audio) * float(target_sr) / orig_sr)
audio_resampled = sps.resample(audio, number_of_samples)
return audio_resampled.astype(np.float32)
@spaces.GPU(duration=30)
def transcribe_audio_whisper(audio_array: np.ndarray, sr: int = 16000):
"""Whisper를 사용한 오디오 전사"""
global whisper_model, whisper_loaded
if not whisper_loaded:
if not load_whisper():
return None
try:
# 오디오가 너무 조용한지 체크
if np.max(np.abs(audio_array)) < 0.01:
logger.warning("오디오가 너무 조용함")
return None
# 음성 인식
result = whisper_model({"array": audio_array, "sampling_rate": sr})
transcription = result["text"].strip()
logger.info(f"Whisper 전사 성공: {transcription[:50]}...")
return transcription if transcription else None
except Exception as e:
logger.error(f"Whisper 오디오 전사 오류: {e}")
import traceback
logger.error(traceback.format_exc())
return None
def process_audio_recording(audio_data):
"""녹음된 오디오 처리"""
global last_audio_data, last_transcription, audio_lock
if audio_data is None:
return None
try:
# 오디오 데이터 추출
if isinstance(audio_data, tuple) and len(audio_data) == 2:
sr, audio = audio_data
else:
logger.warning(f"예상치 못한 오디오 형식: {type(audio_data)}")
return None
if audio is None or len(audio) == 0:
return None
# numpy 배열로 변환
if not isinstance(audio, np.ndarray):
audio = np.array(audio)
# 스테레오를 모노로 변환
if audio.ndim > 1:
audio = audio.mean(axis=1)
# 16kHz로 리샘플링
if sr != 16000:
audio = resample_audio(audio, sr, 16000)
# 저장
with audio_lock:
last_audio_data = (audio, 16000)
logger.info(f"오디오 저장 완료: {len(audio)/16000:.1f}초")
# 전사 시도
transcription = transcribe_audio_whisper(audio, 16000)
if transcription:
with audio_lock:
last_transcription = transcription
return transcription
except Exception as e:
logger.error(f"오디오 처리 오류: {e}")
import traceback
logger.error(traceback.format_exc())
return None
##############################################################################
# 웹 검색 함수
##############################################################################
def do_web_search(query: str) -> str:
"""SerpHouse API를 사용한 웹 검색"""
try:
url = "https://api.serphouse.com/serp/live"
params = {
"q": query,
"domain": "google.com",
"serp_type": "web",
"device": "desktop",
"lang": "ko", # 한국어 우선
"num": "10" # 10개로 제한
}
headers = {
"Authorization": f"Bearer {SERPHOUSE_API_KEY}"
}
logger.info(f"웹 검색 중... 검색어: {query}")
response = requests.get(url, headers=headers, params=params, timeout=60)
response.raise_for_status()
data = response.json()
results = data.get("results", {})
organic = results.get("organic", []) if isinstance(results, dict) else []
if not organic:
return "검색 결과를 찾을 수 없습니다."
max_results = min(10, len(organic))
limited_organic = organic[:max_results]
summary_lines = []
for idx, item in enumerate(limited_organic, start=1):
title = item.get("title", "제목 없음")
link = item.get("link", "#")
snippet = item.get("snippet", "설명 없음")
displayed_link = item.get("displayed_link", link)
summary_lines.append(
f"### 결과 {idx}: {title}\n\n"
f"{snippet}\n\n"
f"**출처**: [{displayed_link}]({link})\n\n"
f"---\n"
)
instructions = """# 웹 검색 결과
아래는 검색 결과입니다. 답변 시 이 정보를 활용하세요:
1. 각 결과의 제목, 내용, 출처 링크를 참조하세요
2. 관련 출처를 명시적으로 인용하세요
3. 여러 출처의 정보를 종합하여 답변하세요
"""
search_results = instructions + "\n".join(summary_lines)
return search_results
except Exception as e:
logger.error(f"웹 검색 실패: {e}")
return f"웹 검색 실패: {str(e)}"
##############################################################################
# 문서 처리 함수
##############################################################################
def analyze_csv_file(path: str) -> str:
"""CSV 파일 분석"""
try:
df = pd.read_csv(path)
if df.shape[0] > 50 or df.shape[1] > 10:
df = df.iloc[:50, :10]
df_str = df.to_string()
if len(df_str) > MAX_CONTENT_CHARS:
df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(중략)..."
return f"**[CSV 파일: {os.path.basename(path)}]**\n\n{df_str}"
except Exception as e:
return f"CSV 읽기 실패 ({os.path.basename(path)}): {str(e)}"
def analyze_txt_file(path: str) -> str:
"""TXT 파일 분석"""
try:
with open(path, "r", encoding="utf-8") as f:
text = f.read()
if len(text) > MAX_CONTENT_CHARS:
text = text[:MAX_CONTENT_CHARS] + "\n...(중략)..."
return f"**[TXT 파일: {os.path.basename(path)}]**\n\n{text}"
except Exception as e:
return f"TXT 읽기 실패 ({os.path.basename(path)}): {str(e)}"
def pdf_to_markdown(pdf_path: str) -> str:
"""PDF를 마크다운으로 변환"""
text_chunks = []
try:
with open(pdf_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
max_pages = min(5, len(reader.pages))
for page_num in range(max_pages):
page = reader.pages[page_num]
page_text = page.extract_text() or ""
page_text = page_text.strip()
if page_text:
if len(page_text) > MAX_CONTENT_CHARS // max_pages:
page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(중략)"
text_chunks.append(f"## 페이지 {page_num+1}\n\n{page_text}\n")
if len(reader.pages) > max_pages:
text_chunks.append(f"\n...({max_pages}/{len(reader.pages)} 페이지 표시)...")
except Exception as e:
return f"PDF 읽기 실패 ({os.path.basename(pdf_path)}): {str(e)}"
full_text = "\n".join(text_chunks)
if len(full_text) > MAX_CONTENT_CHARS:
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(중략)..."
return f"**[PDF 파일: {os.path.basename(pdf_path)}]**\n\n{full_text}"
##############################################################################
# 모델 로드
##############################################################################
@spaces.GPU(duration=120)
def load_model():
global model, processor, model_loaded
if model_loaded:
logger.info("모델이 이미 로드되어 있습니다.")
return True
try:
logger.info("Gemma3-R1984-4B 모델 로딩 시작...")
clear_cuda_cache()
model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager"
)
model_loaded = True
logger.info(f"✅ {model_name} 로딩 완료!")
return True
except Exception as e:
logger.error(f"모델 로딩 실패: {e}")
return False
##############################################################################
# 이미지 분석 (로봇 태스크 중심)
##############################################################################
@spaces.GPU(duration=60)
def analyze_image_for_robot(
image: Union[np.ndarray, Image.Image],
prompt: str,
task_type: str = "general",
use_web_search: bool = False,
enable_thinking: bool = False,
max_new_tokens: int = 300,
audio_transcript: Optional[str] = None
) -> str:
"""로봇 작업을 위한 이미지 분석 (오디오 정보 포함)"""
global model, processor
if not model_loaded:
if not load_model():
return "❌ 모델 로딩 실패"
try:
# numpy 배열을 PIL 이미지로 변환
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
# 태스크별 시스템 프롬프트 구성
system_prompts = {
"general": "당신은 로봇 시각 시스템입니다. 먼저 장면을 1-2줄로 설명하고, 핵심 내용을 간결하게 분석하세요.",
"planning": """당신은 로봇 작업 계획 AI입니다.
먼저 장면 이해를 1-2줄로 설명하고, 그 다음 작업 계획을 작성하세요.
형식:
[장면 이해] 현재 보이는 장면을 1-2줄로 설명
[작업 계획]
Step_1: xxx
Step_2: xxx
Step_n: xxx""",
"grounding": "당신은 객체 위치 시스템입니다. 먼저 보이는 객체들을 한 줄로 설명하고, 요청된 객체 위치를 [x1, y1, x2, y2]로 반환하세요.",
"affordance": "당신은 파지점 분석 AI입니다. 먼저 대상 객체를 한 줄로 설명하고, 파지 영역을 [x1, y1, x2, y2]로 반환하세요.",
"trajectory": "당신은 경로 계획 AI입니다. 먼저 환경을 한 줄로 설명하고, 경로를 [(x1,y1), (x2,y2), ...]로 제시하세요.",
"pointing": "당신은 지점 지정 시스템입니다. 먼저 참조점들을 한 줄로 설명하고, 위치를 [(x1,y1), (x2,y2), ...]로 반환하세요."
}
# 오디오 정보가 있으면 프롬프트 수정
if audio_transcript and task_type == "planning":
system_prompts["planning"] = """당신은 로봇 작업 계획 AI입니다.
먼저 장면 이해를 1-2줄로 설명하고, 주변 소리를 인식했다면 그것도 설명한 후, 작업 계획을 작성하세요.
형식:
[장면 이해] 현재 보이는 장면을 1-2줄로 설명
[주변 소리 인식] 들리는 소리나 음성을 1줄로 설명
[작업 계획]
Step_1: xxx
Step_2: xxx
Step_n: xxx"""
system_prompt = system_prompts.get(task_type, system_prompts["general"])
# Chain-of-Thought 추가 (선택적)
if enable_thinking:
system_prompt += "\n\n추론 과정을
⚡ 멀티모달 AI로 로봇 작업 분석!