Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoProcessor, Gemma3nForConditionalGeneration | |
from PIL import Image | |
import requests | |
import torch | |
import io | |
import os | |
from huggingface_hub import login | |
hf_token = os.environ.get("HF_TOKEN") | |
login(token=hf_token) | |
# Initialize the model and processor | |
model_id = "google/gemma-3n-e4b-it" | |
try: | |
model = Gemma3nForConditionalGeneration.from_pretrained( | |
model_id, device_map="auto", torch_dtype=torch.bfloat16 | |
).eval() | |
processor = AutoProcessor.from_pretrained(model_id) | |
except Exception as e: | |
raise Exception(f"Failed to load model or processor: {str(e)}") | |
def process_inputs(image_input, image_url, text_prompt): | |
""" | |
Process image (from file or URL) and text prompt to generate a response using the Gemma model. | |
Args: | |
image_input: Uploaded image file | |
image_url: URL of an image | |
text_prompt: Text input from the user | |
Returns: | |
Generated text response from the model | |
""" | |
try: | |
# Handle image input: prioritize uploaded image, then URL, then None | |
image = None | |
if image_input is not None: | |
image = Image.open(image_input).convert("RGB") | |
elif image_url: | |
response = requests.get(image_url, stream=True) | |
response.raise_for_status() | |
image = Image.open(io.BytesIO(response.content)).convert("RGB") | |
# Prepare messages for the model | |
messages = [ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": "You are a helpful assistant."}] | |
}, | |
{ | |
"role": "user", | |
"content": [] | |
} | |
] | |
# Add image to content if provided | |
if image is not None: | |
messages[1]["content"].append({"type": "image", "image": image}) | |
# Add text prompt if provided | |
if text_prompt: | |
messages[1]["content"].append({"type": "text", "text": text_prompt}) | |
else: | |
return "Please provide a text prompt." | |
# Process inputs using the processor | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(model.device) | |
input_len = inputs["input_ids"].shape[-1] | |
# Generate response | |
with torch.inference_mode(): | |
generation = model.generate(**inputs, max_new_tokens=500, do_sample=False) | |
generation = generation[0][input_len:] | |
# Decode and return the response | |
decoded = processor.decode(generation, skip_special_tokens=True) | |
return decoded | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=process_inputs, | |
inputs=[ | |
gr.Image(type="filepath", label="Upload Image (optional)"), | |
gr.Textbox(label="Image URL (optional)", placeholder="Enter image URL"), | |
gr.Textbox(label="Text Prompt", placeholder="Enter your prompt here") | |
], | |
outputs=gr.Textbox(label="Model Response"), | |
title="Gemma-3 Multimodal App (Authenticated)", | |
description="Upload an image or provide an image URL, and enter a text prompt to interact with the Gemma-3 model. Ensure you have authenticated with a valid Hugging Face access token.", | |
allow_flagging="never" | |
) | |
# Launch the app | |
iface.launch() |