gemma-3n-alkdf / app.py
broadfield-dev's picture
Update app.py
4dd17a5 verified
import gradio as gr
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
from PIL import Image
import requests
import torch
import io
import os
from huggingface_hub import login
hf_token = os.environ.get("HF_TOKEN")
login(token=hf_token)
# Initialize the model and processor
model_id = "google/gemma-3n-e4b-it"
try:
model = Gemma3nForConditionalGeneration.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.bfloat16
).eval()
processor = AutoProcessor.from_pretrained(model_id)
except Exception as e:
raise Exception(f"Failed to load model or processor: {str(e)}")
def process_inputs(image_input, image_url, text_prompt):
"""
Process image (from file or URL) and text prompt to generate a response using the Gemma model.
Args:
image_input: Uploaded image file
image_url: URL of an image
text_prompt: Text input from the user
Returns:
Generated text response from the model
"""
try:
# Handle image input: prioritize uploaded image, then URL, then None
image = None
if image_input is not None:
image = Image.open(image_input).convert("RGB")
elif image_url:
response = requests.get(image_url, stream=True)
response.raise_for_status()
image = Image.open(io.BytesIO(response.content)).convert("RGB")
# Prepare messages for the model
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": []
}
]
# Add image to content if provided
if image is not None:
messages[1]["content"].append({"type": "image", "image": image})
# Add text prompt if provided
if text_prompt:
messages[1]["content"].append({"type": "text", "text": text_prompt})
else:
return "Please provide a text prompt."
# Process inputs using the processor
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate response
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=500, do_sample=False)
generation = generation[0][input_len:]
# Decode and return the response
decoded = processor.decode(generation, skip_special_tokens=True)
return decoded
except Exception as e:
return f"Error: {str(e)}"
# Define the Gradio interface
iface = gr.Interface(
fn=process_inputs,
inputs=[
gr.Image(type="filepath", label="Upload Image (optional)"),
gr.Textbox(label="Image URL (optional)", placeholder="Enter image URL"),
gr.Textbox(label="Text Prompt", placeholder="Enter your prompt here")
],
outputs=gr.Textbox(label="Model Response"),
title="Gemma-3 Multimodal App (Authenticated)",
description="Upload an image or provide an image URL, and enter a text prompt to interact with the Gemma-3 model. Ensure you have authenticated with a valid Hugging Face access token.",
allow_flagging="never"
)
# Launch the app
iface.launch()