Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
import os | |
import re | |
import tempfile | |
import gc | |
from collections.abc import Iterator | |
from threading import Thread, Lock | |
import json | |
import requests | |
import cv2 | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from loguru import logger | |
from PIL import Image | |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer, pipeline | |
import time | |
import warnings | |
from typing import Dict, List, Optional, Union | |
import librosa | |
import scipy.signal as sps | |
import queue | |
# CSV/TXT ๋ถ์ | |
import pandas as pd | |
# PDF ํ ์คํธ ์ถ์ถ | |
import PyPDF2 | |
warnings.filterwarnings('ignore') | |
# ๋ก๊น ์ค์ | |
logger.remove() | |
logger.add(lambda msg: print(msg, flush=True), level="INFO") | |
print("๐ฎ ๋ก๋ด ์๊ฐ ์์คํ ์ด๊ธฐํ (Gemma3-R1984-4B + Whisper)...") | |
############################################################################## | |
# ์์ ์ ์ | |
############################################################################## | |
MAX_CONTENT_CHARS = 2000 | |
MAX_INPUT_LENGTH = 2096 | |
MAX_NUM_IMAGES = 5 | |
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "") | |
############################################################################## | |
# ์ ์ญ ๋ณ์ | |
############################################################################## | |
model = None | |
processor = None | |
whisper_model = None | |
model_loaded = False | |
whisper_loaded = False | |
model_name = "Gemma3-R1984-4B" | |
# ์ค๋์ค ๊ด๋ จ ์ ์ญ ๋ณ์ | |
audio_lock = Lock() | |
last_audio_data = None | |
last_transcription = "" | |
############################################################################## | |
# ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ | |
############################################################################## | |
def clear_cuda_cache(): | |
"""CUDA ์บ์๋ฅผ ๋ช ์์ ์ผ๋ก ๋น์๋๋ค.""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
############################################################################## | |
# ํค์๋ ์ถ์ถ ํจ์ | |
############################################################################## | |
def extract_keywords(text: str, top_k: int = 5) -> str: | |
"""ํค์๋ ์ถ์ถ""" | |
text = re.sub(r"[^a-zA-Z0-9๊ฐ-ํฃ\s]", "", text) | |
tokens = text.split() | |
seen = set() | |
unique_tokens = [] | |
for token in tokens: | |
if token not in seen and len(token) > 1: | |
seen.add(token) | |
unique_tokens.append(token) | |
key_tokens = unique_tokens[:top_k] | |
return " ".join(key_tokens) | |
############################################################################## | |
# Whisper ๋ชจ๋ธ ๋ก๋ | |
############################################################################## | |
def load_whisper(): | |
global whisper_model, whisper_loaded | |
if whisper_loaded: | |
logger.info("Whisper ๋ชจ๋ธ์ด ์ด๋ฏธ ๋ก๋๋์ด ์์ต๋๋ค.") | |
return True | |
try: | |
logger.info("Whisper ๋ชจ๋ธ ๋ก๋ฉ ์์...") | |
# ํ์ดํ๋ผ์ธ ๋ฐฉ์์ผ๋ก ๋ก๋ | |
device = 0 if torch.cuda.is_available() else "cpu" | |
whisper_model = pipeline( | |
task="automatic-speech-recognition", | |
model="openai/whisper-base", | |
chunk_length_s=30, | |
device=device, | |
) | |
whisper_loaded = True | |
logger.info("โ Whisper ๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ!") | |
return True | |
except Exception as e: | |
logger.error(f"Whisper ๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}") | |
return False | |
############################################################################## | |
# ์ค๋์ค ์ฒ๋ฆฌ ํจ์ (๊ฐ์ํ) | |
############################################################################## | |
def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int = 16000) -> np.ndarray: | |
"""์ค๋์ค ๋ฆฌ์ํ๋ง""" | |
if orig_sr == target_sr: | |
return audio.astype(np.float32) | |
# scipy๋ฅผ ์ฌ์ฉํ ๋ฆฌ์ํ๋ง | |
number_of_samples = round(len(audio) * float(target_sr) / orig_sr) | |
audio_resampled = sps.resample(audio, number_of_samples) | |
return audio_resampled.astype(np.float32) | |
def transcribe_audio_whisper(audio_array: np.ndarray, sr: int = 16000): | |
"""Whisper๋ฅผ ์ฌ์ฉํ ์ค๋์ค ์ ์ฌ""" | |
global whisper_model, whisper_loaded | |
if not whisper_loaded: | |
if not load_whisper(): | |
return None | |
try: | |
# ์ค๋์ค๊ฐ ๋๋ฌด ์กฐ์ฉํ์ง ์ฒดํฌ | |
if np.max(np.abs(audio_array)) < 0.01: | |
logger.warning("์ค๋์ค๊ฐ ๋๋ฌด ์กฐ์ฉํจ") | |
return None | |
# ์์ฑ ์ธ์ | |
result = whisper_model({"array": audio_array, "sampling_rate": sr}) | |
transcription = result["text"].strip() | |
logger.info(f"Whisper ์ ์ฌ ์ฑ๊ณต: {transcription[:50]}...") | |
return transcription if transcription else None | |
except Exception as e: | |
logger.error(f"Whisper ์ค๋์ค ์ ์ฌ ์ค๋ฅ: {e}") | |
import traceback | |
logger.error(traceback.format_exc()) | |
return None | |
def process_audio_recording(audio_data): | |
"""๋ น์๋ ์ค๋์ค ์ฒ๋ฆฌ""" | |
global last_audio_data, last_transcription, audio_lock | |
if audio_data is None: | |
return None | |
try: | |
# ์ค๋์ค ๋ฐ์ดํฐ ์ถ์ถ | |
if isinstance(audio_data, tuple) and len(audio_data) == 2: | |
sr, audio = audio_data | |
else: | |
logger.warning(f"์์์น ๋ชปํ ์ค๋์ค ํ์: {type(audio_data)}") | |
return None | |
if audio is None or len(audio) == 0: | |
return None | |
# numpy ๋ฐฐ์ด๋ก ๋ณํ | |
if not isinstance(audio, np.ndarray): | |
audio = np.array(audio) | |
# ์คํ ๋ ์ค๋ฅผ ๋ชจ๋ ธ๋ก ๋ณํ | |
if audio.ndim > 1: | |
audio = audio.mean(axis=1) | |
# 16kHz๋ก ๋ฆฌ์ํ๋ง | |
if sr != 16000: | |
audio = resample_audio(audio, sr, 16000) | |
# ์ ์ฅ | |
with audio_lock: | |
last_audio_data = (audio, 16000) | |
logger.info(f"์ค๋์ค ์ ์ฅ ์๋ฃ: {len(audio)/16000:.1f}์ด") | |
# ์ ์ฌ ์๋ | |
transcription = transcribe_audio_whisper(audio, 16000) | |
if transcription: | |
with audio_lock: | |
last_transcription = transcription | |
return transcription | |
except Exception as e: | |
logger.error(f"์ค๋์ค ์ฒ๋ฆฌ ์ค๋ฅ: {e}") | |
import traceback | |
logger.error(traceback.format_exc()) | |
return None | |
############################################################################## | |
# ์น ๊ฒ์ ํจ์ | |
############################################################################## | |
def do_web_search(query: str) -> str: | |
"""SerpHouse API๋ฅผ ์ฌ์ฉํ ์น ๊ฒ์""" | |
try: | |
url = "https://api.serphouse.com/serp/live" | |
params = { | |
"q": query, | |
"domain": "google.com", | |
"serp_type": "web", | |
"device": "desktop", | |
"lang": "ko", # ํ๊ตญ์ด ์ฐ์ | |
"num": "10" # 10๊ฐ๋ก ์ ํ | |
} | |
headers = { | |
"Authorization": f"Bearer {SERPHOUSE_API_KEY}" | |
} | |
logger.info(f"์น ๊ฒ์ ์ค... ๊ฒ์์ด: {query}") | |
response = requests.get(url, headers=headers, params=params, timeout=60) | |
response.raise_for_status() | |
data = response.json() | |
results = data.get("results", {}) | |
organic = results.get("organic", []) if isinstance(results, dict) else [] | |
if not organic: | |
return "๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." | |
max_results = min(10, len(organic)) | |
limited_organic = organic[:max_results] | |
summary_lines = [] | |
for idx, item in enumerate(limited_organic, start=1): | |
title = item.get("title", "์ ๋ชฉ ์์") | |
link = item.get("link", "#") | |
snippet = item.get("snippet", "์ค๋ช ์์") | |
displayed_link = item.get("displayed_link", link) | |
summary_lines.append( | |
f"### ๊ฒฐ๊ณผ {idx}: {title}\n\n" | |
f"{snippet}\n\n" | |
f"**์ถ์ฒ**: [{displayed_link}]({link})\n\n" | |
f"---\n" | |
) | |
instructions = """# ์น ๊ฒ์ ๊ฒฐ๊ณผ | |
์๋๋ ๊ฒ์ ๊ฒฐ๊ณผ์ ๋๋ค. ๋ต๋ณ ์ ์ด ์ ๋ณด๋ฅผ ํ์ฉํ์ธ์: | |
1. ๊ฐ ๊ฒฐ๊ณผ์ ์ ๋ชฉ, ๋ด์ฉ, ์ถ์ฒ ๋งํฌ๋ฅผ ์ฐธ์กฐํ์ธ์ | |
2. ๊ด๋ จ ์ถ์ฒ๋ฅผ ๋ช ์์ ์ผ๋ก ์ธ์ฉํ์ธ์ | |
3. ์ฌ๋ฌ ์ถ์ฒ์ ์ ๋ณด๋ฅผ ์ข ํฉํ์ฌ ๋ต๋ณํ์ธ์ | |
""" | |
search_results = instructions + "\n".join(summary_lines) | |
return search_results | |
except Exception as e: | |
logger.error(f"์น ๊ฒ์ ์คํจ: {e}") | |
return f"์น ๊ฒ์ ์คํจ: {str(e)}" | |
############################################################################## | |
# ๋ฌธ์ ์ฒ๋ฆฌ ํจ์ | |
############################################################################## | |
def analyze_csv_file(path: str) -> str: | |
"""CSV ํ์ผ ๋ถ์""" | |
try: | |
df = pd.read_csv(path) | |
if df.shape[0] > 50 or df.shape[1] > 10: | |
df = df.iloc[:50, :10] | |
df_str = df.to_string() | |
if len(df_str) > MAX_CONTENT_CHARS: | |
df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(์ค๋ต)..." | |
return f"**[CSV ํ์ผ: {os.path.basename(path)}]**\n\n{df_str}" | |
except Exception as e: | |
return f"CSV ์ฝ๊ธฐ ์คํจ ({os.path.basename(path)}): {str(e)}" | |
def analyze_txt_file(path: str) -> str: | |
"""TXT ํ์ผ ๋ถ์""" | |
try: | |
with open(path, "r", encoding="utf-8") as f: | |
text = f.read() | |
if len(text) > MAX_CONTENT_CHARS: | |
text = text[:MAX_CONTENT_CHARS] + "\n...(์ค๋ต)..." | |
return f"**[TXT ํ์ผ: {os.path.basename(path)}]**\n\n{text}" | |
except Exception as e: | |
return f"TXT ์ฝ๊ธฐ ์คํจ ({os.path.basename(path)}): {str(e)}" | |
def pdf_to_markdown(pdf_path: str) -> str: | |
"""PDF๋ฅผ ๋งํฌ๋ค์ด์ผ๋ก ๋ณํ""" | |
text_chunks = [] | |
try: | |
with open(pdf_path, "rb") as f: | |
reader = PyPDF2.PdfReader(f) | |
max_pages = min(5, len(reader.pages)) | |
for page_num in range(max_pages): | |
page = reader.pages[page_num] | |
page_text = page.extract_text() or "" | |
page_text = page_text.strip() | |
if page_text: | |
if len(page_text) > MAX_CONTENT_CHARS // max_pages: | |
page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(์ค๋ต)" | |
text_chunks.append(f"## ํ์ด์ง {page_num+1}\n\n{page_text}\n") | |
if len(reader.pages) > max_pages: | |
text_chunks.append(f"\n...({max_pages}/{len(reader.pages)} ํ์ด์ง ํ์)...") | |
except Exception as e: | |
return f"PDF ์ฝ๊ธฐ ์คํจ ({os.path.basename(pdf_path)}): {str(e)}" | |
full_text = "\n".join(text_chunks) | |
if len(full_text) > MAX_CONTENT_CHARS: | |
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(์ค๋ต)..." | |
return f"**[PDF ํ์ผ: {os.path.basename(pdf_path)}]**\n\n{full_text}" | |
############################################################################## | |
# ๋ชจ๋ธ ๋ก๋ | |
############################################################################## | |
def load_model(): | |
global model, processor, model_loaded | |
if model_loaded: | |
logger.info("๋ชจ๋ธ์ด ์ด๋ฏธ ๋ก๋๋์ด ์์ต๋๋ค.") | |
return True | |
try: | |
logger.info("Gemma3-R1984-4B ๋ชจ๋ธ ๋ก๋ฉ ์์...") | |
clear_cuda_cache() | |
model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B") | |
processor = AutoProcessor.from_pretrained(model_id, padding_side="left") | |
model = Gemma3ForConditionalGeneration.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
attn_implementation="eager" | |
) | |
model_loaded = True | |
logger.info(f"โ {model_name} ๋ก๋ฉ ์๋ฃ!") | |
return True | |
except Exception as e: | |
logger.error(f"๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}") | |
return False | |
############################################################################## | |
# ์ด๋ฏธ์ง ๋ถ์ (๋ก๋ด ํ์คํฌ ์ค์ฌ) | |
############################################################################## | |
def analyze_image_for_robot( | |
image: Union[np.ndarray, Image.Image], | |
prompt: str, | |
task_type: str = "general", | |
use_web_search: bool = False, | |
enable_thinking: bool = False, | |
max_new_tokens: int = 300, | |
audio_transcript: Optional[str] = None | |
) -> str: | |
"""๋ก๋ด ์์ ์ ์ํ ์ด๋ฏธ์ง ๋ถ์ (์ค๋์ค ์ ๋ณด ํฌํจ)""" | |
global model, processor | |
if not model_loaded: | |
if not load_model(): | |
return "โ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ" | |
try: | |
# numpy ๋ฐฐ์ด์ PIL ์ด๋ฏธ์ง๋ก ๋ณํ | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image).convert('RGB') | |
# ํ์คํฌ๋ณ ์์คํ ํ๋กฌํํธ ๊ตฌ์ฑ | |
system_prompts = { | |
"general": "๋น์ ์ ๋ก๋ด ์๊ฐ ์์คํ ์ ๋๋ค. ๋จผ์ ์ฅ๋ฉด์ 1-2์ค๋ก ์ค๋ช ํ๊ณ , ํต์ฌ ๋ด์ฉ์ ๊ฐ๊ฒฐํ๊ฒ ๋ถ์ํ์ธ์.", | |
"planning": """๋น์ ์ ๋ก๋ด ์์ ๊ณํ AI์ ๋๋ค. | |
๋จผ์ ์ฅ๋ฉด ์ดํด๋ฅผ 1-2์ค๋ก ์ค๋ช ํ๊ณ , ๊ทธ ๋ค์ ์์ ๊ณํ์ ์์ฑํ์ธ์. | |
ํ์: | |
[์ฅ๋ฉด ์ดํด] ํ์ฌ ๋ณด์ด๋ ์ฅ๋ฉด์ 1-2์ค๋ก ์ค๋ช | |
[์์ ๊ณํ] | |
Step_1: xxx | |
Step_2: xxx | |
Step_n: xxx""", | |
"grounding": "๋น์ ์ ๊ฐ์ฒด ์์น ์์คํ ์ ๋๋ค. ๋จผ์ ๋ณด์ด๋ ๊ฐ์ฒด๋ค์ ํ ์ค๋ก ์ค๋ช ํ๊ณ , ์์ฒญ๋ ๊ฐ์ฒด ์์น๋ฅผ [x1, y1, x2, y2]๋ก ๋ฐํํ์ธ์.", | |
"affordance": "๋น์ ์ ํ์ง์ ๋ถ์ AI์ ๋๋ค. ๋จผ์ ๋์ ๊ฐ์ฒด๋ฅผ ํ ์ค๋ก ์ค๋ช ํ๊ณ , ํ์ง ์์ญ์ [x1, y1, x2, y2]๋ก ๋ฐํํ์ธ์.", | |
"trajectory": "๋น์ ์ ๊ฒฝ๋ก ๊ณํ AI์ ๋๋ค. ๋จผ์ ํ๊ฒฝ์ ํ ์ค๋ก ์ค๋ช ํ๊ณ , ๊ฒฝ๋ก๋ฅผ [(x1,y1), (x2,y2), ...]๋ก ์ ์ํ์ธ์.", | |
"pointing": "๋น์ ์ ์ง์ ์ง์ ์์คํ ์ ๋๋ค. ๋จผ์ ์ฐธ์กฐ์ ๋ค์ ํ ์ค๋ก ์ค๋ช ํ๊ณ , ์์น๋ฅผ [(x1,y1), (x2,y2), ...]๋ก ๋ฐํํ์ธ์." | |
} | |
# ์ค๋์ค ์ ๋ณด๊ฐ ์์ผ๋ฉด ํ๋กฌํํธ ์์ | |
if audio_transcript and task_type == "planning": | |
system_prompts["planning"] = """๋น์ ์ ๋ก๋ด ์์ ๊ณํ AI์ ๋๋ค. | |
๋จผ์ ์ฅ๋ฉด ์ดํด๋ฅผ 1-2์ค๋ก ์ค๋ช ํ๊ณ , ์ฃผ๋ณ ์๋ฆฌ๋ฅผ ์ธ์ํ๋ค๋ฉด ๊ทธ๊ฒ๋ ์ค๋ช ํ ํ, ์์ ๊ณํ์ ์์ฑํ์ธ์. | |
ํ์: | |
[์ฅ๋ฉด ์ดํด] ํ์ฌ ๋ณด์ด๋ ์ฅ๋ฉด์ 1-2์ค๋ก ์ค๋ช | |
[์ฃผ๋ณ ์๋ฆฌ ์ธ์] ๋ค๋ฆฌ๋ ์๋ฆฌ๋ ์์ฑ์ 1์ค๋ก ์ค๋ช | |
[์์ ๊ณํ] | |
Step_1: xxx | |
Step_2: xxx | |
Step_n: xxx""" | |
system_prompt = system_prompts.get(task_type, system_prompts["general"]) | |
# Chain-of-Thought ์ถ๊ฐ (์ ํ์ ) | |
if enable_thinking: | |
system_prompt += "\n\n์ถ๋ก ๊ณผ์ ์ <thinking></thinking> ํ๊ทธ ์์ ์์ฑ ํ ์ต์ข ๋ต๋ณ์ ์ ์ํ์ธ์. ์ฅ๋ฉด ์ดํด๋ ์ถ๋ก ๊ณผ์ ๊ณผ ๋ณ๋๋ก ๋ฐ๋์ ํฌํจํ์ธ์." | |
# ์น ๊ฒ์ ์ํ | |
combined_system = system_prompt | |
if use_web_search: | |
keywords = extract_keywords(prompt, top_k=5) | |
if keywords: | |
logger.info(f"์น ๊ฒ์ ํค์๋: {keywords}") | |
search_results = do_web_search(keywords) | |
combined_system = f"{search_results}\n\n{system_prompt}" | |
# ์ฌ์ฉ์ ํ๋กฌํํธ์ ์ค๋์ค ์ ๋ณด ์ถ๊ฐ | |
user_prompt = prompt | |
if audio_transcript: | |
user_prompt += f"\n\n[์ธ์๋ ์ฃผ๋ณ ์๋ฆฌ: {audio_transcript}]" | |
# ๋ฉ์์ง ๊ตฌ์ฑ | |
messages = [ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": combined_system}] | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "url": image}, | |
{"type": "text", "text": user_prompt} | |
] | |
} | |
] | |
# ์ ๋ ฅ ์ฒ๋ฆฌ | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(device=model.device, dtype=torch.bfloat16) | |
# ์ ๋ ฅ ํ ํฐ ์ ์ ํ | |
if inputs.input_ids.shape[1] > MAX_INPUT_LENGTH: | |
inputs.input_ids = inputs.input_ids[:, -MAX_INPUT_LENGTH:] | |
if 'attention_mask' in inputs: | |
inputs.attention_mask = inputs.attention_mask[:, -MAX_INPUT_LENGTH:] | |
# ์์ฑ | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
eos_token_id=processor.tokenizer.eos_token_id, | |
) | |
# ์ ๋ ฅ ํ ํฐ ์ ๊ฑฐํ์ฌ ์ถ๋ ฅ๋ง ์ถ์ถ | |
generated_tokens = outputs[0][inputs.input_ids.shape[1]:] | |
# ๋์ฝ๋ฉ | |
response = processor.decode(generated_tokens, skip_special_tokens=True).strip() | |
# ํ๋กฌํํธ ์ ๊ฑฐ ๋ฐ ์ ๋ฆฌ | |
# ์ด๋ฏธ ์ ๋ ฅ ํ ํฐ์ ์ ๊ฑฐํ์ผ๋ฏ๋ก ์ถ๊ฐ ์ ๋ฆฌ๋ง ์ํ | |
response = response.strip() | |
# ํน์ ๋จ์์๋ ๋ถํ์ํ ํ ์คํธ ์ ๊ฑฐ | |
if response.startswith("model\n"): | |
response = response[6:].strip() | |
elif response.startswith("model"): | |
response = response[5:].strip() | |
return response | |
except Exception as e: | |
logger.error(f"์ด๋ฏธ์ง ๋ถ์ ์ค๋ฅ: {e}") | |
import traceback | |
return f"โ ๋ถ์ ์ค๋ฅ: {str(e)}\n{traceback.format_exc()}" | |
finally: | |
clear_cuda_cache() | |
############################################################################## | |
# ๋ฌธ์ ๋ถ์ (์คํธ๋ฆฌ๋ฐ) | |
############################################################################## | |
def _model_gen_with_oom_catch(**kwargs): | |
"""OOM ์ฒ๋ฆฌ๋ฅผ ์ํ ์์ฑ ํจ์""" | |
global model | |
try: | |
model.generate(**kwargs) | |
except torch.cuda.OutOfMemoryError: | |
raise RuntimeError("GPU ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ. Max Tokens๋ฅผ ์ค์ฌ์ฃผ์ธ์.") | |
finally: | |
clear_cuda_cache() | |
def analyze_documents_streaming( | |
files: List[str], | |
prompt: str, | |
use_web_search: bool = False, | |
max_new_tokens: int = 2048 | |
) -> Iterator[str]: | |
"""๋ฌธ์ ๋ถ์ (์คํธ๋ฆฌ๋ฐ)""" | |
global model, processor | |
if not model_loaded: | |
if not load_model(): | |
yield "โ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ" | |
return | |
try: | |
# ์์คํ ํ๋กฌํํธ | |
system_content = "๋น์ ์ ๋ฌธ์๋ฅผ ๋ถ์ํ๊ณ ์์ฝํ๋ ์ ๋ฌธ AI์ ๋๋ค." | |
# ์น ๊ฒ์ | |
if use_web_search: | |
keywords = extract_keywords(prompt, top_k=5) | |
if keywords: | |
search_results = do_web_search(keywords) | |
system_content = f"{search_results}\n\n{system_content}" | |
# ๋ฌธ์ ๋ด์ฉ ์ฒ๋ฆฌ | |
doc_contents = [] | |
for file_path in files: | |
if file_path.lower().endswith('.csv'): | |
content = analyze_csv_file(file_path) | |
elif file_path.lower().endswith('.txt'): | |
content = analyze_txt_file(file_path) | |
elif file_path.lower().endswith('.pdf'): | |
content = pdf_to_markdown(file_path) | |
else: | |
continue | |
doc_contents.append(content) | |
# ๋ฉ์์ง ๊ตฌ์ฑ | |
messages = [ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": system_content}] | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "\n\n".join(doc_contents) + f"\n\n{prompt}"} | |
] | |
} | |
] | |
# ์ ๋ ฅ ์ฒ๋ฆฌ | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(device=model.device, dtype=torch.bfloat16) | |
# ์คํธ๋ฆฌ๋ฐ ์ค์ | |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True) | |
gen_kwargs = dict( | |
inputs, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
temperature=0.8, | |
top_p=0.9, | |
) | |
# ๋ณ๋ ์ค๋ ๋์์ ์์ฑ | |
t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs) | |
t.start() | |
# ์คํธ๋ฆฌ๋ฐ ์ถ๋ ฅ | |
output = "" | |
for new_text in streamer: | |
output += new_text | |
yield output | |
except Exception as e: | |
logger.error(f"๋ฌธ์ ๋ถ์ ์ค๋ฅ: {e}") | |
yield f"โ ์ค๋ฅ ๋ฐ์: {str(e)}" | |
finally: | |
clear_cuda_cache() | |
############################################################################## | |
# Gradio UI (๋ก๋ด ์๊ฐํ ์ค์ฌ) | |
############################################################################## | |
css = """ | |
.robot-header { | |
text-align: center; | |
background: linear-gradient(135deg, #1e3c72 0%, #2a5298 50%, #667eea 100%); | |
color: white; | |
padding: 20px; | |
border-radius: 10px; | |
margin-bottom: 20px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
.status-box { | |
text-align: center; | |
padding: 10px; | |
border-radius: 5px; | |
margin: 10px 0; | |
font-weight: bold; | |
} | |
.info-box { | |
background: #f0f0f0; | |
padding: 15px; | |
border-radius: 8px; | |
margin: 10px 0; | |
border-left: 4px solid #2a5298; | |
} | |
.task-button { | |
min-height: 60px; | |
font-size: 1.1em; | |
} | |
.webcam-container { | |
border: 3px solid #2a5298; | |
border-radius: 10px; | |
padding: 10px; | |
background: #f8f9fa; | |
} | |
.auto-capture-status { | |
text-align: center; | |
padding: 5px; | |
border-radius: 5px; | |
margin: 5px 0; | |
font-weight: bold; | |
background: #e8f5e9; | |
color: #2e7d32; | |
} | |
.audio-status { | |
text-align: center; | |
padding: 5px; | |
border-radius: 5px; | |
margin: 5px 0; | |
font-weight: bold; | |
background: #e3f2fd; | |
color: #1565c0; | |
} | |
""" | |
with gr.Blocks(title="๐ค ๋ก๋ด ์๊ฐ ์์คํ (Gemma3-4B)", css=css) as demo: | |
gr.HTML(""" | |
<div class="robot-header"> | |
<h1>๐ค ๋ก๋ด ์๊ฐ ์์คํ </h1> | |
<h3>๐ฎ Gemma3-R1984-4B + ๐ท ์ค์๊ฐ ์น์บ + ๐ค ์์ฑ ์ธ์</h3> | |
<p>โก ๋ฉํฐ๋ชจ๋ฌ AI๋ก ๋ก๋ด ์์ ๋ถ์!</p> | |
</div> | |
""") | |
with gr.Row(): | |
# ์ผ์ชฝ: ์น์บ ๋ฐ ์ ๋ ฅ | |
with gr.Column(scale=1): | |
gr.Markdown("### ๐ท ์ค์๊ฐ ์น์บ ") | |
with gr.Group(elem_classes="webcam-container"): | |
webcam = gr.Image( | |
sources=["webcam"], | |
streaming=True, | |
type="numpy", | |
label="์ค์๊ฐ ์คํธ๋ฆฌ๋ฐ", | |
height=300 | |
) | |
# ์๋ ์บก์ฒ ์ํ ํ์ | |
auto_capture_status = gr.HTML( | |
'<div class="auto-capture-status">๐ ์๋ ์บก์ฒ: ๋๊ธฐ ์ค</div>' | |
) | |
# ์บก์ฒ๋ ์ด๋ฏธ์ง ํ์ | |
captured_image = gr.Image( | |
label="์บก์ฒ๋ ์ด๋ฏธ์ง", | |
height=180, | |
visible=False | |
) | |
# ์ค๋์ค ์ปจํธ๋กค | |
gr.Markdown("### ๐ค ์์ฑ ์ธ์") | |
with gr.Group(): | |
# ์ค๋์ค ์ํ ํ์ | |
audio_status = gr.HTML( | |
'<div class="audio-status">๐ค ์์ฑ ์ธ์: ๋นํ์ฑํ</div>' | |
) | |
# ๋ น์ ์ธํฐํ์ด์ค (์จ๊น ์ํ๋ก ์์) | |
audio_recorder = gr.Audio( | |
sources=["microphone"], | |
type="numpy", | |
label="๐ค 10์ด ๋ น์", | |
visible=False | |
) | |
# ๋ง์ง๋ง ์ธ์๋ ํ ์คํธ | |
last_transcript = gr.Textbox( | |
label="์ธ์๋ ์์ฑ", | |
value="", | |
lines=2, | |
interactive=False | |
) | |
# ๋ก๋ด ์์ ๋ฒํผ๋ค | |
gr.Markdown("### ๐ฏ ๋ก๋ด ์์ ") | |
with gr.Row(): | |
capture_btn = gr.Button("๐ธ ์๋ ์บก์ฒ", variant="primary", elem_classes="task-button") | |
clear_capture_btn = gr.Button("๐๏ธ ์ด๊ธฐํ", elem_classes="task-button") | |
with gr.Column(): | |
auto_capture_toggle = gr.Checkbox( | |
label="๐ ์๋ ์บก์ฒ (10์ด๋ง๋ค)", | |
value=False | |
) | |
use_audio_toggle = gr.Checkbox( | |
label="๐ค ์์ฑ ์ธ์ ์ฌ์ฉ", | |
value=False, | |
info="10์ด๋ง๋ค ์์ฑ์ ์ธ์ํ์ฌ ๋ถ์์ ํฌํจ" | |
) | |
with gr.Row(): | |
planning_btn = gr.Button("๐ ์์ ๊ณํ", elem_classes="task-button") | |
grounding_btn = gr.Button("๐ ๊ฐ์ฒด ์์น", elem_classes="task-button") | |
# ์ค๋ฅธ์ชฝ: ๋ถ์ ์ค์ ๋ฐ ๊ฒฐ๊ณผ | |
with gr.Column(scale=2): | |
gr.Markdown("### โ๏ธ ๋ถ์ ์ค์ ") | |
with gr.Row(): | |
with gr.Column(): | |
task_prompt = gr.Textbox( | |
label="์์ ์ค๋ช ", | |
placeholder="์: ํ ์ด๋ธ ์์ ์ปต์ ์ก์์ ์ฑํฌ๋์ ๋๊ธฐ", | |
value="ํ์ฌ ์ฅ๋ฉด์ ๋ถ์ํ๊ณ ๋ก๋ด์ด ์ํํ ์ ์๋ ์์ ์ ์ ์ํ์ธ์.", | |
lines=2 | |
) | |
with gr.Row(): | |
use_web_search = gr.Checkbox( | |
label="๐ ์น ๊ฒ์", | |
value=False | |
) | |
enable_thinking = gr.Checkbox( | |
label="๐ค ์ถ๋ก ๊ณผ์ ", | |
value=False | |
) | |
max_tokens = gr.Slider( | |
label="์ต๋ ํ ํฐ", | |
minimum=100, | |
maximum=1000, | |
value=300, | |
step=50 | |
) | |
gr.Markdown("### ๐ ๋ถ์ ๊ฒฐ๊ณผ") | |
result_output = gr.Textbox( | |
label="AI ๋ถ์ ๊ฒฐ๊ณผ", | |
lines=18, | |
max_lines=35, | |
show_copy_button=True, | |
elem_id="result" | |
) | |
status_display = gr.HTML( | |
'<div class="status-box" style="background:#d4edda; color:#155724;">๐ฎ ์์คํ ์ค๋น</div>' | |
) | |
# ๋ฌธ์ ๋ถ์ ํญ (์จ๊น) | |
with gr.Tab("๐ ๋ฌธ์ ๋ถ์", visible=False): | |
with gr.Row(): | |
with gr.Column(): | |
doc_files = gr.File( | |
label="๋ฌธ์ ์ ๋ก๋", | |
file_count="multiple", | |
file_types=[".pdf", ".csv", ".txt"], | |
type="filepath" | |
) | |
doc_prompt = gr.Textbox( | |
label="๋ถ์ ์์ฒญ", | |
placeholder="์: ์ด ๋ฌธ์๋ค์ ํต์ฌ ๋ด์ฉ์ ์์ฝํ๊ณ ๋น๊ต ๋ถ์ํ์ธ์.", | |
lines=3 | |
) | |
doc_web_search = gr.Checkbox( | |
label="๐ ์น ๊ฒ์ ์ฌ์ฉ", | |
value=False | |
) | |
analyze_docs_btn = gr.Button("๐ ๋ฌธ์ ๋ถ์", variant="primary") | |
with gr.Column(): | |
doc_result = gr.Textbox( | |
label="๋ถ์ ๊ฒฐ๊ณผ", | |
lines=25, | |
max_lines=50 | |
) | |
# ์ด๋ฒคํธ ํธ๋ค๋ฌ | |
webcam_state = gr.State(None) | |
def capture_webcam(frame): | |
"""์น์บ ํ๋ ์ ์บก์ฒ""" | |
if frame is None: | |
return None, None, '<div class="status-box" style="background:#f8d7da; color:#721c24;">โ ์น์บ ํ๋ ์ ์์</div>' | |
return frame, gr.update(value=frame, visible=True), '<div class="status-box" style="background:#d4edda; color:#155724;">โ ์ด๋ฏธ์ง ์บก์ฒ ์๋ฃ</div>' | |
def clear_capture(): | |
"""์บก์ฒ ์ด๊ธฐํ""" | |
global last_transcription, last_audio_data, audio_lock | |
with audio_lock: | |
last_transcription = "" | |
last_audio_data = None | |
return None, gr.update(visible=False), '<div class="status-box" style="background:#d4edda; color:#155724;">๐ฎ ์์คํ ์ค๋น</div>', "" | |
def analyze_with_task(image, prompt, task_type, use_search, thinking, tokens): | |
"""ํน์ ํ์คํฌ๋ก ์ด๋ฏธ์ง ๋ถ์""" | |
global last_transcription, audio_lock | |
if image is None: | |
return "โ ๋จผ์ ์ด๋ฏธ์ง๋ฅผ ์บก์ฒํ์ธ์.", '<div class="status-box" style="background:#f8d7da; color:#721c24;">โ ์ด๋ฏธ์ง ์์</div>' | |
status = f'<div class="status-box" style="background:#cce5ff; color:#004085;">๐ {task_type} ๋ถ์ ์ค...</div>' | |
# ํ์ฌ ์ ์ฌ ํ ์คํธ ๊ฐ์ ธ์ค๊ธฐ | |
transcript = "" | |
with audio_lock: | |
transcript = last_transcription | |
result = analyze_image_for_robot( | |
image=image, | |
prompt=prompt, | |
task_type=task_type, | |
use_web_search=use_search, | |
enable_thinking=thinking, | |
max_new_tokens=tokens, | |
audio_transcript=transcript if transcript else None | |
) | |
# ๊ฒฐ๊ณผ ํฌ๋งทํ | |
timestamp = time.strftime("%H:%M:%S") | |
task_names = { | |
"planning": "์์ ๊ณํ", | |
"grounding": "๊ฐ์ฒด ์์น", | |
"affordance": "ํ์ง์ ", | |
"trajectory": "๊ฒฝ๋ก ๊ณํ" | |
} | |
formatted_result = f"""๐ค {task_names.get(task_type, '๋ถ์')} ๊ฒฐ๊ณผ ({timestamp}) | |
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
{result} | |
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ""" | |
complete_status = '<div class="status-box" style="background:#d4edda; color:#155724;">โ ๋ถ์ ์๋ฃ!</div>' | |
return formatted_result, complete_status | |
# ์๋ ์บก์ฒ ๋ฐ ๋ถ์ ํจ์ | |
def auto_capture_and_analyze(webcam_frame, task_prompt, use_search, thinking, tokens, use_audio, audio_data): | |
"""์๋ ์บก์ฒ ๋ฐ ๋ถ์""" | |
global last_transcription, audio_lock | |
if webcam_frame is None: | |
return ( | |
None, | |
"์๋ ์บก์ฒ ๋๊ธฐ ์ค...", | |
'<div class="status-box" style="background:#fff3cd; color:#856404;">โณ ์น์บ ๋๊ธฐ ์ค</div>', | |
'<div class="auto-capture-status">๐ ์๋ ์บก์ฒ: ์น์บ ๋๊ธฐ ์ค</div>', | |
"๋๊ธฐ ์ค...", | |
None # ์ค๋์ค ๋ฆฌ์ | |
) | |
# ์บก์ฒ ์ํ | |
timestamp = time.strftime("%H:%M:%S") | |
# ์ค๋์ค ์ฒ๋ฆฌ (์์ผ๋ฉด) | |
if use_audio and audio_data is not None: | |
logger.info(f"[{timestamp}] ์ค๋์ค ์ฒ๋ฆฌ ์์") | |
transcription = process_audio_recording(audio_data) | |
if transcription: | |
logger.info(f"์๋ก์ด ์ ์ฌ: {transcription[:50]}...") | |
# ๋ง์ง๋ง ์ ์ฌ ๊ฒฐ๊ณผ ๊ฐ์ ธ์ค๊ธฐ | |
audio_transcript = "" | |
if use_audio: | |
with audio_lock: | |
audio_transcript = last_transcription | |
if audio_transcript: | |
logger.info(f"๋ถ์์ ์ฌ์ฉํ ์์ฑ: {audio_transcript[:50]}...") | |
# ์ด๋ฏธ์ง ๋ถ์ (์์ ๊ณํ ๋ชจ๋๋ก) | |
result = analyze_image_for_robot( | |
image=webcam_frame, | |
prompt=task_prompt, | |
task_type="planning", | |
use_web_search=use_search, | |
enable_thinking=thinking, | |
max_new_tokens=tokens, | |
audio_transcript=audio_transcript if audio_transcript else None | |
) | |
formatted_result = f"""๐ ์๋ ๋ถ์ ์๋ฃ ({timestamp}) | |
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
{result} | |
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ""" | |
# ๋ง์ง๋ง ์ธ์๋ ํ ์คํธ ์ ๋ฐ์ดํธ | |
transcript_display = audio_transcript if audio_transcript else "์์ฑ ์ธ์ ๋๊ธฐ ์ค..." | |
return ( | |
webcam_frame, | |
formatted_result, | |
'<div class="status-box" style="background:#d4edda; color:#155724;">โ ์๋ ๋ถ์ ์๋ฃ</div>', | |
f'<div class="auto-capture-status">๐ ์๋ ์บก์ฒ: ๋ง์ง๋ง ๋ถ์ {timestamp}</div>', | |
transcript_display, | |
None # ์ค๋์ค ๋ฆฌ์ (๋ค์ ๋ น์ ์ค๋น) | |
) | |
# ์น์บ ์คํธ๋ฆฌ๋ฐ | |
webcam.stream( | |
fn=lambda x: x, | |
inputs=[webcam], | |
outputs=[webcam_state] | |
) | |
# ์๋ ์บก์ฒ ๋ฒํผ | |
capture_btn.click( | |
fn=capture_webcam, | |
inputs=[webcam_state], | |
outputs=[webcam_state, captured_image, status_display] | |
) | |
# ์ด๊ธฐํ ๋ฒํผ | |
clear_capture_btn.click( | |
fn=clear_capture, | |
outputs=[webcam_state, captured_image, status_display, last_transcript] | |
) | |
# ์์ ๋ฒํผ๋ค | |
planning_btn.click( | |
fn=lambda img, p, s, t, tk: analyze_with_task(img, p, "planning", s, t, tk), | |
inputs=[captured_image, task_prompt, use_web_search, enable_thinking, max_tokens], | |
outputs=[result_output, status_display] | |
) | |
grounding_btn.click( | |
fn=lambda img, p, s, t, tk: analyze_with_task(img, p, "grounding", s, t, tk), | |
inputs=[captured_image, task_prompt, use_web_search, enable_thinking, max_tokens], | |
outputs=[result_output, status_display] | |
) | |
# ๋ฌธ์ ๋ถ์ | |
def analyze_docs(files, prompt, use_search): | |
if not files: | |
return "โ ๋ฌธ์๋ฅผ ์ ๋ก๋ํ์ธ์." | |
output = "" | |
for chunk in analyze_documents_streaming(files, prompt, use_search): | |
output = chunk | |
return output | |
analyze_docs_btn.click( | |
fn=analyze_docs, | |
inputs=[doc_files, doc_prompt, doc_web_search], | |
outputs=[doc_result] | |
) | |
# ์๋ ์บก์ฒ ํ์ด๋จธ (10์ด๋ง๋ค) | |
timer = gr.Timer(10.0, active=False) | |
# ์๋ ์บก์ฒ ํ ๊ธ ์ด๋ฒคํธ | |
def toggle_auto_capture(enabled): | |
if enabled: | |
return gr.Timer(10.0, active=True), '<div class="auto-capture-status">๐ ์๋ ์บก์ฒ: ํ์ฑํ๋จ (10์ด๋ง๋ค)</div>' | |
else: | |
return gr.Timer(active=False), '<div class="auto-capture-status">๐ ์๋ ์บก์ฒ: ๋นํ์ฑํ๋จ</div>' | |
auto_capture_toggle.change( | |
fn=toggle_auto_capture, | |
inputs=[auto_capture_toggle], | |
outputs=[timer, auto_capture_status] | |
) | |
# ์ค๋์ค ํ ๊ธ ์ด๋ฒคํธ | |
def toggle_audio(enabled): | |
global last_transcription, last_audio_data, audio_lock | |
if enabled: | |
# Whisper ๋ชจ๋ธ ๋ก๋ | |
load_whisper() | |
# ์ด๊ธฐํ | |
with audio_lock: | |
last_transcription = "" | |
last_audio_data = None | |
logger.info("์ค๋์ค ์ธ์ ํ์ฑํ๋จ") | |
return ( | |
gr.update(visible=True), # audio_recorder ํ์ | |
'<div class="audio-status">๐ค ์์ฑ ์ธ์: ํ์ฑํ๋จ</div>' | |
) | |
else: | |
# ์ด๊ธฐํ | |
with audio_lock: | |
last_transcription = "" | |
last_audio_data = None | |
logger.info("์ค๋์ค ์ธ์ ๋นํ์ฑํ๋จ") | |
return ( | |
gr.update(visible=False), # audio_recorder ์จ๊น | |
'<div class="audio-status">๐ค ์์ฑ ์ธ์: ๋นํ์ฑํ</div>' | |
) | |
use_audio_toggle.change( | |
fn=toggle_audio, | |
inputs=[use_audio_toggle], | |
outputs=[audio_recorder, audio_status] | |
) | |
# ์ค๋์ค ๋ น์ ์๋ฃ ์ ์ฒ๋ฆฌ | |
def on_audio_recorded(audio_data): | |
"""์ค๋์ค ๋ น์ ์๋ฃ ์ ์๋ ์ฒ๋ฆฌ""" | |
global last_transcription, audio_lock | |
if audio_data is not None: | |
logger.info("์ ์ค๋์ค ๋ น์ ๊ฐ์ง") | |
transcription = process_audio_recording(audio_data) | |
if transcription: | |
return transcription | |
with audio_lock: | |
return last_transcription if last_transcription else "์์ฑ ์ธ์ ๋๊ธฐ ์ค..." | |
audio_recorder.change( | |
fn=on_audio_recorded, | |
inputs=[audio_recorder], | |
outputs=[last_transcript] | |
) | |
# ํ์ด๋จธ ํฑ ์ด๋ฒคํธ | |
timer.tick( | |
fn=auto_capture_and_analyze, | |
inputs=[webcam_state, task_prompt, use_web_search, enable_thinking, max_tokens, use_audio_toggle, audio_recorder], | |
outputs=[captured_image, result_output, status_display, auto_capture_status, last_transcript, audio_recorder] | |
) | |
# ์ด๊ธฐ ๋ชจ๋ธ ๋ก๋ | |
def initial_load(): | |
load_model() | |
return "์์คํ ์ค๋น ์๋ฃ! ๐" | |
demo.load( | |
fn=initial_load, | |
outputs=None | |
) | |
if __name__ == "__main__": | |
print("๐ ๋ก๋ด ์๊ฐ ์์คํ ์์ (Gemma3-R1984-4B + Whisper)...") | |
demo.queue().launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True, | |
debug=False | |
) |