MRasheq commited on
Commit
07be9b4
·
1 Parent(s): 1332eb3

Third Commit

Browse files
Files changed (2) hide show
  1. app.py +189 -124
  2. requirements.txt +13 -2
app.py CHANGED
@@ -1,157 +1,222 @@
1
  import gradio as gr
2
  import torch
3
- from typing import List
4
  from sentence_transformers import SentenceTransformer
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import psycopg2
7
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class RAGPipeline:
10
  def __init__(self):
11
- # Database connection string
12
  self.connection_string = "postgresql://Data_owner:JsxygNDC15IO@ep-cool-hill-a5k13m05-pooler.us-east-2.aws.neon.tech/Data?sslmode=require"
13
-
14
- # Initialize embedding model
15
  self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
16
-
17
- # Initialize LLM
 
 
18
  self.llm_model = AutoModelForCausalLM.from_pretrained(
19
  "deepseek-ai/DeepSeek-R1",
20
  trust_remote_code=True,
21
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
 
 
 
 
 
 
 
22
  )
23
  self.llm_tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1", trust_remote_code=True)
24
-
25
- # Move model to GPU if available
26
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
- self.llm_model = self.llm_model.to(self.device)
28
-
29
- # Initialize prompt template
30
- self.prompt_template = """
31
- Use the following context to answer the question. If you cannot answer the question based on the context, say so.
32
-
33
- Context: {context}
34
-
35
- Question: {question}
36
-
37
- Answer: Let me help you with that.
38
- """
39
-
40
  def generate_embedding(self, text: str) -> List[float]:
41
- """Generate embeddings for input text."""
42
- embedding = self.embedding_model.encode(text)
43
- return embedding.tolist()
44
-
45
- def similarity_search(self, query_embedding: List[float], top_k: int = 3) -> List[dict]:
46
- """Perform similarity search in PostgreSQL using vector comparison."""
47
- with psycopg2.connect(self.connection_string) as conn:
48
- with conn.cursor() as cur:
49
- embedding_array = np.array(query_embedding)
50
-
51
- query = """
52
- SELECT text, title, url,
53
- 1 - (vector <=> %s) as similarity
54
- FROM bents
55
- ORDER BY vector <=> %s
56
- LIMIT %s;
57
- """
58
- cur.execute(query, (embedding_array.tolist(), embedding_array.tolist(), top_k))
59
- results = cur.fetchall()
60
-
61
- similar_docs = [
62
- {
63
- 'text': row[0],
64
- 'title': row[1],
65
- 'url': row[2],
66
- 'similarity': row[3]
67
- }
68
- for row in results
69
- ]
70
-
71
- return similar_docs
72
-
73
- def generate_response(self, query: str, context: str,
74
- max_tokens: int = 512,
75
- temperature: float = 0.7,
76
- top_p: float = 0.95) -> str:
77
- """Generate response using the LLM."""
78
- prompt = self.prompt_template.format(context=context, question=query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- inputs = self.llm_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.device)
 
81
 
82
- with torch.no_grad():
83
- outputs = self.llm_model.generate(
84
- **inputs,
85
- max_new_tokens=max_tokens,
86
- do_sample=True,
87
- temperature=temperature,
88
- top_p=top_p,
89
- pad_token_id=self.llm_tokenizer.eos_token_id,
90
- )
 
 
 
 
 
 
91
 
92
- response = self.llm_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
93
- return response.strip()
94
-
95
- def process_query(self,
96
- query: str,
97
- max_tokens: int = 512,
98
- temperature: float = 0.7,
99
- top_p: float = 0.95,
100
- top_k: int = 3) -> dict:
101
- """Process user query through the complete RAG pipeline."""
102
- query_embedding = self.generate_embedding(query)
103
- similar_docs = self.similarity_search(query_embedding, top_k=top_k)
104
- context = "\n".join([doc['text'] for doc in similar_docs])
105
- response = self.generate_response(query, context, max_tokens, temperature, top_p)
106
 
107
- # Format sources for display
108
- sources = "\n\nSources:\n" + "\n".join([
109
  f"- {doc['title']} (Similarity: {doc['similarity']:.2f})\n URL: {doc['url']}"
110
- for doc in similar_docs
111
  ])
112
 
113
- return response + sources
114
-
115
- # Initialize RAG pipeline globally
116
- rag_pipeline = RAGPipeline()
117
 
118
- def process_message(
119
- message: str,
120
- history: List[tuple[str, str]],
121
- max_tokens: int,
122
- temperature: float,
123
- top_p: float,
124
- top_k: int
125
- ) -> str:
126
- """Process message and maintain chat history."""
127
- try:
128
- response = rag_pipeline.process_query(
129
- message,
130
  max_tokens=max_tokens,
131
  temperature=temperature,
132
  top_p=top_p,
133
- top_k=top_k
 
134
  )
135
- return response
136
- except Exception as e:
137
- return f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- # Create Gradio interface
140
- demo = gr.ChatInterface(
141
- process_message,
142
- additional_inputs=[
143
- gr.Slider(minimum=64, maximum=2048, value=512, step=64, label="Max tokens"),
144
- gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
145
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
146
- gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top-k documents"),
147
- ],
148
- title="RAG-Powered Chat Assistant",
149
- description="""This chat interface uses RAG (Retrieval Augmented Generation) to provide informed responses
150
- based on the content in the database. The assistant retrieves relevant documents and uses them
151
- as context for generating responses.""",
152
- theme="soft"
153
- )
154
 
155
- # Launch the interface
156
  if __name__ == "__main__":
157
- demo.launch(share=True) # Set share=False in production
 
1
  import gradio as gr
2
  import torch
3
+ from typing import List, Dict, Any
4
  from sentence_transformers import SentenceTransformer
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import psycopg2
7
  import numpy as np
8
+ from dataclasses import dataclass
9
+ from datetime import datetime
10
+ import json
11
+ import os
12
+
13
+ @dataclass
14
+ class ChatConfig:
15
+ max_tokens: int = 512
16
+ temperature: float = 0.7
17
+ top_p: float = 0.95
18
+ top_k: int = 3
19
+ system_prompt: str = "You are a helpful AI assistant that provides accurate information based on the given context."
20
 
21
  class RAGPipeline:
22
  def __init__(self):
 
23
  self.connection_string = "postgresql://Data_owner:JsxygNDC15IO@ep-cool-hill-a5k13m05-pooler.us-east-2.aws.neon.tech/Data?sslmode=require"
 
 
24
  self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
25
+ self.load_llm()
26
+ self.chat_config = ChatConfig()
27
+
28
+ def load_llm(self):
29
  self.llm_model = AutoModelForCausalLM.from_pretrained(
30
  "deepseek-ai/DeepSeek-R1",
31
  trust_remote_code=True,
32
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
33
+ load_in_4bit=True,
34
+ device_map="auto",
35
+ quantization_config={
36
+ "load_in_4bit": True,
37
+ "bnb_4bit_compute_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
38
+ "bnb_4bit_quant_type": "nf4",
39
+ "bnb_4bit_use_double_quant": True
40
+ }
41
  )
42
  self.llm_tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1", trust_remote_code=True)
 
 
43
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
44
+
 
 
 
 
 
 
 
 
 
 
 
 
45
  def generate_embedding(self, text: str) -> List[float]:
46
+ return self.embedding_model.encode(text).tolist()
47
+
48
+ def similarity_search(self, query_embedding: List[float]) -> List[dict]:
49
+ try:
50
+ with psycopg2.connect(self.connection_string) as conn:
51
+ with conn.cursor() as cur:
52
+ embedding_array = np.array(query_embedding)
53
+ query = """
54
+ SELECT text, title, url,
55
+ 1 - (vector <=> %s) as similarity
56
+ FROM bents
57
+ ORDER BY vector <=> %s
58
+ LIMIT %s;
59
+ """
60
+ cur.execute(query, (embedding_array.tolist(), embedding_array.tolist(), self.chat_config.top_k))
61
+ results = cur.fetchall()
62
+ return [
63
+ {
64
+ 'text': row[0],
65
+ 'title': row[1],
66
+ 'url': row[2],
67
+ 'similarity': row[3]
68
+ }
69
+ for row in results
70
+ ]
71
+ except Exception as e:
72
+ print(f"Database error: {str(e)}")
73
+ return []
74
+
75
+ def format_conversation(self, messages: List[Dict[str, str]]) -> str:
76
+ formatted = f"System: {self.chat_config.system_prompt}\n\n"
77
+ for msg in messages:
78
+ role = msg["role"].capitalize()
79
+ content = msg["content"]
80
+ formatted += f"{role}: {content}\n\n"
81
+ return formatted.strip()
82
+
83
+ def generate_response(self, messages: List[Dict[str, str]], context: str) -> str:
84
+ try:
85
+ conversation = self.format_conversation(messages)
86
+ context_prompt = f"Context:\n{context}\n\nCurrent conversation:\n{conversation}\n\nAssistant:"
87
+
88
+ inputs = self.llm_tokenizer(context_prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.device)
89
+
90
+ with torch.no_grad():
91
+ outputs = self.llm_model.generate(
92
+ **inputs,
93
+ max_new_tokens=self.chat_config.max_tokens,
94
+ do_sample=True,
95
+ temperature=self.chat_config.temperature,
96
+ top_p=self.chat_config.top_p,
97
+ pad_token_id=self.llm_tokenizer.eos_token_id,
98
+ )
99
+
100
+ response = self.llm_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
101
+ return response.strip()
102
+ except Exception as e:
103
+ return f"Error generating response: {str(e)}"
104
+
105
+ def process_query(self, message: str, chat_history: List[Dict[str, str]]) -> tuple[str, List[dict]]:
106
+ query_embedding = self.generate_embedding(message)
107
+ similar_docs = self.similarity_search(query_embedding)
108
+ context = "\n".join([doc['text'] for doc in similar_docs])
109
 
110
+ messages = chat_history + [{"role": "user", "content": message}]
111
+ response = self.generate_response(messages, context)
112
 
113
+ return response, similar_docs
114
+
115
+ class GradioRAGChat:
116
+ def __init__(self):
117
+ self.rag = RAGPipeline()
118
+ self.chat_history = []
119
+
120
+ def process_message(self, message: str, history: List[tuple[str, str]]) -> tuple[str, List[dict]]:
121
+ # Convert Gradio history format to our format
122
+ chat_history = []
123
+ for user_msg, assistant_msg in history:
124
+ if user_msg:
125
+ chat_history.append({"role": "user", "content": user_msg})
126
+ if assistant_msg:
127
+ chat_history.append({"role": "assistant", "content": assistant_msg})
128
 
129
+ response, sources = self.rag.process_query(message, chat_history)
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ # Format response with sources
132
+ formatted_sources = "\n\nSources:\n" + "\n".join([
133
  f"- {doc['title']} (Similarity: {doc['similarity']:.2f})\n URL: {doc['url']}"
134
+ for doc in sources
135
  ])
136
 
137
+ return response + formatted_sources
 
 
 
138
 
139
+ def update_config(
140
+ self,
141
+ max_tokens: int,
142
+ temperature: float,
143
+ top_p: float,
144
+ top_k: int,
145
+ system_prompt: str
146
+ ) -> str:
147
+ self.rag.chat_config = ChatConfig(
 
 
 
148
  max_tokens=max_tokens,
149
  temperature=temperature,
150
  top_p=top_p,
151
+ top_k=top_k,
152
+ system_prompt=system_prompt
153
  )
154
+ return f"Configuration updated successfully at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
155
+
156
+ def create_interface(self):
157
+ with gr.Blocks(theme=gr.themes.Soft()) as interface:
158
+ gr.Markdown("# RAG-Powered Chat Assistant")
159
+
160
+ with gr.Tabs():
161
+ with gr.Tab("Chat"):
162
+ chatbot = gr.ChatInterface(
163
+ fn=self.process_message,
164
+ title="",
165
+ description="Ask questions about the content in the database."
166
+ )
167
+
168
+ with gr.Tab("Configuration"):
169
+ with gr.Group():
170
+ gr.Markdown("### Model Parameters")
171
+ with gr.Row():
172
+ max_tokens = gr.Slider(
173
+ minimum=64, maximum=2048, value=512, step=64,
174
+ label="Max Tokens"
175
+ )
176
+ temperature = gr.Slider(
177
+ minimum=0.1, maximum=2.0, value=0.7, step=0.1,
178
+ label="Temperature"
179
+ )
180
+ with gr.Row():
181
+ top_p = gr.Slider(
182
+ minimum=0.1, maximum=1.0, value=0.95, step=0.05,
183
+ label="Top-p"
184
+ )
185
+ top_k = gr.Slider(
186
+ minimum=1, maximum=10, value=3, step=1,
187
+ label="Top-k Documents"
188
+ )
189
+
190
+ system_prompt = gr.Textbox(
191
+ value=self.rag.chat_config.system_prompt,
192
+ label="System Prompt",
193
+ lines=3
194
+ )
195
+
196
+ update_btn = gr.Button("Update Configuration")
197
+ config_status = gr.Textbox(label="Status", interactive=False)
198
+
199
+ update_btn.click(
200
+ fn=self.update_config,
201
+ inputs=[max_tokens, temperature, top_p, top_k, system_prompt],
202
+ outputs=[config_status]
203
+ )
204
+
205
+ gr.Markdown("""
206
+ ### About
207
+ This chat interface uses RAG (Retrieval Augmented Generation) to provide informed responses based on the content in the database.
208
+ The assistant retrieves relevant documents and uses them as context for generating responses.
209
+
210
+ - Use the Chat tab for asking questions
211
+ - Use the Configuration tab to adjust model parameters
212
+ """)
213
+
214
+ return interface
215
 
216
+ def main():
217
+ chat_app = GradioRAGChat()
218
+ interface = chat_app.create_interface()
219
+ interface.launch(share=False) # Set share=True for public URL
 
 
 
 
 
 
 
 
 
 
 
220
 
 
221
  if __name__ == "__main__":
222
+ main()
requirements.txt CHANGED
@@ -5,6 +5,11 @@ torch>=2.0.0
5
  transformers>=4.36.0
6
  sentence-transformers>=2.2.2
7
 
 
 
 
 
 
8
  # Database
9
  psycopg2-binary>=2.9.9
10
  pgvector>=0.2.3
@@ -15,9 +20,15 @@ pandas>=2.0.0
15
 
16
  # Deep learning
17
  accelerate>=0.24.0
18
- bitsandbytes>=0.41.0
19
  safetensors>=0.4.0
 
20
 
21
  # Utilities
22
  tqdm>=4.65.0
23
- python-dotenv>=1.0.0
 
 
 
 
 
 
5
  transformers>=4.36.0
6
  sentence-transformers>=2.2.2
7
 
8
+ # Web UI
9
+ gradio>=4.13.0
10
+ uvicorn>=0.27.0
11
+ fastapi>=0.109.0
12
+
13
  # Database
14
  psycopg2-binary>=2.9.9
15
  pgvector>=0.2.3
 
20
 
21
  # Deep learning
22
  accelerate>=0.24.0
23
+ bitsandbytes>=0.41.3
24
  safetensors>=0.4.0
25
+ transformers>=4.36.2 # Specific version for compatibility
26
 
27
  # Utilities
28
  tqdm>=4.65.0
29
+ python-dotenv>=1.0.0
30
+
31
+ # Optional: for better performance
32
+ httpx>=0.26.0
33
+ websockets>=12.0
34
+ aiohttp>=3.9.0