Spaces:
Runtime error
Runtime error
File size: 9,186 Bytes
1332eb3 07be9b4 0f3adc8 1332eb3 0f3adc8 07be9b4 15f38f6 0f3adc8 07be9b4 0f3adc8 07be9b4 0f3adc8 07be9b4 0f3adc8 07be9b4 0f3adc8 07be9b4 0f3adc8 07be9b4 0f3adc8 07be9b4 0f3adc8 07be9b4 1332eb3 07be9b4 1332eb3 0f3adc8 07be9b4 1332eb3 07be9b4 1332eb3 07be9b4 1332eb3 07be9b4 1332eb3 07be9b4 15f38f6 07be9b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import gradio as gr
import torch
from typing import List, Dict, Any
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
import psycopg2
import numpy as np
from dataclasses import dataclass
from datetime import datetime
import json
import os
@dataclass
class ChatConfig:
max_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 3
system_prompt: str = "You are a helpful AI assistant that provides accurate information based on the given context."
class RAGPipeline:
def __init__(self):
self.connection_string = "postgresql://Data_owner:JsxygNDC15IO@ep-cool-hill-a5k13m05-pooler.us-east-2.aws.neon.tech/Data?sslmode=require"
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
self.load_llm()
self.chat_config = ChatConfig()
def load_llm(self):
self.llm_model = AutoModelForCausalLM.from_pretrained(
"deepseek-ai/DeepSeek-R1",
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
quantization_config={
"load_in_4bit": True,
"bnb_4bit_compute_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": True
}
)
self.llm_tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1", trust_remote_code=True)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def generate_embedding(self, text: str) -> List[float]:
return self.embedding_model.encode(text).tolist()
def similarity_search(self, query_embedding: List[float]) -> List[dict]:
try:
with psycopg2.connect(self.connection_string) as conn:
with conn.cursor() as cur:
embedding_array = np.array(query_embedding)
query = """
SELECT text, title, url,
1 - (vector <=> %s) as similarity
FROM bents
ORDER BY vector <=> %s
LIMIT %s;
"""
cur.execute(query, (embedding_array.tolist(), embedding_array.tolist(), self.chat_config.top_k))
results = cur.fetchall()
return [
{
'text': row[0],
'title': row[1],
'url': row[2],
'similarity': row[3]
}
for row in results
]
except Exception as e:
print(f"Database error: {str(e)}")
return []
def format_conversation(self, messages: List[Dict[str, str]]) -> str:
formatted = f"System: {self.chat_config.system_prompt}\n\n"
for msg in messages:
role = msg["role"].capitalize()
content = msg["content"]
formatted += f"{role}: {content}\n\n"
return formatted.strip()
def generate_response(self, messages: List[Dict[str, str]], context: str) -> str:
try:
conversation = self.format_conversation(messages)
context_prompt = f"Context:\n{context}\n\nCurrent conversation:\n{conversation}\n\nAssistant:"
inputs = self.llm_tokenizer(context_prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.device)
with torch.no_grad():
outputs = self.llm_model.generate(
**inputs,
max_new_tokens=self.chat_config.max_tokens,
do_sample=True,
temperature=self.chat_config.temperature,
top_p=self.chat_config.top_p,
pad_token_id=self.llm_tokenizer.eos_token_id,
)
response = self.llm_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response.strip()
except Exception as e:
return f"Error generating response: {str(e)}"
def process_query(self, message: str, chat_history: List[Dict[str, str]]) -> tuple[str, List[dict]]:
query_embedding = self.generate_embedding(message)
similar_docs = self.similarity_search(query_embedding)
context = "\n".join([doc['text'] for doc in similar_docs])
messages = chat_history + [{"role": "user", "content": message}]
response = self.generate_response(messages, context)
return response, similar_docs
class GradioRAGChat:
def __init__(self):
self.rag = RAGPipeline()
self.chat_history = []
def process_message(self, message: str, history: List[tuple[str, str]]) -> tuple[str, List[dict]]:
# Convert Gradio history format to our format
chat_history = []
for user_msg, assistant_msg in history:
if user_msg:
chat_history.append({"role": "user", "content": user_msg})
if assistant_msg:
chat_history.append({"role": "assistant", "content": assistant_msg})
response, sources = self.rag.process_query(message, chat_history)
# Format response with sources
formatted_sources = "\n\nSources:\n" + "\n".join([
f"- {doc['title']} (Similarity: {doc['similarity']:.2f})\n URL: {doc['url']}"
for doc in sources
])
return response + formatted_sources
def update_config(
self,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
system_prompt: str
) -> str:
self.rag.chat_config = ChatConfig(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
system_prompt=system_prompt
)
return f"Configuration updated successfully at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
def create_interface(self):
with gr.Blocks(theme=gr.themes.Soft()) as interface:
gr.Markdown("# RAG-Powered Chat Assistant")
with gr.Tabs():
with gr.Tab("Chat"):
chatbot = gr.ChatInterface(
fn=self.process_message,
title="",
description="Ask questions about the content in the database."
)
with gr.Tab("Configuration"):
with gr.Group():
gr.Markdown("### Model Parameters")
with gr.Row():
max_tokens = gr.Slider(
minimum=64, maximum=2048, value=512, step=64,
label="Max Tokens"
)
temperature = gr.Slider(
minimum=0.1, maximum=2.0, value=0.7, step=0.1,
label="Temperature"
)
with gr.Row():
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05,
label="Top-p"
)
top_k = gr.Slider(
minimum=1, maximum=10, value=3, step=1,
label="Top-k Documents"
)
system_prompt = gr.Textbox(
value=self.rag.chat_config.system_prompt,
label="System Prompt",
lines=3
)
update_btn = gr.Button("Update Configuration")
config_status = gr.Textbox(label="Status", interactive=False)
update_btn.click(
fn=self.update_config,
inputs=[max_tokens, temperature, top_p, top_k, system_prompt],
outputs=[config_status]
)
gr.Markdown("""
### About
This chat interface uses RAG (Retrieval Augmented Generation) to provide informed responses based on the content in the database.
The assistant retrieves relevant documents and uses them as context for generating responses.
- Use the Chat tab for asking questions
- Use the Configuration tab to adjust model parameters
""")
return interface
def main():
chat_app = GradioRAGChat()
interface = chat_app.create_interface()
interface.launch(share=False) # Set share=True for public URL
if __name__ == "__main__":
main() |