|
import gradio as gr |
|
import os |
|
from PIL import Image |
|
import tempfile |
|
import PyPDF2 |
|
import io |
|
from typing import List, Tuple, Optional |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
|
|
def load_gemma_model(): |
|
""" |
|
Load Gemma model and tokenizer from Hugging Face. |
|
""" |
|
global model, tokenizer |
|
try: |
|
model_name = "google/gemma-2-2b-it" |
|
print(f"Loading {model_name}...") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None |
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
print("Model loaded successfully!") |
|
return True |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
return False |
|
|
|
def gemma_3_inference(prompt_text: str, pil_image: Optional[Image.Image] = None, chat_history: Optional[List[Tuple[str, str]]] = None) -> str: |
|
""" |
|
Gemma-2 model inference function. |
|
""" |
|
global model, tokenizer |
|
|
|
|
|
if model is None or tokenizer is None: |
|
if not load_gemma_model(): |
|
return "β Error: Could not load Gemma model. Please check your internet connection and try again." |
|
|
|
try: |
|
|
|
conversation = [] |
|
|
|
|
|
if chat_history: |
|
for user_msg, bot_msg in chat_history[-3:]: |
|
conversation.append({"role": "user", "content": user_msg}) |
|
conversation.append({"role": "assistant", "content": bot_msg}) |
|
|
|
|
|
if pil_image: |
|
prompt_text = f"[Image uploaded - Note: This model doesn't have vision capabilities, but I can help with text-based questions about images] {prompt_text}" |
|
|
|
|
|
conversation.append({"role": "user", "content": prompt_text}) |
|
|
|
|
|
formatted_prompt = tokenizer.apply_chat_template( |
|
conversation, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
|
|
inputs = tokenizer( |
|
formatted_prompt, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=2048 |
|
) |
|
|
|
|
|
if torch.cuda.is_available() and model.device.type == 'cuda': |
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=512, |
|
temperature=0.7, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
response = response[len(formatted_prompt):].strip() |
|
|
|
return f"π€ Gemma Response: {response}" |
|
|
|
except Exception as e: |
|
return f"β Error generating response: {str(e)}. Please try again." |
|
|
|
def extract_text_from_pdf(file_path: str) -> str: |
|
""" |
|
Extract text from PDF file using PyPDF2. |
|
""" |
|
try: |
|
with open(file_path, 'rb') as file: |
|
pdf_reader = PyPDF2.PdfReader(file) |
|
text = "" |
|
for page in pdf_reader.pages: |
|
text += page.extract_text() + "\n" |
|
return text.strip() |
|
except Exception as e: |
|
return f"Error reading PDF: {str(e)}" |
|
|
|
def extract_text_from_txt(file_path: str) -> str: |
|
""" |
|
Extract text from TXT file. |
|
""" |
|
try: |
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
return file.read().strip() |
|
except Exception as e: |
|
try: |
|
|
|
with open(file_path, 'r', encoding='latin-1') as file: |
|
return file.read().strip() |
|
except Exception as e2: |
|
return f"Error reading text file: {str(e2)}" |
|
|
|
def process_file_input(file_input) -> str: |
|
""" |
|
Process uploaded file and extract text content. |
|
""" |
|
if file_input is None: |
|
return "" |
|
|
|
file_path = file_input.name |
|
file_extension = os.path.splitext(file_path)[1].lower() |
|
|
|
if file_extension == '.pdf': |
|
extracted_text = extract_text_from_pdf(file_path) |
|
return f"π Content from PDF ({os.path.basename(file_path)}):\n{extracted_text[:1000]}{'...' if len(extracted_text) > 1000 else ''}" |
|
elif file_extension == '.txt': |
|
extracted_text = extract_text_from_txt(file_path) |
|
return f"π Content from text file ({os.path.basename(file_path)}):\n{extracted_text[:1000]}{'...' if len(extracted_text) > 1000 else ''}" |
|
else: |
|
return f"β Unsupported file type: {file_extension}. Please upload PDF or TXT files only." |
|
|
|
def process_input(user_text: str, image_input: Optional[Image.Image], file_input, chat_history: List[Tuple[str, str]], file_context: str) -> Tuple[List[Tuple[str, str]], str, None, None, str]: |
|
""" |
|
Main function to process user input and generate response. |
|
Returns: (updated_chat_history, cleared_text, cleared_image, cleared_file, updated_file_context) |
|
""" |
|
if not user_text.strip() and image_input is None and file_input is None: |
|
return chat_history, "", None, None, file_context |
|
|
|
|
|
current_file_context = "" |
|
if file_input is not None: |
|
current_file_context = process_file_input(file_input) |
|
|
|
|
|
combined_prompt = "" |
|
if current_file_context: |
|
combined_prompt = f"{current_file_context}\n\nUser Query: {user_text}" |
|
|
|
file_context = current_file_context |
|
elif file_context and user_text.strip(): |
|
combined_prompt = f"{file_context}\n\nUser Query: {user_text}" |
|
else: |
|
combined_prompt = user_text |
|
|
|
|
|
if image_input is not None: |
|
|
|
bot_response = gemma_3_inference(combined_prompt, pil_image=image_input, chat_history=chat_history) |
|
user_display = f"{user_text} [Image uploaded]" |
|
else: |
|
|
|
bot_response = gemma_3_inference(combined_prompt, chat_history=chat_history) |
|
if current_file_context: |
|
user_display = f"{user_text} [File: {os.path.basename(file_input.name) if file_input else 'Unknown'}]" |
|
else: |
|
user_display = user_text |
|
|
|
|
|
chat_history.append((user_display, bot_response)) |
|
|
|
|
|
return chat_history, "", None, None, file_context if current_file_context else file_context |
|
|
|
def clear_chat(chat_history: List[Tuple[str, str]], file_context: str) -> Tuple[List[Tuple[str, str]], str, None, None, str]: |
|
""" |
|
Clear chat history and reset all inputs. |
|
""" |
|
return [], "", None, None, "" |
|
|
|
|
|
with gr.Blocks(title="Gemma-2 Multimodal Chat", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
""" |
|
# π Gemma-2 Multimodal Chat Application |
|
|
|
Welcome to the sophisticated Gemma-2 chat interface powered by Google's Gemma-2-2B-IT model! This application supports: |
|
- π¬ **Text conversations** with persistent chat history |
|
- πΌοΈ **File processing** - upload PDF or TXT files for context |
|
- π **Document analysis** - extract and analyze text from uploaded files |
|
- π§ **Contextual responses** - the model remembers your conversation |
|
|
|
**How to use:** |
|
1. Type your message in the text box |
|
2. Optionally upload a file (PDF/TXT) for document analysis |
|
3. Click Submit or press Enter |
|
4. Use Clear to reset the conversation |
|
|
|
*Note: This application uses the real Gemma-2-2B-IT model from Hugging Face. First message may take longer as the model loads.* |
|
""" |
|
) |
|
|
|
|
|
chat_history_state = gr.State([]) |
|
file_context_state = gr.State("") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
|
|
chatbot = gr.Chatbot( |
|
label="Chat History", |
|
height=400, |
|
show_label=True, |
|
container=True, |
|
bubble_full_width=False |
|
) |
|
|
|
|
|
with gr.Row(): |
|
user_input = gr.Textbox( |
|
label="Your message", |
|
placeholder="Type your message here...", |
|
lines=2, |
|
scale=4 |
|
) |
|
submit_btn = gr.Button("Submit", variant="primary", scale=1) |
|
|
|
|
|
clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") |
|
|
|
with gr.Column(scale=1): |
|
|
|
gr.Markdown("### π Upload Content") |
|
|
|
image_input = gr.Image( |
|
label="Upload Image (for vision tasks)", |
|
type="pil", |
|
height=200 |
|
) |
|
|
|
file_input = gr.File( |
|
label="Upload File (PDF or TXT)", |
|
file_types=[".pdf", ".txt"], |
|
height=100 |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
**Tips:** |
|
- Upload either an image OR a file per message |
|
- PDF files will have their text extracted |
|
- File content persists as context for follow-up questions |
|
- Note: Gemma-2 doesn't have native vision capabilities, but you can still upload images and ask text-based questions about them |
|
""" |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=process_input, |
|
inputs=[user_input, image_input, file_input, chat_history_state, file_context_state], |
|
outputs=[chatbot, user_input, image_input, file_input, file_context_state] |
|
).then( |
|
lambda: gr.update(value=chat_history_state.value), |
|
outputs=[chat_history_state] |
|
) |
|
|
|
user_input.submit( |
|
fn=process_input, |
|
inputs=[user_input, image_input, file_input, chat_history_state, file_context_state], |
|
outputs=[chatbot, user_input, image_input, file_input, file_context_state] |
|
).then( |
|
lambda: gr.update(value=chat_history_state.value), |
|
outputs=[chat_history_state] |
|
) |
|
|
|
clear_btn.click( |
|
fn=clear_chat, |
|
inputs=[chat_history_state, file_context_state], |
|
outputs=[chatbot, user_input, image_input, file_input, file_context_state] |
|
).then( |
|
lambda: gr.update(value=[]), |
|
outputs=[chat_history_state] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=True, |
|
show_error=True |
|
) |