vip11017's picture
Initial commit
abd032e
from typing import List, TypedDict
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.tools import tool
from langchain_core.runnables import RunnableLambda
from langchain_ollama.chat_models import ChatOllama
from langchain_qdrant import QdrantVectorStore
from langchain_huggingface import HuggingFaceEmbeddings
from langgraph.graph import StateGraph, END
from qdrant_client import QdrantClient
import os
from dotenv import load_dotenv
load_dotenv()
session_histories: dict[str, list] = {}
OLLAMA_MODEL = "mistral:latest"
COLLECTION_NAME = "wellness_docs"
EMBEDDING_MODEL = "intfloat/e5-large-v2"
QDRANT_URL = os.getenv('QDRANT_URL')
OLLAMA_URL = os.getenv('OLLAMA_URL')
QDRANT_PORT = 6333
llm = ChatOllama(model=OLLAMA_MODEL, temperature=0.1)
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
try:
client = QdrantClient(url=QDRANT_URL, port=QDRANT_PORT)
vector_store = QdrantVectorStore(
client=client,
collection_name=COLLECTION_NAME,
embedding=embeddings,
)
except Exception as e:
raise RuntimeError(f"Failed to connect to Qdrant: {e}")
class GraphState(TypedDict):
"""
Represents the state of a chat session, including input, output, history, memory,
response, tool results, and user role for LangGraph
"""
input: str
history: List[BaseMessage] #list of past messages
response: str
tool_results: dict
template = """
You are Auro-Chat, a helpful assistant trained on wellness products, blogs, and FAQs.
IMPORTANT: Always refer to the **actual** prior conversation shown below. Do NOT invent past questions.
In the conversation history, it is in the format of:
- Human: The User's query
- AI: What you responded to the query
Use this information to know what questions the user asked previously
Conversation History:
{history}
Contextual Knowledge:
{agent_scratchpad}
Question: {input}
Answer:
Keep your answer concise (maximum 3 sentences)
"""
@tool('Retrieve')
def retrieve_tool(query: str) -> List[dict[str, str]]:
"""
Retrieves relevant content from wellness product knowledge base, including blog posts, product descriptions, and FAQs.
Returns a list of formatted strings with content and sources.
"""
docs = vector_store.similarity_search_with_score(query, k=10)
return [
{"content": doc.page_content, "source": doc.metadata.get("source", "unknown"), 'score': score}
for doc, score in docs if score > 0.8]
all_tools = [retrieve_tool]
tool_descriptions = "\n".join(f"{tool.name}: {tool.description}" for tool in all_tools)
tool_names = ", ".join(tool.name for tool in all_tools)
def retrieve_node(state: GraphState) -> GraphState:
query = state['input']
tool_results = {}
for tool in all_tools:
try:
tool_results[tool.name] = tool.invoke({'query': query})
except Exception as e:
tool_results[tool.name] = [{'content': f"Tool {tool.name} failed: {str(e)}", "source": "system"}]
state['tool_results'] = tool_results
return state
def generate_answer(state: GraphState):
"""
This function generates an answer to the query using the llm and the context provided.
"""
query = state['input']
history = state.get('history', [])
history_text = "\n".join(
f"Human: {m.content}" if isinstance(m, HumanMessage) else f"AI: {m.content}"
for m in history
)
intermediate_steps = state.get('tool_results', {})
steps_string = "\n".join(
f"{tool_name}:\n" +
"\n---\n".join(
f"{entry.get('content', '')}\n(Source: {entry.get('source', 'unknown')})"
for entry in (tool_output if isinstance(tool_output, list) else [tool_output])
)
for tool_name, tool_output in intermediate_steps.items()
)
prompt_input = template.format(
input=query,
#tools = tool_descriptions,
#tool_names=tool_names,
agent_scratchpad=steps_string,
history=history_text
)
llm_response = llm.invoke(prompt_input)
state['response'] = llm_response.content if hasattr(llm_response, 'content') else str(llm_response)
state['history'].append(HumanMessage(content=query))
state['history'].append(AIMessage(content=state['response']))
return state
graph = StateGraph(GraphState)
#Add nodes to the graph
graph.add_node("route_tool", RunnableLambda(retrieve_node))
graph.add_node("generate_response", RunnableLambda(generate_answer))
# Define the flow of the graph
graph.set_entry_point("route_tool")
graph.add_edge("route_tool", "generate_response")
graph.add_edge("generate_response", END)
app = graph.compile()
def get_response(query: str, config) -> dict:
session_id = config['configurable']['thread_id']
history = session_histories.get(session_id, [])
input_data = {
"input": query,
"history": history
}
result = app.invoke(input_data, config=config)
session_histories[session_id] = result.get("history", [])
return result