Spaces:
Running
Running
import re | |
from dataclasses import dataclass | |
from typing import List, Optional | |
from cot_reasoning import VisualizationConfig, wrap_text | |
class SelfRefineStep: | |
"""Data class representing a single step in self-refine reasoning""" | |
number: int | |
content: str | |
is_revised: bool = False | |
revision_of: Optional[int] = None | |
class SelfRefineResponse: | |
"""Data class representing a complete self-refine response""" | |
question: str | |
steps: List[SelfRefineStep] | |
answer: Optional[str] = None | |
revision_check: Optional[str] = None | |
revised_answer: Optional[str] = None | |
def parse_selfrefine_response(response_text: str, question: str) -> SelfRefineResponse: | |
""" | |
Parse self-refine response text to extract steps, answers, and revisions. | |
Args: | |
response_text: The raw response from the API | |
question: The original question | |
Returns: | |
SelfRefineResponse object containing all components | |
""" | |
# Extract initial steps | |
step_pattern = r'<step number="(\d+)">\s*(.*?)\s*</step>' | |
steps = [] | |
for match in re.finditer(step_pattern, response_text, re.DOTALL): | |
number = int(match.group(1)) | |
content = match.group(2).strip() | |
steps.append(SelfRefineStep(number=number, content=content)) | |
# Extract initial answer | |
answer_pattern = r'<answer>\s*(.*?)\s*</answer>' | |
answer_match = re.search(answer_pattern, response_text, re.DOTALL) | |
answer = answer_match.group(1).strip() if answer_match else None | |
# Extract revision check | |
check_pattern = r'<revision_check>\s*(.*?)\s*</revision_check>' | |
check_match = re.search(check_pattern, response_text, re.DOTALL) | |
revision_check = check_match.group(1).strip() if check_match else None | |
# Extract revised steps | |
revised_step_pattern = r'<revised_step number="(\d+)" revises="(\d+)">\s*(.*?)\s*</revised_step>' | |
for match in re.finditer(revised_step_pattern, response_text, re.DOTALL): | |
number = int(match.group(1)) | |
revises = int(match.group(2)) | |
content = match.group(3).strip() | |
steps.append(SelfRefineStep( | |
number=number, | |
content=content, | |
is_revised=True, | |
revision_of=revises | |
)) | |
# Extract revised answer | |
revised_answer_pattern = r'<revised_answer>\s*(.*?)\s*</revised_answer>' | |
revised_answer_match = re.search(revised_answer_pattern, response_text, re.DOTALL) | |
revised_answer = revised_answer_match.group(1).strip() if revised_answer_match else None | |
return SelfRefineResponse( | |
question=question, | |
steps=steps, | |
answer=answer, | |
revision_check=revision_check, | |
revised_answer=revised_answer | |
) | |
def create_mermaid_diagram(sr_response: SelfRefineResponse, config: VisualizationConfig) -> str: | |
""" | |
Create a Mermaid diagram for self-refine reasoning. | |
Args: | |
sr_response: SelfRefineResponse object containing the reasoning steps | |
config: VisualizationConfig for text formatting | |
Returns: | |
Mermaid diagram markup as a string | |
""" | |
diagram = ['<div class="mermaid">', 'graph TD'] | |
# Add question node | |
question_content = wrap_text(sr_response.question, config) | |
diagram.append(f' Q["{question_content}"]') | |
# Track original and revised steps | |
original_steps = [s for s in sr_response.steps if not s.is_revised] | |
revised_steps = [s for s in sr_response.steps if s.is_revised] | |
# Add original steps and connect them | |
prev_node = 'Q' | |
for step in original_steps: | |
node_id = f'S{step.number}' | |
content = wrap_text(step.content, config) | |
diagram.append(f' {node_id}["{content}"]') | |
diagram.append(f' {prev_node} --> {node_id}') | |
prev_node = node_id | |
# Add initial answer if present | |
if sr_response.answer: | |
answer_content = wrap_text(sr_response.answer, config) | |
diagram.append(f' A["{answer_content}"]') | |
diagram.append(f' {prev_node} --> A') | |
prev_node = 'A' | |
# Add revision check if present | |
if sr_response.revision_check: | |
check_content = wrap_text(sr_response.revision_check, config) | |
diagram.append(f' RC["{check_content}"]') | |
diagram.append(f' {prev_node} --> RC') | |
# Add revised steps if any | |
if revised_steps: | |
# Process each revision step | |
for i, step in enumerate(revised_steps): | |
rev_node_id = f'R{step.number}' | |
content = wrap_text(step.content, config) | |
diagram.append(f' {rev_node_id}["{content}"]') | |
# Connect from the revision check to problematic step, then to revision | |
if step.revision_of: | |
orig_node = f'S{step.revision_of}' | |
# Add connection from revision check to problematic step | |
diagram.append(f' RC --> {orig_node}') | |
# Add connection from problematic step to its revision | |
diagram.append(f' {orig_node} --> {rev_node_id}') | |
# Connect subsequent revised steps | |
if i < len(revised_steps) - 1: | |
next_node = f'R{revised_steps[i + 1].number}' | |
diagram.append(f' {rev_node_id} --> {next_node}') | |
# Add revised answer if present | |
if sr_response.revised_answer: | |
revised_content = wrap_text(sr_response.revised_answer, config) | |
diagram.append(f' RA["{revised_content}"]') | |
last_node = f'R{revised_steps[-1].number}' if revised_steps else 'RC' | |
diagram.append(f' {last_node} --> RA') | |
# Add styles | |
diagram.extend([ | |
' classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;', | |
' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;', | |
' classDef answer fill:#d4edda,stroke:#28a745,stroke-width:2px;', | |
' classDef revision fill:#fff3cd,stroke:#ffc107,stroke-width:2px;', | |
' class Q question;', | |
' class A,RA answer;', | |
' class RC revision;' | |
]) | |
# Style revision nodes | |
for step in revised_steps: | |
diagram.append(f' class R{step.number} revision;') | |
diagram.append('</div>') | |
return '\n'.join(diagram) |