Spaces:
Sleeping
Sleeping
from datetime import datetime, timedelta | |
import time | |
import gradio as gr | |
import numpy as np | |
from llama_index.core import VectorStoreIndex, StorageContext, Settings | |
from llama_index.core.node_parser import SimpleNodeParser | |
from llama_index.core.prompts import PromptTemplate | |
from llama_index.vector_stores.qdrant import QdrantVectorStore | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.llms.together import TogetherLLM | |
from qdrant_client import QdrantClient | |
from sentence_transformers import CrossEncoder | |
from typing import Generator, Iterable, Tuple, Any | |
# === Config === | |
QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.9Pj8v4ACpX3m5U3SZUrG_jzrjGF-T41J5icZ6EPMxnc" | |
QDRANT_URL = "https://d36718f0-be68-4040-b276-f1f39bc1aeb9.us-east4-0.gcp.cloud.qdrant.io" | |
qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) | |
AVAILABLE_COLLECTIONS = ["ImageOnline", "tezjet-site", "anish-pharma"] | |
index_cache = {} | |
active_state = {"collection": None, "query_engine": None} | |
# === Normalized Embedding Wrapper === | |
def normalize_vector(vec): | |
vec = np.array(vec) | |
return vec / np.linalg.norm(vec) | |
class NormalizedEmbedding(HuggingFaceEmbedding): | |
def get_text_embedding(self, text: str): | |
vec = super().get_text_embedding(text) | |
return normalize_vector(vec) | |
def get_query_embedding(self, query: str): | |
vec = super().get_query_embedding(query) | |
return normalize_vector(vec) | |
embed_model = NormalizedEmbedding(model_name="BAAI/bge-base-en-v1.5") | |
# === LLM (kept for compatibility; streaming uses Together SDK directly) === | |
llm = TogetherLLM( | |
model="meta-llama/Llama-3-8b-chat-hf", | |
api_key="a36246d65d8290f43667350b364c5b6bb8562eb50a4b947eec5bd7e79f2dffc6", | |
temperature=0.3, | |
max_tokens=1024, | |
top_p=0.7 | |
) | |
Settings.embed_model = embed_model | |
Settings.llm = llm | |
# === Cross-Encoder for Reranking === | |
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
# === Prompt Template === | |
custom_prompt = PromptTemplate( | |
"You are an expert assistant for ImageOnline Pvt Ltd.\n" | |
"Answer the user's query using relevant information from the context below.\n\n" | |
"Context:\n{context_str}\n\n" | |
"Query: {query_str}\n\n" | |
) | |
# === Load Index === | |
def load_index_for_collection(collection_name: str) -> VectorStoreIndex: | |
vector_store = QdrantVectorStore( | |
client=qdrant_client, | |
collection_name=collection_name, | |
enable_hnsw=True | |
) | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
return VectorStoreIndex.from_vector_store(vector_store=vector_store, storage_context=storage_context) | |
# === Reference Renderer === | |
def get_clickable_references_from_response(source_nodes, max_refs=2): | |
seen = set() | |
links = [] | |
for node in source_nodes: | |
metadata = node.node.metadata | |
section = metadata.get("section") or metadata.get("title") or "Unknown" | |
source = metadata.get("source") or "Unknown" | |
key = (section, source) | |
if key not in seen: | |
seen.add(key) | |
if source.startswith("http"): | |
links.append(f"- [{section}]({source})") | |
else: | |
links.append(f"- {section}: {source}") | |
if len(links) >= max_refs: | |
break | |
return links | |
# === Safe Streaming Adapter for Together API (True Streaming) === | |
# Requires: pip install together | |
from together import Together | |
def _extract_event_text(event: Any) -> str: | |
""" | |
Safely extract the streamed text delta from an event returned by the Together SDK. | |
Supports dict-like and object-like events. | |
Returns empty string if nothing found. | |
""" | |
try: | |
# Try object attribute access | |
choices = getattr(event, "choices", None) | |
if choices: | |
# event.choices[0].delta could be object-like | |
first = choices[0] | |
delta = getattr(first, "delta", None) | |
if delta: | |
text = getattr(delta, "content", None) | |
if text: | |
return text | |
# sometimes content is directly in choice | |
text = getattr(first, "text", None) | |
if text: | |
return text | |
except Exception: | |
pass | |
# Try dict-like access | |
try: | |
if isinstance(event, dict): | |
choices = event.get("choices") | |
if choices and len(choices) > 0: | |
first = choices[0] | |
# delta may be nested | |
delta = first.get("delta") if isinstance(first, dict) else None | |
if isinstance(delta, dict): | |
return delta.get("content", "") or delta.get("text", "") or "" | |
# fallback to message/content | |
message = first.get("message") or {} | |
if isinstance(message, dict): | |
return message.get("content", "") or "" | |
return first.get("text", "") or "" | |
except Exception: | |
pass | |
return "" | |
def _extract_response_text(resp: Any) -> str: | |
""" | |
Safely extract full response text from a non-streaming response object/dict from Together SDK. | |
""" | |
try: | |
# object-like | |
choices = getattr(resp, "choices", None) | |
if choices and len(choices) > 0: | |
first = choices[0] | |
# message may be attribute or dict | |
message = getattr(first, "message", None) | |
if message: | |
# message.content may be attribute | |
content = getattr(message, "content", None) | |
if content: | |
return content | |
# dict | |
if isinstance(message, dict): | |
return message.get("content", "") or "" | |
# fallback to text on choice | |
text = getattr(first, "text", None) | |
if text: | |
return text | |
except Exception: | |
pass | |
# dict-like | |
try: | |
if isinstance(resp, dict): | |
choices = resp.get("choices", []) | |
if choices: | |
first = choices[0] | |
message = first.get("message") or {} | |
if isinstance(message, dict): | |
return message.get("content", "") or "" | |
return first.get("text", "") or "" | |
except Exception: | |
pass | |
# final fallback | |
return str(resp) | |
class StreamingLLMAdapter: | |
def __init__(self, api_key: str, model: str, temperature: float = 0.3, top_p: float = 0.7, chunk_size: int = 64): | |
self.client = Together(api_key=api_key) | |
self.model = model | |
self.temperature = temperature | |
self.top_p = top_p | |
self.chunk_size = chunk_size | |
def stream_complete(self, prompt: str, max_tokens: int = 1024, **kwargs) -> Generator[str, None, None]: | |
""" | |
Use Together's native streaming API to yield tokens in real time. | |
Falls back to non-streamed response if streaming isn't available or errors. | |
""" | |
try: | |
# the Together SDK exposes an iterator when stream=True | |
events = self.client.chat.completions.create( | |
model=self.model, | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=max_tokens, | |
temperature=self.temperature, | |
top_p=self.top_p, | |
stream=True | |
) | |
for event in events: | |
# robust extraction (handles dicts or objects) | |
text_piece = _extract_event_text(event) | |
if text_piece: | |
yield text_piece | |
except Exception: | |
# fallback to synchronous non-streaming | |
yield from self._sync_fallback(prompt, max_tokens, **kwargs) | |
def _sync_fallback(self, prompt: str, max_tokens: int = 1024, **kwargs) -> Generator[str, None, None]: | |
"""Call Together API without streaming and yield chunks.""" | |
try: | |
resp = self.client.chat.completions.create( | |
model=self.model, | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=max_tokens, | |
temperature=self.temperature, | |
top_p=self.top_p | |
) | |
text = _extract_response_text(resp) | |
except Exception as e: | |
text = f"[Error from LLM: {e}]" | |
for i in range(0, len(text), self.chunk_size): | |
yield text[i:i + self.chunk_size] | |
# instantiate streaming adapter (keep your API key here) | |
streaming_llm = StreamingLLMAdapter( | |
api_key="a36246d65d8290f43667350b364c5b6bb8562eb50a4b947eec5bd7e79f2dffc6", | |
model="meta-llama/Llama-3-8b-chat-hf", | |
temperature=0.3, | |
top_p=0.7 | |
) | |
# === Query Chain with Reranking === | |
def rag_chain_prompt_and_sources(query: str, top_k: int = 3): | |
""" | |
Returns (prompt_text, top_nodes) using the existing retrieval + reranking flow. | |
We separate building prompt from calling the LLM so we can stream the final call. | |
""" | |
if not active_state["query_engine"]: | |
return None, None, "⚠️ Please select a website collection first." | |
raw_nodes = active_state["query_engine"].retrieve(query) | |
# Step 2: Rerank | |
pairs = [(query, n.node.get_content()) for n in raw_nodes] | |
scores = reranker.predict(pairs) | |
scored_nodes = sorted(zip(raw_nodes, scores), key=lambda x: x[1], reverse=True) | |
top_nodes = [n for n, _ in scored_nodes[:top_k]] | |
# Step 3: Compose prompt | |
context = "\n\n".join([n.node.get_content() for n in top_nodes]) | |
prompt = custom_prompt.format(context_str=context, query_str=query) | |
return prompt, top_nodes, None | |
# === Collection Switch === | |
def handle_collection_change(selected): | |
now = datetime.utcnow() | |
cached = index_cache.get(selected) | |
if cached: | |
query_engine, ts = cached | |
if now - ts < timedelta(hours=1): | |
active_state["collection"] = selected | |
active_state["query_engine"] = query_engine | |
return f"✅ Now chatting with: `{selected}`", [], [] | |
index = load_index_for_collection(selected) | |
query_engine = index.as_query_engine(similarity_top_k=10, vector_store_query_mode="default") | |
index_cache[selected] = (query_engine, now) | |
active_state["collection"] = selected | |
active_state["query_engine"] = query_engine | |
return f"✅ Now chatting with: `{selected}`", [], [] | |
# === Streaming Chat Handler === | |
def chat_interface_stream(message: str, history: list) -> Generator[Tuple[list, list, str], None, None]: | |
""" | |
Yields tuples of (chatbot_history, state, textbox_value) so Gradio gets | |
the right number of outputs for each yield when using streaming. | |
""" | |
history = history or [] | |
message = (message or "").strip() | |
if not message: | |
# still return all outputs | |
yield history, history, "" | |
return | |
timestamp_user = datetime.now().strftime("%H:%M:%S") | |
user_msg = f"🧑 **You**\n{message}\n\n⏱️ {timestamp_user}" | |
# append placeholder bot typing state | |
history.append((user_msg, "⏳ _Bot is typing..._")) | |
# initial update (user message + typing) | |
yield history, history, "" | |
prompt, top_nodes, err = rag_chain_prompt_and_sources(message) | |
if err: | |
history[-1] = (user_msg, f"🤖 **Bot**\n{err}") | |
yield history, history, "" | |
return | |
assistant_text = "" | |
chunk_count = 0 | |
flush_every_n = 3 # flush every 3 small deltas (tweak if you want more frequent updates) | |
try: | |
# stream from Together | |
for chunk in streaming_llm.stream_complete(prompt, max_tokens=1024): | |
assistant_text += chunk | |
chunk_count += 1 | |
# periodically flush partial output to UI | |
if chunk_count % flush_every_n == 0: | |
history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text}") | |
yield history, history, "" | |
# after streaming completes, append any leftover partial (if not flushed recently) | |
history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text}") | |
except Exception as e: | |
# on error, show error message | |
history[-1] = (user_msg, f"🤖 **Bot**\n⚠️ {str(e)}") | |
yield history, history, "" | |
return | |
# Add references at the end | |
references = get_clickable_references_from_response(top_nodes) | |
if references: | |
assistant_text += "\n\n📚 **Reference(s):**\n" + "\n".join(references) | |
timestamp_bot = datetime.now().strftime("%H:%M:%S") | |
history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text.strip()}\n\n⏱️ {timestamp_bot}") | |
# final yield with textbox cleared | |
yield history, history, "" | |
# Fallback synchronous chat (kept for compatibility if you want non-streaming) | |
def chat_interface_sync(message, history): | |
history = history or [] | |
message = message.strip() | |
if not message: | |
raise ValueError("Please enter a valid question.") | |
timestamp_user = datetime.now().strftime("%H:%M:%S") | |
user_msg = f"🧑 **You**\n{message}\n\n⏱️ {timestamp_user}" | |
bot_msg = "⏳ _Bot is typing..._" | |
history.append((user_msg, bot_msg)) | |
try: | |
time.sleep(0.5) | |
prompt, top_nodes, err = rag_chain_prompt_and_sources(message) | |
if err: | |
timestamp_bot = datetime.now().strftime("%H:%M:%S") | |
history[-1] = (user_msg, f"🤖 **Bot**\n{err}\n\n⏱️ {timestamp_bot}") | |
return history, history, "" | |
resp = llm.complete(prompt).text | |
references = get_clickable_references_from_response(top_nodes) | |
if references: | |
resp += "\n\n📚 **Reference(s):**\n" + "\n".join(references) | |
timestamp_bot = datetime.now().strftime("%H:%M:%S") | |
bot_msg = f"🤖 **Bot**\n{resp.strip()}\n\n⏱️ {timestamp_bot}" | |
history[-1] = (user_msg, bot_msg) | |
except Exception as e: | |
timestamp_bot = datetime.now().strftime("%H:%M:%S") | |
error_msg = f"🤖 **Bot**\n⚠️ {str(e)}\n\n⏱️ {timestamp_bot}" | |
history[-1] = (user_msg, error_msg) | |
return history, history, "" | |
# === Gradio UI === | |
def launch_gradio(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# 💬 Demo IOPL Multi-Website Chatbot") | |
gr.Markdown("Choose a website you want to chat with.") | |
with gr.Row(): | |
collection_dropdown = gr.Dropdown(choices=AVAILABLE_COLLECTIONS, label="Select Website to chat") | |
load_button = gr.Button("Load Website") | |
collection_status = gr.Markdown("") | |
chatbot = gr.Chatbot() | |
state = gr.State([]) | |
with gr.Row(equal_height=True): | |
msg = gr.Textbox(placeholder="Ask your question...", show_label=False, scale=9) | |
send_btn = gr.Button("🚀 Send", scale=1) | |
load_button.click( | |
fn=handle_collection_change, | |
inputs=collection_dropdown, | |
outputs=[collection_status, chatbot, state] | |
) | |
# Use the streaming generator for submit/click so Gradio receives yields | |
msg.submit(chat_interface_stream, inputs=[msg, state], outputs=[chatbot, state, msg]) | |
send_btn.click(chat_interface_stream, inputs=[msg, state], outputs=[chatbot, state, msg]) | |
with gr.Row(): | |
clear_btn = gr.Button("🧹 Clear Chat") | |
clear_btn.click(fn=lambda: ([], []), outputs=[chatbot, state]) | |
return demo | |
demo = launch_gradio() | |
demo.launch() |