gemma / app.py
w1r4
initial
42e7ec7
raw
history blame
11.8 kB
import gradio as gr
import os
from PIL import Image
import tempfile
import PyPDF2
import io
from typing import List, Tuple, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Global variables for model and tokenizer
model = None
tokenizer = None
def load_gemma_model():
"""
Load Gemma model and tokenizer from Hugging Face.
"""
global model, tokenizer
try:
model_name = "google/gemma-2-2b-it" # Using Gemma-2 as it's available
print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
# Add padding token if it doesn't exist
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Model loaded successfully!")
return True
except Exception as e:
print(f"Error loading model: {e}")
return False
def gemma_3_inference(prompt_text: str, pil_image: Optional[Image.Image] = None, chat_history: Optional[List[Tuple[str, str]]] = None) -> str:
"""
Gemma-2 model inference function.
"""
global model, tokenizer
# Load model if not already loaded
if model is None or tokenizer is None:
if not load_gemma_model():
return "❌ Error: Could not load Gemma model. Please check your internet connection and try again."
try:
# Build conversation context
conversation = []
# Add chat history for context (last 3 exchanges)
if chat_history:
for user_msg, bot_msg in chat_history[-3:]:
conversation.append({"role": "user", "content": user_msg})
conversation.append({"role": "assistant", "content": bot_msg})
# Handle image input (note: Gemma-2 doesn't have native vision, so we'll describe the limitation)
if pil_image:
prompt_text = f"[Image uploaded - Note: This model doesn't have vision capabilities, but I can help with text-based questions about images] {prompt_text}"
# Add current user message
conversation.append({"role": "user", "content": prompt_text})
# Format conversation for Gemma
formatted_prompt = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
# Tokenize input
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=2048
)
# Move to same device as model
if torch.cuda.is_available() and model.device.type == 'cuda':
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the new generated part
response = response[len(formatted_prompt):].strip()
return f"πŸ€– Gemma Response: {response}"
except Exception as e:
return f"❌ Error generating response: {str(e)}. Please try again."
def extract_text_from_pdf(file_path: str) -> str:
"""
Extract text from PDF file using PyPDF2.
"""
try:
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
return text.strip()
except Exception as e:
return f"Error reading PDF: {str(e)}"
def extract_text_from_txt(file_path: str) -> str:
"""
Extract text from TXT file.
"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read().strip()
except Exception as e:
try:
# Try with different encoding if UTF-8 fails
with open(file_path, 'r', encoding='latin-1') as file:
return file.read().strip()
except Exception as e2:
return f"Error reading text file: {str(e2)}"
def process_file_input(file_input) -> str:
"""
Process uploaded file and extract text content.
"""
if file_input is None:
return ""
file_path = file_input.name
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension == '.pdf':
extracted_text = extract_text_from_pdf(file_path)
return f"πŸ“„ Content from PDF ({os.path.basename(file_path)}):\n{extracted_text[:1000]}{'...' if len(extracted_text) > 1000 else ''}"
elif file_extension == '.txt':
extracted_text = extract_text_from_txt(file_path)
return f"πŸ“ Content from text file ({os.path.basename(file_path)}):\n{extracted_text[:1000]}{'...' if len(extracted_text) > 1000 else ''}"
else:
return f"❌ Unsupported file type: {file_extension}. Please upload PDF or TXT files only."
def process_input(user_text: str, image_input: Optional[Image.Image], file_input, chat_history: List[Tuple[str, str]], file_context: str) -> Tuple[List[Tuple[str, str]], str, None, None, str]:
"""
Main function to process user input and generate response.
Returns: (updated_chat_history, cleared_text, cleared_image, cleared_file, updated_file_context)
"""
if not user_text.strip() and image_input is None and file_input is None:
return chat_history, "", None, None, file_context
# Process file input if provided
current_file_context = ""
if file_input is not None:
current_file_context = process_file_input(file_input)
# Combine file context with user text
combined_prompt = ""
if current_file_context:
combined_prompt = f"{current_file_context}\n\nUser Query: {user_text}"
# Update persistent file context
file_context = current_file_context
elif file_context and user_text.strip(): # Use previous file context if available
combined_prompt = f"{file_context}\n\nUser Query: {user_text}"
else:
combined_prompt = user_text
# Generate response using Gemma model
if image_input is not None:
# Handle image + text input
bot_response = gemma_3_inference(combined_prompt, pil_image=image_input, chat_history=chat_history)
user_display = f"{user_text} [Image uploaded]"
else:
# Handle text-only input (potentially with file context)
bot_response = gemma_3_inference(combined_prompt, chat_history=chat_history)
if current_file_context:
user_display = f"{user_text} [File: {os.path.basename(file_input.name) if file_input else 'Unknown'}]"
else:
user_display = user_text
# Update chat history
chat_history.append((user_display, bot_response))
# Return updated history and clear inputs
return chat_history, "", None, None, file_context if current_file_context else file_context
def clear_chat(chat_history: List[Tuple[str, str]], file_context: str) -> Tuple[List[Tuple[str, str]], str, None, None, str]:
"""
Clear chat history and reset all inputs.
"""
return [], "", None, None, ""
# Create Gradio interface
with gr.Blocks(title="Gemma-2 Multimodal Chat", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸš€ Gemma-2 Multimodal Chat Application
Welcome to the sophisticated Gemma-2 chat interface powered by Google's Gemma-2-2B-IT model! This application supports:
- πŸ’¬ **Text conversations** with persistent chat history
- πŸ–ΌοΈ **File processing** - upload PDF or TXT files for context
- πŸ“„ **Document analysis** - extract and analyze text from uploaded files
- 🧠 **Contextual responses** - the model remembers your conversation
**How to use:**
1. Type your message in the text box
2. Optionally upload a file (PDF/TXT) for document analysis
3. Click Submit or press Enter
4. Use Clear to reset the conversation
*Note: This application uses the real Gemma-2-2B-IT model from Hugging Face. First message may take longer as the model loads.*
"""
)
# State variables
chat_history_state = gr.State([])
file_context_state = gr.State("")
with gr.Row():
with gr.Column(scale=2):
# Chat interface
chatbot = gr.Chatbot(
label="Chat History",
height=400,
show_label=True,
container=True,
bubble_full_width=False
)
# Input area
with gr.Row():
user_input = gr.Textbox(
label="Your message",
placeholder="Type your message here...",
lines=2,
scale=4
)
submit_btn = gr.Button("Submit", variant="primary", scale=1)
# Clear button
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary")
with gr.Column(scale=1):
# File upload area
gr.Markdown("### πŸ“Ž Upload Content")
image_input = gr.Image(
label="Upload Image (for vision tasks)",
type="pil",
height=200
)
file_input = gr.File(
label="Upload File (PDF or TXT)",
file_types=[".pdf", ".txt"],
height=100
)
gr.Markdown(
"""
**Tips:**
- Upload either an image OR a file per message
- PDF files will have their text extracted
- File content persists as context for follow-up questions
- Note: Gemma-2 doesn't have native vision capabilities, but you can still upload images and ask text-based questions about them
"""
)
# Event handlers
submit_btn.click(
fn=process_input,
inputs=[user_input, image_input, file_input, chat_history_state, file_context_state],
outputs=[chatbot, user_input, image_input, file_input, file_context_state]
).then(
lambda: gr.update(value=chat_history_state.value),
outputs=[chat_history_state]
)
user_input.submit(
fn=process_input,
inputs=[user_input, image_input, file_input, chat_history_state, file_context_state],
outputs=[chatbot, user_input, image_input, file_input, file_context_state]
).then(
lambda: gr.update(value=chat_history_state.value),
outputs=[chat_history_state]
)
clear_btn.click(
fn=clear_chat,
inputs=[chat_history_state, file_context_state],
outputs=[chatbot, user_input, image_input, file_input, file_context_state]
).then(
lambda: gr.update(value=[]),
outputs=[chat_history_state]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)