import re import requests import textwrap from dataclasses import dataclass from typing import List, Optional @dataclass class CoTStep: """Data class representing a single CoT step""" number: int content: str @dataclass class CoTResponse: """Data class representing a complete CoT response""" question: str steps: List[CoTStep] answer: Optional[str] = None @dataclass class VisualizationConfig: """Configuration for CoT visualization""" max_chars_per_line: int = 40 max_lines: int = 4 truncation_suffix: str = "..." class AnthropicAPI: """Class to handle interactions with the Anthropic API""" def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"): self.api_key = api_key self.model = model self.base_url = "https://api.anthropic.com/v1/messages" self.headers = { "x-api-key": api_key, "anthropic-version": "2023-06-01", "content-type": "application/json" } def generate_response(self, prompt: str, max_tokens: int = 1024, prompt_format: str = None) -> str: """Generate a response using the Anthropic API""" formatted_prompt = self._format_prompt(prompt, prompt_format) if prompt_format else prompt data = { "model": self.model, "messages": [{"role": "user", "content": formatted_prompt}], "max_tokens": max_tokens } try: response = requests.post(self.base_url, headers=self.headers, json=data) response.raise_for_status() return response.json()["content"][0]["text"] except Exception as e: raise Exception(f"API call failed: {str(e)}") def _format_prompt(self, question: str, prompt_format: str = None) -> str: """Format the prompt using custom format if provided""" if prompt_format: return prompt_format.format(question=question) # Default format if none provided return f"""Please answer the question using the following format, with each step clearly marked: Question: {question} Let's solve this step by step: [First step of reasoning] [Second step of reasoning] [Third step of reasoning] ... (add more steps as needed) [Final answer] Note: 1. Each step must be wrapped in XML tags 2. Each step must have a number attribute 3. The final answer must be wrapped in tags """ def wrap_text(text: str, config: VisualizationConfig) -> str: """Wrap text to fit within box constraints""" text = text.replace('\n', ' ').replace('"', "'") wrapped_lines = textwrap.wrap(text, width=config.max_chars_per_line) if len(wrapped_lines) > config.max_lines: # Option 1: Simply truncate and add ellipsis to the last line wrapped_lines = wrapped_lines[:config.max_lines] wrapped_lines[-1] = wrapped_lines[-1][:config.max_chars_per_line-3] + "..." # Option 2 (alternative): Include part of the next line to show continuity # original_next_line = wrapped_lines[config.max_lines] if len(wrapped_lines) > config.max_lines else "" # wrapped_lines = wrapped_lines[:config.max_lines-1] # wrapped_lines.append(original_next_line[:config.max_chars_per_line-3] + "...") return "
".join(wrapped_lines) def parse_cot_response(response_text: str, question: str) -> CoTResponse: """ Parse CoT response text to extract steps and final answer. Args: response_text: The raw response from the API question: The original question Returns: CoTResponse object containing question, steps, and answer """ # Extract all steps step_pattern = r'\s*(.*?)\s*' steps = [] for match in re.finditer(step_pattern, response_text, re.DOTALL): number = int(match.group(1)) content = match.group(2).strip() steps.append(CoTStep(number=number, content=content)) # Extract answer answer_pattern = r'\s*(.*?)\s*' answer_match = re.search(answer_pattern, response_text, re.DOTALL) answer = answer_match.group(1).strip() if answer_match else None # Sort steps by number steps.sort(key=lambda x: x.number) return CoTResponse(question=question, steps=steps, answer=answer) def create_mermaid_diagram(cot_response: CoTResponse, config: VisualizationConfig) -> str: """ Convert CoT steps to Mermaid diagram with improved text wrapping. Args: cot_response: CoTResponse object containing the reasoning steps config: VisualizationConfig for text formatting Returns: Mermaid diagram markup as a string """ diagram = ['
', 'graph TD'] # Add question node question_content = wrap_text(cot_response.question, config) diagram.append(f' Q["{question_content}"]') # Add steps with wrapped text and connect them if cot_response.steps: # Connect question to first step diagram.append(f' Q --> S{cot_response.steps[0].number}') # Add all steps for i, step in enumerate(cot_response.steps): content = wrap_text(step.content, config) node_id = f'S{step.number}' diagram.append(f' {node_id}["{content}"]') # Connect steps sequentially if i < len(cot_response.steps) - 1: next_id = f'S{cot_response.steps[i + 1].number}' diagram.append(f' {node_id} --> {next_id}') # Add final answer node if cot_response.answer: answer = wrap_text(cot_response.answer, config) diagram.append(f' A["{answer}"]') if cot_response.steps: diagram.append(f' S{cot_response.steps[-1].number} --> A') else: diagram.append(' Q --> A') # Add styles for better visualization diagram.extend([ ' classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;', ' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;', ' classDef answer fill:#d4edda,stroke:#28a745,stroke-width:2px;', ' class Q question;', ' class A answer;', ' linkStyle default stroke:#666,stroke-width:2px;' ]) diagram.append('
') return '\n'.join(diagram)