ReasonGraph / cot_reasoning.py
ZongqianLi's picture
Upload 16 files
7eda955 verified
import re
import requests
import textwrap
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class CoTStep:
"""Data class representing a single CoT step"""
number: int
content: str
@dataclass
class CoTResponse:
"""Data class representing a complete CoT response"""
question: str
steps: List[CoTStep]
answer: Optional[str] = None
@dataclass
class VisualizationConfig:
"""Configuration for CoT visualization"""
max_chars_per_line: int = 40
max_lines: int = 4
truncation_suffix: str = "..."
class AnthropicAPI:
"""Class to handle interactions with the Anthropic API"""
def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"):
self.api_key = api_key
self.model = model
self.base_url = "https://api.anthropic.com/v1/messages"
self.headers = {
"x-api-key": api_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json"
}
def generate_response(self, prompt: str, max_tokens: int = 1024, prompt_format: str = None) -> str:
"""Generate a response using the Anthropic API"""
formatted_prompt = self._format_prompt(prompt, prompt_format) if prompt_format else prompt
data = {
"model": self.model,
"messages": [{"role": "user", "content": formatted_prompt}],
"max_tokens": max_tokens
}
try:
response = requests.post(self.base_url, headers=self.headers, json=data)
response.raise_for_status()
return response.json()["content"][0]["text"]
except Exception as e:
raise Exception(f"API call failed: {str(e)}")
def _format_prompt(self, question: str, prompt_format: str = None) -> str:
"""Format the prompt using custom format if provided"""
if prompt_format:
return prompt_format.format(question=question)
# Default format if none provided
return f"""Please answer the question using the following format, with each step clearly marked:
Question: {question}
Let's solve this step by step:
<step number="1">
[First step of reasoning]
</step>
<step number="2">
[Second step of reasoning]
</step>
<step number="3">
[Third step of reasoning]
</step>
... (add more steps as needed)
<answer>
[Final answer]
</answer>
Note:
1. Each step must be wrapped in XML tags <step>
2. Each step must have a number attribute
3. The final answer must be wrapped in <answer> tags
"""
def wrap_text(text: str, config: VisualizationConfig) -> str:
"""Wrap text to fit within box constraints"""
text = text.replace('\n', ' ').replace('"', "'")
wrapped_lines = textwrap.wrap(text, width=config.max_chars_per_line)
if len(wrapped_lines) > config.max_lines:
# Option 1: Simply truncate and add ellipsis to the last line
wrapped_lines = wrapped_lines[:config.max_lines]
wrapped_lines[-1] = wrapped_lines[-1][:config.max_chars_per_line-3] + "..."
# Option 2 (alternative): Include part of the next line to show continuity
# original_next_line = wrapped_lines[config.max_lines] if len(wrapped_lines) > config.max_lines else ""
# wrapped_lines = wrapped_lines[:config.max_lines-1]
# wrapped_lines.append(original_next_line[:config.max_chars_per_line-3] + "...")
return "<br>".join(wrapped_lines)
def parse_cot_response(response_text: str, question: str) -> CoTResponse:
"""
Parse CoT response text to extract steps and final answer.
Args:
response_text: The raw response from the API
question: The original question
Returns:
CoTResponse object containing question, steps, and answer
"""
# Extract all steps
step_pattern = r'<step number="(\d+)">\s*(.*?)\s*</step>'
steps = []
for match in re.finditer(step_pattern, response_text, re.DOTALL):
number = int(match.group(1))
content = match.group(2).strip()
steps.append(CoTStep(number=number, content=content))
# Extract answer
answer_pattern = r'<answer>\s*(.*?)\s*</answer>'
answer_match = re.search(answer_pattern, response_text, re.DOTALL)
answer = answer_match.group(1).strip() if answer_match else None
# Sort steps by number
steps.sort(key=lambda x: x.number)
return CoTResponse(question=question, steps=steps, answer=answer)
def create_mermaid_diagram(cot_response: CoTResponse, config: VisualizationConfig) -> str:
"""
Convert CoT steps to Mermaid diagram with improved text wrapping.
Args:
cot_response: CoTResponse object containing the reasoning steps
config: VisualizationConfig for text formatting
Returns:
Mermaid diagram markup as a string
"""
diagram = ['<div class="mermaid">', 'graph TD']
# Add question node
question_content = wrap_text(cot_response.question, config)
diagram.append(f' Q["{question_content}"]')
# Add steps with wrapped text and connect them
if cot_response.steps:
# Connect question to first step
diagram.append(f' Q --> S{cot_response.steps[0].number}')
# Add all steps
for i, step in enumerate(cot_response.steps):
content = wrap_text(step.content, config)
node_id = f'S{step.number}'
diagram.append(f' {node_id}["{content}"]')
# Connect steps sequentially
if i < len(cot_response.steps) - 1:
next_id = f'S{cot_response.steps[i + 1].number}'
diagram.append(f' {node_id} --> {next_id}')
# Add final answer node
if cot_response.answer:
answer = wrap_text(cot_response.answer, config)
diagram.append(f' A["{answer}"]')
if cot_response.steps:
diagram.append(f' S{cot_response.steps[-1].number} --> A')
else:
diagram.append(' Q --> A')
# Add styles for better visualization
diagram.extend([
' classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;',
' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;',
' classDef answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
' class Q question;',
' class A answer;',
' linkStyle default stroke:#666,stroke-width:2px;'
])
diagram.append('</div>')
return '\n'.join(diagram)