Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Gradio UI for Multimodal Gemma Model - Hugging Face Space Version | |
| Fixed: Added all missing modules (projectors.py, lightning_module.py, logging.py, data/, training/) | |
| Updated requirements.txt with rich and datasets libraries | |
| """ | |
| import sys | |
| import torch | |
| import gradio as gr | |
| from pathlib import Path | |
| from PIL import Image | |
| import io | |
| import time | |
| import logging | |
| import os | |
| from huggingface_hub import hf_hub_download, login | |
| # Try to login with HF token if available (for Spaces with secrets) | |
| try: | |
| # Try environment variables (for Space secrets) | |
| hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| print("β Logged in to Hugging Face with token") | |
| else: | |
| print("β οΈ No HF token found in environment variables") | |
| print("Please set HF_TOKEN as a Space secret to access gated models") | |
| except Exception as e: | |
| print(f"β οΈ HF login failed: {e}") | |
| # Model imports | |
| from src.models import MultimodalGemmaLightning | |
| from src.utils.config import load_config, merge_configs | |
| # Global model variable | |
| model = None | |
| config = None | |
| def load_model(): | |
| """Download and load the trained multimodal model from HF""" | |
| global model, config | |
| if model is not None: | |
| return "β Model already loaded!" | |
| try: | |
| print("π Loading multimodal Gemma model...") | |
| # Get token and show status | |
| hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") | |
| if hf_token: | |
| print(f"β Token found: {hf_token[:10]}...") | |
| try: | |
| login(token=hf_token) | |
| print("β Authentication successful") | |
| except Exception as e: | |
| print(f"β Authentication failed: {e}") | |
| else: | |
| print("β No HF_TOKEN found in environment") | |
| return "β Please add HF_TOKEN as a Space secret to access your gated model" | |
| print("π₯ Downloading model checkpoint...") | |
| checkpoint_path = hf_hub_download( | |
| repo_id="sagar007/multimodal-gemma-270m-llava", | |
| filename="final_model.ckpt", | |
| cache_dir="./model_cache", | |
| token=hf_token | |
| ) | |
| # Use local config files (included in Space) | |
| print("π Loading configs...") | |
| model_config = load_config("configs/model_config.yaml") | |
| training_config = load_config("configs/training_config.yaml") | |
| data_config = load_config("configs/data_config.yaml") | |
| config = merge_configs([model_config, training_config, data_config]) | |
| print("π Loading model from checkpoint...") | |
| # Load model exactly like local gradio_app.py | |
| model = MultimodalGemmaLightning.load_from_checkpoint( | |
| checkpoint_path, | |
| config=config, | |
| strict=False, | |
| map_location="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| model.eval() | |
| # Move to appropriate device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| print(f"β Model loaded successfully on {device}!") | |
| return f"β Model loaded successfully on {device}!" | |
| except Exception as e: | |
| error_msg = f"β Error loading model: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| def predict_with_image(image, question, max_tokens=100, temperature=0.7): | |
| """Generate response for image + text input""" | |
| global model, config | |
| if model is None: | |
| return "β Please load the model first using the 'Load Model' button!" | |
| if image is None: | |
| return "β Please upload an image!" | |
| if not question.strip(): | |
| question = "What do you see in this image?" | |
| try: | |
| # Get device | |
| device = next(model.parameters()).device | |
| # Process image | |
| if isinstance(image, str): | |
| image = Image.open(image).convert('RGB') | |
| elif not isinstance(image, Image.Image): | |
| image = Image.fromarray(image).convert('RGB') | |
| # Prepare image for model | |
| vision_inputs = model.model.vision_processor( | |
| images=[image], | |
| return_tensors="pt" | |
| ) | |
| pixel_values = vision_inputs["pixel_values"].to(device) | |
| # Prepare text prompt | |
| prompt = f"<image>\\nHuman: {question}\\nAssistant:" | |
| # Tokenize text | |
| text_inputs = model.model.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=256 | |
| ) | |
| input_ids = text_inputs["input_ids"].to(device) | |
| attention_mask = text_inputs["attention_mask"].to(device) | |
| # Generate response | |
| with torch.no_grad(): | |
| # Use the full multimodal model with image inputs | |
| outputs = model.model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| images=pixel_values, | |
| max_new_tokens=min(max_tokens, 150), | |
| temperature=min(max(temperature, 0.1), 2.0), | |
| do_sample=temperature > 0.1, | |
| repetition_penalty=1.1 | |
| ) | |
| # Decode response | |
| input_length = input_ids.shape[1] | |
| generated_tokens = outputs[0][input_length:] | |
| response = model.model.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| # Clean up response | |
| response = response.strip() | |
| if not response: | |
| response = "I can see the image, but I'm having trouble generating a detailed response." | |
| return response | |
| except Exception as e: | |
| error_msg = f"β Error during inference: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| def chat_with_image(image, question, history, max_tokens, temperature): | |
| """Chat interface function""" | |
| if model is None: | |
| response = "β Please load the model first!" | |
| else: | |
| response = predict_with_image(image, question, max_tokens, temperature) | |
| # Add to history - using messages format | |
| history.append({"role": "user", "content": question}) | |
| history.append({"role": "assistant", "content": response}) | |
| return history, "" | |
| def create_gradio_interface(): | |
| """Create the Gradio interface""" | |
| # Custom CSS for better styling | |
| css = """ | |
| .container { | |
| max-width: 1200px; | |
| margin: auto; | |
| padding: 20px; | |
| } | |
| .header { | |
| text-align: center; | |
| margin-bottom: 30px; | |
| } | |
| .model-info { | |
| background-color: #f0f8ff; | |
| padding: 15px; | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="Multimodal Gemma Chat") as demo: | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>π Multimodal Gemma-270M Chat</h1> | |
| <p>Upload an image and chat with your trained vision-language model!</p> | |
| <p><a href="https://huggingface.co/sagar007/multimodal-gemma-270m-llava">π€ Model</a></p> | |
| </div> | |
| """) | |
| # Model status section | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML(""" | |
| <div class="model-info"> | |
| <h3>π Model Info</h3> | |
| <ul> | |
| <li><strong>Base Model:</strong> Google Gemma-270M</li> | |
| <li><strong>Vision:</strong> CLIP ViT-Large</li> | |
| <li><strong>Training:</strong> LLaVA-150K + COCO Images</li> | |
| <li><strong>Parameters:</strong> 18.6M trainable / 539M total</li> | |
| </ul> | |
| </div> | |
| """) | |
| # Model loading | |
| load_btn = gr.Button("π Load Model", variant="primary", size="lg") | |
| model_status = gr.Textbox( | |
| label="Model Status", | |
| value="Click 'Load Model' to start", | |
| interactive=False | |
| ) | |
| gr.HTML("<hr>") | |
| # Main interface | |
| with gr.Row(): | |
| # Left column - Image and controls | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| label="πΈ Upload Image", | |
| type="pil", | |
| height=300 | |
| ) | |
| # Example images | |
| gr.HTML("<p><strong>π‘ Tip:</strong> Upload any image and ask questions about it</p>") | |
| # Generation settings | |
| with gr.Accordion("βοΈ Generation Settings", open=False): | |
| max_tokens = gr.Slider( | |
| minimum=10, | |
| maximum=200, | |
| value=100, | |
| step=10, | |
| label="Max Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| # Right column - Chat interface | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label="π¬ Chat with Image", | |
| height=400, | |
| show_label=True, | |
| type="messages" | |
| ) | |
| question_input = gr.Textbox( | |
| label="β Ask about the image", | |
| placeholder="What do you see in this image?", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("π¬ Send", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear Chat") | |
| # Example prompts | |
| with gr.Row(): | |
| gr.HTML("<h3>π‘ Example Questions:</h3>") | |
| with gr.Row(): | |
| example_questions = [ | |
| "What do you see in this image?", | |
| "Describe the main objects in the picture.", | |
| "What colors are prominent in this image?", | |
| "Are there any people in the image?", | |
| "What's the setting or location?", | |
| "What objects are in the foreground?" | |
| ] | |
| for i, question in enumerate(example_questions): | |
| if i % 3 == 0: | |
| with gr.Row(): | |
| pass | |
| gr.Button( | |
| question, | |
| size="sm" | |
| ).click( | |
| lambda x=question: x, | |
| outputs=question_input | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <hr> | |
| <div style="text-align: center; margin-top: 20px;"> | |
| <p><strong>π― Your Multimodal Gemma Model</strong></p> | |
| <p>Text-only β Vision-Language Model using LLaVA Architecture</p> | |
| <p>Model: <a href="https://huggingface.co/sagar007/multimodal-gemma-270m-llava">sagar007/multimodal-gemma-270m-llava</a></p> | |
| </div> | |
| """) | |
| # Event handlers | |
| load_btn.click( | |
| fn=load_model, | |
| outputs=model_status | |
| ) | |
| submit_btn.click( | |
| fn=chat_with_image, | |
| inputs=[image_input, question_input, chatbot, max_tokens, temperature], | |
| outputs=[chatbot, question_input] | |
| ) | |
| question_input.submit( | |
| fn=chat_with_image, | |
| inputs=[image_input, question_input, chatbot, max_tokens, temperature], | |
| outputs=[chatbot, question_input] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ([], ""), | |
| outputs=[chatbot, question_input] | |
| ) | |
| return demo | |
| def main(): | |
| """Main function to launch the Gradio app""" | |
| print("π Starting Multimodal Gemma Gradio Space...") | |
| # Create interface | |
| demo = create_gradio_interface() | |
| # Launch | |
| print("π Launching Gradio interface...") | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() |