Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import tempfile | |
| import spaces | |
| import gradio as gr | |
| import subprocess | |
| import sys | |
| import cv2 | |
| import threading | |
| import queue | |
| import time | |
| from collections import deque | |
| from deep_translator import GoogleTranslator | |
| def install_flash_attn_wheel(): | |
| flash_attn_wheel_url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl" | |
| try: | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", flash_attn_wheel_url]) | |
| print("Wheel installed successfully!") | |
| except subprocess.CalledProcessError as e: | |
| print(f"Failed to install the flash attnetion wheel. Error: {e}") | |
| install_flash_attn_wheel() | |
| try: | |
| from mmengine.visualization import Visualizer | |
| except ImportError: | |
| Visualizer = None | |
| print("Warning: mmengine is not installed, visualization is disabled.") | |
| # Load the model and tokenizer | |
| model_path = "ByteDance/Sa2VA-4B" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype="auto", | |
| device_map="cuda:0", | |
| trust_remote_code=True, | |
| ).eval().cuda() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code = True, | |
| ) | |
| class WebcamProcessor: | |
| def __init__(self, model, tokenizer, fps_target=15, buffer_size=5): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.fps_target = fps_target | |
| self.frame_interval = 1.0 / fps_target | |
| self.buffer_size = buffer_size | |
| self.frame_buffer = deque(maxlen=buffer_size) | |
| self.result_queue = queue.Queue() | |
| self.is_running = False | |
| self.last_process_time = 0 | |
| def start(self): | |
| try: | |
| self.is_running = True | |
| self.capture = cv2.VideoCapture(0) | |
| if not self.capture.isOpened(): | |
| raise Exception("Failed to open webcam") | |
| # Set camera properties | |
| self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, 640) | |
| self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) | |
| self.capture_thread = threading.Thread(target=self._capture_loop) | |
| self.process_thread = threading.Thread(target=self._process_loop) | |
| self.capture_thread.daemon = True | |
| self.process_thread.daemon = True | |
| self.capture_thread.start() | |
| self.process_thread.start() | |
| return "Webcam started successfully" | |
| except Exception as e: | |
| self.is_running = False | |
| return f"Failed to start webcam: {str(e)}" | |
| def stop(self): | |
| try: | |
| self.is_running = False | |
| if hasattr(self, 'capture_thread'): | |
| self.capture_thread.join(timeout=1.0) | |
| if hasattr(self, 'process_thread'): | |
| self.process_thread.join(timeout=1.0) | |
| if hasattr(self, 'capture'): | |
| self.capture.release() | |
| return "Webcam stopped successfully" | |
| except Exception as e: | |
| return f"Error stopping webcam: {str(e)}" | |
| def _capture_loop(self): | |
| while self.is_running: | |
| try: | |
| ret, frame = self.capture.read() | |
| if ret: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame = cv2.resize(frame, (640, 480)) | |
| current_time = time.time() | |
| if current_time - self.last_process_time >= self.frame_interval: | |
| self.frame_buffer.append(frame) | |
| self.last_process_time = current_time | |
| time.sleep(0.01) # Small delay to prevent CPU overuse | |
| except Exception as e: | |
| print(f"Capture error: {e}") | |
| time.sleep(0.1) | |
| def _process_loop(self): | |
| while self.is_running: | |
| try: | |
| if len(self.frame_buffer) >= self.buffer_size: | |
| frames = list(self.frame_buffer) | |
| result = self.model.predict_forward( | |
| video=frames, | |
| text="<image>Describe what you see", | |
| tokenizer=self.tokenizer | |
| ) | |
| self.result_queue.put(result) | |
| self.frame_buffer.clear() | |
| time.sleep(0.1) | |
| except Exception as e: | |
| print(f"Processing error: {e}") | |
| time.sleep(0.1) | |
| from third_parts import VideoReader | |
| def read_video(video_path, video_interval): | |
| vid_frames = VideoReader(video_path)[::video_interval] | |
| temp_dir = tempfile.mkdtemp() | |
| os.makedirs(temp_dir, exist_ok=True) | |
| image_paths = [] | |
| for frame_idx in range(len(vid_frames)): | |
| frame_image = vid_frames[frame_idx] | |
| frame_image = frame_image[..., ::-1] | |
| frame_image = Image.fromarray(frame_image) | |
| vid_frames[frame_idx] = frame_image | |
| image_path = os.path.join(temp_dir, f"frame_{frame_idx:04d}.jpg") | |
| frame_image.save(image_path, format="JPEG") | |
| image_paths.append(image_path) | |
| return vid_frames, image_paths | |
| def visualize(pred_mask, image_path, work_dir): | |
| visualizer = Visualizer() | |
| img = cv2.imread(image_path) | |
| visualizer.set_image(img) | |
| visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4) | |
| visual_result = visualizer.get_image() | |
| output_path = os.path.join(work_dir, os.path.basename(image_path)) | |
| cv2.imwrite(output_path, visual_result) | |
| return output_path | |
| def translate_to_korean(text): | |
| try: | |
| translator = GoogleTranslator(source='en', target='ko') | |
| return translator.translate(text) | |
| except Exception as e: | |
| print(f"Translation error: {e}") | |
| return text | |
| def image_vision(image_input_path, prompt): | |
| is_korean = any(ord('κ°') <= ord(char) <= ord('ν£') for char in prompt) | |
| image_path = image_input_path | |
| text_prompts = f"<image>{prompt}" | |
| image = Image.open(image_path).convert('RGB') | |
| input_dict = { | |
| 'image': image, | |
| 'text': text_prompts, | |
| 'past_text': '', | |
| 'mask_prompts': None, | |
| 'tokenizer': tokenizer, | |
| } | |
| return_dict = model.predict_forward(**input_dict) | |
| print(return_dict) | |
| answer = return_dict["prediction"] | |
| if is_korean: | |
| if '[SEG]' in answer: | |
| parts = answer.split('[SEG]') | |
| translated_parts = [translate_to_korean(part.strip()) for part in parts] | |
| answer = '[SEG]'.join(translated_parts) | |
| else: | |
| answer = translate_to_korean(answer) | |
| seg_image = return_dict["prediction_masks"] | |
| if '[SEG]' in answer and Visualizer is not None: | |
| pred_masks = seg_image[0] | |
| temp_dir = tempfile.mkdtemp() | |
| pred_mask = pred_masks | |
| os.makedirs(temp_dir, exist_ok=True) | |
| seg_result = visualize(pred_mask, image_input_path, temp_dir) | |
| return answer, seg_result | |
| else: | |
| return answer, None | |
| def video_vision(video_input_path, prompt, video_interval): | |
| is_korean = any(ord('κ°') <= ord(char) <= ord('ν£') for char in prompt) | |
| cap = cv2.VideoCapture(video_input_path) | |
| original_fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_skip_factor = video_interval | |
| new_fps = original_fps / frame_skip_factor | |
| vid_frames, image_paths = read_video(video_input_path, video_interval) | |
| question = f"<image>{prompt}" | |
| result = model.predict_forward( | |
| video=vid_frames, | |
| text=question, | |
| tokenizer=tokenizer, | |
| ) | |
| prediction = result['prediction'] | |
| print(prediction) | |
| if is_korean: | |
| if '[SEG]' in prediction: | |
| parts = prediction.split('[SEG]') | |
| translated_parts = [translate_to_korean(part.strip()) for part in parts] | |
| prediction = '[SEG]'.join(translated_parts) | |
| else: | |
| prediction = translate_to_korean(prediction) | |
| if '[SEG]' in prediction and Visualizer is not None: | |
| _seg_idx = 0 | |
| pred_masks = result['prediction_masks'][_seg_idx] | |
| seg_frames = [] | |
| for frame_idx in range(len(vid_frames)): | |
| pred_mask = pred_masks[frame_idx] | |
| temp_dir = tempfile.mkdtemp() | |
| os.makedirs(temp_dir, exist_ok=True) | |
| seg_frame = visualize(pred_mask, image_paths[frame_idx], temp_dir) | |
| seg_frames.append(seg_frame) | |
| output_video = "output_video.mp4" | |
| frame = cv2.imread(seg_frames[0]) | |
| height, width, layers = frame.shape | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| video = cv2.VideoWriter(output_video, fourcc, new_fps, (width, height)) | |
| for img_path in seg_frames: | |
| frame = cv2.imread(img_path) | |
| video.write(frame) | |
| video.release() | |
| print(f"Video created successfully at {output_video}") | |
| return prediction, output_video | |
| else: | |
| return prediction, None | |
| def webcam_vision(prompt): | |
| try: | |
| if not hasattr(webcam_vision, 'processor'): | |
| webcam_vision.processor = WebcamProcessor(model, tokenizer) | |
| if not webcam_vision.processor.is_running: | |
| status = webcam_vision.processor.start() | |
| if "Failed" in status: | |
| return f"Error: {status}" | |
| try: | |
| result = webcam_vision.processor.result_queue.get(timeout=5) | |
| prediction = result['prediction'] | |
| # Check if Korean translation is needed | |
| is_korean = any(ord('κ°') <= ord(char) <= ord('ν£') for char in prompt) | |
| if is_korean: | |
| prediction = translate_to_korean(prediction) | |
| return prediction | |
| except queue.Empty: | |
| return "No results available yet. Please try again." | |
| except Exception as e: | |
| return f"Processing error: {str(e)}" | |
| except Exception as e: | |
| return f"System error: {str(e)}" | |
| # Gradio UI | |
| with gr.Blocks(analytics_enabled=False) as demo: | |
| with gr.Column(): | |
| gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos") | |
| with gr.Tab("Single Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Image IN", type="filepath") | |
| with gr.Row(): | |
| instruction = gr.Textbox(label="Instruction", scale=4) | |
| submit_image_btn = gr.Button("Submit", scale=1) | |
| with gr.Column(): | |
| output_res = gr.Textbox(label="Response") | |
| output_image = gr.Image(label="Segmentation", type="numpy") | |
| submit_image_btn.click( | |
| fn = image_vision, | |
| inputs = [image_input, instruction], | |
| outputs = [output_res, output_image] | |
| ) | |
| with gr.Tab("Video"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Video IN") | |
| frame_interval = gr.Slider(label="Frame interval", step=1, minimum=1, maximum=12, value=6) | |
| with gr.Row(): | |
| vid_instruction = gr.Textbox(label="Instruction", scale=4) | |
| submit_video_btn = gr.Button("Submit", scale=1) | |
| with gr.Column(): | |
| vid_output_res = gr.Textbox(label="Response") | |
| output_video = gr.Video(label="Segmentation") | |
| submit_video_btn.click( | |
| fn = video_vision, | |
| inputs = [video_input, vid_instruction, frame_interval], | |
| outputs = [vid_output_res, output_video] | |
| ) | |
| with gr.Tab("Webcam"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| # μΉμΊ μ λ ₯μ μν μ»΄ν¬λνΈ | |
| webcam_input = gr.Image( | |
| label="Webcam Input", | |
| type="numpy", | |
| sources="webcam", | |
| streaming=True, | |
| mirror_webcam=True | |
| ) | |
| with gr.Row(): | |
| webcam_instruction = gr.Textbox( | |
| label="Instruction", | |
| placeholder="Enter instruction here...", | |
| scale=4 | |
| ) | |
| start_button = gr.Button("Start", scale=1) | |
| stop_button = gr.Button("Stop", scale=1) | |
| with gr.Column(): | |
| webcam_output = gr.Textbox(label="Response") | |
| processed_view = gr.Image(label="Processed View") | |
| status_text = gr.Textbox(label="Status", value="Ready") | |
| def start_webcam_processing(instruction): | |
| try: | |
| if hasattr(webcam_vision, 'processor'): | |
| webcam_vision.processor.stop() | |
| webcam_vision.processor = WebcamProcessor(model, tokenizer) | |
| status = webcam_vision.processor.start() | |
| return webcam_vision(instruction) | |
| except Exception as e: | |
| return f"Error starting webcam: {str(e)}" | |
| start_button.click( | |
| fn=start_webcam_processing, | |
| inputs=[webcam_instruction], | |
| outputs=[webcam_output] | |
| ) | |
| stop_button.click( | |
| fn=lambda: "Stopped" if hasattr(webcam_vision, 'processor') and webcam_vision.processor.stop() else "Not running", | |
| outputs=[status_text] | |
| ) | |
| # μΉμΊ μ‘μΈμ€λ₯Ό μν μ€μ μΆκ° | |
| demo.queue().launch( | |
| server_name="0.0.0.0", # λͺ¨λ IPμμ μ κ·Ό κ°λ₯ | |
| server_port=7860, # ν¬νΈ μ§μ | |
| share=True, # κ³΅κ° λ§ν¬ μμ± | |
| show_api=False, | |
| show_error=True | |
| ) |