Spaces:
Running
Running
from dataclasses import dataclass | |
from typing import List, Optional | |
import re | |
import textwrap | |
from cot_reasoning import VisualizationConfig | |
class BSNode: | |
"""Data class representing a node in the Beam Search tree""" | |
id: str | |
content: str | |
score: float | |
parent_id: Optional[str] = None | |
children: List['BSNode'] = None | |
is_best_path: bool = False | |
path_score: Optional[float] = None | |
def __post_init__(self): | |
if self.children is None: | |
self.children = [] | |
class BSResponse: | |
"""Data class representing a complete Beam Search response""" | |
question: str | |
root: BSNode | |
answer: Optional[str] = None | |
best_score: Optional[float] = None | |
result_nodes: List[BSNode] = None | |
def __post_init__(self): | |
if self.result_nodes is None: | |
self.result_nodes = [] | |
def parse_bs_response(response_text: str, question: str) -> BSResponse: | |
"""Parse Beam Search response text to extract nodes and build the tree""" | |
# Parse nodes | |
node_pattern = r'<node id="([^"]+)"(?:\s+parent="([^"]+)")?\s*score="([^"]+)"(?:\s+path_score="([^"]+)")?\s*>\s*(.*?)\s*</node>' | |
nodes_dict = {} | |
result_nodes = [] | |
# First pass: create all nodes | |
for match in re.finditer(node_pattern, response_text, re.DOTALL): | |
node_id = match.group(1) | |
parent_id = match.group(2) | |
score = float(match.group(3)) | |
path_score = float(match.group(4)) if match.group(4) else None | |
content = match.group(5).strip() | |
node = BSNode( | |
id=node_id, | |
content=content, | |
score=score, | |
parent_id=parent_id, | |
path_score=path_score | |
) | |
nodes_dict[node_id] = node | |
# Collect result nodes | |
if node_id.startswith('result'): | |
result_nodes.append(node) | |
# Second pass: build tree relationships | |
root = None | |
for node in nodes_dict.values(): | |
if node.parent_id is None: | |
root = node | |
else: | |
parent = nodes_dict.get(node.parent_id) | |
if parent: | |
parent.children.append(node) | |
# Parse answer if present | |
answer_pattern = r'<answer>\s*Best path \(path_score: ([^\)]+)\):\s*(.*?)\s*</answer>' | |
answer_match = re.search(answer_pattern, response_text, re.DOTALL) | |
answer = None | |
best_score = None | |
if answer_match: | |
best_score = float(answer_match.group(1)) | |
answer = answer_match.group(2).strip() | |
# Mark the best path based on path_score | |
current_path_score = best_score | |
for node in nodes_dict.values(): | |
if node.path_score and abs(node.path_score - current_path_score) < 1e-6: | |
# Mark all nodes in the path as best | |
current = node | |
while current: | |
current.is_best_path = True | |
current = nodes_dict.get(current.parent_id) | |
return BSResponse( | |
question=question, | |
root=root, | |
answer=answer, | |
best_score=best_score, | |
result_nodes=result_nodes | |
) | |
def create_mermaid_diagram(bs_response: BSResponse, config: VisualizationConfig) -> str: | |
"""Convert Beam Search response to Mermaid diagram""" | |
diagram = ['<div class="mermaid">', 'graph TD'] | |
# Add question node | |
question_content = wrap_text(bs_response.question, config) | |
diagram.append(f' Q["{question_content}"]') | |
def add_node_and_children(node: BSNode, parent_id: Optional[str] = None): | |
# Format content to include scores | |
score_info = f"Score: {node.score:.2f}" | |
if node.path_score: | |
score_info += f"<br>Path Score: {node.path_score:.2f}" | |
node_content = f"{wrap_text(node.content, config)}<br>{score_info}" | |
# Determine node style based on type and path | |
if node.id.startswith('result'): | |
node_style = 'result' | |
if node.is_best_path: | |
node_style = 'best_result' | |
else: | |
node_style = 'intermediate' | |
if node.is_best_path: | |
node_style = 'best_intermediate' | |
# Add node | |
diagram.append(f' {node.id}["{node_content}"]') | |
diagram.append(f' class {node.id} {node_style};') | |
# Add connection from parent | |
if parent_id: | |
diagram.append(f' {parent_id} --> {node.id}') | |
# Process children | |
for child in node.children: | |
add_node_and_children(child, node.id) | |
# Build tree structure | |
if bs_response.root: | |
diagram.append(f' Q --> {bs_response.root.id}') | |
add_node_and_children(bs_response.root) | |
# Add final answer | |
if bs_response.answer: | |
answer_content = wrap_text( | |
f"Final Answer (Path Score: {bs_response.best_score:.2f}):<br>{bs_response.answer}", | |
config | |
) | |
diagram.append(f' Answer["{answer_content}"]') | |
# Connect all result nodes to the answer | |
for result_node in bs_response.result_nodes: | |
diagram.append(f' {result_node.id} --> Answer') | |
diagram.append(' class Answer final_answer;') | |
# Add styles | |
diagram.extend([ | |
' classDef intermediate fill:#f9f9f9,stroke:#333,stroke-width:2px;', | |
' classDef best_intermediate fill:#f9f9f9,stroke:#333,stroke-width:2px;', | |
' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;', | |
' classDef result fill:#f3f4f6,stroke:#4b5563,stroke-width:2px;', | |
' classDef best_result fill:#bfdbfe,stroke:#3b82f6,stroke-width:2px;', | |
' classDef final_answer fill:#d4edda,stroke:#28a745,stroke-width:2px;', | |
' class Q question;', | |
' linkStyle default stroke:#666,stroke-width:2px;' | |
]) | |
diagram.append('</div>') | |
return '\n'.join(diagram) | |
def wrap_text(text: str, config: VisualizationConfig) -> str: | |
"""Wrap text to fit within box constraints""" | |
text = text.replace('\n', ' ').replace('"', "'") | |
wrapped_lines = textwrap.wrap(text, width=config.max_chars_per_line) | |
if len(wrapped_lines) > config.max_lines: | |
# Option 1: Simply truncate and add ellipsis to the last line | |
wrapped_lines = wrapped_lines[:config.max_lines] | |
wrapped_lines[-1] = wrapped_lines[-1][:config.max_chars_per_line-3] + "..." | |
# Option 2 (alternative): Include part of the next line to show continuity | |
# original_next_line = wrapped_lines[config.max_lines] if len(wrapped_lines) > config.max_lines else "" | |
# wrapped_lines = wrapped_lines[:config.max_lines-1] | |
# wrapped_lines.append(original_next_line[:config.max_chars_per_line-3] + "...") | |
return "<br>".join(wrapped_lines) |