IOPL-Chatbot-2 / app.py
IProject-10's picture
Update app.py
ccb5f9c verified
from datetime import datetime, timedelta
import time
import gradio as gr
import numpy as np
from llama_index.core import VectorStoreIndex, StorageContext, Settings
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core.prompts import PromptTemplate
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.together import TogetherLLM
from qdrant_client import QdrantClient
from sentence_transformers import CrossEncoder
from typing import Generator, Iterable, Tuple, Any
# === Config ===
QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.9Pj8v4ACpX3m5U3SZUrG_jzrjGF-T41J5icZ6EPMxnc"
QDRANT_URL = "https://d36718f0-be68-4040-b276-f1f39bc1aeb9.us-east4-0.gcp.cloud.qdrant.io"
qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
AVAILABLE_COLLECTIONS = ["ImageOnline", "tezjet-site", "anish-pharma"]
index_cache = {}
active_state = {"collection": None, "query_engine": None}
# === Normalized Embedding Wrapper ===
def normalize_vector(vec):
vec = np.array(vec)
return vec / np.linalg.norm(vec)
class NormalizedEmbedding(HuggingFaceEmbedding):
def get_text_embedding(self, text: str):
vec = super().get_text_embedding(text)
return normalize_vector(vec)
def get_query_embedding(self, query: str):
vec = super().get_query_embedding(query)
return normalize_vector(vec)
embed_model = NormalizedEmbedding(model_name="BAAI/bge-base-en-v1.5")
# === LLM (kept for compatibility; streaming uses Together SDK directly) ===
llm = TogetherLLM(
model="meta-llama/Llama-3-8b-chat-hf",
api_key="a36246d65d8290f43667350b364c5b6bb8562eb50a4b947eec5bd7e79f2dffc6",
temperature=0.3,
max_tokens=1024,
top_p=0.7
)
Settings.embed_model = embed_model
Settings.llm = llm
# === Cross-Encoder for Reranking ===
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
# === Prompt Template ===
custom_prompt = PromptTemplate(
"You are an expert assistant for ImageOnline Pvt Ltd.\n"
"Answer the user's query using relevant information from the context below.\n\n"
"Context:\n{context_str}\n\n"
"Query: {query_str}\n\n"
)
# === Load Index ===
def load_index_for_collection(collection_name: str) -> VectorStoreIndex:
vector_store = QdrantVectorStore(
client=qdrant_client,
collection_name=collection_name,
enable_hnsw=True
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
return VectorStoreIndex.from_vector_store(vector_store=vector_store, storage_context=storage_context)
# === Reference Renderer ===
def get_clickable_references_from_response(source_nodes, max_refs=2):
seen = set()
links = []
for node in source_nodes:
metadata = node.node.metadata
section = metadata.get("section") or metadata.get("title") or "Unknown"
source = metadata.get("source") or "Unknown"
key = (section, source)
if key not in seen:
seen.add(key)
if source.startswith("http"):
links.append(f"- [{section}]({source})")
else:
links.append(f"- {section}: {source}")
if len(links) >= max_refs:
break
return links
# === Safe Streaming Adapter for Together API (True Streaming) ===
# Requires: pip install together
from together import Together
def _extract_event_text(event: Any) -> str:
"""
Safely extract the streamed text delta from an event returned by the Together SDK.
Supports dict-like and object-like events.
Returns empty string if nothing found.
"""
try:
# Try object attribute access
choices = getattr(event, "choices", None)
if choices:
# event.choices[0].delta could be object-like
first = choices[0]
delta = getattr(first, "delta", None)
if delta:
text = getattr(delta, "content", None)
if text:
return text
# sometimes content is directly in choice
text = getattr(first, "text", None)
if text:
return text
except Exception:
pass
# Try dict-like access
try:
if isinstance(event, dict):
choices = event.get("choices")
if choices and len(choices) > 0:
first = choices[0]
# delta may be nested
delta = first.get("delta") if isinstance(first, dict) else None
if isinstance(delta, dict):
return delta.get("content", "") or delta.get("text", "") or ""
# fallback to message/content
message = first.get("message") or {}
if isinstance(message, dict):
return message.get("content", "") or ""
return first.get("text", "") or ""
except Exception:
pass
return ""
def _extract_response_text(resp: Any) -> str:
"""
Safely extract full response text from a non-streaming response object/dict from Together SDK.
"""
try:
# object-like
choices = getattr(resp, "choices", None)
if choices and len(choices) > 0:
first = choices[0]
# message may be attribute or dict
message = getattr(first, "message", None)
if message:
# message.content may be attribute
content = getattr(message, "content", None)
if content:
return content
# dict
if isinstance(message, dict):
return message.get("content", "") or ""
# fallback to text on choice
text = getattr(first, "text", None)
if text:
return text
except Exception:
pass
# dict-like
try:
if isinstance(resp, dict):
choices = resp.get("choices", [])
if choices:
first = choices[0]
message = first.get("message") or {}
if isinstance(message, dict):
return message.get("content", "") or ""
return first.get("text", "") or ""
except Exception:
pass
# final fallback
return str(resp)
class StreamingLLMAdapter:
def __init__(self, api_key: str, model: str, temperature: float = 0.3, top_p: float = 0.7, chunk_size: int = 64):
self.client = Together(api_key=api_key)
self.model = model
self.temperature = temperature
self.top_p = top_p
self.chunk_size = chunk_size
def stream_complete(self, prompt: str, max_tokens: int = 1024, **kwargs) -> Generator[str, None, None]:
"""
Use Together's native streaming API to yield tokens in real time.
Falls back to non-streamed response if streaming isn't available or errors.
"""
try:
# the Together SDK exposes an iterator when stream=True
events = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=self.temperature,
top_p=self.top_p,
stream=True
)
for event in events:
# robust extraction (handles dicts or objects)
text_piece = _extract_event_text(event)
if text_piece:
yield text_piece
except Exception:
# fallback to synchronous non-streaming
yield from self._sync_fallback(prompt, max_tokens, **kwargs)
def _sync_fallback(self, prompt: str, max_tokens: int = 1024, **kwargs) -> Generator[str, None, None]:
"""Call Together API without streaming and yield chunks."""
try:
resp = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=self.temperature,
top_p=self.top_p
)
text = _extract_response_text(resp)
except Exception as e:
text = f"[Error from LLM: {e}]"
for i in range(0, len(text), self.chunk_size):
yield text[i:i + self.chunk_size]
# instantiate streaming adapter (keep your API key here)
streaming_llm = StreamingLLMAdapter(
api_key="a36246d65d8290f43667350b364c5b6bb8562eb50a4b947eec5bd7e79f2dffc6",
model="meta-llama/Llama-3-8b-chat-hf",
temperature=0.3,
top_p=0.7
)
# === Query Chain with Reranking ===
def rag_chain_prompt_and_sources(query: str, top_k: int = 3):
"""
Returns (prompt_text, top_nodes) using the existing retrieval + reranking flow.
We separate building prompt from calling the LLM so we can stream the final call.
"""
if not active_state["query_engine"]:
return None, None, "⚠️ Please select a website collection first."
raw_nodes = active_state["query_engine"].retrieve(query)
# Step 2: Rerank
pairs = [(query, n.node.get_content()) for n in raw_nodes]
scores = reranker.predict(pairs)
scored_nodes = sorted(zip(raw_nodes, scores), key=lambda x: x[1], reverse=True)
top_nodes = [n for n, _ in scored_nodes[:top_k]]
# Step 3: Compose prompt
context = "\n\n".join([n.node.get_content() for n in top_nodes])
prompt = custom_prompt.format(context_str=context, query_str=query)
return prompt, top_nodes, None
# === Collection Switch ===
def handle_collection_change(selected):
now = datetime.utcnow()
cached = index_cache.get(selected)
if cached:
query_engine, ts = cached
if now - ts < timedelta(hours=1):
active_state["collection"] = selected
active_state["query_engine"] = query_engine
return f"✅ Now chatting with: `{selected}`", [], []
index = load_index_for_collection(selected)
query_engine = index.as_query_engine(similarity_top_k=10, vector_store_query_mode="default")
index_cache[selected] = (query_engine, now)
active_state["collection"] = selected
active_state["query_engine"] = query_engine
return f"✅ Now chatting with: `{selected}`", [], []
# === Streaming Chat Handler ===
def chat_interface_stream(message: str, history: list) -> Generator[Tuple[list, list, str], None, None]:
"""
Yields tuples of (chatbot_history, state, textbox_value) so Gradio gets
the right number of outputs for each yield when using streaming.
"""
history = history or []
message = (message or "").strip()
if not message:
# still return all outputs
yield history, history, ""
return
timestamp_user = datetime.now().strftime("%H:%M:%S")
user_msg = f"🧑 **You**\n{message}\n\n⏱️ {timestamp_user}"
# append placeholder bot typing state
history.append((user_msg, "⏳ _Bot is typing..._"))
# initial update (user message + typing)
yield history, history, ""
prompt, top_nodes, err = rag_chain_prompt_and_sources(message)
if err:
history[-1] = (user_msg, f"🤖 **Bot**\n{err}")
yield history, history, ""
return
assistant_text = ""
chunk_count = 0
flush_every_n = 3 # flush every 3 small deltas (tweak if you want more frequent updates)
try:
# stream from Together
for chunk in streaming_llm.stream_complete(prompt, max_tokens=1024):
assistant_text += chunk
chunk_count += 1
# periodically flush partial output to UI
if chunk_count % flush_every_n == 0:
history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text}")
yield history, history, ""
# after streaming completes, append any leftover partial (if not flushed recently)
history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text}")
except Exception as e:
# on error, show error message
history[-1] = (user_msg, f"🤖 **Bot**\n⚠️ {str(e)}")
yield history, history, ""
return
# Add references at the end
references = get_clickable_references_from_response(top_nodes)
if references:
assistant_text += "\n\n📚 **Reference(s):**\n" + "\n".join(references)
timestamp_bot = datetime.now().strftime("%H:%M:%S")
history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text.strip()}\n\n⏱️ {timestamp_bot}")
# final yield with textbox cleared
yield history, history, ""
# Fallback synchronous chat (kept for compatibility if you want non-streaming)
def chat_interface_sync(message, history):
history = history or []
message = message.strip()
if not message:
raise ValueError("Please enter a valid question.")
timestamp_user = datetime.now().strftime("%H:%M:%S")
user_msg = f"🧑 **You**\n{message}\n\n⏱️ {timestamp_user}"
bot_msg = "⏳ _Bot is typing..._"
history.append((user_msg, bot_msg))
try:
time.sleep(0.5)
prompt, top_nodes, err = rag_chain_prompt_and_sources(message)
if err:
timestamp_bot = datetime.now().strftime("%H:%M:%S")
history[-1] = (user_msg, f"🤖 **Bot**\n{err}\n\n⏱️ {timestamp_bot}")
return history, history, ""
resp = llm.complete(prompt).text
references = get_clickable_references_from_response(top_nodes)
if references:
resp += "\n\n📚 **Reference(s):**\n" + "\n".join(references)
timestamp_bot = datetime.now().strftime("%H:%M:%S")
bot_msg = f"🤖 **Bot**\n{resp.strip()}\n\n⏱️ {timestamp_bot}"
history[-1] = (user_msg, bot_msg)
except Exception as e:
timestamp_bot = datetime.now().strftime("%H:%M:%S")
error_msg = f"🤖 **Bot**\n⚠️ {str(e)}\n\n⏱️ {timestamp_bot}"
history[-1] = (user_msg, error_msg)
return history, history, ""
# === Gradio UI ===
def launch_gradio():
with gr.Blocks() as demo:
gr.Markdown("# 💬 Demo IOPL Multi-Website Chatbot")
gr.Markdown("Choose a website you want to chat with.")
with gr.Row():
collection_dropdown = gr.Dropdown(choices=AVAILABLE_COLLECTIONS, label="Select Website to chat")
load_button = gr.Button("Load Website")
collection_status = gr.Markdown("")
chatbot = gr.Chatbot()
state = gr.State([])
with gr.Row(equal_height=True):
msg = gr.Textbox(placeholder="Ask your question...", show_label=False, scale=9)
send_btn = gr.Button("🚀 Send", scale=1)
load_button.click(
fn=handle_collection_change,
inputs=collection_dropdown,
outputs=[collection_status, chatbot, state]
)
# Use the streaming generator for submit/click so Gradio receives yields
msg.submit(chat_interface_stream, inputs=[msg, state], outputs=[chatbot, state, msg])
send_btn.click(chat_interface_stream, inputs=[msg, state], outputs=[chatbot, state, msg])
with gr.Row():
clear_btn = gr.Button("🧹 Clear Chat")
clear_btn.click(fn=lambda: ([], []), outputs=[chatbot, state])
return demo
demo = launch_gradio()
demo.launch()