import re
from dataclasses import dataclass
from typing import List, Optional
from cot_reasoning import VisualizationConfig, wrap_text
@dataclass
class SelfRefineStep:
"""Data class representing a single step in self-refine reasoning"""
number: int
content: str
is_revised: bool = False
revision_of: Optional[int] = None
@dataclass
class SelfRefineResponse:
"""Data class representing a complete self-refine response"""
question: str
steps: List[SelfRefineStep]
answer: Optional[str] = None
revision_check: Optional[str] = None
revised_answer: Optional[str] = None
def parse_selfrefine_response(response_text: str, question: str) -> SelfRefineResponse:
"""
Parse self-refine response text to extract steps, answers, and revisions.
Args:
response_text: The raw response from the API
question: The original question
Returns:
SelfRefineResponse object containing all components
"""
# Extract initial steps
step_pattern = r'\s*(.*?)\s*'
steps = []
for match in re.finditer(step_pattern, response_text, re.DOTALL):
number = int(match.group(1))
content = match.group(2).strip()
steps.append(SelfRefineStep(number=number, content=content))
# Extract initial answer
answer_pattern = r'\s*(.*?)\s*'
answer_match = re.search(answer_pattern, response_text, re.DOTALL)
answer = answer_match.group(1).strip() if answer_match else None
# Extract revision check
check_pattern = r'\s*(.*?)\s*'
check_match = re.search(check_pattern, response_text, re.DOTALL)
revision_check = check_match.group(1).strip() if check_match else None
# Extract revised steps
revised_step_pattern = r'\s*(.*?)\s*'
for match in re.finditer(revised_step_pattern, response_text, re.DOTALL):
number = int(match.group(1))
revises = int(match.group(2))
content = match.group(3).strip()
steps.append(SelfRefineStep(
number=number,
content=content,
is_revised=True,
revision_of=revises
))
# Extract revised answer
revised_answer_pattern = r'\s*(.*?)\s*'
revised_answer_match = re.search(revised_answer_pattern, response_text, re.DOTALL)
revised_answer = revised_answer_match.group(1).strip() if revised_answer_match else None
return SelfRefineResponse(
question=question,
steps=steps,
answer=answer,
revision_check=revision_check,
revised_answer=revised_answer
)
def create_mermaid_diagram(sr_response: SelfRefineResponse, config: VisualizationConfig) -> str:
"""
Create a Mermaid diagram for self-refine reasoning.
Args:
sr_response: SelfRefineResponse object containing the reasoning steps
config: VisualizationConfig for text formatting
Returns:
Mermaid diagram markup as a string
"""
diagram = ['
', 'graph TD']
# Add question node
question_content = wrap_text(sr_response.question, config)
diagram.append(f' Q["{question_content}"]')
# Track original and revised steps
original_steps = [s for s in sr_response.steps if not s.is_revised]
revised_steps = [s for s in sr_response.steps if s.is_revised]
# Add original steps and connect them
prev_node = 'Q'
for step in original_steps:
node_id = f'S{step.number}'
content = wrap_text(step.content, config)
diagram.append(f' {node_id}["{content}"]')
diagram.append(f' {prev_node} --> {node_id}')
prev_node = node_id
# Add initial answer if present
if sr_response.answer:
answer_content = wrap_text(sr_response.answer, config)
diagram.append(f' A["{answer_content}"]')
diagram.append(f' {prev_node} --> A')
prev_node = 'A'
# Add revision check if present
if sr_response.revision_check:
check_content = wrap_text(sr_response.revision_check, config)
diagram.append(f' RC["{check_content}"]')
diagram.append(f' {prev_node} --> RC')
# Add revised steps if any
if revised_steps:
# Process each revision step
for i, step in enumerate(revised_steps):
rev_node_id = f'R{step.number}'
content = wrap_text(step.content, config)
diagram.append(f' {rev_node_id}["{content}"]')
# Connect from the revision check to problematic step, then to revision
if step.revision_of:
orig_node = f'S{step.revision_of}'
# Add connection from revision check to problematic step
diagram.append(f' RC --> {orig_node}')
# Add connection from problematic step to its revision
diagram.append(f' {orig_node} --> {rev_node_id}')
# Connect subsequent revised steps
if i < len(revised_steps) - 1:
next_node = f'R{revised_steps[i + 1].number}'
diagram.append(f' {rev_node_id} --> {next_node}')
# Add revised answer if present
if sr_response.revised_answer:
revised_content = wrap_text(sr_response.revised_answer, config)
diagram.append(f' RA["{revised_content}"]')
last_node = f'R{revised_steps[-1].number}' if revised_steps else 'RC'
diagram.append(f' {last_node} --> RA')
# Add styles
diagram.extend([
' classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;',
' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;',
' classDef answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
' classDef revision fill:#fff3cd,stroke:#ffc107,stroke-width:2px;',
' class Q question;',
' class A,RA answer;',
' class RC revision;'
])
# Style revision nodes
for step in revised_steps:
diagram.append(f' class R{step.number} revision;')
diagram.append('
')
return '\n'.join(diagram)