Spaces:
Running
Running
| from typing import List, TypedDict | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage | |
| from langchain_core.tools import tool | |
| from langchain_core.runnables import RunnableLambda | |
| from langchain_ollama.chat_models import ChatOllama | |
| from langchain_qdrant import QdrantVectorStore | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langgraph.graph import StateGraph, END | |
| from qdrant_client import QdrantClient | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| session_histories: dict[str, list] = {} | |
| OLLAMA_MODEL = "mistral:latest" | |
| COLLECTION_NAME = "wellness_docs" | |
| EMBEDDING_MODEL = "intfloat/e5-large-v2" | |
| QDRANT_URL = os.getenv('QDRANT_URL') | |
| OLLAMA_URL = os.getenv('OLLAMA_URL') | |
| QDRANT_PORT = 6333 | |
| llm = ChatOllama(model=OLLAMA_MODEL, temperature=0.1) | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) | |
| try: | |
| client = QdrantClient(url=QDRANT_URL, port=QDRANT_PORT) | |
| vector_store = QdrantVectorStore( | |
| client=client, | |
| collection_name=COLLECTION_NAME, | |
| embedding=embeddings, | |
| ) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to connect to Qdrant: {e}") | |
| class GraphState(TypedDict): | |
| """ | |
| Represents the state of a chat session, including input, output, history, memory, | |
| response, tool results, and user role for LangGraph | |
| """ | |
| input: str | |
| history: List[BaseMessage] #list of past messages | |
| response: str | |
| tool_results: dict | |
| template = """ | |
| You are Auro-Chat, a helpful assistant trained on wellness products, blogs, and FAQs. | |
| IMPORTANT: Always refer to the **actual** prior conversation shown below. Do NOT invent past questions. | |
| In the conversation history, it is in the format of: | |
| - Human: The User's query | |
| - AI: What you responded to the query | |
| Use this information to know what questions the user asked previously | |
| Conversation History: | |
| {history} | |
| Contextual Knowledge: | |
| {agent_scratchpad} | |
| Question: {input} | |
| Answer: | |
| Keep your answer concise (maximum 3 sentences) | |
| """ | |
| def retrieve_tool(query: str) -> List[dict[str, str]]: | |
| """ | |
| Retrieves relevant content from wellness product knowledge base, including blog posts, product descriptions, and FAQs. | |
| Returns a list of formatted strings with content and sources. | |
| """ | |
| docs = vector_store.similarity_search_with_score(query, k=10) | |
| return [ | |
| {"content": doc.page_content, "source": doc.metadata.get("source", "unknown"), 'score': score} | |
| for doc, score in docs if score > 0.8] | |
| all_tools = [retrieve_tool] | |
| tool_descriptions = "\n".join(f"{tool.name}: {tool.description}" for tool in all_tools) | |
| tool_names = ", ".join(tool.name for tool in all_tools) | |
| def retrieve_node(state: GraphState) -> GraphState: | |
| query = state['input'] | |
| tool_results = {} | |
| for tool in all_tools: | |
| try: | |
| tool_results[tool.name] = tool.invoke({'query': query}) | |
| except Exception as e: | |
| tool_results[tool.name] = [{'content': f"Tool {tool.name} failed: {str(e)}", "source": "system"}] | |
| state['tool_results'] = tool_results | |
| return state | |
| def generate_answer(state: GraphState): | |
| """ | |
| This function generates an answer to the query using the llm and the context provided. | |
| """ | |
| query = state['input'] | |
| history = state.get('history', []) | |
| history_text = "\n".join( | |
| f"Human: {m.content}" if isinstance(m, HumanMessage) else f"AI: {m.content}" | |
| for m in history | |
| ) | |
| intermediate_steps = state.get('tool_results', {}) | |
| steps_string = "\n".join( | |
| f"{tool_name}:\n" + | |
| "\n---\n".join( | |
| f"{entry.get('content', '')}\n(Source: {entry.get('source', 'unknown')})" | |
| for entry in (tool_output if isinstance(tool_output, list) else [tool_output]) | |
| ) | |
| for tool_name, tool_output in intermediate_steps.items() | |
| ) | |
| prompt_input = template.format( | |
| input=query, | |
| #tools = tool_descriptions, | |
| #tool_names=tool_names, | |
| agent_scratchpad=steps_string, | |
| history=history_text | |
| ) | |
| llm_response = llm.invoke(prompt_input) | |
| state['response'] = llm_response.content if hasattr(llm_response, 'content') else str(llm_response) | |
| state['history'].append(HumanMessage(content=query)) | |
| state['history'].append(AIMessage(content=state['response'])) | |
| return state | |
| graph = StateGraph(GraphState) | |
| #Add nodes to the graph | |
| graph.add_node("route_tool", RunnableLambda(retrieve_node)) | |
| graph.add_node("generate_response", RunnableLambda(generate_answer)) | |
| # Define the flow of the graph | |
| graph.set_entry_point("route_tool") | |
| graph.add_edge("route_tool", "generate_response") | |
| graph.add_edge("generate_response", END) | |
| app = graph.compile() | |
| def get_response(query: str, config) -> dict: | |
| session_id = config['configurable']['thread_id'] | |
| history = session_histories.get(session_id, []) | |
| input_data = { | |
| "input": query, | |
| "history": history | |
| } | |
| result = app.invoke(input_data, config=config) | |
| session_histories[session_id] = result.get("history", []) | |
| return result |