import gradio as gr from transformers import AutoProcessor, Gemma3nForConditionalGeneration from PIL import Image import requests import torch import io import os from huggingface_hub import login hf_token = os.environ.get("HF_TOKEN") login(token=hf_token) # Initialize the model and processor model_id = "google/gemma-3n-e4b-it" try: model = Gemma3nForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16 ).eval() processor = AutoProcessor.from_pretrained(model_id) except Exception as e: raise Exception(f"Failed to load model or processor: {str(e)}") def process_inputs(image_input, image_url, text_prompt): """ Process image (from file or URL) and text prompt to generate a response using the Gemma model. Args: image_input: Uploaded image file image_url: URL of an image text_prompt: Text input from the user Returns: Generated text response from the model """ try: # Handle image input: prioritize uploaded image, then URL, then None image = None if image_input is not None: image = Image.open(image_input).convert("RGB") elif image_url: response = requests.get(image_url, stream=True) response.raise_for_status() image = Image.open(io.BytesIO(response.content)).convert("RGB") # Prepare messages for the model messages = [ { "role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}] }, { "role": "user", "content": [] } ] # Add image to content if provided if image is not None: messages[1]["content"].append({"type": "image", "image": image}) # Add text prompt if provided if text_prompt: messages[1]["content"].append({"type": "text", "text": text_prompt}) else: return "Please provide a text prompt." # Process inputs using the processor inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) input_len = inputs["input_ids"].shape[-1] # Generate response with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=500, do_sample=False) generation = generation[0][input_len:] # Decode and return the response decoded = processor.decode(generation, skip_special_tokens=True) return decoded except Exception as e: return f"Error: {str(e)}" # Define the Gradio interface iface = gr.Interface( fn=process_inputs, inputs=[ gr.Image(type="filepath", label="Upload Image (optional)"), gr.Textbox(label="Image URL (optional)", placeholder="Enter image URL"), gr.Textbox(label="Text Prompt", placeholder="Enter your prompt here") ], outputs=gr.Textbox(label="Model Response"), title="Gemma-3 Multimodal App (Authenticated)", description="Upload an image or provide an image URL, and enter a text prompt to interact with the Gemma-3 model. Ensure you have authenticated with a valid Hugging Face access token.", allow_flagging="never" ) # Launch the app iface.launch()