from datetime import datetime, timedelta import time import gradio as gr import numpy as np from llama_index.core import VectorStoreIndex, StorageContext, Settings from llama_index.core.node_parser import SimpleNodeParser from llama_index.core.prompts import PromptTemplate from llama_index.vector_stores.qdrant import QdrantVectorStore from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.llms.together import TogetherLLM from qdrant_client import QdrantClient from sentence_transformers import CrossEncoder from typing import Generator, Iterable, Tuple, Any # === Config === QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.9Pj8v4ACpX3m5U3SZUrG_jzrjGF-T41J5icZ6EPMxnc" QDRANT_URL = "https://d36718f0-be68-4040-b276-f1f39bc1aeb9.us-east4-0.gcp.cloud.qdrant.io" qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) AVAILABLE_COLLECTIONS = ["ImageOnline", "tezjet-site", "anish-pharma"] index_cache = {} active_state = {"collection": None, "query_engine": None} # === Normalized Embedding Wrapper === def normalize_vector(vec): vec = np.array(vec) return vec / np.linalg.norm(vec) class NormalizedEmbedding(HuggingFaceEmbedding): def get_text_embedding(self, text: str): vec = super().get_text_embedding(text) return normalize_vector(vec) def get_query_embedding(self, query: str): vec = super().get_query_embedding(query) return normalize_vector(vec) embed_model = NormalizedEmbedding(model_name="BAAI/bge-base-en-v1.5") # === LLM (kept for compatibility; streaming uses Together SDK directly) === llm = TogetherLLM( model="meta-llama/Llama-3-8b-chat-hf", api_key="a36246d65d8290f43667350b364c5b6bb8562eb50a4b947eec5bd7e79f2dffc6", temperature=0.3, max_tokens=1024, top_p=0.7 ) Settings.embed_model = embed_model Settings.llm = llm # === Cross-Encoder for Reranking === reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") # === Prompt Template === custom_prompt = PromptTemplate( "You are an expert assistant for ImageOnline Pvt Ltd.\n" "Answer the user's query using relevant information from the context below.\n\n" "Context:\n{context_str}\n\n" "Query: {query_str}\n\n" ) # === Load Index === def load_index_for_collection(collection_name: str) -> VectorStoreIndex: vector_store = QdrantVectorStore( client=qdrant_client, collection_name=collection_name, enable_hnsw=True ) storage_context = StorageContext.from_defaults(vector_store=vector_store) return VectorStoreIndex.from_vector_store(vector_store=vector_store, storage_context=storage_context) # === Reference Renderer === def get_clickable_references_from_response(source_nodes, max_refs=2): seen = set() links = [] for node in source_nodes: metadata = node.node.metadata section = metadata.get("section") or metadata.get("title") or "Unknown" source = metadata.get("source") or "Unknown" key = (section, source) if key not in seen: seen.add(key) if source.startswith("http"): links.append(f"- [{section}]({source})") else: links.append(f"- {section}: {source}") if len(links) >= max_refs: break return links # === Safe Streaming Adapter for Together API (True Streaming) === # Requires: pip install together from together import Together def _extract_event_text(event: Any) -> str: """ Safely extract the streamed text delta from an event returned by the Together SDK. Supports dict-like and object-like events. Returns empty string if nothing found. """ try: # Try object attribute access choices = getattr(event, "choices", None) if choices: # event.choices[0].delta could be object-like first = choices[0] delta = getattr(first, "delta", None) if delta: text = getattr(delta, "content", None) if text: return text # sometimes content is directly in choice text = getattr(first, "text", None) if text: return text except Exception: pass # Try dict-like access try: if isinstance(event, dict): choices = event.get("choices") if choices and len(choices) > 0: first = choices[0] # delta may be nested delta = first.get("delta") if isinstance(first, dict) else None if isinstance(delta, dict): return delta.get("content", "") or delta.get("text", "") or "" # fallback to message/content message = first.get("message") or {} if isinstance(message, dict): return message.get("content", "") or "" return first.get("text", "") or "" except Exception: pass return "" def _extract_response_text(resp: Any) -> str: """ Safely extract full response text from a non-streaming response object/dict from Together SDK. """ try: # object-like choices = getattr(resp, "choices", None) if choices and len(choices) > 0: first = choices[0] # message may be attribute or dict message = getattr(first, "message", None) if message: # message.content may be attribute content = getattr(message, "content", None) if content: return content # dict if isinstance(message, dict): return message.get("content", "") or "" # fallback to text on choice text = getattr(first, "text", None) if text: return text except Exception: pass # dict-like try: if isinstance(resp, dict): choices = resp.get("choices", []) if choices: first = choices[0] message = first.get("message") or {} if isinstance(message, dict): return message.get("content", "") or "" return first.get("text", "") or "" except Exception: pass # final fallback return str(resp) class StreamingLLMAdapter: def __init__(self, api_key: str, model: str, temperature: float = 0.3, top_p: float = 0.7, chunk_size: int = 64): self.client = Together(api_key=api_key) self.model = model self.temperature = temperature self.top_p = top_p self.chunk_size = chunk_size def stream_complete(self, prompt: str, max_tokens: int = 1024, **kwargs) -> Generator[str, None, None]: """ Use Together's native streaming API to yield tokens in real time. Falls back to non-streamed response if streaming isn't available or errors. """ try: # the Together SDK exposes an iterator when stream=True events = self.client.chat.completions.create( model=self.model, messages=[{"role": "user", "content": prompt}], max_tokens=max_tokens, temperature=self.temperature, top_p=self.top_p, stream=True ) for event in events: # robust extraction (handles dicts or objects) text_piece = _extract_event_text(event) if text_piece: yield text_piece except Exception: # fallback to synchronous non-streaming yield from self._sync_fallback(prompt, max_tokens, **kwargs) def _sync_fallback(self, prompt: str, max_tokens: int = 1024, **kwargs) -> Generator[str, None, None]: """Call Together API without streaming and yield chunks.""" try: resp = self.client.chat.completions.create( model=self.model, messages=[{"role": "user", "content": prompt}], max_tokens=max_tokens, temperature=self.temperature, top_p=self.top_p ) text = _extract_response_text(resp) except Exception as e: text = f"[Error from LLM: {e}]" for i in range(0, len(text), self.chunk_size): yield text[i:i + self.chunk_size] # instantiate streaming adapter (keep your API key here) streaming_llm = StreamingLLMAdapter( api_key="a36246d65d8290f43667350b364c5b6bb8562eb50a4b947eec5bd7e79f2dffc6", model="meta-llama/Llama-3-8b-chat-hf", temperature=0.3, top_p=0.7 ) # === Query Chain with Reranking === def rag_chain_prompt_and_sources(query: str, top_k: int = 3): """ Returns (prompt_text, top_nodes) using the existing retrieval + reranking flow. We separate building prompt from calling the LLM so we can stream the final call. """ if not active_state["query_engine"]: return None, None, "⚠️ Please select a website collection first." raw_nodes = active_state["query_engine"].retrieve(query) # Step 2: Rerank pairs = [(query, n.node.get_content()) for n in raw_nodes] scores = reranker.predict(pairs) scored_nodes = sorted(zip(raw_nodes, scores), key=lambda x: x[1], reverse=True) top_nodes = [n for n, _ in scored_nodes[:top_k]] # Step 3: Compose prompt context = "\n\n".join([n.node.get_content() for n in top_nodes]) prompt = custom_prompt.format(context_str=context, query_str=query) return prompt, top_nodes, None # === Collection Switch === def handle_collection_change(selected): now = datetime.utcnow() cached = index_cache.get(selected) if cached: query_engine, ts = cached if now - ts < timedelta(hours=1): active_state["collection"] = selected active_state["query_engine"] = query_engine return f"✅ Now chatting with: `{selected}`", [], [] index = load_index_for_collection(selected) query_engine = index.as_query_engine(similarity_top_k=10, vector_store_query_mode="default") index_cache[selected] = (query_engine, now) active_state["collection"] = selected active_state["query_engine"] = query_engine return f"✅ Now chatting with: `{selected}`", [], [] # === Streaming Chat Handler === def chat_interface_stream(message: str, history: list) -> Generator[Tuple[list, list, str], None, None]: """ Yields tuples of (chatbot_history, state, textbox_value) so Gradio gets the right number of outputs for each yield when using streaming. """ history = history or [] message = (message or "").strip() if not message: # still return all outputs yield history, history, "" return timestamp_user = datetime.now().strftime("%H:%M:%S") user_msg = f"🧑 **You**\n{message}\n\n⏱️ {timestamp_user}" # append placeholder bot typing state history.append((user_msg, "⏳ _Bot is typing..._")) # initial update (user message + typing) yield history, history, "" prompt, top_nodes, err = rag_chain_prompt_and_sources(message) if err: history[-1] = (user_msg, f"🤖 **Bot**\n{err}") yield history, history, "" return assistant_text = "" chunk_count = 0 flush_every_n = 3 # flush every 3 small deltas (tweak if you want more frequent updates) try: # stream from Together for chunk in streaming_llm.stream_complete(prompt, max_tokens=1024): assistant_text += chunk chunk_count += 1 # periodically flush partial output to UI if chunk_count % flush_every_n == 0: history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text}") yield history, history, "" # after streaming completes, append any leftover partial (if not flushed recently) history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text}") except Exception as e: # on error, show error message history[-1] = (user_msg, f"🤖 **Bot**\n⚠️ {str(e)}") yield history, history, "" return # Add references at the end references = get_clickable_references_from_response(top_nodes) if references: assistant_text += "\n\n📚 **Reference(s):**\n" + "\n".join(references) timestamp_bot = datetime.now().strftime("%H:%M:%S") history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text.strip()}\n\n⏱️ {timestamp_bot}") # final yield with textbox cleared yield history, history, "" # Fallback synchronous chat (kept for compatibility if you want non-streaming) def chat_interface_sync(message, history): history = history or [] message = message.strip() if not message: raise ValueError("Please enter a valid question.") timestamp_user = datetime.now().strftime("%H:%M:%S") user_msg = f"🧑 **You**\n{message}\n\n⏱️ {timestamp_user}" bot_msg = "⏳ _Bot is typing..._" history.append((user_msg, bot_msg)) try: time.sleep(0.5) prompt, top_nodes, err = rag_chain_prompt_and_sources(message) if err: timestamp_bot = datetime.now().strftime("%H:%M:%S") history[-1] = (user_msg, f"🤖 **Bot**\n{err}\n\n⏱️ {timestamp_bot}") return history, history, "" resp = llm.complete(prompt).text references = get_clickable_references_from_response(top_nodes) if references: resp += "\n\n📚 **Reference(s):**\n" + "\n".join(references) timestamp_bot = datetime.now().strftime("%H:%M:%S") bot_msg = f"🤖 **Bot**\n{resp.strip()}\n\n⏱️ {timestamp_bot}" history[-1] = (user_msg, bot_msg) except Exception as e: timestamp_bot = datetime.now().strftime("%H:%M:%S") error_msg = f"🤖 **Bot**\n⚠️ {str(e)}\n\n⏱️ {timestamp_bot}" history[-1] = (user_msg, error_msg) return history, history, "" # === Gradio UI === def launch_gradio(): with gr.Blocks() as demo: gr.Markdown("# 💬 Demo IOPL Multi-Website Chatbot") gr.Markdown("Choose a website you want to chat with.") with gr.Row(): collection_dropdown = gr.Dropdown(choices=AVAILABLE_COLLECTIONS, label="Select Website to chat") load_button = gr.Button("Load Website") collection_status = gr.Markdown("") chatbot = gr.Chatbot() state = gr.State([]) with gr.Row(equal_height=True): msg = gr.Textbox(placeholder="Ask your question...", show_label=False, scale=9) send_btn = gr.Button("🚀 Send", scale=1) load_button.click( fn=handle_collection_change, inputs=collection_dropdown, outputs=[collection_status, chatbot, state] ) # Use the streaming generator for submit/click so Gradio receives yields msg.submit(chat_interface_stream, inputs=[msg, state], outputs=[chatbot, state, msg]) send_btn.click(chat_interface_stream, inputs=[msg, state], outputs=[chatbot, state, msg]) with gr.Row(): clear_btn = gr.Button("🧹 Clear Chat") clear_btn.click(fn=lambda: ([], []), outputs=[chatbot, state]) return demo demo = launch_gradio() demo.launch()