sagar007's picture
Upload app.py with huggingface_hub
57d5fd7 verified
#!/usr/bin/env python3
"""
Gradio UI for Multimodal Gemma Model - Hugging Face Space Version
Fixed: Added all missing modules (projectors.py, lightning_module.py, logging.py, data/, training/)
Updated requirements.txt with rich and datasets libraries
"""
import sys
import torch
import gradio as gr
from pathlib import Path
from PIL import Image
import io
import time
import logging
import os
from huggingface_hub import hf_hub_download, login
# Try to login with HF token if available (for Spaces with secrets)
try:
# Try environment variables (for Space secrets)
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
if hf_token:
login(token=hf_token)
print("βœ… Logged in to Hugging Face with token")
else:
print("⚠️ No HF token found in environment variables")
print("Please set HF_TOKEN as a Space secret to access gated models")
except Exception as e:
print(f"⚠️ HF login failed: {e}")
# Model imports
from src.models import MultimodalGemmaLightning
from src.utils.config import load_config, merge_configs
# Global model variable
model = None
config = None
def load_model():
"""Download and load the trained multimodal model from HF"""
global model, config
if model is not None:
return "βœ… Model already loaded!"
try:
print("πŸ”„ Loading multimodal Gemma model...")
# Get token and show status
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
if hf_token:
print(f"βœ… Token found: {hf_token[:10]}...")
try:
login(token=hf_token)
print("βœ… Authentication successful")
except Exception as e:
print(f"❌ Authentication failed: {e}")
else:
print("❌ No HF_TOKEN found in environment")
return "❌ Please add HF_TOKEN as a Space secret to access your gated model"
print("πŸ“₯ Downloading model checkpoint...")
checkpoint_path = hf_hub_download(
repo_id="sagar007/multimodal-gemma-270m-llava",
filename="final_model.ckpt",
cache_dir="./model_cache",
token=hf_token
)
# Use local config files (included in Space)
print("πŸ“ Loading configs...")
model_config = load_config("configs/model_config.yaml")
training_config = load_config("configs/training_config.yaml")
data_config = load_config("configs/data_config.yaml")
config = merge_configs([model_config, training_config, data_config])
print("πŸ“ Loading model from checkpoint...")
# Load model exactly like local gradio_app.py
model = MultimodalGemmaLightning.load_from_checkpoint(
checkpoint_path,
config=config,
strict=False,
map_location="cuda" if torch.cuda.is_available() else "cpu"
)
model.eval()
# Move to appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"βœ… Model loaded successfully on {device}!")
return f"βœ… Model loaded successfully on {device}!"
except Exception as e:
error_msg = f"❌ Error loading model: {str(e)}"
print(error_msg)
return error_msg
def predict_with_image(image, question, max_tokens=100, temperature=0.7):
"""Generate response for image + text input"""
global model, config
if model is None:
return "❌ Please load the model first using the 'Load Model' button!"
if image is None:
return "❌ Please upload an image!"
if not question.strip():
question = "What do you see in this image?"
try:
# Get device
device = next(model.parameters()).device
# Process image
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif not isinstance(image, Image.Image):
image = Image.fromarray(image).convert('RGB')
# Prepare image for model
vision_inputs = model.model.vision_processor(
images=[image],
return_tensors="pt"
)
pixel_values = vision_inputs["pixel_values"].to(device)
# Prepare text prompt
prompt = f"<image>\\nHuman: {question}\\nAssistant:"
# Tokenize text
text_inputs = model.model.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256
)
input_ids = text_inputs["input_ids"].to(device)
attention_mask = text_inputs["attention_mask"].to(device)
# Generate response
with torch.no_grad():
# Use the full multimodal model with image inputs
outputs = model.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
images=pixel_values,
max_new_tokens=min(max_tokens, 150),
temperature=min(max(temperature, 0.1), 2.0),
do_sample=temperature > 0.1,
repetition_penalty=1.1
)
# Decode response
input_length = input_ids.shape[1]
generated_tokens = outputs[0][input_length:]
response = model.model.tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Clean up response
response = response.strip()
if not response:
response = "I can see the image, but I'm having trouble generating a detailed response."
return response
except Exception as e:
error_msg = f"❌ Error during inference: {str(e)}"
print(error_msg)
return error_msg
def chat_with_image(image, question, history, max_tokens, temperature):
"""Chat interface function"""
if model is None:
response = "❌ Please load the model first!"
else:
response = predict_with_image(image, question, max_tokens, temperature)
# Add to history - using messages format
history.append({"role": "user", "content": question})
history.append({"role": "assistant", "content": response})
return history, ""
def create_gradio_interface():
"""Create the Gradio interface"""
# Custom CSS for better styling
css = """
.container {
max-width: 1200px;
margin: auto;
padding: 20px;
}
.header {
text-align: center;
margin-bottom: 30px;
}
.model-info {
background-color: #f0f8ff;
padding: 15px;
border-radius: 10px;
margin-bottom: 20px;
}
"""
with gr.Blocks(css=css, title="Multimodal Gemma Chat") as demo:
gr.HTML("""
<div class="header">
<h1>πŸŽ‰ Multimodal Gemma-270M Chat</h1>
<p>Upload an image and chat with your trained vision-language model!</p>
<p><a href="https://huggingface.co/sagar007/multimodal-gemma-270m-llava">πŸ€— Model</a></p>
</div>
""")
# Model status section
with gr.Row():
with gr.Column():
gr.HTML("""
<div class="model-info">
<h3>πŸ“Š Model Info</h3>
<ul>
<li><strong>Base Model:</strong> Google Gemma-270M</li>
<li><strong>Vision:</strong> CLIP ViT-Large</li>
<li><strong>Training:</strong> LLaVA-150K + COCO Images</li>
<li><strong>Parameters:</strong> 18.6M trainable / 539M total</li>
</ul>
</div>
""")
# Model loading
load_btn = gr.Button("πŸš€ Load Model", variant="primary", size="lg")
model_status = gr.Textbox(
label="Model Status",
value="Click 'Load Model' to start",
interactive=False
)
gr.HTML("<hr>")
# Main interface
with gr.Row():
# Left column - Image and controls
with gr.Column(scale=1):
image_input = gr.Image(
label="πŸ“Έ Upload Image",
type="pil",
height=300
)
# Example images
gr.HTML("<p><strong>πŸ’‘ Tip:</strong> Upload any image and ask questions about it</p>")
# Generation settings
with gr.Accordion("βš™οΈ Generation Settings", open=False):
max_tokens = gr.Slider(
minimum=10,
maximum=200,
value=100,
step=10,
label="Max Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
# Right column - Chat interface
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="πŸ’¬ Chat with Image",
height=400,
show_label=True,
type="messages"
)
question_input = gr.Textbox(
label="❓ Ask about the image",
placeholder="What do you see in this image?",
lines=2
)
with gr.Row():
submit_btn = gr.Button("πŸ’¬ Send", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat")
# Example prompts
with gr.Row():
gr.HTML("<h3>πŸ’‘ Example Questions:</h3>")
with gr.Row():
example_questions = [
"What do you see in this image?",
"Describe the main objects in the picture.",
"What colors are prominent in this image?",
"Are there any people in the image?",
"What's the setting or location?",
"What objects are in the foreground?"
]
for i, question in enumerate(example_questions):
if i % 3 == 0:
with gr.Row():
pass
gr.Button(
question,
size="sm"
).click(
lambda x=question: x,
outputs=question_input
)
# Footer
gr.HTML("""
<hr>
<div style="text-align: center; margin-top: 20px;">
<p><strong>🎯 Your Multimodal Gemma Model</strong></p>
<p>Text-only β†’ Vision-Language Model using LLaVA Architecture</p>
<p>Model: <a href="https://huggingface.co/sagar007/multimodal-gemma-270m-llava">sagar007/multimodal-gemma-270m-llava</a></p>
</div>
""")
# Event handlers
load_btn.click(
fn=load_model,
outputs=model_status
)
submit_btn.click(
fn=chat_with_image,
inputs=[image_input, question_input, chatbot, max_tokens, temperature],
outputs=[chatbot, question_input]
)
question_input.submit(
fn=chat_with_image,
inputs=[image_input, question_input, chatbot, max_tokens, temperature],
outputs=[chatbot, question_input]
)
clear_btn.click(
fn=lambda: ([], ""),
outputs=[chatbot, question_input]
)
return demo
def main():
"""Main function to launch the Gradio app"""
print("πŸš€ Starting Multimodal Gemma Gradio Space...")
# Create interface
demo = create_gradio_interface()
# Launch
print("🌐 Launching Gradio interface...")
demo.launch()
if __name__ == "__main__":
main()