import gradio as gr import spaces import time import os import torch from PIL import Image from threading import Thread from transformers import TextIteratorStreamer, AutoConfig, AutoModelForCausalLM from constants import ( IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, ) from conversation import conv_templates from eval_utils import load_maya_model from utils import disable_torch_init from mm_utils import tokenizer_image_token, process_images from huggingface_hub._login import _login # Import LLaVA modules to register model types from model import * from model.language_model.llava_cohere import LlavaCohereForCausalLM, LlavaCohereConfig # Register model type and config AutoConfig.register("llava_cohere", LlavaCohereConfig) AutoModelForCausalLM.register(LlavaCohereConfig, LlavaCohereForCausalLM) hf_token = os.getenv("hf_token") _login(token=hf_token, add_to_git_credential=False) # Global Variables MODEL_BASE = "CohereForAI/aya-23-8B" MODEL_PATH = "maya-multimodal/maya" MODE = "finetuned" def load_model(): """Load the Maya model and required components""" model, tokenizer, image_processor, _ = load_maya_model( MODEL_BASE, MODEL_PATH, None, MODE ) model = model.cuda() model.eval() return model, tokenizer, image_processor # Load model globally print("Loading model...") model, tokenizer, image_processor = load_model() print("Model loaded successfully!") def validate_image_file(image_path): """Validate that the image file exists and is in a supported format.""" if not os.path.isfile(image_path): raise gr.Error(f"Error: File {image_path} does not exist.") try: with Image.open(image_path) as img: img.verify() return True except (IOError, SyntaxError) as e: raise gr.Error(f"Error: {image_path} is not a valid image file. {e}") @spaces.GPU def process_chat_stream(message, history): print(message) print("History:", history) image = None # Initialize image variable first # First try to get image from current message if message.get("files", []): current_files = message["files"] if current_files: last_file = current_files[-1] image = last_file["path"] if isinstance(last_file, dict) else last_file # If no image in current message, try to get from history if image is None and history: for hist in reversed(history): print("Processing history item:", hist) if isinstance(hist["content"], tuple): image = hist["content"][0] break elif isinstance(hist["content"], dict) and hist["content"].get("files"): hist_files = hist["content"]["files"] if hist_files: first_file = hist_files[0] image = first_file["path"] if isinstance(first_file, dict) else first_file break # Check if we found an image if image is None: raise gr.Error("Please upload an image to start the conversation.") # Validate and process image validate_image_file(image) image = Image.open(image).convert("RGB") # Process image for the model image_tensor = process_images([image], image_processor, model.config) if image_tensor is None: raise gr.Error("Failed to process image") image_tensor = image_tensor.cuda() # Prepare conversation conv = conv_templates["aya"].copy() # Add conversation history for hist in history: # Handle user messages if hist["role"] == "user": # Extract text content based on format if isinstance(hist["content"], str): human_text = hist["content"] elif isinstance(hist["content"], tuple): human_text = hist["content"][1] if len(hist["content"]) > 1 else "" else: human_text = hist["content"] conv.append_message(conv.roles[0], human_text) # Handle assistant messages elif hist["role"] == "assistant": conv.append_message(conv.roles[1], hist["content"]) # Format current message with proper image token placement current_message = message["text"] if not history: if model.config.mm_use_im_start_end: current_message = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{current_message}" else: current_message = f"{DEFAULT_IMAGE_TOKEN}\n{current_message}" # Add current message to conversation conv.append_message(conv.roles[0], current_message) conv.append_message(conv.roles[1], None) # Get prompt and ensure input_ids are properly created prompt = conv.get_prompt() # print("PROMPT: ", prompt) try: input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') if input_ids is None: raise ValueError("Tokenization returned None") # Ensure input_ids is 2D tensor if len(input_ids.shape) == 1: input_ids = input_ids.unsqueeze(0) input_ids = input_ids.cuda() # Validate vision tower and image tensor before starting generation if not hasattr(model, 'get_vision_tower') or model.get_vision_tower() is None: raise ValueError("Model's vision tower is not properly initialized") if image_tensor is None: raise ValueError("Image tensor is None") # Setup streamer and generation streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) generation_kwargs = { "inputs": input_ids, "images": image_tensor, "image_sizes": [image.size], "streamer": streamer, "temperature": 0.3, "do_sample": True, "top_p": 0.9, "num_beams": 1, "max_new_tokens": 4096, "use_cache": True } def generate_with_error_handling(): try: model.generate(**generation_kwargs) except Exception as e: import traceback error_msg = f"Generation error: {str(e)}\nTraceback:\n{''.join(traceback.format_exc())}" raise gr.Error(error_msg) thread = Thread(target=generate_with_error_handling) thread.start() except Exception as e: error_msg = f"Setup error: {str(e)}" import traceback error_msg += f"\nTraceback:\n{''.join(traceback.format_exc())}" raise gr.Error(error_msg) partial_message = "" for new_token in streamer: partial_message += new_token time.sleep(0.1) yield {"role": "assistant", "content": partial_message} # Create Gradio interface chatbot = gr.Chatbot( show_label=False, height=450, show_share_button=False, show_copy_button=False, avatar_images=None, container=True, render_markdown=True, scale=1, type="messages" ) chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) with gr.Blocks(fill_height=True, ) as demo: gr.ChatInterface( fn=process_chat_stream, title="Maya: Multilingual Multimodal Model", examples=[{"text": "Describe this photo in detail.", "files": ["./asian_food.jpg"]}, {"text": "What is the name of this famous sight in the photo?", "files": ["./hawaii.jpg"]}], description="Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error. [Read the research paper](https://huggingface.co/papers/2412.07112)\n\nTeam 💚 Maya", stop_btn="Stop Generation", multimodal=True, textbox=chat_input, chatbot=chatbot, ) if __name__ == "__main__": demo.queue(api_open=False) demo.launch(show_api=False, share=False)