Deepseek-R1-PF / app.py
MRasheq's picture
Fourth Commit
9ca5fa9
raw
history blame
9.19 kB
import gradio as gr
import torch
from typing import List, Dict, Any
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
import psycopg2
import numpy as np
from dataclasses import dataclass
from datetime import datetime
import json
import os
@dataclass
class ChatConfig:
max_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 3
system_prompt: str = "You are a helpful AI assistant that provides accurate information based on the given context."
class RAGPipeline:
def __init__(self):
self.connection_string = "postgresql://Data_owner:JsxygNDC15IO@ep-cool-hill-a5k13m05-pooler.us-east-2.aws.neon.tech/Data?sslmode=require"
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
self.load_llm()
self.chat_config = ChatConfig()
def load_llm(self):
self.llm_model = AutoModelForCausalLM.from_pretrained(
"deepseek-ai/DeepSeek-R1",
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
quantization_config={
"load_in_4bit": True,
"bnb_4bit_compute_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": True
}
)
self.llm_tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1", trust_remote_code=True)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def generate_embedding(self, text: str) -> List[float]:
return self.embedding_model.encode(text).tolist()
def similarity_search(self, query_embedding: List[float]) -> List[dict]:
try:
with psycopg2.connect(self.connection_string) as conn:
with conn.cursor() as cur:
embedding_array = np.array(query_embedding)
query = """
SELECT text, title, url,
1 - (vector <=> %s) as similarity
FROM bents
ORDER BY vector <=> %s
LIMIT %s;
"""
cur.execute(query, (embedding_array.tolist(), embedding_array.tolist(), self.chat_config.top_k))
results = cur.fetchall()
return [
{
'text': row[0],
'title': row[1],
'url': row[2],
'similarity': row[3]
}
for row in results
]
except Exception as e:
print(f"Database error: {str(e)}")
return []
def format_conversation(self, messages: List[Dict[str, str]]) -> str:
formatted = f"System: {self.chat_config.system_prompt}\n\n"
for msg in messages:
role = msg["role"].capitalize()
content = msg["content"]
formatted += f"{role}: {content}\n\n"
return formatted.strip()
def generate_response(self, messages: List[Dict[str, str]], context: str) -> str:
try:
conversation = self.format_conversation(messages)
context_prompt = f"Context:\n{context}\n\nCurrent conversation:\n{conversation}\n\nAssistant:"
inputs = self.llm_tokenizer(context_prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.device)
with torch.no_grad():
outputs = self.llm_model.generate(
**inputs,
max_new_tokens=self.chat_config.max_tokens,
do_sample=True,
temperature=self.chat_config.temperature,
top_p=self.chat_config.top_p,
pad_token_id=self.llm_tokenizer.eos_token_id,
)
response = self.llm_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response.strip()
except Exception as e:
return f"Error generating response: {str(e)}"
def process_query(self, message: str, chat_history: List[Dict[str, str]]) -> tuple[str, List[dict]]:
query_embedding = self.generate_embedding(message)
similar_docs = self.similarity_search(query_embedding)
context = "\n".join([doc['text'] for doc in similar_docs])
messages = chat_history + [{"role": "user", "content": message}]
response = self.generate_response(messages, context)
return response, similar_docs
class GradioRAGChat:
def __init__(self):
self.rag = RAGPipeline()
self.chat_history = []
def process_message(self, message: str, history: List[tuple[str, str]]) -> tuple[str, List[dict]]:
# Convert Gradio history format to our format
chat_history = []
for user_msg, assistant_msg in history:
if user_msg:
chat_history.append({"role": "user", "content": user_msg})
if assistant_msg:
chat_history.append({"role": "assistant", "content": assistant_msg})
response, sources = self.rag.process_query(message, chat_history)
# Format response with sources
formatted_sources = "\n\nSources:\n" + "\n".join([
f"- {doc['title']} (Similarity: {doc['similarity']:.2f})\n URL: {doc['url']}"
for doc in sources
])
return response + formatted_sources
def update_config(
self,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
system_prompt: str
) -> str:
self.rag.chat_config = ChatConfig(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
system_prompt=system_prompt
)
return f"Configuration updated successfully at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
def create_interface(self):
with gr.Blocks(theme=gr.themes.Soft()) as interface:
gr.Markdown("# RAG-Powered Chat Assistant")
with gr.Tabs():
with gr.Tab("Chat"):
chatbot = gr.ChatInterface(
fn=self.process_message,
title="",
description="Ask questions about the content in the database."
)
with gr.Tab("Configuration"):
with gr.Group():
gr.Markdown("### Model Parameters")
with gr.Row():
max_tokens = gr.Slider(
minimum=64, maximum=2048, value=512, step=64,
label="Max Tokens"
)
temperature = gr.Slider(
minimum=0.1, maximum=2.0, value=0.7, step=0.1,
label="Temperature"
)
with gr.Row():
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05,
label="Top-p"
)
top_k = gr.Slider(
minimum=1, maximum=10, value=3, step=1,
label="Top-k Documents"
)
system_prompt = gr.Textbox(
value=self.rag.chat_config.system_prompt,
label="System Prompt",
lines=3
)
update_btn = gr.Button("Update Configuration")
config_status = gr.Textbox(label="Status", interactive=False)
update_btn.click(
fn=self.update_config,
inputs=[max_tokens, temperature, top_p, top_k, system_prompt],
outputs=[config_status]
)
gr.Markdown("""
### About
This chat interface uses RAG (Retrieval Augmented Generation) to provide informed responses based on the content in the database.
The assistant retrieves relevant documents and uses them as context for generating responses.
- Use the Chat tab for asking questions
- Use the Configuration tab to adjust model parameters
""")
return interface
def main():
chat_app = GradioRAGChat()
interface = chat_app.create_interface()
interface.launch(share=False) # Set share=True for public URL
if __name__ == "__main__":
main()