import gradio as gr import os from PIL import Image import tempfile import PyPDF2 import io from typing import List, Tuple, Optional from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Global variables for model and tokenizer model = None tokenizer = None def load_gemma_model(): """ Load Gemma model and tokenizer from Hugging Face. """ global model, tokenizer try: model_name = "google/gemma-2-2b-it" # Using Gemma-2 as it's available print(f"Loading {model_name}...") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) # Add padding token if it doesn't exist if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Model loaded successfully!") return True except Exception as e: print(f"Error loading model: {e}") return False def gemma_3_inference(prompt_text: str, pil_image: Optional[Image.Image] = None, chat_history: Optional[List[Tuple[str, str]]] = None) -> str: """ Gemma-2 model inference function. """ global model, tokenizer # Load model if not already loaded if model is None or tokenizer is None: if not load_gemma_model(): return "❌ Error: Could not load Gemma model. Please check your internet connection and try again." try: # Build conversation context conversation = [] # Add chat history for context (last 3 exchanges) if chat_history: for user_msg, bot_msg in chat_history[-3:]: conversation.append({"role": "user", "content": user_msg}) conversation.append({"role": "assistant", "content": bot_msg}) # Handle image input (note: Gemma-2 doesn't have native vision, so we'll describe the limitation) if pil_image: prompt_text = f"[Image uploaded - Note: This model doesn't have vision capabilities, but I can help with text-based questions about images] {prompt_text}" # Add current user message conversation.append({"role": "user", "content": prompt_text}) # Format conversation for Gemma formatted_prompt = tokenizer.apply_chat_template( conversation, tokenize=False, add_generation_prompt=True ) # Tokenize input inputs = tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=2048 ) # Move to same device as model if torch.cuda.is_available() and model.device.type == 'cuda': inputs = {k: v.to(model.device) for k, v in inputs.items()} # Generate response with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=512, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) # Decode response response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the new generated part response = response[len(formatted_prompt):].strip() return f"🤖 Gemma Response: {response}" except Exception as e: return f"❌ Error generating response: {str(e)}. Please try again." def extract_text_from_pdf(file_path: str) -> str: """ Extract text from PDF file using PyPDF2. """ try: with open(file_path, 'rb') as file: pdf_reader = PyPDF2.PdfReader(file) text = "" for page in pdf_reader.pages: text += page.extract_text() + "\n" return text.strip() except Exception as e: return f"Error reading PDF: {str(e)}" def extract_text_from_txt(file_path: str) -> str: """ Extract text from TXT file. """ try: with open(file_path, 'r', encoding='utf-8') as file: return file.read().strip() except Exception as e: try: # Try with different encoding if UTF-8 fails with open(file_path, 'r', encoding='latin-1') as file: return file.read().strip() except Exception as e2: return f"Error reading text file: {str(e2)}" def process_file_input(file_input) -> str: """ Process uploaded file and extract text content. """ if file_input is None: return "" file_path = file_input.name file_extension = os.path.splitext(file_path)[1].lower() if file_extension == '.pdf': extracted_text = extract_text_from_pdf(file_path) return f"📄 Content from PDF ({os.path.basename(file_path)}):\n{extracted_text[:1000]}{'...' if len(extracted_text) > 1000 else ''}" elif file_extension == '.txt': extracted_text = extract_text_from_txt(file_path) return f"📝 Content from text file ({os.path.basename(file_path)}):\n{extracted_text[:1000]}{'...' if len(extracted_text) > 1000 else ''}" else: return f"❌ Unsupported file type: {file_extension}. Please upload PDF or TXT files only." def process_input(user_text: str, image_input: Optional[Image.Image], file_input, chat_history: List[Tuple[str, str]], file_context: str) -> Tuple[List[Tuple[str, str]], str, None, None, str]: """ Main function to process user input and generate response. Returns: (updated_chat_history, cleared_text, cleared_image, cleared_file, updated_file_context) """ if not user_text.strip() and image_input is None and file_input is None: return chat_history, "", None, None, file_context # Process file input if provided current_file_context = "" if file_input is not None: current_file_context = process_file_input(file_input) # Combine file context with user text combined_prompt = "" if current_file_context: combined_prompt = f"{current_file_context}\n\nUser Query: {user_text}" # Update persistent file context file_context = current_file_context elif file_context and user_text.strip(): # Use previous file context if available combined_prompt = f"{file_context}\n\nUser Query: {user_text}" else: combined_prompt = user_text # Generate response using Gemma model if image_input is not None: # Handle image + text input bot_response = gemma_3_inference(combined_prompt, pil_image=image_input, chat_history=chat_history) user_display = f"{user_text} [Image uploaded]" else: # Handle text-only input (potentially with file context) bot_response = gemma_3_inference(combined_prompt, chat_history=chat_history) if current_file_context: user_display = f"{user_text} [File: {os.path.basename(file_input.name) if file_input else 'Unknown'}]" else: user_display = user_text # Update chat history chat_history.append((user_display, bot_response)) # Return updated history and clear inputs return chat_history, "", None, None, file_context if current_file_context else file_context def clear_chat(chat_history: List[Tuple[str, str]], file_context: str) -> Tuple[List[Tuple[str, str]], str, None, None, str]: """ Clear chat history and reset all inputs. """ return [], "", None, None, "" # Create Gradio interface with gr.Blocks(title="Gemma-2 Multimodal Chat", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🚀 Gemma-2 Multimodal Chat Application Welcome to the sophisticated Gemma-2 chat interface powered by Google's Gemma-2-2B-IT model! This application supports: - 💬 **Text conversations** with persistent chat history - 🖼️ **File processing** - upload PDF or TXT files for context - 📄 **Document analysis** - extract and analyze text from uploaded files - 🧠 **Contextual responses** - the model remembers your conversation **How to use:** 1. Type your message in the text box 2. Optionally upload a file (PDF/TXT) for document analysis 3. Click Submit or press Enter 4. Use Clear to reset the conversation *Note: This application uses the real Gemma-2-2B-IT model from Hugging Face. First message may take longer as the model loads.* """ ) # State variables chat_history_state = gr.State([]) file_context_state = gr.State("") with gr.Row(): with gr.Column(scale=2): # Chat interface chatbot = gr.Chatbot( label="Chat History", height=400, show_label=True, container=True, bubble_full_width=False ) # Input area with gr.Row(): user_input = gr.Textbox( label="Your message", placeholder="Type your message here...", lines=2, scale=4 ) submit_btn = gr.Button("Submit", variant="primary", scale=1) # Clear button clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary") with gr.Column(scale=1): # File upload area gr.Markdown("### 📎 Upload Content") image_input = gr.Image( label="Upload Image (for vision tasks)", type="pil", height=200 ) file_input = gr.File( label="Upload File (PDF or TXT)", file_types=[".pdf", ".txt"], height=100 ) gr.Markdown( """ **Tips:** - Upload either an image OR a file per message - PDF files will have their text extracted - File content persists as context for follow-up questions - Note: Gemma-2 doesn't have native vision capabilities, but you can still upload images and ask text-based questions about them """ ) # Event handlers submit_btn.click( fn=process_input, inputs=[user_input, image_input, file_input, chat_history_state, file_context_state], outputs=[chatbot, user_input, image_input, file_input, file_context_state] ).then( lambda: gr.update(value=chat_history_state.value), outputs=[chat_history_state] ) user_input.submit( fn=process_input, inputs=[user_input, image_input, file_input, chat_history_state, file_context_state], outputs=[chatbot, user_input, image_input, file_input, file_context_state] ).then( lambda: gr.update(value=chat_history_state.value), outputs=[chat_history_state] ) clear_btn.click( fn=clear_chat, inputs=[chat_history_state, file_context_state], outputs=[chatbot, user_input, image_input, file_input, file_context_state] ).then( lambda: gr.update(value=[]), outputs=[chat_history_state] ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=True, show_error=True )