Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
# In[1]: | |
#!/usr/bin/env python | |
# coding: utf-8 | |
import os | |
import json | |
import requests | |
import gradio as gr | |
from typing import Literal, List, Dict, Any | |
from pydantic import BaseModel, Field | |
from dotenv import load_dotenv | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain.schema import Document | |
from langgraph.graph import END, StateGraph | |
from typing_extensions import TypedDict | |
# Load environment variables | |
load_dotenv() | |
# Configuration | |
BASE_URL = "https://api.llama.com/v1" | |
LLAMA_API_KEY = os.environ.get('LLAMA_API_KEY') | |
# Initialize global variables | |
vectorstore = None | |
retriever = None | |
web_search_tool = None | |
app = None | |
class RouteQuery(BaseModel): | |
"""Route a user query to the most relevant datasource.""" | |
datasource: Literal["vectorstore", "web_search"] = Field( | |
..., | |
description="Given a user question choose to route it to web search or a vectorstore.", | |
) | |
class GraphState(TypedDict): | |
"""Represents the state of our graph.""" | |
question: str | |
generation: str | |
web_search: str | |
documents: List[str] | |
def initialize_system(): | |
"""Initialize the RAG system with vectorstore and workflow.""" | |
global vectorstore, retriever, web_search_tool, app | |
try: | |
# Read configuration | |
with open('wragby.json', 'r') as file: | |
data = json.load(file) | |
urls = data["urls"] | |
# Build Index | |
docs = [WebBaseLoader(url).load() for url in urls] | |
docs_list = [item for sublist in docs for item in sublist] | |
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
chunk_size=500, chunk_overlap=0 | |
) | |
doc_splits = text_splitter.split_documents(docs_list) | |
vectorstore = Chroma.from_documents( | |
documents=doc_splits, | |
collection_name="rag-chroma", | |
embedding=HuggingFaceEmbeddings(), | |
) | |
retriever = vectorstore.as_retriever() | |
# Initialize web search | |
web_search_tool = TavilySearchResults(k=3) | |
# Build workflow | |
app = build_workflow() | |
return "β System initialized successfully!" | |
except Exception as e: | |
return f"β Error initializing system: {str(e)}" | |
def chat_completion(messages, model="Llama-4-Scout-17B-16E-Instruct-FP8", max_tokens=1024): | |
"""Make API call to Llama.""" | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {LLAMA_API_KEY}", | |
} | |
payload = { | |
"messages": messages, | |
"model": model, | |
"max_tokens": max_tokens, | |
"stream": False, | |
} | |
response = requests.post("https://api.llama.com/v1/chat/completions", headers=headers, json=payload) | |
return response | |
def route_query(question: str) -> RouteQuery: | |
"""Route a user question using Llama API with structured output.""" | |
system_message = """You are an expert at routing a user question to a vectorstore or web search. | |
The vectorstore contains documents related to the business Wragby Solutions, their product information, and customer sales. | |
Use the vectorstore for questions on these topics. Otherwise, use web-search. | |
You must respond with a JSON object in this exact format: | |
{"datasource": "vectorstore"} or {"datasource": "web_search"} | |
Only respond with the JSON object, no additional text.""" | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": question} | |
] | |
try: | |
response = chat_completion(messages, max_tokens=50) | |
content = response.json()['completion_message']['content']['text'].strip() | |
route_data = json.loads(content) | |
return RouteQuery(**route_data) | |
except Exception as e: | |
print(f"Error parsing response: {e}") | |
return RouteQuery(datasource="web_search") | |
def format_docs(docs): | |
"""Format a list of documents into a single string.""" | |
if not docs: | |
return "" | |
formatted_docs = [] | |
for doc in docs: | |
try: | |
if hasattr(doc, 'page_content'): | |
formatted_docs.append(doc.page_content) | |
elif isinstance(doc, dict) and 'content' in doc: | |
formatted_docs.append(doc['content']) | |
elif isinstance(doc, dict) and 'page_content' in doc: | |
formatted_docs.append(doc['page_content']) | |
elif isinstance(doc, str): | |
formatted_docs.append(doc) | |
else: | |
formatted_docs.append(str(doc)) | |
except Exception as e: | |
print(f"Error processing document: {e}") | |
formatted_docs.append(str(doc)) | |
return "\n\n".join(formatted_docs) | |
def rag_generate_answer(question: str, docs: list) -> str: | |
"""Generate an answer using RAG.""" | |
system_message = """You are an assistant for question-answering tasks. | |
Use the following pieces of retrieved context to answer the question. | |
If you don't know the answer, just say that you don't know. | |
Use three sentences maximum and keep the answer concise.""" | |
context = format_docs(docs) | |
user_message = f"""Context: {context} | |
Question: {question} | |
Answer:""" | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
] | |
try: | |
response = chat_completion(messages, max_tokens=512) | |
answer = response.json()['completion_message']['content']['text'].strip() | |
return answer | |
except Exception as e: | |
print(f"Error generating RAG answer: {e}") | |
return "I apologize, but I encountered an error while generating an answer." | |
def grade_answer_quality(question: str, generation: str) -> dict: | |
"""Grade whether an LLM generation addresses/resolves the user question.""" | |
system_message = """You are a grader assessing whether an answer addresses / resolves a question. | |
Give a binary score 'yes' or 'no'. 'Yes' means that the answer resolves the question. | |
You must respond with exactly one word: | |
- yes (if the answer addresses and resolves the question) | |
- no (if the answer does not address or resolve the question) | |
Only respond with 'yes' or 'no', no additional text or explanation.""" | |
user_message = f"User question: \n\n {question} \n\n LLM generation: {generation}" | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
] | |
try: | |
response = chat_completion(messages, max_tokens=10) | |
content = response.json()['completion_message']['content']['text'].strip().lower() | |
if "yes" in content: | |
score = "yes" | |
elif "no" in content: | |
score = "no" | |
else: | |
score = "no" | |
return {"score": score} | |
except Exception as e: | |
print(f"Error calling Llama API for answer grading: {e}") | |
return {"score": "no"} | |
def grade_hallucinations(documents: list, generation: str) -> dict: | |
"""Grade whether an LLM generation is grounded in the provided documents.""" | |
system_message = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. | |
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts. | |
You must respond with exactly one word: | |
- yes (if the generation is grounded in the facts) | |
- no (if the generation contains hallucinations or unsupported claims) | |
Only respond with 'yes' or 'no', no additional text or explanation.""" | |
if isinstance(documents, list): | |
if documents and hasattr(documents[0], 'page_content'): | |
docs_text = format_docs(documents) | |
else: | |
docs_text = "\n\n".join(documents) | |
else: | |
docs_text = str(documents) | |
user_message = f"Set of facts: \n\n {docs_text} \n\n LLM generation: {generation}" | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
] | |
try: | |
response = chat_completion(messages, max_tokens=10) | |
content = response.json()['completion_message']['content']['text'].strip().lower() | |
if "yes" in content: | |
score = "yes" | |
elif "no" in content: | |
score = "no" | |
else: | |
score = "no" | |
return {"score": score} | |
except Exception as e: | |
print(f"Error calling Llama API for hallucination grading: {e}") | |
return {"score": "no"} | |
def grade_document_relevance(question: str, document: str) -> dict: | |
"""Grade the relevance of a retrieved document to a user question.""" | |
system_message = """You are a grader assessing relevance of a retrieved document to a user question. | |
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. | |
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. | |
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. | |
You must respond with exactly one word: | |
- yes (if document is relevant) | |
- no (if document is not relevant) | |
Only respond with 'yes' or 'no', no additional text or explanation.""" | |
user_message = f"Retrieved document: \n\n {document} \n\n User question: {question}" | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
] | |
try: | |
response = chat_completion(messages) | |
content = response.json()['completion_message']['content']['text'].strip().lower() | |
if "yes" in content: | |
score = "yes" | |
elif "no" in content: | |
score = "no" | |
else: | |
score = "no" | |
return {"score": score} | |
except Exception as e: | |
print(f"Error calling Llama API for document grading: {e}") | |
return {"score": "no"} | |
# Workflow Nodes | |
def retrieve(state): | |
"""Retrieve documents from vectorstore""" | |
print("---RETRIEVE---") | |
question = state["question"] | |
documents = retriever.invoke(question) | |
return {"documents": documents, "question": question} | |
def generate(state): | |
"""Generate answer using RAG on retrieved documents""" | |
print("---GENERATE---") | |
question = state["question"] | |
documents = state["documents"] | |
generation = rag_generate_answer(question, documents) | |
return {"documents": documents, "question": question, "generation": generation} | |
def grade_documents(state): | |
"""Determines whether the retrieved documents are relevant to the question""" | |
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---") | |
question = state["question"] | |
documents = state["documents"] | |
filtered_docs = [] | |
web_search = "No" | |
for d in documents: | |
score = grade_document_relevance(question, d.page_content) | |
grade = score["score"] | |
if grade.lower() == "yes": | |
print("---GRADE: DOCUMENT RELEVANT---") | |
filtered_docs.append(d) | |
else: | |
print("---GRADE: DOCUMENT NOT RELEVANT---") | |
web_search = "Yes" | |
continue | |
return {"documents": filtered_docs, "question": question, "web_search": web_search} | |
def web_search(state): | |
"""Web search based on the question""" | |
print("---WEB SEARCH---", state) | |
question = state["question"] | |
documents = state.get("documents") | |
docs = web_search_tool.invoke({"query": question}) | |
web_results = "\n".join([d["content"] for d in docs]) | |
web_results = Document(page_content=web_results) | |
if documents is not None: | |
documents.append(web_results) | |
else: | |
documents = [web_results] | |
return {"documents": documents, "question": question} | |
def route_question(state): | |
"""Route question to web search or RAG.""" | |
print("---ROUTE QUESTION---") | |
question = state["question"] | |
source = route_query(question) | |
if source.datasource == 'web_search': | |
print("---ROUTE QUESTION TO WEB SEARCH---") | |
return "websearch" | |
elif source.datasource == 'vectorstore': | |
print("---ROUTE QUESTION TO RAG---") | |
return "vectorstore" | |
def decide_to_generate(state): | |
"""Determines whether to generate an answer, or add web search""" | |
print("---ASSESS GRADED DOCUMENTS---") | |
web_search = state["web_search"] | |
if web_search == "Yes": | |
print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---") | |
return "websearch" | |
else: | |
print("---DECISION: GENERATE---") | |
return "generate" | |
def grade_generation_v_documents_and_question(state): | |
"""Determines whether the generation is grounded in the document and answers question.""" | |
print("---CHECK HALLUCINATIONS---") | |
question = state["question"] | |
documents = state["documents"] | |
generation = state["generation"] | |
score = grade_hallucinations(documents, generation) | |
grade = score["score"] | |
if grade == "yes": | |
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") | |
print("---GRADE GENERATION vs QUESTION---") | |
score = grade_answer_quality(question, generation) | |
grade = score["score"] | |
if grade == "yes": | |
print("---DECISION: GENERATION ADDRESSES QUESTION---") | |
return "useful" | |
else: | |
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---") | |
return "not useful" | |
else: | |
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---") | |
return "not supported" | |
def build_workflow(): | |
"""Build the RAG workflow graph.""" | |
workflow = StateGraph(GraphState) | |
# Define the nodes | |
workflow.add_node("websearch", web_search) | |
workflow.add_node("retrieve", retrieve) | |
workflow.add_node("grade_documents", grade_documents) | |
workflow.add_node("generate", generate) | |
# Build graph | |
workflow.set_conditional_entry_point( | |
route_question, | |
{ | |
"websearch": "websearch", | |
"vectorstore": "retrieve", | |
}, | |
) | |
workflow.add_edge("retrieve", "grade_documents") | |
workflow.add_conditional_edges( | |
"grade_documents", | |
decide_to_generate, | |
{ | |
"websearch": "websearch", | |
"generate": "generate", | |
}, | |
) | |
workflow.add_edge("websearch", "generate") | |
workflow.add_conditional_edges( | |
"generate", | |
grade_generation_v_documents_and_question, | |
{ | |
"not supported": "generate", | |
"useful": END, | |
"not useful": "websearch", | |
}, | |
) | |
return workflow.compile().with_config({"run_name": "Wragby Solutions Assistant"}) | |
def process_question(question: str, history: List[List[str]]) -> tuple: | |
"""Process a question through the RAG system and return the answer with sources.""" | |
if not question.strip(): | |
return history, "Please enter a question." | |
if app is None: | |
return history, "β System not initialized. Please click 'Initialize System' first." | |
try: | |
# Process through the workflow | |
inputs = {"question": question} | |
final_state = None | |
for output in app.stream(inputs): | |
for key, value in output.items(): | |
print(f"Finished running: {key}") | |
final_state = value | |
if final_state and "generation" in final_state: | |
answer = final_state["generation"] | |
# Get source information | |
sources = [] | |
if "documents" in final_state and final_state["documents"]: | |
for i, doc in enumerate(final_state["documents"][:3]): # Show top 3 sources | |
if hasattr(doc, 'metadata') and 'source' in doc.metadata: | |
sources.append(f"Source {i+1}: {doc.metadata['source']}") | |
else: | |
sources.append(f"Source {i+1}: Retrieved document") | |
# Format response with sources | |
if sources: | |
full_response = f"{answer}\n\n**Sources:**\n" + "\n".join(sources) | |
else: | |
full_response = answer | |
# Update chat history | |
history.append([question, full_response]) | |
return history, "" | |
else: | |
history.append([question, "I apologize, but I couldn't generate an answer for your question."]) | |
return history, "" | |
except Exception as e: | |
error_msg = f"β Error processing question: {str(e)}" | |
history.append([question, error_msg]) | |
return history, "" | |
def clear_chat(): | |
"""Clear the chat history.""" | |
return [], "" | |
# Create Gradio Interface | |
def create_gradio_app(): | |
"""Create and configure the Gradio interface.""" | |
# Custom CSS for better styling | |
css = """ | |
.gradio-container { | |
max-width: 1200px !important; | |
margin: auto !important; | |
} | |
.chat-container { | |
height: 500px !important; | |
} | |
.title { | |
text-align: center; | |
color: #2D5AA0; | |
margin-bottom: 20px; | |
} | |
.description { | |
text-align: center; | |
color: #666; | |
margin-bottom: 30px; | |
} | |
""" | |
with gr.Blocks(css=css, title="Wragby Solutions Q&A Assistant") as demo: | |
gr.HTML(""" | |
<div class="title"> | |
<h1>π€ Wragby Solutions Q&A Assistant</h1> | |
</div> | |
<div class="description"> | |
<p>Ask questions about Wragby Solutions products, services, and business information. | |
The system will search through company documents and the web to provide accurate answers.</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
# Chat interface | |
chatbot = gr.Chatbot( | |
label="Chat History", | |
height=500, | |
show_label=True, | |
container=True, | |
elem_classes=["chat-container"] | |
) | |
with gr.Row(): | |
question_input = gr.Textbox( | |
placeholder="Ask a question about Wragby Solutions...", | |
label="Your Question", | |
lines=2, | |
scale=4 | |
) | |
submit_btn = gr.Button("Submit", variant="primary", scale=1) | |
with gr.Row(): | |
clear_btn = gr.Button("Clear Chat", variant="secondary") | |
with gr.Column(scale=1): | |
# System controls | |
gr.HTML("<h3>System Controls</h3>") | |
init_btn = gr.Button("Initialize System", variant="primary") | |
init_status = gr.Textbox( | |
label="System Status", | |
value="Click 'Initialize System' to start", | |
interactive=False | |
) | |
gr.HTML(""" | |
<div style="margin-top: 20px; padding: 15px; background-color: #f0f8ff; border-radius: 5px;"> | |
<h4>π‘ Sample Questions:</h4> | |
<ul> | |
<li>What are the types of solutions offered by Wbizmanager?</li> | |
<li>How can SMBs use Wbizmanager?</li> | |
<li>What SAP solutions are available from Wragby?</li> | |
<li>Tell me about Wragby Solutions services</li> | |
</ul> | |
</div> | |
""") | |
# Event handlers | |
init_btn.click( | |
fn=initialize_system, | |
outputs=[init_status] | |
) | |
submit_btn.click( | |
fn=process_question, | |
inputs=[question_input, chatbot], | |
outputs=[chatbot, question_input] | |
) | |
question_input.submit( | |
fn=process_question, | |
inputs=[question_input, chatbot], | |
outputs=[chatbot, question_input] | |
) | |
clear_btn.click( | |
fn=clear_chat, | |
outputs=[chatbot, question_input] | |
) | |
return demo | |
# In[2]: | |
# Create and launch the Gradio app | |
demo = create_gradio_app() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, # Set to True if you want to create a public link | |
debug=True | |
) | |