IProject-10 commited on
Commit
7b1f23f
·
verified ·
1 Parent(s): a707b48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +375 -251
app.py CHANGED
@@ -1,276 +1,401 @@
1
- import nltk
2
- nltk.download('punkt')
3
- nltk.download('punkt_tab')
4
-
5
- # SECTIONED URL LIST (in case we want to tag later)
6
- url_dict = {
7
- "Website Designing": [
8
- "https://www.imageonline.co.in/website-designing-mumbai.html",
9
- "https://www.imageonline.co.in/domain-hosting-services-india.html",
10
- "https://www.imageonline.co.in/best-seo-company-mumbai.html",
11
- "https://www.imageonline.co.in/wordpress-blog-designing-india.html",
12
- "https://www.imageonline.co.in/social-media-marketing-company-mumbai.html",
13
- "https://www.imageonline.co.in/website-template-customization-india.html",
14
- "https://www.imageonline.co.in/regular-website-maintanence-services.html",
15
- "https://www.imageonline.co.in/mobile-app-designing-mumbai.html",
16
- "https://www.imageonline.co.in/web-application-screen-designing.html"
17
- ],
18
- "Website Development": [
19
- "https://www.imageonline.co.in/website-development-mumbai.html",
20
- "https://www.imageonline.co.in/open-source-customization.html",
21
- "https://www.imageonline.co.in/ecommerce-development-company-mumbai.html",
22
- "https://www.imageonline.co.in/website-with-content-management-system.html",
23
- "https://www.imageonline.co.in/web-application-development-india.html"
24
- ],
25
- "Mobile App Development": [
26
- "https://www.imageonline.co.in/mobile-app-development-company-mumbai.html"
27
- ],
28
- "About Us": [
29
- "https://www.imageonline.co.in/about-us.html",
30
- "https://www.imageonline.co.in/vision.html",
31
- "https://www.imageonline.co.in/team.html"
32
- ],
33
- "Testimonials": [
34
- "https://www.imageonline.co.in/testimonial.html"
35
- ]
36
- }
37
-
38
- import trafilatura
39
- import requests
40
-
41
- # Function to extract clean text using trafilatura
42
- def extract_clean_text(url):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- Fetch and extract clean main content from a URL using trafilatura.
45
- Returns None if content couldn't be extracted.
 
46
  """
47
  try:
48
- downloaded = trafilatura.fetch_url(url)
49
- if downloaded:
50
- content = trafilatura.extract(downloaded, include_comments=False, include_tables=False)
51
- return content
52
- except Exception as e:
53
- print(f"Error fetching {url}: {e}")
54
- return None
55
-
56
- # Scrape data and prepare for RAG with metadata
57
- scraped_data = []
58
-
59
- for section, urls in url_dict.items():
60
- for url in urls:
61
- print(f"🟩 Scraping: {url}")
62
- text = extract_clean_text(url)
63
- if text:
64
- print(f"✅ Extracted {len(text)} characters.\n")
65
- scraped_data.append({
66
- "content": text,
67
- "metadata": {
68
- "source": url,
69
- "section": section
70
- }
71
- })
72
- else:
73
- print(f"❌ Failed to extract content from {url}.\n")
74
-
75
- print(f"Total pages scraped: {len(scraped_data)}")
76
-
77
- import tiktoken
78
- from nltk.tokenize import sent_tokenize
79
-
80
- # Initialize GPT tokenizer (cl100k_base works with Together.ai and OpenAI APIs)
81
- tokenizer = tiktoken.get_encoding("cl100k_base")
82
-
83
- def chunk_text(text, max_tokens=400):
84
- """
85
- Chunk text into overlapping segments based on sentence boundaries and token limits.
86
- """
87
- sentences = sent_tokenize(text)
88
- chunks = []
89
- current_chunk = []
90
-
91
- for sentence in sentences:
92
- current_chunk.append(sentence)
93
- tokens = tokenizer.encode(" ".join(current_chunk))
94
- if len(tokens) > max_tokens:
95
- # Finalize current chunk without last sentence
96
- current_chunk.pop()
97
- chunks.append(" ".join(current_chunk).strip())
98
- current_chunk = [sentence] # Start new chunk with overflow sentence
99
-
100
- # Append final chunk
101
- if current_chunk:
102
- chunks.append(" ".join(current_chunk).strip())
103
-
104
- return chunks
105
-
106
- chunked_data = []
107
-
108
- for item in scraped_data:
109
- text = item["content"]
110
- metadata = item["metadata"]
111
-
112
- chunks = chunk_text(text, max_tokens=400)
113
-
114
- for chunk in chunks:
115
- chunked_data.append({
116
- "content": chunk,
117
- "metadata": metadata # Keep the same URL + section for each chunk
118
- })
119
-
120
- # Extract text chunks from chunked_data for embedding
121
- texts_to_embed = [item["content"] for item in chunked_data]
122
-
123
- from sentence_transformers import SentenceTransformer
124
-
125
- # Load the embedding model
126
- embedding_model = SentenceTransformer("BAAI/bge-base-en-v1.5")
127
-
128
- def embed_chunks(text_list, model):
129
  """
130
- Generate embeddings for a list of text chunks.
131
  """
132
- return model.encode(text_list, convert_to_numpy=True)
133
-
134
- # Generate embeddings
135
- embeddings = embed_chunks(texts_to_embed, embedding_model)
136
-
137
- print(f"✅ Generated {len(embeddings)} embeddings")
138
- print(f"🔹 Shape of first embedding: {embeddings[0].shape}")
139
-
140
- import chromadb
141
- import uuid
142
-
143
- # Initialize ChromaDB client (persistent storage)
144
- chroma_client = chromadb.PersistentClient(path="./chroma_store")
145
-
146
- # Create or get collection
147
- collection = chroma_client.get_or_create_collection(name="imageonline_chunks")
148
-
149
- # Extract documents, embeddings, metadatas
150
- documents = [item["content"] for item in chunked_data]
151
- metadatas = [item["metadata"] for item in chunked_data]
152
- ids = [str(uuid.uuid4()) for _ in documents]
153
-
154
- # Safety check
155
- assert len(documents) == len(embeddings) == len(metadatas), "Data length mismatch!"
156
-
157
- # Add to ChromaDB
158
- collection.add(
159
- documents=documents,
160
- embeddings=embeddings.tolist(),
161
- metadatas=metadatas,
162
- ids=ids
163
- )
164
-
165
- # Sample query
166
- query = "web design company"
167
- query_embedding = embedding_model.encode([query])[0]
168
-
169
- # Query ChromaDB
170
- results = collection.query(
171
- query_embeddings=[query_embedding.tolist()],
172
- n_results=3
173
- )
174
-
175
- # Display results
176
- for i in range(len(results['documents'][0])):
177
- print(f"\n🔍 Match {i+1}:")
178
- print(f"Content: {results['documents'][0][i][:200]}...")
179
- print(f"📎 Metadata: {results['metadatas'][0][i]}")
180
-
181
- from langchain_core.prompts import ChatPromptTemplate
182
- from langchain_core.runnables import RunnableLambda, RunnablePassthrough
183
- from langchain_core.output_parsers import StrOutputParser
184
- from langchain_together import ChatTogether
185
-
186
- from langchain_community.vectorstores import Chroma
187
- from langchain_community.embeddings import HuggingFaceEmbeddings
188
-
189
- # Initialize vectorstore
190
- embedding_function = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
191
-
192
- vectorstore = Chroma(
193
- client=chroma_client, # from your previous chroma setup
194
- collection_name="imageonline_chunks",
195
- embedding_function=embedding_function
196
- )
197
-
198
- # Create retriever
199
- retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
200
-
201
- def retrieve_and_format(query):
202
- docs = retriever.get_relevant_documents(query)
203
-
204
- context_strings = []
205
- for doc in docs:
206
- content = doc.page_content
207
- metadata = doc.metadata
208
- source = metadata.get("source", "")
209
- section = metadata.get("section", "")
210
- context_strings.append(f"[{section}] {content}\n(Source: {source})")
211
-
212
- return "\n\n".join(context_strings)
213
-
214
- llm = ChatTogether(
 
 
 
 
 
 
 
215
  model="meta-llama/Llama-3-8b-chat-hf",
216
  temperature=0.3,
217
- max_tokens=1024,
218
- top_p=0.7,
219
- together_api_key="a36246d65d8290f43667350b364c5b6bb8562eb50a4b947eec5bd7e79f2dffc6" # Replace before deployment or use os.getenv
220
  )
221
 
222
- prompt = ChatPromptTemplate.from_template("""
223
- You are an expert assistant for ImageOnline Web Solutions.
224
-
225
- Answer the user's query based ONLY on the following context:
226
-
227
- {context}
228
-
229
- Query: {question}
230
- """)
231
 
232
- rag_chain = (
233
- {"context": RunnableLambda(retrieve_and_format), "question": RunnablePassthrough()}
234
- | prompt
235
- | llm
236
- | StrOutputParser()
237
- )
238
-
239
- import gradio as gr
240
-
241
- def chat_interface(message, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  history = history or []
243
-
244
- # Display user message
245
- history.append(("🧑 You: " + message, "⏳ Generating response..."))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  try:
248
- # Call RAG pipeline
249
- answer = rag_chain.invoke(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- # Replace placeholder with actual response
252
- history[-1] = ("🧑 You: " + message, "🤖 Bot: " + answer)
 
 
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  except Exception as e:
255
- error_msg = f"⚠️ Error: {str(e)}"
256
- history[-1] = ("🧑 You: " + message, f"🤖 Bot: {error_msg}")
 
257
 
258
- return history, history
259
 
 
260
  def launch_gradio():
261
  with gr.Blocks() as demo:
262
- gr.Markdown("# 💬 ImageOnline RAG Chatbot")
263
- gr.Markdown("Ask about Website Designing, App Development, SEO, Hosting, etc.")
 
 
 
 
 
264
 
265
  chatbot = gr.Chatbot()
266
  state = gr.State([])
267
 
268
- with gr.Row():
269
- msg = gr.Textbox(placeholder="Ask your question here...", show_label=False, scale=8)
270
- send_btn = gr.Button("📨 Send", scale=1)
 
 
 
 
 
 
271
 
272
- msg.submit(chat_interface, inputs=[msg, state], outputs=[chatbot, state])
273
- send_btn.click(chat_interface, inputs=[msg, state], outputs=[chatbot, state])
 
274
 
275
  with gr.Row():
276
  clear_btn = gr.Button("🧹 Clear Chat")
@@ -278,6 +403,5 @@ def launch_gradio():
278
 
279
  return demo
280
 
281
- if __name__ == "__main__":
282
- demo = launch_gradio()
283
- demo.launch()
 
1
+ from datetime import datetime, timedelta
2
+ import time
3
+ import gradio as gr
4
+ import numpy as np
5
+ from llama_index.core import VectorStoreIndex, StorageContext, Settings
6
+ from llama_index.core.node_parser import SimpleNodeParser
7
+ from llama_index.core.prompts import PromptTemplate
8
+ from llama_index.vector_stores.qdrant import QdrantVectorStore
9
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
10
+ from llama_index.llms.together import TogetherLLM
11
+ from qdrant_client import QdrantClient
12
+ from sentence_transformers import CrossEncoder
13
+ from typing import Generator, Iterable, Tuple, Any
14
+
15
+ # === Config ===
16
+ QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.9Pj8v4ACpX3m5U3SZUrG_jzrjGF-T41J5icZ6EPMxnc"
17
+ QDRANT_URL = "https://d36718f0-be68-4040-b276-f1f39bc1aeb9.us-east4-0.gcp.cloud.qdrant.io"
18
+
19
+ qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
20
+ AVAILABLE_COLLECTIONS = ["demo-chatbot", "tezjet-site", "anish-pharma"]
21
+ index_cache = {}
22
+ active_state = {"collection": None, "query_engine": None}
23
+
24
+ # === Normalized Embedding Wrapper ===
25
+ def normalize_vector(vec):
26
+ vec = np.array(vec)
27
+ return vec / np.linalg.norm(vec)
28
+
29
+ class NormalizedEmbedding(HuggingFaceEmbedding):
30
+ def get_text_embedding(self, text: str):
31
+ vec = super().get_text_embedding(text)
32
+ return normalize_vector(vec)
33
+
34
+ def get_query_embedding(self, query: str):
35
+ vec = super().get_query_embedding(query)
36
+ return normalize_vector(vec)
37
+
38
+ embed_model = NormalizedEmbedding(model_name="BAAI/bge-base-en-v1.5")
39
+
40
+ # === LLM (kept for compatibility; streaming uses Together SDK directly) ===
41
+ llm = TogetherLLM(
42
+ model="meta-llama/Llama-3-8b-chat-hf",
43
+ api_key="a36246d65d8290f43667350b364c5b6bb8562eb50a4b947eec5bd7e79f2dffc6",
44
+ temperature=0.3,
45
+ max_tokens=1024,
46
+ top_p=0.7
47
+ )
48
+ Settings.embed_model = embed_model
49
+ Settings.llm = llm
50
+
51
+ # === Cross-Encoder for Reranking ===
52
+ reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
53
+
54
+ # === Prompt Template ===
55
+ custom_prompt = PromptTemplate(
56
+ "You are an expert assistant for ImageOnline Pvt Ltd.\n"
57
+ "Answer the user's query using relevant information from the context below.\n\n"
58
+ "Context:\n{context_str}\n\n"
59
+ "Query: {query_str}\n\n"
60
+ )
61
+
62
+ # === Load Index ===
63
+ def load_index_for_collection(collection_name: str) -> VectorStoreIndex:
64
+ vector_store = QdrantVectorStore(
65
+ client=qdrant_client,
66
+ collection_name=collection_name,
67
+ enable_hnsw=True
68
+ )
69
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
70
+ return VectorStoreIndex.from_vector_store(vector_store=vector_store, storage_context=storage_context)
71
+
72
+ # === Reference Renderer ===
73
+ def get_clickable_references_from_response(source_nodes, max_refs=2):
74
+ seen = set()
75
+ links = []
76
+ for node in source_nodes:
77
+ metadata = node.node.metadata
78
+ section = metadata.get("section") or metadata.get("title") or "Unknown"
79
+ source = metadata.get("source") or "Unknown"
80
+ key = (section, source)
81
+ if key not in seen:
82
+ seen.add(key)
83
+ if source.startswith("http"):
84
+ links.append(f"- [{section}]({source})")
85
+ else:
86
+ links.append(f"- {section}: {source}")
87
+ if len(links) >= max_refs:
88
+ break
89
+ return links
90
+
91
+ # === Safe Streaming Adapter for Together API (True Streaming) ===
92
+ # Requires: pip install together
93
+ from together import Together
94
+
95
+ def _extract_event_text(event: Any) -> str:
96
  """
97
+ Safely extract the streamed text delta from an event returned by the Together SDK.
98
+ Supports dict-like and object-like events.
99
+ Returns empty string if nothing found.
100
  """
101
  try:
102
+ # Try object attribute access
103
+ choices = getattr(event, "choices", None)
104
+ if choices:
105
+ # event.choices[0].delta could be object-like
106
+ first = choices[0]
107
+ delta = getattr(first, "delta", None)
108
+ if delta:
109
+ text = getattr(delta, "content", None)
110
+ if text:
111
+ return text
112
+ # sometimes content is directly in choice
113
+ text = getattr(first, "text", None)
114
+ if text:
115
+ return text
116
+ except Exception:
117
+ pass
118
+
119
+ # Try dict-like access
120
+ try:
121
+ if isinstance(event, dict):
122
+ choices = event.get("choices")
123
+ if choices and len(choices) > 0:
124
+ first = choices[0]
125
+ # delta may be nested
126
+ delta = first.get("delta") if isinstance(first, dict) else None
127
+ if isinstance(delta, dict):
128
+ return delta.get("content", "") or delta.get("text", "") or ""
129
+ # fallback to message/content
130
+ message = first.get("message") or {}
131
+ if isinstance(message, dict):
132
+ return message.get("content", "") or ""
133
+ return first.get("text", "") or ""
134
+ except Exception:
135
+ pass
136
+
137
+ return ""
138
+
139
+ def _extract_response_text(resp: Any) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  """
141
+ Safely extract full response text from a non-streaming response object/dict from Together SDK.
142
  """
143
+ try:
144
+ # object-like
145
+ choices = getattr(resp, "choices", None)
146
+ if choices and len(choices) > 0:
147
+ first = choices[0]
148
+ # message may be attribute or dict
149
+ message = getattr(first, "message", None)
150
+ if message:
151
+ # message.content may be attribute
152
+ content = getattr(message, "content", None)
153
+ if content:
154
+ return content
155
+ # dict
156
+ if isinstance(message, dict):
157
+ return message.get("content", "") or ""
158
+ # fallback to text on choice
159
+ text = getattr(first, "text", None)
160
+ if text:
161
+ return text
162
+ except Exception:
163
+ pass
164
+
165
+ # dict-like
166
+ try:
167
+ if isinstance(resp, dict):
168
+ choices = resp.get("choices", [])
169
+ if choices:
170
+ first = choices[0]
171
+ message = first.get("message") or {}
172
+ if isinstance(message, dict):
173
+ return message.get("content", "") or ""
174
+ return first.get("text", "") or ""
175
+ except Exception:
176
+ pass
177
+
178
+ # final fallback
179
+ return str(resp)
180
+
181
+ class StreamingLLMAdapter:
182
+ def __init__(self, api_key: str, model: str, temperature: float = 0.3, top_p: float = 0.7, chunk_size: int = 64):
183
+ self.client = Together(api_key=api_key)
184
+ self.model = model
185
+ self.temperature = temperature
186
+ self.top_p = top_p
187
+ self.chunk_size = chunk_size
188
+
189
+ def stream_complete(self, prompt: str, max_tokens: int = 1024, **kwargs) -> Generator[str, None, None]:
190
+ """
191
+ Use Together's native streaming API to yield tokens in real time.
192
+ Falls back to non-streamed response if streaming isn't available or errors.
193
+ """
194
+ try:
195
+ # the Together SDK exposes an iterator when stream=True
196
+ events = self.client.chat.completions.create(
197
+ model=self.model,
198
+ messages=[{"role": "user", "content": prompt}],
199
+ max_tokens=max_tokens,
200
+ temperature=self.temperature,
201
+ top_p=self.top_p,
202
+ stream=True
203
+ )
204
+ for event in events:
205
+ # robust extraction (handles dicts or objects)
206
+ text_piece = _extract_event_text(event)
207
+ if text_piece:
208
+ yield text_piece
209
+ except Exception:
210
+ # fallback to synchronous non-streaming
211
+ yield from self._sync_fallback(prompt, max_tokens, **kwargs)
212
+
213
+ def _sync_fallback(self, prompt: str, max_tokens: int = 1024, **kwargs) -> Generator[str, None, None]:
214
+ """Call Together API without streaming and yield chunks."""
215
+ try:
216
+ resp = self.client.chat.completions.create(
217
+ model=self.model,
218
+ messages=[{"role": "user", "content": prompt}],
219
+ max_tokens=max_tokens,
220
+ temperature=self.temperature,
221
+ top_p=self.top_p
222
+ )
223
+ text = _extract_response_text(resp)
224
+ except Exception as e:
225
+ text = f"[Error from LLM: {e}]"
226
+
227
+ for i in range(0, len(text), self.chunk_size):
228
+ yield text[i:i + self.chunk_size]
229
+
230
+ # instantiate streaming adapter (keep your API key here)
231
+ streaming_llm = StreamingLLMAdapter(
232
+ api_key="a36246d65d8290f43667350b364c5b6bb8562eb50a4b947eec5bd7e79f2dffc6",
233
  model="meta-llama/Llama-3-8b-chat-hf",
234
  temperature=0.3,
235
+ top_p=0.7
 
 
236
  )
237
 
 
 
 
 
 
 
 
 
 
238
 
239
+ # === Query Chain with Reranking ===
240
+ def rag_chain_prompt_and_sources(query: str, top_k: int = 3):
241
+ """
242
+ Returns (prompt_text, top_nodes) using the existing retrieval + reranking flow.
243
+ We separate building prompt from calling the LLM so we can stream the final call.
244
+ """
245
+ if not active_state["query_engine"]:
246
+ return None, None, "⚠️ Please select a website collection first."
247
+
248
+ raw_nodes = active_state["query_engine"].retrieve(query)
249
+
250
+ # Step 2: Rerank
251
+ pairs = [(query, n.node.get_content()) for n in raw_nodes]
252
+ scores = reranker.predict(pairs)
253
+ scored_nodes = sorted(zip(raw_nodes, scores), key=lambda x: x[1], reverse=True)
254
+ top_nodes = [n for n, _ in scored_nodes[:top_k]]
255
+
256
+ # Step 3: Compose prompt
257
+ context = "\n\n".join([n.node.get_content() for n in top_nodes])
258
+ prompt = custom_prompt.format(context_str=context, query_str=query)
259
+ return prompt, top_nodes, None
260
+
261
+ # === Collection Switch ===
262
+ def handle_collection_change(selected):
263
+ now = datetime.utcnow()
264
+ cached = index_cache.get(selected)
265
+ if cached:
266
+ query_engine, ts = cached
267
+ if now - ts < timedelta(hours=1):
268
+ active_state["collection"] = selected
269
+ active_state["query_engine"] = query_engine
270
+ return f"✅ Now chatting with: `{selected}`", [], []
271
+
272
+ index = load_index_for_collection(selected)
273
+ query_engine = index.as_query_engine(similarity_top_k=10, vector_store_query_mode="default")
274
+ index_cache[selected] = (query_engine, now)
275
+ active_state["collection"] = selected
276
+ active_state["query_engine"] = query_engine
277
+
278
+ return f"✅ Now chatting with: `{selected}`", [], []
279
+
280
+ # === Streaming Chat Handler ===
281
+ def chat_interface_stream(message: str, history: list) -> Generator[Tuple[list, list, str], None, None]:
282
+ """
283
+ Yields tuples of (chatbot_history, state, textbox_value) so Gradio gets
284
+ the right number of outputs for each yield when using streaming.
285
+ """
286
  history = history or []
287
+ message = (message or "").strip()
288
+ if not message:
289
+ # still return all outputs
290
+ yield history, history, ""
291
+ return
292
+
293
+ timestamp_user = datetime.now().strftime("%H:%M:%S")
294
+ user_msg = f"🧑 **You**\n{message}\n\n⏱️ {timestamp_user}"
295
+ # append placeholder bot typing state
296
+ history.append((user_msg, "⏳ _Bot is typing..._"))
297
+ # initial update (user message + typing)
298
+ yield history, history, ""
299
+
300
+ prompt, top_nodes, err = rag_chain_prompt_and_sources(message)
301
+ if err:
302
+ history[-1] = (user_msg, f"🤖 **Bot**\n{err}")
303
+ yield history, history, ""
304
+ return
305
+
306
+ assistant_text = ""
307
+ chunk_count = 0
308
+ flush_every_n = 3 # flush every 3 small deltas (tweak if you want more frequent updates)
309
 
310
  try:
311
+ # stream from Together
312
+ for chunk in streaming_llm.stream_complete(prompt, max_tokens=1024):
313
+ assistant_text += chunk
314
+ chunk_count += 1
315
+ # periodically flush partial output to UI
316
+ if chunk_count % flush_every_n == 0:
317
+ history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text}")
318
+ yield history, history, ""
319
+ # after streaming completes, append any leftover partial (if not flushed recently)
320
+ history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text}")
321
+ except Exception as e:
322
+ # on error, show error message
323
+ history[-1] = (user_msg, f"🤖 **Bot**\n⚠️ {str(e)}")
324
+ yield history, history, ""
325
+ return
326
+
327
+ # Add references at the end
328
+ references = get_clickable_references_from_response(top_nodes)
329
+ if references:
330
+ assistant_text += "\n\n📚 **Reference(s):**\n" + "\n".join(references)
331
+
332
+ timestamp_bot = datetime.now().strftime("%H:%M:%S")
333
+ history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text.strip()}\n\n⏱️ {timestamp_bot}")
334
+ # final yield with textbox cleared
335
+ yield history, history, ""
336
+
337
+ # Fallback synchronous chat (kept for compatibility if you want non-streaming)
338
+ def chat_interface_sync(message, history):
339
+ history = history or []
340
+ message = message.strip()
341
+ if not message:
342
+ raise ValueError("Please enter a valid question.")
343
 
344
+ timestamp_user = datetime.now().strftime("%H:%M:%S")
345
+ user_msg = f"🧑 **You**\n{message}\n\n⏱️ {timestamp_user}"
346
+ bot_msg = "⏳ _Bot is typing..._"
347
+ history.append((user_msg, bot_msg))
348
 
349
+ try:
350
+ time.sleep(0.5)
351
+ prompt, top_nodes, err = rag_chain_prompt_and_sources(message)
352
+ if err:
353
+ timestamp_bot = datetime.now().strftime("%H:%M:%S")
354
+ history[-1] = (user_msg, f"🤖 **Bot**\n{err}\n\n⏱️ {timestamp_bot}")
355
+ return history, history, ""
356
+
357
+ resp = llm.complete(prompt).text
358
+ references = get_clickable_references_from_response(top_nodes)
359
+ if references:
360
+ resp += "\n\n📚 **Reference(s):**\n" + "\n".join(references)
361
+
362
+ timestamp_bot = datetime.now().strftime("%H:%M:%S")
363
+ bot_msg = f"🤖 **Bot**\n{resp.strip()}\n\n⏱️ {timestamp_bot}"
364
+ history[-1] = (user_msg, bot_msg)
365
  except Exception as e:
366
+ timestamp_bot = datetime.now().strftime("%H:%M:%S")
367
+ error_msg = f"🤖 **Bot**\n⚠️ {str(e)}\n\n⏱️ {timestamp_bot}"
368
+ history[-1] = (user_msg, error_msg)
369
 
370
+ return history, history, ""
371
 
372
+ # === Gradio UI ===
373
  def launch_gradio():
374
  with gr.Blocks() as demo:
375
+ gr.Markdown("# 💬 Multi-Website RAG Chatbot")
376
+ gr.Markdown("Choose a website collection to start chatting.")
377
+
378
+ with gr.Row():
379
+ collection_dropdown = gr.Dropdown(choices=AVAILABLE_COLLECTIONS, label="Select Website Collection")
380
+ load_button = gr.Button("Load Website")
381
+ collection_status = gr.Markdown("")
382
 
383
  chatbot = gr.Chatbot()
384
  state = gr.State([])
385
 
386
+ with gr.Row(equal_height=True):
387
+ msg = gr.Textbox(placeholder="Ask your question...", show_label=False, scale=9)
388
+ send_btn = gr.Button("🚀 Send", scale=1)
389
+
390
+ load_button.click(
391
+ fn=handle_collection_change,
392
+ inputs=collection_dropdown,
393
+ outputs=[collection_status, chatbot, state]
394
+ )
395
 
396
+ # Use the streaming generator for submit/click so Gradio receives yields
397
+ msg.submit(chat_interface_stream, inputs=[msg, state], outputs=[chatbot, state, msg])
398
+ send_btn.click(chat_interface_stream, inputs=[msg, state], outputs=[chatbot, state, msg])
399
 
400
  with gr.Row():
401
  clear_btn = gr.Button("🧹 Clear Chat")
 
403
 
404
  return demo
405
 
406
+ demo = launch_gradio()
407
+ demo.launch()