gemma-3-270m-it / app.py
Norod78's picture
Update app.py
8f5eed9 verified
raw
history blame
6.18 kB
import gradio as gr
import cv2
import torch
from PIL import Image
from pathlib import Path
from threading import Thread
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
import spaces
import time
TITLE = " google/gemma-3-270m-it "
DESCRIPTION= """
It's so small
"""
# model config
model_270m_name = "google/gemma-3-270m-it"
model_270m = Gemma3ForConditionalGeneration.from_pretrained(
model_270m_name,
device_map="auto",
torch_dtype=torch.bfloat16
).eval()
processor_4b = AutoProcessor.from_pretrained(model_270m_name)
# I will add timestamp later
def extract_video_frames(video_path, num_frames=8):
cap = cv2.VideoCapture(video_path)
frames = []
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
step = max(total_frames // num_frames, 1)
for i in range(num_frames):
cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame))
cap.release()
return frames
def format_message(content, files):
message_content = []
if content:
parts = content.split('<image>')
for i, part in enumerate(parts):
if part.strip():
message_content.append({"type": "text", "text": part.strip()})
if i < len(parts) - 1 and files:
img = Image.open(files.pop(0))
message_content.append({"type": "image", "image": img})
for file in files:
file_path = file if isinstance(file, str) else file.name
if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']:
img = Image.open(file_path)
message_content.append({"type": "image", "image": img})
elif Path(file_path).suffix.lower() in ['.mp4', '.mov']:
frames = extract_video_frames(file_path)
for frame in frames:
message_content.append({"type": "image", "image": frame})
return message_content
def format_conversation_history(chat_history):
messages = []
current_user_content = []
for item in chat_history:
role = item["role"]
content = item["content"]
if role == "user":
if isinstance(content, str):
current_user_content.append({"type": "text", "text": content})
elif isinstance(content, list):
current_user_content.extend(content)
else:
current_user_content.append({"type": "text", "text": str(content)})
elif role == "assistant":
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
current_user_content = []
messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]})
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
return messages
@spaces.GPU(duration=120)
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
if isinstance(input_data, dict) and "text" in input_data:
text = input_data["text"]
files = input_data.get("files", [])
else:
text = str(input_data)
files = []
new_message_content = format_message(text, files)
new_message = {"role": "user", "content": new_message_content}
system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else []
processed_history = format_conversation_history(chat_history)
messages = system_message + processed_history
if messages and messages[-1]["role"] == "user":
messages[-1]["content"].extend(new_message["content"])
else:
messages.append(new_message)
model = model_270m
processor = processor_4b
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
).to(model.device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate_response,
chatbot=gr.Chatbot(rtl=True, show_copy_button=True,type="messages"),
additional_inputs=[
gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512),
gr.Textbox(
label="System Prompt",
value="You are a very helpful multimodal assistant",
lines=4,
placeholder="Change the settings",
text_align = 'left', rtl = False
),
gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.2),
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.4),
gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=30),
gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1),
],
examples=[
[{"text": "Write a poem which describes this image", "files": ["examples/image1.jpg"]}],
],
textbox=gr.MultimodalTextbox(
rtl=False,
label="input",
file_types=["image", "video"],
file_count="multiple",
placeholder="Input text, image or video",
),
cache_examples=False,
type="messages",
fill_height=True,
stop_btn="Stop",
css_paths=["style.css"],
multimodal=True,
title=TITLE,
description=DESCRIPTION,
theme=gr.themes.Soft(),
)
if __name__ == "__main__":
chat_interface.queue(max_size=20).launch()