import os import pathlib import tempfile from collections.abc import Iterator from threading import Thread import av import gradio as gr import spaces import torch from gradio.utils import get_upload_folder from transformers import AutoModelForImageTextToText, AutoProcessor from transformers.generation.streamers import TextIteratorStreamer model_id = "google/gemma-3n-E4B-it" processor = AutoProcessor.from_pretrained(model_id) model = AutoModelForImageTextToText.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16) IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp") VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm") AUDIO_FILE_TYPES = (".mp3", ".wav") GRADIO_TEMP_DIR = get_upload_folder() TARGET_FPS = int(os.getenv("TARGET_FPS", "3")) MAX_FRAMES = int(os.getenv("MAX_FRAMES", "30")) MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000")) def get_file_type(path: str) -> str: if path.endswith(IMAGE_FILE_TYPES): return "image" if path.endswith(VIDEO_FILE_TYPES): return "video" if path.endswith(AUDIO_FILE_TYPES): return "audio" error_message = f"Unsupported file type: {path}" raise ValueError(error_message) def count_files_in_new_message(paths: list[str]) -> tuple[int, int]: video_count = 0 non_video_count = 0 for path in paths: if path.endswith(VIDEO_FILE_TYPES): video_count += 1 else: non_video_count += 1 return video_count, non_video_count def validate_media_constraints(message: dict) -> bool: video_count, non_video_count = count_files_in_new_message(message["files"]) if video_count > 1: gr.Warning("Only one video is supported.") return False if video_count == 1 and non_video_count > 0: gr.Warning("Mixing images and videos is not allowed.") return False return True def extract_frames_to_tempdir( video_path: str, target_fps: float, max_frames: int | None = None, parent_dir: str | None = None, prefix: str = "frames_", ) -> str: temp_dir = tempfile.mkdtemp(prefix=prefix, dir=parent_dir) container = av.open(video_path) video_stream = container.streams.video[0] if video_stream.duration is None or video_stream.time_base is None: raise ValueError("video_stream is missing duration or time_base") time_base = video_stream.time_base duration = float(video_stream.duration * time_base) interval = 1.0 / target_fps total_frames = int(duration * target_fps) if max_frames is not None: total_frames = min(total_frames, max_frames) target_times = [i * interval for i in range(total_frames)] target_index = 0 for frame in container.decode(video=0): if frame.pts is None: continue timestamp = float(frame.pts * time_base) if target_index < len(target_times) and abs(timestamp - target_times[target_index]) < (interval / 2): frame_path = pathlib.Path(temp_dir) / f"frame_{target_index:04d}.jpg" frame.to_image().save(frame_path) target_index += 1 if max_frames is not None and target_index >= max_frames: break container.close() return temp_dir def process_new_user_message(message: dict) -> list[dict]: if not message["files"]: return [{"type": "text", "text": message["text"]}] file_types = [get_file_type(path) for path in message["files"]] if len(file_types) == 1 and file_types[0] == "video": gr.Info(f"Video will be processed at {TARGET_FPS} FPS, max {MAX_FRAMES} frames in this Space.") temp_dir = extract_frames_to_tempdir( message["files"][0], target_fps=TARGET_FPS, max_frames=MAX_FRAMES, parent_dir=GRADIO_TEMP_DIR, ) paths = sorted(pathlib.Path(temp_dir).glob("*.jpg")) return [ {"type": "text", "text": message["text"]}, *[{"type": "image", "image": path.as_posix()} for path in paths], ] return [ {"type": "text", "text": message["text"]}, *[{"type": file_type, file_type: path} for path, file_type in zip(message["files"], file_types, strict=True)], ] def process_history(history: list[dict]) -> list[dict]: messages = [] current_user_content: list[dict] = [] for item in history: if item["role"] == "assistant": if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]}) else: content = item["content"] if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) else: filepath = content[0] file_type = get_file_type(filepath) current_user_content.append({"type": file_type, file_type: filepath}) return messages @spaces.GPU(duration=120) @torch.inference_mode() def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]: if not validate_media_constraints(message): yield "" return messages = [] if system_prompt: messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) messages.extend(process_history(history)) messages.append({"role": "user", "content": process_new_user_message(message)}) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ) n_tokens = inputs["input_ids"].shape[1] if n_tokens > MAX_INPUT_TOKENS: gr.Warning( f"Input too long. Max {MAX_INPUT_TOKENS} tokens. Got {n_tokens} tokens. This limit is set to avoid CUDA out-of-memory errors in this Space." ) yield "" return inputs = inputs.to(device=model.device, dtype=torch.bfloat16) streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=False, disable_compile=True, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() output = "" for delta in streamer: output += delta yield output examples = [ [ { "text": "What is the capital of France?", "files": [], } ], [ { "text": "Describe this image in detail.", "files": ["assets/cat.jpeg"], } ], [ { "text": "Transcribe the following speech segment in English.", "files": ["assets/speech.wav"], } ], [ { "text": "Transcribe the following speech segment in English.", "files": ["assets/speech2.wav"], } ], [ { "text": "Describe this video", "files": ["assets/holding_phone.mp4"], } ], ] demo = gr.ChatInterface( fn=generate, type="messages", textbox=gr.MultimodalTextbox( file_types=list(IMAGE_FILE_TYPES + VIDEO_FILE_TYPES + AUDIO_FILE_TYPES), file_count="multiple", autofocus=True, ), multimodal=True, additional_inputs=[ gr.Textbox(label="System Prompt", value="You are a helpful assistant."), gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700), ], stop_btn=False, title="Gemma 3n E4B it", examples=examples, run_examples_on_click=False, cache_examples=False, css_paths="style.css", delete_cache=(1800, 1800), ) if __name__ == "__main__": demo.launch()