import json, re import operator from itertools import chain from pathlib import Path from dotenv import load_dotenv from datetime import datetime from typing import List, Dict, Any, TypedDict, Annotated, Optional from langgraph.graph import StateGraph, END from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage, HumanMessage, AIMessage from langchain_core.language_models import BaseLanguageModel from langchain_core.tools import BaseTool from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser _ = load_dotenv() class ToolCallLog(TypedDict): """ A TypedDict representing a log entry for a tool call. Attributes: timestamp (str): The timestamp of when the tool call was made. tool_call_id (str): The unique identifier for the tool call. name (str): The name of the tool that was called. args (Any): The arguments passed to the tool. content (str): The content or result of the tool call. """ timestamp: str tool_call_id: str name: str args: Any content: str class AgentState(TypedDict): """ A TypedDict representing the state of an agent. Attributes: messages (Annotated[List[AnyMessage], operator.add]): A list of messages representing the conversation history. The operator.add annotation indicates that new messages should be appended to this list. """ messages: Annotated[List[AnyMessage], operator.add] class ChatAgent: """ A class representing an agent that processes requests and executes tools based on language model responses. Attributes: model (BaseLanguageModel): The language model used for processing. tools (Dict[str, BaseTool]): A dictionary of available tools. checkpointer (Any): Manages and persists the agent's state. system_prompt (str): The system instructions for the agent. workflow (StateGraph): The compiled workflow for the agent's processing. log_tools (bool): Whether to log tool calls. log_path (Path): Path to save tool call logs. """ def __init__( self, model: BaseLanguageModel, tools: List[BaseTool], checkpointer: Any = None, prompts: Dict[str, str] = {}, log_tools: bool = True, log_dir: Optional[str] = "logs", ): """ Initialize the Agent. Args: model (BaseLanguageModel): The language model to use. tools (List[BaseTool]): A list of available tools. checkpointer (Any, optional): State persistence manager. Defaults to None. prompts (Dict[str, str], optional): System instructions. Defaults to {}. log_tools (bool, optional): Whether to log tool calls. Defaults to True. log_dir (str, optional): Directory to save logs. Defaults to 'logs'. """ self.prompts = prompts self.log_tools = log_tools self.checkpointer = checkpointer if self.log_tools: self.log_path = Path(log_dir or "logs") self.log_path.mkdir(exist_ok=True) # Define the agent workflow workflow = StateGraph(AgentState) workflow.add_node("process", self.process_request) workflow.add_node("execute", self.execute_tools) workflow.add_node("summarize", self.summarize_message) workflow.add_conditional_edges( "process", self.has_tool_calls, {True: "execute", False: END} ) workflow.add_edge("execute", "process") workflow.set_entry_point("process") self.workflow = workflow.compile(checkpointer=self.checkpointer) self.tools = {t.name: t for t in tools} self.model = model.bind_tools(tools) def process_request(self, state: AgentState) -> Dict[str, List[AnyMessage]]: """ Process the request using the language model. Args: state (AgentState): The current state of the agent. Returns: Dict[str, List[AnyMessage]]: A dictionary containing the model's response. """ messages = state["messages"] # print('process_request input', state) if self.prompts: messages = [SystemMessage(content=self.prompts["MEDICAL_ASSISTANT"])] + messages response = self.model.invoke(messages) # print('process_request output', response) return {"messages": [response]} def has_tool_calls(self, state: AgentState) -> bool: """ Check if the response contains any tool calls. Args: state (AgentState): The current state of the agent. Returns: bool: True if tool calls exist, False otherwise. """ response = state["messages"][-1] return len(response.tool_calls) > 0 def execute_tools(self, state: AgentState) -> Dict[str, List[ToolMessage]]: """ Execute tool calls from the model's response. Args: state (AgentState): The current state of the agent. Returns: Dict[str, List[ToolMessage]]: A dictionary containing tool execution results. """ tool_calls = state["messages"][-1].tool_calls results = [] for call in tool_calls: # print(f"Executing tool: {call}") if call["name"] not in self.tools: print("\n....invalid tool....") result = "invalid tool, please retry" else: result = self.tools[call["name"]].invoke(call["args"]) results.append( ToolMessage( tool_call_id=call["id"], name=call["name"], args=call["args"], content=str(result), ) ) self._save_tool_calls(results) print("Returning to model processing!") return {"messages": results} def summarize_message(self, thread_id: str): """ Summarize the previous messages using the language model. Args: state (AgentState): The current state of the agent. Returns: Dict[str, List[AnyMessage]]: A dictionary containing the model's response. """ # history = list(self.workflow.get_state_history({"configurable": {"thread_id": thread_id}})) history = self.workflow.get_state({"configurable": {"thread_id": thread_id}}) # all_messages = list( # chain.from_iterable( # snap.values.get("messages", []) for snap in reversed(history) # ) # ) all_messages = history.values["messages"] # print(all_messages) chat_messages = [ { "type": "Tool message", "tool_name": m.name, "content": m.content } for m in all_messages if isinstance(m, ToolMessage) ] summarize_prompt = ChatPromptTemplate.from_messages([ ("system", self.prompts["SUMMARIZE_SYSTEM_PROMPT"]), ("human", self.prompts["SUMMARIZE_USER_PROMPT"]) ]) llm_chain = summarize_prompt | self.model | StrOutputParser() llm_output = llm_chain.invoke({"chat_messages": chat_messages}) cleaned_output = re.sub(r"```(?:json)?", "", llm_output, flags=re.IGNORECASE).strip() return json.loads(cleaned_output) # print(result_json) # return result_json def _save_tool_calls(self, tool_calls: List[ToolMessage]) -> None: """ Save tool calls to a JSON file with timestamp-based naming. Args: tool_calls (List[ToolMessage]): List of tool calls to save. """ if not self.log_tools: return timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = self.log_path / f"tool_calls_{timestamp}.json" logs: List[ToolCallLog] = [] for call in tool_calls: log_entry = { "tool_call_id": call.tool_call_id, "name": call.name, "args": call.args, "content": call.content, "timestamp": datetime.now().isoformat(), } logs.append(log_entry) with open(filename, "w") as f: json.dump(logs, f, indent=4) def normalise_messages(self, raw: List[Any]) -> List[Dict[str, Any]]: """Turn mixed LangGraph snapshots into a uniform list of dicts.""" out: List[Dict[str, Any]] = [] for m in raw: # ── LangChain message objects ────────────────────────────── if isinstance(m, HumanMessage): out.append({"role": "human", "content": m.content}) elif isinstance(m, AIMessage): out.append({"role": "ai", "content": m.content}) elif isinstance(m, ToolMessage): out.append( { "role": "tool", "name": m.name, # e.g. "chest_xray_classifier" "content": m.content, } ) elif isinstance(m, SystemMessage): out.append({"role": "system", "content": m.content}) # ── Bare dicts coming from the client / snapshots ────────── elif isinstance(m, dict): role = m.get("role") payload = m.get("content") # content might be a list of blocks (e.g. [{"type":"text","text":"hello"}]) if isinstance(payload, list): # keep the raw structure or flatten text blocks: text_only = " ".join( blk["text"] for blk in payload if blk.get("type") == "text" ) payload = text_only if text_only else payload out.append({"role": role, "content": payload}) # ── Fallback: stringify anything unknown ────────────────── else: out.append({"role": "unknown", "content": str(m)}) return out