Chryslerx10's picture
debugging
a66797c
raw
history blame
10.4 kB
import json, re
import operator
from itertools import chain
from pathlib import Path
from dotenv import load_dotenv
from datetime import datetime
from typing import List, Dict, Any, TypedDict, Annotated, Optional
from langgraph.graph import StateGraph, END
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage, HumanMessage, AIMessage
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
_ = load_dotenv()
class ToolCallLog(TypedDict):
"""
A TypedDict representing a log entry for a tool call.
Attributes:
timestamp (str): The timestamp of when the tool call was made.
tool_call_id (str): The unique identifier for the tool call.
name (str): The name of the tool that was called.
args (Any): The arguments passed to the tool.
content (str): The content or result of the tool call.
"""
timestamp: str
tool_call_id: str
name: str
args: Any
content: str
class AgentState(TypedDict):
"""
A TypedDict representing the state of an agent.
Attributes:
messages (Annotated[List[AnyMessage], operator.add]): A list of messages
representing the conversation history. The operator.add annotation
indicates that new messages should be appended to this list.
"""
messages: Annotated[List[AnyMessage], operator.add]
class ChatAgent:
"""
A class representing an agent that processes requests and executes tools based on
language model responses.
Attributes:
model (BaseLanguageModel): The language model used for processing.
tools (Dict[str, BaseTool]): A dictionary of available tools.
checkpointer (Any): Manages and persists the agent's state.
system_prompt (str): The system instructions for the agent.
workflow (StateGraph): The compiled workflow for the agent's processing.
log_tools (bool): Whether to log tool calls.
log_path (Path): Path to save tool call logs.
"""
def __init__(
self,
model: BaseLanguageModel,
tools: List[BaseTool],
checkpointer: Any = None,
prompts: Dict[str, str] = {},
log_tools: bool = True,
log_dir: Optional[str] = "logs",
):
"""
Initialize the Agent.
Args:
model (BaseLanguageModel): The language model to use.
tools (List[BaseTool]): A list of available tools.
checkpointer (Any, optional): State persistence manager. Defaults to None.
prompts (Dict[str, str], optional): System instructions. Defaults to {}.
log_tools (bool, optional): Whether to log tool calls. Defaults to True.
log_dir (str, optional): Directory to save logs. Defaults to 'logs'.
"""
self.prompts = prompts
self.log_tools = log_tools
self.checkpointer = checkpointer
if self.log_tools:
self.log_path = Path(log_dir or "logs")
self.log_path.mkdir(exist_ok=True)
# Define the agent workflow
workflow = StateGraph(AgentState)
workflow.add_node("process", self.process_request)
workflow.add_node("execute", self.execute_tools)
workflow.add_node("summarize", self.summarize_message)
workflow.add_conditional_edges(
"process", self.has_tool_calls, {True: "execute", False: END}
)
workflow.add_edge("execute", "process")
workflow.set_entry_point("process")
self.workflow = workflow.compile(checkpointer=self.checkpointer)
self.tools = {t.name: t for t in tools}
self.model = model.bind_tools(tools)
def process_request(self, state: AgentState) -> Dict[str, List[AnyMessage]]:
"""
Process the request using the language model.
Args:
state (AgentState): The current state of the agent.
Returns:
Dict[str, List[AnyMessage]]: A dictionary containing the model's response.
"""
messages = state["messages"]
# print('process_request input', state)
if self.prompts:
messages = [SystemMessage(content=self.prompts["MEDICAL_ASSISTANT"])] + messages
response = self.model.invoke(messages)
# print('process_request output', response)
return {"messages": [response]}
def has_tool_calls(self, state: AgentState) -> bool:
"""
Check if the response contains any tool calls.
Args:
state (AgentState): The current state of the agent.
Returns:
bool: True if tool calls exist, False otherwise.
"""
response = state["messages"][-1]
return len(response.tool_calls) > 0
def execute_tools(self, state: AgentState) -> Dict[str, List[ToolMessage]]:
"""
Execute tool calls from the model's response.
Args:
state (AgentState): The current state of the agent.
Returns:
Dict[str, List[ToolMessage]]: A dictionary containing tool execution results.
"""
tool_calls = state["messages"][-1].tool_calls
results = []
for call in tool_calls:
# print(f"Executing tool: {call}")
if call["name"] not in self.tools:
print("\n....invalid tool....")
result = "invalid tool, please retry"
else:
result = self.tools[call["name"]].invoke(call["args"])
results.append(
ToolMessage(
tool_call_id=call["id"],
name=call["name"],
args=call["args"],
content=str(result),
)
)
self._save_tool_calls(results)
print("Returning to model processing!")
return {"messages": results}
def summarize_message(self, thread_id: str):
"""
Summarize the previous messages using the language model.
Args:
state (AgentState): The current state of the agent.
Returns:
Dict[str, List[AnyMessage]]: A dictionary containing the model's response.
"""
# history = list(self.workflow.get_state_history({"configurable": {"thread_id": thread_id}}))
history = self.workflow.get_state({"configurable": {"thread_id": thread_id}})
# all_messages = list(
# chain.from_iterable(
# snap.values.get("messages", []) for snap in reversed(history)
# )
# )
all_messages = history.values["messages"]
# print(all_messages)
chat_messages = [
{
"type": "Tool message",
"tool_name": m.name,
"content": m.content
}
for m in all_messages
if isinstance(m, ToolMessage)
]
summarize_prompt = ChatPromptTemplate.from_messages([
("system", self.prompts["SUMMARIZE_SYSTEM_PROMPT"]),
("human", self.prompts["SUMMARIZE_USER_PROMPT"])
])
llm_chain = summarize_prompt | self.model | StrOutputParser()
llm_output = llm_chain.invoke({"chat_messages": chat_messages})
cleaned_output = re.sub(r"```(?:json)?", "", llm_output, flags=re.IGNORECASE).strip()
return json.loads(cleaned_output)
# print(result_json)
# return result_json
def _save_tool_calls(self, tool_calls: List[ToolMessage]) -> None:
"""
Save tool calls to a JSON file with timestamp-based naming.
Args:
tool_calls (List[ToolMessage]): List of tool calls to save.
"""
if not self.log_tools:
return
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = self.log_path / f"tool_calls_{timestamp}.json"
logs: List[ToolCallLog] = []
for call in tool_calls:
log_entry = {
"tool_call_id": call.tool_call_id,
"name": call.name,
"args": call.args,
"content": call.content,
"timestamp": datetime.now().isoformat(),
}
logs.append(log_entry)
with open(filename, "w") as f:
json.dump(logs, f, indent=4)
def normalise_messages(self, raw: List[Any]) -> List[Dict[str, Any]]:
"""Turn mixed LangGraph snapshots into a uniform list of dicts."""
out: List[Dict[str, Any]] = []
for m in raw:
# ── LangChain message objects ──────────────────────────────
if isinstance(m, HumanMessage):
out.append({"role": "human", "content": m.content})
elif isinstance(m, AIMessage):
out.append({"role": "ai", "content": m.content})
elif isinstance(m, ToolMessage):
out.append(
{
"role": "tool",
"name": m.name, # e.g. "chest_xray_classifier"
"content": m.content,
}
)
elif isinstance(m, SystemMessage):
out.append({"role": "system", "content": m.content})
# ── Bare dicts coming from the client / snapshots ──────────
elif isinstance(m, dict):
role = m.get("role")
payload = m.get("content")
# content might be a list of blocks (e.g. [{"type":"text","text":"hello"}])
if isinstance(payload, list):
# keep the raw structure or flatten text blocks:
text_only = " ".join(
blk["text"] for blk in payload if blk.get("type") == "text"
)
payload = text_only if text_only else payload
out.append({"role": role, "content": payload})
# ── Fallback: stringify anything unknown ──────────────────
else:
out.append({"role": "unknown", "content": str(m)})
return out