Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from typing import List, Dict, Any | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import psycopg2 | |
| import numpy as np | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| import json | |
| import os | |
| class ChatConfig: | |
| max_tokens: int = 512 | |
| temperature: float = 0.7 | |
| top_p: float = 0.95 | |
| top_k: int = 3 | |
| system_prompt: str = "You are a helpful AI assistant that provides accurate information based on the given context." | |
| class RAGPipeline: | |
| def __init__(self): | |
| self.connection_string = "postgresql://Data_owner:JsxygNDC15IO@ep-cool-hill-a5k13m05-pooler.us-east-2.aws.neon.tech/Data?sslmode=require" | |
| self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| self.load_llm() | |
| self.chat_config = ChatConfig() | |
| def load_llm(self): | |
| self.llm_model = AutoModelForCausalLM.from_pretrained( | |
| "deepseek-ai/DeepSeek-R1", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| quantization_config={ | |
| "load_in_4bit": True, | |
| "bnb_4bit_compute_dtype": torch.float16 if torch.cuda.is_available() else torch.float32, | |
| "bnb_4bit_quant_type": "nf4", | |
| "bnb_4bit_use_double_quant": True | |
| } | |
| ) | |
| self.llm_tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1", trust_remote_code=True) | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def generate_embedding(self, text: str) -> List[float]: | |
| return self.embedding_model.encode(text).tolist() | |
| def similarity_search(self, query_embedding: List[float]) -> List[dict]: | |
| try: | |
| with psycopg2.connect(self.connection_string) as conn: | |
| with conn.cursor() as cur: | |
| embedding_array = np.array(query_embedding) | |
| query = """ | |
| SELECT text, title, url, | |
| 1 - (vector <=> %s) as similarity | |
| FROM bents | |
| ORDER BY vector <=> %s | |
| LIMIT %s; | |
| """ | |
| cur.execute(query, (embedding_array.tolist(), embedding_array.tolist(), self.chat_config.top_k)) | |
| results = cur.fetchall() | |
| return [ | |
| { | |
| 'text': row[0], | |
| 'title': row[1], | |
| 'url': row[2], | |
| 'similarity': row[3] | |
| } | |
| for row in results | |
| ] | |
| except Exception as e: | |
| print(f"Database error: {str(e)}") | |
| return [] | |
| def format_conversation(self, messages: List[Dict[str, str]]) -> str: | |
| formatted = f"System: {self.chat_config.system_prompt}\n\n" | |
| for msg in messages: | |
| role = msg["role"].capitalize() | |
| content = msg["content"] | |
| formatted += f"{role}: {content}\n\n" | |
| return formatted.strip() | |
| def generate_response(self, messages: List[Dict[str, str]], context: str) -> str: | |
| try: | |
| conversation = self.format_conversation(messages) | |
| context_prompt = f"Context:\n{context}\n\nCurrent conversation:\n{conversation}\n\nAssistant:" | |
| inputs = self.llm_tokenizer(context_prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.llm_model.generate( | |
| **inputs, | |
| max_new_tokens=self.chat_config.max_tokens, | |
| do_sample=True, | |
| temperature=self.chat_config.temperature, | |
| top_p=self.chat_config.top_p, | |
| pad_token_id=self.llm_tokenizer.eos_token_id, | |
| ) | |
| response = self.llm_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| return response.strip() | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}" | |
| def process_query(self, message: str, chat_history: List[Dict[str, str]]) -> tuple[str, List[dict]]: | |
| query_embedding = self.generate_embedding(message) | |
| similar_docs = self.similarity_search(query_embedding) | |
| context = "\n".join([doc['text'] for doc in similar_docs]) | |
| messages = chat_history + [{"role": "user", "content": message}] | |
| response = self.generate_response(messages, context) | |
| return response, similar_docs | |
| class GradioRAGChat: | |
| def __init__(self): | |
| self.rag = RAGPipeline() | |
| self.chat_history = [] | |
| def process_message(self, message: str, history: List[tuple[str, str]]) -> tuple[str, List[dict]]: | |
| # Convert Gradio history format to our format | |
| chat_history = [] | |
| for user_msg, assistant_msg in history: | |
| if user_msg: | |
| chat_history.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| chat_history.append({"role": "assistant", "content": assistant_msg}) | |
| response, sources = self.rag.process_query(message, chat_history) | |
| # Format response with sources | |
| formatted_sources = "\n\nSources:\n" + "\n".join([ | |
| f"- {doc['title']} (Similarity: {doc['similarity']:.2f})\n URL: {doc['url']}" | |
| for doc in sources | |
| ]) | |
| return response + formatted_sources | |
| def update_config( | |
| self, | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| system_prompt: str | |
| ) -> str: | |
| self.rag.chat_config = ChatConfig( | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| system_prompt=system_prompt | |
| ) | |
| return f"Configuration updated successfully at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" | |
| def create_interface(self): | |
| with gr.Blocks(theme=gr.themes.Soft()) as interface: | |
| gr.Markdown("# RAG-Powered Chat Assistant") | |
| with gr.Tabs(): | |
| with gr.Tab("Chat"): | |
| chatbot = gr.ChatInterface( | |
| fn=self.process_message, | |
| title="", | |
| description="Ask questions about the content in the database." | |
| ) | |
| with gr.Tab("Configuration"): | |
| with gr.Group(): | |
| gr.Markdown("### Model Parameters") | |
| with gr.Row(): | |
| max_tokens = gr.Slider( | |
| minimum=64, maximum=2048, value=512, step=64, | |
| label="Max Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, maximum=2.0, value=0.7, step=0.1, | |
| label="Temperature" | |
| ) | |
| with gr.Row(): | |
| top_p = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.95, step=0.05, | |
| label="Top-p" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, maximum=10, value=3, step=1, | |
| label="Top-k Documents" | |
| ) | |
| system_prompt = gr.Textbox( | |
| value=self.rag.chat_config.system_prompt, | |
| label="System Prompt", | |
| lines=3 | |
| ) | |
| update_btn = gr.Button("Update Configuration") | |
| config_status = gr.Textbox(label="Status", interactive=False) | |
| update_btn.click( | |
| fn=self.update_config, | |
| inputs=[max_tokens, temperature, top_p, top_k, system_prompt], | |
| outputs=[config_status] | |
| ) | |
| gr.Markdown(""" | |
| ### About | |
| This chat interface uses RAG (Retrieval Augmented Generation) to provide informed responses based on the content in the database. | |
| The assistant retrieves relevant documents and uses them as context for generating responses. | |
| - Use the Chat tab for asking questions | |
| - Use the Configuration tab to adjust model parameters | |
| """) | |
| return interface | |
| def main(): | |
| chat_app = GradioRAGChat() | |
| interface = chat_app.create_interface() | |
| interface.launch(share=False) # Set share=True for public URL | |
| if __name__ == "__main__": | |
| main() |