import gradio as gr import torch from typing import List, Dict, Any from sentence_transformers import SentenceTransformer from transformers import AutoModelForCausalLM, AutoTokenizer import psycopg2 import numpy as np from dataclasses import dataclass from datetime import datetime import json import os @dataclass class ChatConfig: max_tokens: int = 512 temperature: float = 0.7 top_p: float = 0.95 top_k: int = 3 system_prompt: str = "You are a helpful AI assistant that provides accurate information based on the given context." class RAGPipeline: def __init__(self): self.connection_string = "postgresql://Data_owner:JsxygNDC15IO@ep-cool-hill-a5k13m05-pooler.us-east-2.aws.neon.tech/Data?sslmode=require" self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') self.load_llm() self.chat_config = ChatConfig() def load_llm(self): self.llm_model = AutoModelForCausalLM.from_pretrained( "deepseek-ai/DeepSeek-R1", trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", quantization_config={ "load_in_4bit": True, "bnb_4bit_compute_dtype": torch.float16 if torch.cuda.is_available() else torch.float32, "bnb_4bit_quant_type": "nf4", "bnb_4bit_use_double_quant": True } ) self.llm_tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1", trust_remote_code=True) self.device = "cuda" if torch.cuda.is_available() else "cpu" def generate_embedding(self, text: str) -> List[float]: return self.embedding_model.encode(text).tolist() def similarity_search(self, query_embedding: List[float]) -> List[dict]: try: with psycopg2.connect(self.connection_string) as conn: with conn.cursor() as cur: embedding_array = np.array(query_embedding) query = """ SELECT text, title, url, 1 - (vector <=> %s) as similarity FROM bents ORDER BY vector <=> %s LIMIT %s; """ cur.execute(query, (embedding_array.tolist(), embedding_array.tolist(), self.chat_config.top_k)) results = cur.fetchall() return [ { 'text': row[0], 'title': row[1], 'url': row[2], 'similarity': row[3] } for row in results ] except Exception as e: print(f"Database error: {str(e)}") return [] def format_conversation(self, messages: List[Dict[str, str]]) -> str: formatted = f"System: {self.chat_config.system_prompt}\n\n" for msg in messages: role = msg["role"].capitalize() content = msg["content"] formatted += f"{role}: {content}\n\n" return formatted.strip() def generate_response(self, messages: List[Dict[str, str]], context: str) -> str: try: conversation = self.format_conversation(messages) context_prompt = f"Context:\n{context}\n\nCurrent conversation:\n{conversation}\n\nAssistant:" inputs = self.llm_tokenizer(context_prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.device) with torch.no_grad(): outputs = self.llm_model.generate( **inputs, max_new_tokens=self.chat_config.max_tokens, do_sample=True, temperature=self.chat_config.temperature, top_p=self.chat_config.top_p, pad_token_id=self.llm_tokenizer.eos_token_id, ) response = self.llm_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) return response.strip() except Exception as e: return f"Error generating response: {str(e)}" def process_query(self, message: str, chat_history: List[Dict[str, str]]) -> tuple[str, List[dict]]: query_embedding = self.generate_embedding(message) similar_docs = self.similarity_search(query_embedding) context = "\n".join([doc['text'] for doc in similar_docs]) messages = chat_history + [{"role": "user", "content": message}] response = self.generate_response(messages, context) return response, similar_docs class GradioRAGChat: def __init__(self): self.rag = RAGPipeline() self.chat_history = [] def process_message(self, message: str, history: List[tuple[str, str]]) -> tuple[str, List[dict]]: # Convert Gradio history format to our format chat_history = [] for user_msg, assistant_msg in history: if user_msg: chat_history.append({"role": "user", "content": user_msg}) if assistant_msg: chat_history.append({"role": "assistant", "content": assistant_msg}) response, sources = self.rag.process_query(message, chat_history) # Format response with sources formatted_sources = "\n\nSources:\n" + "\n".join([ f"- {doc['title']} (Similarity: {doc['similarity']:.2f})\n URL: {doc['url']}" for doc in sources ]) return response + formatted_sources def update_config( self, max_tokens: int, temperature: float, top_p: float, top_k: int, system_prompt: str ) -> str: self.rag.chat_config = ChatConfig( max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, system_prompt=system_prompt ) return f"Configuration updated successfully at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" def create_interface(self): with gr.Blocks(theme=gr.themes.Soft()) as interface: gr.Markdown("# RAG-Powered Chat Assistant") with gr.Tabs(): with gr.Tab("Chat"): chatbot = gr.ChatInterface( fn=self.process_message, title="", description="Ask questions about the content in the database." ) with gr.Tab("Configuration"): with gr.Group(): gr.Markdown("### Model Parameters") with gr.Row(): max_tokens = gr.Slider( minimum=64, maximum=2048, value=512, step=64, label="Max Tokens" ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature" ) with gr.Row(): top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p" ) top_k = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Top-k Documents" ) system_prompt = gr.Textbox( value=self.rag.chat_config.system_prompt, label="System Prompt", lines=3 ) update_btn = gr.Button("Update Configuration") config_status = gr.Textbox(label="Status", interactive=False) update_btn.click( fn=self.update_config, inputs=[max_tokens, temperature, top_p, top_k, system_prompt], outputs=[config_status] ) gr.Markdown(""" ### About This chat interface uses RAG (Retrieval Augmented Generation) to provide informed responses based on the content in the database. The assistant retrieves relevant documents and uses them as context for generating responses. - Use the Chat tab for asking questions - Use the Configuration tab to adjust model parameters """) return interface def main(): chat_app = GradioRAGChat() interface = chat_app.create_interface() interface.launch(share=False) # Set share=True for public URL if __name__ == "__main__": main()