Spaces:
Running
Running
from datetime import datetime, timedelta | |
import time | |
import gradio as gr | |
import numpy as np | |
from llama_index.core import VectorStoreIndex, StorageContext, Settings | |
from llama_index.core.node_parser import SimpleNodeParser | |
from llama_index.core.prompts import PromptTemplate | |
from llama_index.vector_stores.qdrant import QdrantVectorStore | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.llms.together import TogetherLLM | |
from qdrant_client import QdrantClient | |
from sentence_transformers import CrossEncoder | |
from typing import Generator, Iterable, Tuple, Any | |
# === Config === | |
MAX_OUTPUT_TOKENS = 300 # hard cap for concise answers | |
QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.9Pj8v4ACpX3m5U3SZUrG_jzrjGF-T41J5icZ6EPMxnc" | |
QDRANT_URL = "https://d36718f0-be68-4040-b276-f1f39bc1aeb9.us-east4-0.gcp.cloud.qdrant.io" | |
qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) | |
AVAILABLE_COLLECTIONS = ["ImageOnline", "tezjet-site", "anish-pharma"] | |
index_cache = {} | |
active_state = {"collection": None, "query_engine": None} | |
# === Normalized Embedding Wrapper === | |
def normalize_vector(vec): | |
vec = np.array(vec) | |
return vec / np.linalg.norm(vec) | |
class NormalizedEmbedding(HuggingFaceEmbedding): | |
def get_text_embedding(self, text: str): | |
vec = super().get_text_embedding(text) | |
return normalize_vector(vec) | |
def get_query_embedding(self, query: str): | |
vec = super().get_query_embedding(query) | |
return normalize_vector(vec) | |
embed_model = NormalizedEmbedding(model_name="BAAI/bge-base-en-v1.5") | |
# === LLM (kept for compatibility; streaming uses Together SDK directly) === | |
llm = TogetherLLM( | |
model="meta-llama/Llama-3-8b-chat-hf", | |
api_key="a36246d65d8290f43667350b364c5b6bb8562eb50a4b947eec5bd7e79f2dffc6", | |
temperature=0.3, | |
max_tokens=MAX_OUTPUT_TOKENS, | |
top_p=0.7 | |
) | |
Settings.embed_model = embed_model | |
Settings.llm = llm | |
# === Cross-Encoder for Reranking === | |
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
# === Prompt Template (Optimized for Conciseness & Token Limit) === | |
custom_prompt = PromptTemplate( | |
"You are an expert assistant for ImageOnline Pvt Ltd.\n" | |
"Instructions:\n" | |
"- Be concise, factual, and to the point.\n" | |
"- Use bullet points where possible.\n" | |
"- Do not repeat previous answers unless asked.\n" | |
"- Stop once the question is addressed.\n" | |
"- If user may need more detail, invite follow-up questions.\n" | |
f"- Keep the answer within {MAX_OUTPUT_TOKENS} tokens.\n\n" | |
"Context (summarize if long):\n{context_str}\n\n" | |
"Query: {query_str}\n\n" | |
"Answer:\n" | |
) | |
# === Load Index === | |
def load_index_for_collection(collection_name: str) -> VectorStoreIndex: | |
vector_store = QdrantVectorStore( | |
client=qdrant_client, | |
collection_name=collection_name, | |
enable_hnsw=True | |
) | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
return VectorStoreIndex.from_vector_store(vector_store=vector_store, storage_context=storage_context) | |
# === Reference Renderer === | |
def get_clickable_references_from_response(source_nodes, max_refs=2): | |
seen = set() | |
links = [] | |
for node in source_nodes: | |
metadata = node.node.metadata | |
section = metadata.get("section") or metadata.get("title") or "Unknown" | |
source = metadata.get("source") or "Unknown" | |
key = (section, source) | |
if key not in seen: | |
seen.add(key) | |
if source.startswith("http"): | |
links.append(f"- [{section}]({source})") | |
else: | |
links.append(f"- {section}: {source}") | |
if len(links) >= max_refs: | |
break | |
return links | |
# === Safe Streaming Adapter for Together API (True Streaming) === | |
from together import Together | |
def _extract_event_text(event: Any) -> str: | |
try: | |
choices = getattr(event, "choices", None) | |
if choices: | |
first = choices[0] | |
delta = getattr(first, "delta", None) | |
if delta: | |
text = getattr(delta, "content", None) | |
if text: | |
return text | |
text = getattr(first, "text", None) | |
if text: | |
return text | |
except Exception: | |
pass | |
try: | |
if isinstance(event, dict): | |
choices = event.get("choices") | |
if choices and len(choices) > 0: | |
first = choices[0] | |
delta = first.get("delta") if isinstance(first, dict) else None | |
if isinstance(delta, dict): | |
return delta.get("content", "") or delta.get("text", "") or "" | |
message = first.get("message") or {} | |
if isinstance(message, dict): | |
return message.get("content", "") or "" | |
return first.get("text", "") or "" | |
except Exception: | |
pass | |
return "" | |
def _extract_response_text(resp: Any) -> str: | |
try: | |
choices = getattr(resp, "choices", None) | |
if choices and len(choices) > 0: | |
first = choices[0] | |
message = getattr(first, "message", None) | |
if message: | |
content = getattr(message, "content", None) | |
if content: | |
return content | |
if isinstance(message, dict): | |
return message.get("content", "") or "" | |
text = getattr(first, "text", None) | |
if text: | |
return text | |
except Exception: | |
pass | |
try: | |
if isinstance(resp, dict): | |
choices = resp.get("choices", []) | |
if choices: | |
first = choices[0] | |
message = first.get("message") or {} | |
if isinstance(message, dict): | |
return message.get("content", "") or "" | |
return first.get("text", "") or "" | |
except Exception: | |
pass | |
return str(resp) | |
class StreamingLLMAdapter: | |
def __init__(self, api_key: str, model: str, temperature: float = 0.3, top_p: float = 0.7, chunk_size: int = 64): | |
self.client = Together(api_key=api_key) | |
self.model = model | |
self.temperature = temperature | |
self.top_p = top_p | |
self.chunk_size = chunk_size | |
def stream_complete(self, prompt: str, max_tokens: int = MAX_OUTPUT_TOKENS, **kwargs) -> Generator[str, None, None]: | |
try: | |
events = self.client.chat.completions.create( | |
model=self.model, | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=max_tokens, | |
temperature=self.temperature, | |
top_p=self.top_p, | |
stream=True | |
) | |
for event in events: | |
text_piece = _extract_event_text(event) | |
if text_piece: | |
yield text_piece | |
except Exception: | |
yield from self._sync_fallback(prompt, max_tokens, **kwargs) | |
def _sync_fallback(self, prompt: str, max_tokens: int = MAX_OUTPUT_TOKENS, **kwargs) -> Generator[str, None, None]: | |
try: | |
resp = self.client.chat.completions.create( | |
model=self.model, | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=max_tokens, | |
temperature=self.temperature, | |
top_p=self.top_p | |
) | |
text = _extract_response_text(resp) | |
except Exception as e: | |
text = f"[Error from LLM: {e}]" | |
for i in range(0, len(text), self.chunk_size): | |
yield text[i:i + self.chunk_size] | |
streaming_llm = StreamingLLMAdapter( | |
api_key="a36246d65d8290f43667350b364c5b6bb8562eb50a4b947eec5bd7e79f2dffc6", | |
model="meta-llama/Llama-3-8b-chat-hf", | |
temperature=0.3, | |
top_p=0.7 | |
) | |
# === Query Chain with Reranking === | |
def rag_chain_prompt_and_sources(query: str, top_k: int = 3): | |
if not active_state["query_engine"]: | |
return None, None, "⚠️ Please select a website collection first." | |
raw_nodes = active_state["query_engine"].retrieve(query) | |
pairs = [(query, n.node.get_content()) for n in raw_nodes] | |
scores = reranker.predict(pairs) | |
scored_nodes = sorted(zip(raw_nodes, scores), key=lambda x: x[1], reverse=True) | |
top_nodes = [n for n, _ in scored_nodes[:top_k]] | |
# Truncate context if too large to save tokens | |
context = "\n\n".join([n.node.get_content() for n in top_nodes]) | |
if len(context) > 4000: | |
context = context[:4000] + "...\n[Context truncated for brevity]" | |
prompt = custom_prompt.format(context_str=context, query_str=query) | |
return prompt, top_nodes, None | |
# === Collection Switch === | |
def handle_collection_change(selected): | |
now = datetime.utcnow() | |
cached = index_cache.get(selected) | |
if cached: | |
query_engine, ts = cached | |
if now - ts < timedelta(hours=1): | |
active_state["collection"] = selected | |
active_state["query_engine"] = query_engine | |
return f"✅ Now chatting with: `{selected}`", [], [] | |
index = load_index_for_collection(selected) | |
query_engine = index.as_query_engine(similarity_top_k=10, vector_store_query_mode="default") | |
index_cache[selected] = (query_engine, now) | |
active_state["collection"] = selected | |
active_state["query_engine"] = query_engine | |
return f"✅ Now chatting with: `{selected}`", [], [] | |
# === Streaming Chat Handler === | |
def chat_interface_stream(message: str, history: list) -> Generator[Tuple[list, list, str], None, None]: | |
history = history or [] | |
message = (message or "").strip() | |
if not message: | |
yield history, history, "" | |
return | |
timestamp_user = datetime.now().strftime("%H:%M:%S") | |
user_msg = f"🧑 **You**\n{message}\n\n⏱️ {timestamp_user}" | |
history.append((user_msg, "⏳ _Bot is typing..._")) | |
yield history, history, "" | |
prompt, top_nodes, err = rag_chain_prompt_and_sources(message) | |
if err: | |
history[-1] = (user_msg, f"🤖 **Bot**\n{err}") | |
yield history, history, "" | |
return | |
assistant_text = "" | |
chunk_count = 0 | |
flush_every_n = 3 | |
try: | |
for chunk in streaming_llm.stream_complete(prompt, max_tokens=MAX_OUTPUT_TOKENS): | |
assistant_text += chunk | |
chunk_count += 1 | |
if chunk_count % flush_every_n == 0: | |
history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text}") | |
yield history, history, "" | |
history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text}") | |
except Exception as e: | |
history[-1] = (user_msg, f"🤖 **Bot**\n⚠️ {str(e)}") | |
yield history, history, "" | |
return | |
references = get_clickable_references_from_response(top_nodes) | |
if references: | |
assistant_text += "\n\n📚 **Reference(s):**\n" + "\n".join(references) | |
timestamp_bot = datetime.now().strftime("%H:%M:%S") | |
history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text.strip()}\n\n⏱️ {timestamp_bot}") | |
yield history, history, "" | |
# Fallback synchronous chat | |
def chat_interface_sync(message, history): | |
history = history or [] | |
message = message.strip() | |
if not message: | |
raise ValueError("Please enter a valid question.") | |
timestamp_user = datetime.now().strftime("%H:%M:%S") | |
user_msg = f"🧑 **You**\n{message}\n\n⏱️ {timestamp_user}" | |
bot_msg = "⏳ _Bot is typing..._" | |
history.append((user_msg, bot_msg)) | |
try: | |
time.sleep(0.5) | |
prompt, top_nodes, err = rag_chain_prompt_and_sources(message) | |
if err: | |
timestamp_bot = datetime.now().strftime("%H:%M:%S") | |
history[-1] = (user_msg, f"🤖 **Bot**\n{err}\n\n⏱️ {timestamp_bot}") | |
return history, history, "" | |
resp = llm.complete(prompt, max_tokens=MAX_OUTPUT_TOKENS).text | |
references = get_clickable_references_from_response(top_nodes) | |
if references: | |
resp += "\n\n📚 **Reference(s):**\n" + "\n".join(references) | |
timestamp_bot = datetime.now().strftime("%H:%M:%S") | |
bot_msg = f"🤖 **Bot**\n{resp.strip()}\n\n⏱️ {timestamp_bot}" | |
history[-1] = (user_msg, bot_msg) | |
except Exception as e: | |
timestamp_bot = datetime.now().strftime("%H:%M:%S") | |
error_msg = f"🤖 **Bot**\n⚠️ {str(e)}\n\n⏱️ {timestamp_bot}" | |
history[-1] = (user_msg, error_msg) | |
return history, history, "" | |
# === Gradio UI === | |
def launch_gradio(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# 💬 Demo IOPL Multi-Website Chatbot") | |
gr.Markdown("Choose a website to chat with.") | |
with gr.Row(): | |
collection_dropdown = gr.Dropdown(choices=AVAILABLE_COLLECTIONS, label="Select Website to chat") | |
load_button = gr.Button("Load Website") | |
collection_status = gr.Markdown("") | |
chatbot = gr.Chatbot() | |
state = gr.State([]) | |
with gr.Row(equal_height=True): | |
msg = gr.Textbox(placeholder="Ask your question...", show_label=False, scale=9) | |
send_btn = gr.Button("🚀 Send", scale=1) | |
load_button.click( | |
fn=handle_collection_change, | |
inputs=collection_dropdown, | |
outputs=[collection_status, chatbot, state] | |
) | |
msg.submit(chat_interface_stream, inputs=[msg, state], outputs=[chatbot, state, msg]) | |
send_btn.click(chat_interface_stream, inputs=[msg, state], outputs=[chatbot, state, msg]) | |
with gr.Row(): | |
clear_btn = gr.Button("🧹 Clear Chat") | |
clear_btn.click(fn=lambda: ([], []), outputs=[chatbot, state]) | |
return demo | |
demo = launch_gradio() | |
demo.launch() |