IOPL-Chatbot-3 / app.py
IProject-10's picture
Update app.py
88d43e1 verified
from datetime import datetime, timedelta
import time
import gradio as gr
import numpy as np
from llama_index.core import VectorStoreIndex, StorageContext, Settings
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core.prompts import PromptTemplate
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.together import TogetherLLM
from qdrant_client import QdrantClient
from sentence_transformers import CrossEncoder
from typing import Generator, Iterable, Tuple, Any
# === Config ===
MAX_OUTPUT_TOKENS = 300 # hard cap for concise answers
QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.9Pj8v4ACpX3m5U3SZUrG_jzrjGF-T41J5icZ6EPMxnc"
QDRANT_URL = "https://d36718f0-be68-4040-b276-f1f39bc1aeb9.us-east4-0.gcp.cloud.qdrant.io"
qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
AVAILABLE_COLLECTIONS = ["ImageOnline", "tezjet-site", "anish-pharma"]
index_cache = {}
active_state = {"collection": None, "query_engine": None}
# === Normalized Embedding Wrapper ===
def normalize_vector(vec):
vec = np.array(vec)
return vec / np.linalg.norm(vec)
class NormalizedEmbedding(HuggingFaceEmbedding):
def get_text_embedding(self, text: str):
vec = super().get_text_embedding(text)
return normalize_vector(vec)
def get_query_embedding(self, query: str):
vec = super().get_query_embedding(query)
return normalize_vector(vec)
embed_model = NormalizedEmbedding(model_name="BAAI/bge-base-en-v1.5")
# === LLM (kept for compatibility; streaming uses Together SDK directly) ===
llm = TogetherLLM(
model="meta-llama/Llama-3-8b-chat-hf",
api_key="a36246d65d8290f43667350b364c5b6bb8562eb50a4b947eec5bd7e79f2dffc6",
temperature=0.3,
max_tokens=MAX_OUTPUT_TOKENS,
top_p=0.7
)
Settings.embed_model = embed_model
Settings.llm = llm
# === Cross-Encoder for Reranking ===
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
# === Prompt Template (Optimized for Conciseness & Token Limit) ===
custom_prompt = PromptTemplate(
"You are an expert assistant for ImageOnline Pvt Ltd.\n"
"Instructions:\n"
"- Be concise, factual, and to the point.\n"
"- Use bullet points where possible.\n"
"- Do not repeat previous answers unless asked.\n"
"- Stop once the question is addressed.\n"
"- If user may need more detail, invite follow-up questions.\n"
f"- Keep the answer within {MAX_OUTPUT_TOKENS} tokens.\n\n"
"Context (summarize if long):\n{context_str}\n\n"
"Query: {query_str}\n\n"
"Answer:\n"
)
# === Load Index ===
def load_index_for_collection(collection_name: str) -> VectorStoreIndex:
vector_store = QdrantVectorStore(
client=qdrant_client,
collection_name=collection_name,
enable_hnsw=True
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
return VectorStoreIndex.from_vector_store(vector_store=vector_store, storage_context=storage_context)
# === Reference Renderer ===
def get_clickable_references_from_response(source_nodes, max_refs=2):
seen = set()
links = []
for node in source_nodes:
metadata = node.node.metadata
section = metadata.get("section") or metadata.get("title") or "Unknown"
source = metadata.get("source") or "Unknown"
key = (section, source)
if key not in seen:
seen.add(key)
if source.startswith("http"):
links.append(f"- [{section}]({source})")
else:
links.append(f"- {section}: {source}")
if len(links) >= max_refs:
break
return links
# === Safe Streaming Adapter for Together API (True Streaming) ===
from together import Together
def _extract_event_text(event: Any) -> str:
try:
choices = getattr(event, "choices", None)
if choices:
first = choices[0]
delta = getattr(first, "delta", None)
if delta:
text = getattr(delta, "content", None)
if text:
return text
text = getattr(first, "text", None)
if text:
return text
except Exception:
pass
try:
if isinstance(event, dict):
choices = event.get("choices")
if choices and len(choices) > 0:
first = choices[0]
delta = first.get("delta") if isinstance(first, dict) else None
if isinstance(delta, dict):
return delta.get("content", "") or delta.get("text", "") or ""
message = first.get("message") or {}
if isinstance(message, dict):
return message.get("content", "") or ""
return first.get("text", "") or ""
except Exception:
pass
return ""
def _extract_response_text(resp: Any) -> str:
try:
choices = getattr(resp, "choices", None)
if choices and len(choices) > 0:
first = choices[0]
message = getattr(first, "message", None)
if message:
content = getattr(message, "content", None)
if content:
return content
if isinstance(message, dict):
return message.get("content", "") or ""
text = getattr(first, "text", None)
if text:
return text
except Exception:
pass
try:
if isinstance(resp, dict):
choices = resp.get("choices", [])
if choices:
first = choices[0]
message = first.get("message") or {}
if isinstance(message, dict):
return message.get("content", "") or ""
return first.get("text", "") or ""
except Exception:
pass
return str(resp)
class StreamingLLMAdapter:
def __init__(self, api_key: str, model: str, temperature: float = 0.3, top_p: float = 0.7, chunk_size: int = 64):
self.client = Together(api_key=api_key)
self.model = model
self.temperature = temperature
self.top_p = top_p
self.chunk_size = chunk_size
def stream_complete(self, prompt: str, max_tokens: int = MAX_OUTPUT_TOKENS, **kwargs) -> Generator[str, None, None]:
try:
events = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=self.temperature,
top_p=self.top_p,
stream=True
)
for event in events:
text_piece = _extract_event_text(event)
if text_piece:
yield text_piece
except Exception:
yield from self._sync_fallback(prompt, max_tokens, **kwargs)
def _sync_fallback(self, prompt: str, max_tokens: int = MAX_OUTPUT_TOKENS, **kwargs) -> Generator[str, None, None]:
try:
resp = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=self.temperature,
top_p=self.top_p
)
text = _extract_response_text(resp)
except Exception as e:
text = f"[Error from LLM: {e}]"
for i in range(0, len(text), self.chunk_size):
yield text[i:i + self.chunk_size]
streaming_llm = StreamingLLMAdapter(
api_key="a36246d65d8290f43667350b364c5b6bb8562eb50a4b947eec5bd7e79f2dffc6",
model="meta-llama/Llama-3-8b-chat-hf",
temperature=0.3,
top_p=0.7
)
# === Query Chain with Reranking ===
def rag_chain_prompt_and_sources(query: str, top_k: int = 3):
if not active_state["query_engine"]:
return None, None, "⚠️ Please select a website collection first."
raw_nodes = active_state["query_engine"].retrieve(query)
pairs = [(query, n.node.get_content()) for n in raw_nodes]
scores = reranker.predict(pairs)
scored_nodes = sorted(zip(raw_nodes, scores), key=lambda x: x[1], reverse=True)
top_nodes = [n for n, _ in scored_nodes[:top_k]]
# Truncate context if too large to save tokens
context = "\n\n".join([n.node.get_content() for n in top_nodes])
if len(context) > 4000:
context = context[:4000] + "...\n[Context truncated for brevity]"
prompt = custom_prompt.format(context_str=context, query_str=query)
return prompt, top_nodes, None
# === Collection Switch ===
def handle_collection_change(selected):
now = datetime.utcnow()
cached = index_cache.get(selected)
if cached:
query_engine, ts = cached
if now - ts < timedelta(hours=1):
active_state["collection"] = selected
active_state["query_engine"] = query_engine
return f"✅ Now chatting with: `{selected}`", [], []
index = load_index_for_collection(selected)
query_engine = index.as_query_engine(similarity_top_k=10, vector_store_query_mode="default")
index_cache[selected] = (query_engine, now)
active_state["collection"] = selected
active_state["query_engine"] = query_engine
return f"✅ Now chatting with: `{selected}`", [], []
# === Streaming Chat Handler ===
def chat_interface_stream(message: str, history: list) -> Generator[Tuple[list, list, str], None, None]:
history = history or []
message = (message or "").strip()
if not message:
yield history, history, ""
return
timestamp_user = datetime.now().strftime("%H:%M:%S")
user_msg = f"🧑 **You**\n{message}\n\n⏱️ {timestamp_user}"
history.append((user_msg, "⏳ _Bot is typing..._"))
yield history, history, ""
prompt, top_nodes, err = rag_chain_prompt_and_sources(message)
if err:
history[-1] = (user_msg, f"🤖 **Bot**\n{err}")
yield history, history, ""
return
assistant_text = ""
chunk_count = 0
flush_every_n = 3
try:
for chunk in streaming_llm.stream_complete(prompt, max_tokens=MAX_OUTPUT_TOKENS):
assistant_text += chunk
chunk_count += 1
if chunk_count % flush_every_n == 0:
history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text}")
yield history, history, ""
history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text}")
except Exception as e:
history[-1] = (user_msg, f"🤖 **Bot**\n⚠️ {str(e)}")
yield history, history, ""
return
references = get_clickable_references_from_response(top_nodes)
if references:
assistant_text += "\n\n📚 **Reference(s):**\n" + "\n".join(references)
timestamp_bot = datetime.now().strftime("%H:%M:%S")
history[-1] = (user_msg, f"🤖 **Bot**\n{assistant_text.strip()}\n\n⏱️ {timestamp_bot}")
yield history, history, ""
# Fallback synchronous chat
def chat_interface_sync(message, history):
history = history or []
message = message.strip()
if not message:
raise ValueError("Please enter a valid question.")
timestamp_user = datetime.now().strftime("%H:%M:%S")
user_msg = f"🧑 **You**\n{message}\n\n⏱️ {timestamp_user}"
bot_msg = "⏳ _Bot is typing..._"
history.append((user_msg, bot_msg))
try:
time.sleep(0.5)
prompt, top_nodes, err = rag_chain_prompt_and_sources(message)
if err:
timestamp_bot = datetime.now().strftime("%H:%M:%S")
history[-1] = (user_msg, f"🤖 **Bot**\n{err}\n\n⏱️ {timestamp_bot}")
return history, history, ""
resp = llm.complete(prompt, max_tokens=MAX_OUTPUT_TOKENS).text
references = get_clickable_references_from_response(top_nodes)
if references:
resp += "\n\n📚 **Reference(s):**\n" + "\n".join(references)
timestamp_bot = datetime.now().strftime("%H:%M:%S")
bot_msg = f"🤖 **Bot**\n{resp.strip()}\n\n⏱️ {timestamp_bot}"
history[-1] = (user_msg, bot_msg)
except Exception as e:
timestamp_bot = datetime.now().strftime("%H:%M:%S")
error_msg = f"🤖 **Bot**\n⚠️ {str(e)}\n\n⏱️ {timestamp_bot}"
history[-1] = (user_msg, error_msg)
return history, history, ""
# === Gradio UI ===
def launch_gradio():
with gr.Blocks() as demo:
gr.Markdown("# 💬 Demo IOPL Multi-Website Chatbot")
gr.Markdown("Choose a website to chat with.")
with gr.Row():
collection_dropdown = gr.Dropdown(choices=AVAILABLE_COLLECTIONS, label="Select Website to chat")
load_button = gr.Button("Load Website")
collection_status = gr.Markdown("")
chatbot = gr.Chatbot()
state = gr.State([])
with gr.Row(equal_height=True):
msg = gr.Textbox(placeholder="Ask your question...", show_label=False, scale=9)
send_btn = gr.Button("🚀 Send", scale=1)
load_button.click(
fn=handle_collection_change,
inputs=collection_dropdown,
outputs=[collection_status, chatbot, state]
)
msg.submit(chat_interface_stream, inputs=[msg, state], outputs=[chatbot, state, msg])
send_btn.click(chat_interface_stream, inputs=[msg, state], outputs=[chatbot, state, msg])
with gr.Row():
clear_btn = gr.Button("🧹 Clear Chat")
clear_btn.click(fn=lambda: ([], []), outputs=[chatbot, state])
return demo
demo = launch_gradio()
demo.launch()