ZongqianLi commited on
Commit
30fabb4
·
verified ·
1 Parent(s): 0ea39d6

Upload 15 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ static/assets/banner-bg.jpg filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use official Python base image
2
+ FROM python:3.11.8-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Copy dependency file
8
+ COPY requirements.txt .
9
+
10
+ # Install system dependencies and Python dependencies
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Copy project files
14
+ COPY . .
15
+
16
+ # Expose port
17
+ EXPOSE 7860
18
+
19
+ # Set environment variables
20
+ ENV FLASK_APP=app.py
21
+ ENV FLASK_RUN_HOST=0.0.0.0
22
+ ENV FLASK_RUN_PORT=7860
23
+
24
+ # Run application
25
+ CMD ["flask", "run"]
api_base.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base API module for handling different API providers.
3
+ This module provides a unified interface for interacting with various API providers
4
+ like Anthropic, OpenAI, Google Gemini and Together AI.
5
+ """
6
+
7
+ from abc import ABC, abstractmethod
8
+ import logging
9
+ import requests
10
+ from openai import OpenAI
11
+ from typing import Optional, Dict, Any, List
12
+ from dataclasses import dataclass
13
+
14
+ # Configure logging
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
18
+ )
19
+ logger = logging.getLogger(__name__)
20
+
21
+ @dataclass
22
+ class APIResponse:
23
+ """Standardized API response structure"""
24
+ text: str
25
+ raw_response: Any
26
+ usage: Dict[str, int]
27
+ model: str
28
+
29
+ class APIError(Exception):
30
+ """Custom exception for API-related errors"""
31
+ def __init__(self, message: str, provider: str, status_code: Optional[int] = None):
32
+ self.message = message
33
+ self.provider = provider
34
+ self.status_code = status_code
35
+ super().__init__(f"{provider} API Error: {message} (Status: {status_code})")
36
+
37
+ class BaseAPI(ABC):
38
+ """Abstract base class for API interactions"""
39
+
40
+ def __init__(self, api_key: str, model: str):
41
+ self.api_key = api_key
42
+ self.model = model
43
+ self.provider_name = "base" # Override in subclasses
44
+
45
+ @abstractmethod
46
+ def generate_response(self, prompt: str, max_tokens: int = 1024,
47
+ prompt_format: Optional[str] = None) -> str:
48
+ """Generate a response using the API"""
49
+ pass
50
+
51
+ def _format_prompt(self, question: str, prompt_format: Optional[str] = None) -> str:
52
+ """Format the prompt using custom format if provided"""
53
+ if prompt_format:
54
+ return prompt_format.format(question=question)
55
+
56
+ # Default format if none provided
57
+ return f"""Please answer the question using the following format, with each step clearly marked:
58
+
59
+ Question: {question}
60
+
61
+ Let's solve this step by step:
62
+ <step number="1">
63
+ [First step of reasoning]
64
+ </step>
65
+ <step number="2">
66
+ [Second step of reasoning]
67
+ </step>
68
+ <step number="3">
69
+ [Third step of reasoning]
70
+ </step>
71
+ ... (add more steps as needed)
72
+ <answer>
73
+ [Final answer]
74
+ </answer>
75
+
76
+ Note:
77
+ 1. Each step must be wrapped in XML tags <step>
78
+ 2. Each step must have a number attribute
79
+ 3. The final answer must be wrapped in <answer> tags
80
+ """
81
+
82
+ def _handle_error(self, error: Exception, context: str = "") -> None:
83
+ """Standardized error handling"""
84
+ error_msg = f"{self.provider_name} API error in {context}: {str(error)}"
85
+ logger.error(error_msg)
86
+ raise APIError(str(error), self.provider_name)
87
+
88
+ class AnthropicAPI(BaseAPI):
89
+ """Class to handle interactions with the Anthropic API"""
90
+
91
+ def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"):
92
+ super().__init__(api_key, model)
93
+ self.provider_name = "Anthropic"
94
+ self.base_url = "https://api.anthropic.com/v1/messages"
95
+ self.headers = {
96
+ "x-api-key": api_key,
97
+ "anthropic-version": "2023-06-01",
98
+ "content-type": "application/json"
99
+ }
100
+
101
+ def generate_response(self, prompt: str, max_tokens: int = 1024,
102
+ prompt_format: Optional[str] = None) -> str:
103
+ """Generate a response using the Anthropic API"""
104
+ try:
105
+ formatted_prompt = self._format_prompt(prompt, prompt_format)
106
+ data = {
107
+ "model": self.model,
108
+ "messages": [{"role": "user", "content": formatted_prompt}],
109
+ "max_tokens": max_tokens
110
+ }
111
+
112
+ logger.info(f"Sending request to Anthropic API with model {self.model}")
113
+ response = requests.post(self.base_url, headers=self.headers, json=data)
114
+ response.raise_for_status()
115
+
116
+ response_data = response.json()
117
+ return response_data["content"][0]["text"]
118
+
119
+ except requests.exceptions.RequestException as e:
120
+ self._handle_error(e, "request")
121
+ except (KeyError, IndexError) as e:
122
+ self._handle_error(e, "response parsing")
123
+ except Exception as e:
124
+ self._handle_error(e, "unexpected")
125
+
126
+ class OpenAIAPI(BaseAPI):
127
+ """Class to handle interactions with the OpenAI API"""
128
+
129
+ def __init__(self, api_key: str, model: str = "gpt-4-turbo-preview"):
130
+ super().__init__(api_key, model)
131
+ self.provider_name = "OpenAI"
132
+ try:
133
+ self.client = OpenAI(api_key=api_key)
134
+ except Exception as e:
135
+ self._handle_error(e, "initialization")
136
+
137
+ def generate_response(self, prompt: str, max_tokens: int = 1024,
138
+ prompt_format: Optional[str] = None) -> str:
139
+ """Generate a response using the OpenAI API"""
140
+ try:
141
+ formatted_prompt = self._format_prompt(prompt, prompt_format)
142
+
143
+ logger.info(f"Sending request to OpenAI API with model {self.model}")
144
+ response = self.client.chat.completions.create(
145
+ model=self.model,
146
+ messages=[{"role": "user", "content": formatted_prompt}],
147
+ max_tokens=max_tokens
148
+ )
149
+
150
+ return response.choices[0].message.content
151
+
152
+ except Exception as e:
153
+ self._handle_error(e, "request or response processing")
154
+
155
+ class GeminiAPI(BaseAPI):
156
+ """Class to handle interactions with the Google Gemini API"""
157
+
158
+ def __init__(self, api_key: str, model: str = "gemini-2.0-flash"):
159
+ super().__init__(api_key, model)
160
+ self.provider_name = "Gemini"
161
+ try:
162
+ from google import genai
163
+ self.client = genai.Client(api_key=api_key)
164
+ except Exception as e:
165
+ self._handle_error(e, "initialization")
166
+
167
+ def generate_response(self, prompt: str, max_tokens: int = 1024,
168
+ prompt_format: Optional[str] = None) -> str:
169
+ """Generate a response using the Gemini API"""
170
+ try:
171
+ from google.genai import types
172
+ formatted_prompt = self._format_prompt(prompt, prompt_format)
173
+
174
+ logger.info(f"Sending request to Gemini API with model {self.model}")
175
+ response = self.client.models.generate_content(
176
+ model=self.model,
177
+ contents=[formatted_prompt],
178
+ config=types.GenerateContentConfig(
179
+ max_output_tokens=max_tokens,
180
+ temperature=0.7
181
+ )
182
+ )
183
+
184
+ if not response.text:
185
+ raise APIError("Empty response from Gemini API", self.provider_name)
186
+
187
+ return response.text
188
+
189
+ except Exception as e:
190
+ self._handle_error(e, "request or response processing")
191
+
192
+ class TogetherAPI(BaseAPI):
193
+ """Class to handle interactions with the Together AI API"""
194
+
195
+ def __init__(self, api_key: str, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"):
196
+ super().__init__(api_key, model)
197
+ self.provider_name = "Together"
198
+ try:
199
+ from together import Together
200
+ self.client = Together(api_key=api_key)
201
+ except Exception as e:
202
+ self._handle_error(e, "initialization")
203
+
204
+ def generate_response(self, prompt: str, max_tokens: int = 1024,
205
+ prompt_format: Optional[str] = None) -> str:
206
+ """Generate a response using the Together AI API"""
207
+ try:
208
+ formatted_prompt = self._format_prompt(prompt, prompt_format)
209
+
210
+ logger.info(f"Sending request to Together AI API with model {self.model}")
211
+ response = self.client.chat.completions.create(
212
+ model=self.model,
213
+ messages=[{"role": "user", "content": formatted_prompt}],
214
+ max_tokens=max_tokens
215
+ )
216
+
217
+ # Robust response extraction
218
+ if hasattr(response, 'choices') and response.choices:
219
+ return response.choices[0].message.content
220
+ elif hasattr(response, 'text'):
221
+ return response.text
222
+ else:
223
+ # If response doesn't match expected structures
224
+ raise APIError("Unexpected response format from Together AI", self.provider_name)
225
+
226
+ except Exception as e:
227
+ self._handle_error(e, "request or response processing")
228
+
229
+ class APIFactory:
230
+ """Factory class for creating API instances"""
231
+
232
+ _providers = {
233
+ "anthropic": {
234
+ "class": AnthropicAPI,
235
+ "default_model": "claude-3-opus-20240229"
236
+ },
237
+ "openai": {
238
+ "class": OpenAIAPI,
239
+ "default_model": "gpt-4-turbo-preview"
240
+ },
241
+ "google": {
242
+ "class": GeminiAPI,
243
+ "default_model": "gemini-2.0-flash"
244
+ },
245
+ "together": {
246
+ "class": TogetherAPI,
247
+ "default_model": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
248
+ }
249
+ }
250
+
251
+ @classmethod
252
+ def supported_providers(cls) -> List[str]:
253
+ """Get list of supported providers"""
254
+ return list(cls._providers.keys())
255
+
256
+ @classmethod
257
+ def create_api(cls, provider: str, api_key: str, model: Optional[str] = None) -> BaseAPI:
258
+ """Factory method to create appropriate API instance"""
259
+ provider = provider.lower()
260
+ if provider not in cls._providers:
261
+ raise ValueError(f"Unsupported provider: {provider}. "
262
+ f"Supported providers are: {', '.join(cls.supported_providers())}")
263
+
264
+ provider_info = cls._providers[provider]
265
+ api_class = provider_info["class"]
266
+ model = model or provider_info["default_model"]
267
+
268
+ logger.info(f"Creating API instance for provider: {provider}, model: {model}")
269
+ return api_class(api_key=api_key, model=model)
270
+
271
+ def create_api(provider: str, api_key: str, model: Optional[str] = None) -> BaseAPI:
272
+ """Convenience function to create API instance"""
273
+ return APIFactory.create_api(provider, api_key, model)
274
+
275
+ # Example usage:
276
+ if __name__ == "__main__":
277
+ # Example with Anthropic
278
+ anthropic_api = create_api("anthropic", "your-api-key")
279
+
280
+ # Example with OpenAI
281
+ openai_api = create_api("openai", "your-api-key", "gpt-4")
282
+
283
+ # Example with Gemini
284
+ gemini_api = create_api("gemini", "your-api-key", "gemini-2.0-flash")
285
+
286
+ # Example with Together AI
287
+ together_api = create_api("together", "your-api-key")
288
+
289
+ # Get supported providers
290
+ providers = APIFactory.supported_providers()
291
+ print(f"Supported providers: {providers}")
api_keys.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "anthropic": "",
3
+ "openai": "",
4
+ "google": "",
5
+ "together": ""
6
+ }
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ from api_base import create_api # New import for API factory
3
+ from cot_reasoning import (
4
+ VisualizationConfig,
5
+ create_mermaid_diagram as create_cot_diagram,
6
+ parse_cot_response
7
+ )
8
+ from tot_reasoning import (
9
+ create_mermaid_diagram as create_tot_diagram,
10
+ parse_tot_response
11
+ )
12
+ from l2m_reasoning import (
13
+ create_mermaid_diagram as create_l2m_diagram,
14
+ parse_l2m_response
15
+ )
16
+ from selfconsistency_reasoning import (
17
+ create_mermaid_diagram as create_scr_diagram,
18
+ parse_scr_response
19
+ )
20
+ from selfrefine_reasoning import (
21
+ create_mermaid_diagram as create_srf_diagram,
22
+ parse_selfrefine_response
23
+ )
24
+ from bs_reasoning import (
25
+ create_mermaid_diagram as create_bs_diagram,
26
+ parse_bs_response
27
+ )
28
+ from configs import config
29
+ import logging
30
+
31
+ # Configure logging
32
+ logging.basicConfig(
33
+ level=logging.INFO,
34
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
35
+ )
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # Initialize Flask app
39
+ app = Flask(__name__)
40
+
41
+ @app.route('/')
42
+ def index():
43
+ """Render the main page"""
44
+ return render_template('index.html')
45
+
46
+ @app.route('/config')
47
+ def get_config():
48
+ """Get initial configuration"""
49
+ return jsonify(config.get_initial_values())
50
+
51
+ @app.route('/method-config/<method_id>')
52
+ def get_method_config(method_id):
53
+ """Get configuration for specific method"""
54
+ method_config = config.get_method_config(method_id)
55
+ if method_config:
56
+ return jsonify(method_config)
57
+ return jsonify({"error": "Method not found"}), 404
58
+
59
+ @app.route('/provider-api-key/<provider>')
60
+ def get_provider_api_key(provider):
61
+ """Get default API key for specific provider"""
62
+ try:
63
+ api_key = config.general.get_default_api_key(provider)
64
+ return jsonify({
65
+ 'success': True,
66
+ 'api_key': api_key
67
+ })
68
+ except Exception as e:
69
+ logger.error(f"Error getting API key for provider {provider}: {str(e)}")
70
+ return jsonify({
71
+ 'success': False,
72
+ 'error': str(e)
73
+ }), 500
74
+
75
+ @app.route('/select-method', methods=['POST'])
76
+ def select_method():
77
+ """Let the model select the most appropriate reasoning method"""
78
+ try:
79
+ data = request.json
80
+ if not data:
81
+ return jsonify({'success': False, 'error': 'No data provided'}), 400
82
+
83
+ # Extract parameters
84
+ api_key = data.get('api_key')
85
+ provider = data.get('provider', 'anthropic')
86
+ model = data.get('model')
87
+ question = data.get('question')
88
+
89
+ if not all([api_key, model, question]):
90
+ return jsonify({'success': False, 'error': 'Missing required parameters'}), 400
91
+
92
+ # Create the selection prompt
93
+ methods = config.methods
94
+ prompt = f"""Given this question: "{question}"
95
+
96
+ Please select the most appropriate reasoning method from the following options to solve it:
97
+
98
+ {chr(10).join(f'- {method_id}: {config.name}' for method_id, config in methods.items())}
99
+
100
+ Consider the characteristics of each method and the nature of the question.
101
+ Output your selection in exactly this format:
102
+ <selected_method>method_id</selected_method>
103
+ where method_id is strictly one of: {', '.join(methods.keys())}.
104
+ Do not use the method or words that are not in {', '.join(methods.keys())}."""
105
+
106
+ # Get model's selection
107
+ try:
108
+ api = create_api(provider, api_key, model)
109
+ response = api.generate_response(prompt, max_tokens=100)
110
+
111
+ # Extract method ID using basic string parsing
112
+ import re
113
+ match = re.search(r'<selected_method>(\w+)</selected_method>', response)
114
+ if match and match.group(1) in methods:
115
+ selected_method = match.group(1)
116
+ return jsonify({
117
+ 'success': True,
118
+ 'selected_method': selected_method,
119
+ 'raw_response': response
120
+ })
121
+ else:
122
+ return jsonify({
123
+ 'success': False,
124
+ 'error': 'Invalid method selection in response'
125
+ }), 400
126
+
127
+ except Exception as e:
128
+ return jsonify({
129
+ 'success': False,
130
+ 'error': f'API call failed: {str(e)}'
131
+ }), 500
132
+
133
+ except Exception as e:
134
+ logger.error(f"Error in method selection: {str(e)}")
135
+ return jsonify({
136
+ 'success': False,
137
+ 'error': str(e)
138
+ }), 500
139
+
140
+ @app.route('/process', methods=['POST'])
141
+ def process():
142
+ """Process the reasoning request"""
143
+ try:
144
+ # Get request data
145
+ data = request.json
146
+ if not data:
147
+ return jsonify({
148
+ 'success': False,
149
+ 'error': 'No data provided'
150
+ }), 400
151
+
152
+ # Extract parameters
153
+ api_key = data.get('api_key')
154
+ if not api_key:
155
+ return jsonify({
156
+ 'success': False,
157
+ 'error': 'API key is required'
158
+ }), 400
159
+
160
+ question = data.get('question')
161
+ if not question:
162
+ return jsonify({
163
+ 'success': False,
164
+ 'error': 'Question is required'
165
+ }), 400
166
+
167
+ # Get optional parameters with defaults
168
+ provider = data.get('provider', 'anthropic') # New parameter for provider
169
+ model = data.get('model', config.general.available_models[0])
170
+ max_tokens = int(data.get('max_tokens', config.general.max_tokens))
171
+ prompt_format = data.get('prompt_format')
172
+ chars_per_line = int(data.get('chars_per_line', config.general.chars_per_line))
173
+ max_lines = int(data.get('max_lines', config.general.max_lines))
174
+ reasoning_method = data.get('reasoning_method', 'cot')
175
+
176
+ # Initialize API with factory function
177
+ try:
178
+ api = create_api(provider, api_key, model)
179
+ except Exception as e:
180
+ return jsonify({
181
+ 'success': False,
182
+ 'error': f'Failed to initialize API: {str(e)}'
183
+ }), 400
184
+
185
+ # Get model response
186
+ logger.info(f"Generating response for question using {provider} {model}")
187
+ try:
188
+ raw_response = api.generate_response(
189
+ question,
190
+ max_tokens=max_tokens,
191
+ prompt_format=prompt_format
192
+ )
193
+ except Exception as e:
194
+ return jsonify({
195
+ 'success': False,
196
+ 'error': f'API call failed: {str(e)}'
197
+ }), 500
198
+
199
+ # Create visualization config
200
+ viz_config = VisualizationConfig(
201
+ max_chars_per_line=chars_per_line,
202
+ max_lines=max_lines
203
+ )
204
+
205
+ # Generate visualization based on reasoning method
206
+ visualization = None
207
+ try:
208
+ if reasoning_method == 'cot':
209
+ result = parse_cot_response(raw_response, question)
210
+ visualization = create_cot_diagram(result, viz_config)
211
+ elif reasoning_method == 'tot':
212
+ result = parse_tot_response(raw_response, question)
213
+ visualization = create_tot_diagram(result, viz_config)
214
+ elif reasoning_method == 'l2m':
215
+ result = parse_l2m_response(raw_response, question)
216
+ visualization = create_l2m_diagram(result, viz_config)
217
+ elif reasoning_method == 'scr':
218
+ result = parse_scr_response(raw_response, question)
219
+ visualization = create_scr_diagram(result, viz_config)
220
+ elif reasoning_method == 'srf':
221
+ result = parse_selfrefine_response(raw_response, question)
222
+ visualization = create_srf_diagram(result, viz_config)
223
+ elif reasoning_method == 'bs':
224
+ result = parse_bs_response(raw_response, question)
225
+ visualization = create_bs_diagram(result, viz_config)
226
+
227
+ logger.info("Successfully generated visualization")
228
+ except Exception as viz_error:
229
+ logger.error(f"Visualization generation failed: {str(viz_error)}")
230
+ # Continue without visualization
231
+
232
+ # Return successful response
233
+ return jsonify({
234
+ 'success': True,
235
+ 'raw_output': raw_response,
236
+ 'visualization': visualization
237
+ })
238
+
239
+ except Exception as e:
240
+ # Log the error and return error response
241
+ logger.error(f"Error processing request: {str(e)}")
242
+ return jsonify({
243
+ 'success': False,
244
+ 'error': str(e)
245
+ }), 500
246
+
247
+ @app.errorhandler(404)
248
+ def not_found_error(error):
249
+ """Handle 404 errors"""
250
+ return jsonify({
251
+ 'success': False,
252
+ 'error': 'Resource not found'
253
+ }), 404
254
+
255
+ @app.errorhandler(500)
256
+ def internal_error(error):
257
+ """Handle 500 errors"""
258
+ return jsonify({
259
+ 'success': False,
260
+ 'error': 'Internal server error'
261
+ }), 500
262
+
263
+ if __name__ == '__main__':
264
+ try:
265
+ # Run the application
266
+ app.run(
267
+ host='0.0.0.0',
268
+ port=5001,
269
+ debug=False # Disable debug mode in production
270
+ )
271
+ except Exception as e:
272
+ logger.error(f"Failed to start application: {str(e)}")
bs_reasoning.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+ import re
4
+ import textwrap
5
+ from cot_reasoning import VisualizationConfig
6
+
7
+ @dataclass
8
+ class BSNode:
9
+ """Data class representing a node in the Beam Search tree"""
10
+ id: str
11
+ content: str
12
+ score: float
13
+ parent_id: Optional[str] = None
14
+ children: List['BSNode'] = None
15
+ is_best_path: bool = False
16
+ path_score: Optional[float] = None
17
+
18
+ def __post_init__(self):
19
+ if self.children is None:
20
+ self.children = []
21
+
22
+ @dataclass
23
+ class BSResponse:
24
+ """Data class representing a complete Beam Search response"""
25
+ question: str
26
+ root: BSNode
27
+ answer: Optional[str] = None
28
+ best_score: Optional[float] = None
29
+ result_nodes: List[BSNode] = None
30
+
31
+ def __post_init__(self):
32
+ if self.result_nodes is None:
33
+ self.result_nodes = []
34
+
35
+ def parse_bs_response(response_text: str, question: str) -> BSResponse:
36
+ """Parse Beam Search response text to extract nodes and build the tree"""
37
+ # Parse nodes
38
+ node_pattern = r'<node id="([^"]+)"(?:\s+parent="([^"]+)")?\s*score="([^"]+)"(?:\s+path_score="([^"]+)")?\s*>\s*(.*?)\s*</node>'
39
+ nodes_dict = {}
40
+ result_nodes = []
41
+
42
+ # First pass: create all nodes
43
+ for match in re.finditer(node_pattern, response_text, re.DOTALL):
44
+ node_id = match.group(1)
45
+ parent_id = match.group(2)
46
+ score = float(match.group(3))
47
+ path_score = float(match.group(4)) if match.group(4) else None
48
+ content = match.group(5).strip()
49
+
50
+ node = BSNode(
51
+ id=node_id,
52
+ content=content,
53
+ score=score,
54
+ parent_id=parent_id,
55
+ path_score=path_score
56
+ )
57
+ nodes_dict[node_id] = node
58
+
59
+ # Collect result nodes
60
+ if node_id.startswith('result'):
61
+ result_nodes.append(node)
62
+
63
+ # Second pass: build tree relationships
64
+ root = None
65
+ for node in nodes_dict.values():
66
+ if node.parent_id is None:
67
+ root = node
68
+ else:
69
+ parent = nodes_dict.get(node.parent_id)
70
+ if parent:
71
+ parent.children.append(node)
72
+
73
+ # Parse answer if present
74
+ answer_pattern = r'<answer>\s*Best path \(path_score: ([^\)]+)\):\s*(.*?)\s*</answer>'
75
+ answer_match = re.search(answer_pattern, response_text, re.DOTALL)
76
+ answer = None
77
+ best_score = None
78
+
79
+ if answer_match:
80
+ best_score = float(answer_match.group(1))
81
+ answer = answer_match.group(2).strip()
82
+
83
+ # Mark the best path based on path_score
84
+ current_path_score = best_score
85
+ for node in nodes_dict.values():
86
+ if node.path_score and abs(node.path_score - current_path_score) < 1e-6:
87
+ # Mark all nodes in the path as best
88
+ current = node
89
+ while current:
90
+ current.is_best_path = True
91
+ current = nodes_dict.get(current.parent_id)
92
+
93
+ return BSResponse(
94
+ question=question,
95
+ root=root,
96
+ answer=answer,
97
+ best_score=best_score,
98
+ result_nodes=result_nodes
99
+ )
100
+
101
+ def create_mermaid_diagram(bs_response: BSResponse, config: VisualizationConfig) -> str:
102
+ """Convert Beam Search response to Mermaid diagram"""
103
+ diagram = ['<div class="mermaid">', 'graph TD']
104
+
105
+ # Add question node
106
+ question_content = wrap_text(bs_response.question, config)
107
+ diagram.append(f' Q["{question_content}"]')
108
+
109
+ def add_node_and_children(node: BSNode, parent_id: Optional[str] = None):
110
+ # Format content to include scores
111
+ score_info = f"Score: {node.score:.2f}"
112
+ if node.path_score:
113
+ score_info += f"<br>Path Score: {node.path_score:.2f}"
114
+ node_content = f"{wrap_text(node.content, config)}<br>{score_info}"
115
+
116
+ # Determine node style based on type and path
117
+ if node.id.startswith('result'):
118
+ node_style = 'result'
119
+ if node.is_best_path:
120
+ node_style = 'best_result'
121
+ else:
122
+ node_style = 'intermediate'
123
+ if node.is_best_path:
124
+ node_style = 'best_intermediate'
125
+
126
+ # Add node
127
+ diagram.append(f' {node.id}["{node_content}"]')
128
+ diagram.append(f' class {node.id} {node_style};')
129
+
130
+ # Add connection from parent
131
+ if parent_id:
132
+ diagram.append(f' {parent_id} --> {node.id}')
133
+
134
+ # Process children
135
+ for child in node.children:
136
+ add_node_and_children(child, node.id)
137
+
138
+ # Build tree structure
139
+ if bs_response.root:
140
+ diagram.append(f' Q --> {bs_response.root.id}')
141
+ add_node_and_children(bs_response.root)
142
+
143
+ # Add final answer
144
+ if bs_response.answer:
145
+ answer_content = wrap_text(
146
+ f"Final Answer (Path Score: {bs_response.best_score:.2f}):<br>{bs_response.answer}",
147
+ config
148
+ )
149
+ diagram.append(f' Answer["{answer_content}"]')
150
+
151
+ # Connect all result nodes to the answer
152
+ for result_node in bs_response.result_nodes:
153
+ diagram.append(f' {result_node.id} --> Answer')
154
+
155
+ diagram.append(' class Answer final_answer;')
156
+
157
+ # Add styles
158
+ diagram.extend([
159
+ ' classDef intermediate fill:#f9f9f9,stroke:#333,stroke-width:2px;',
160
+ ' classDef best_intermediate fill:#f9f9f9,stroke:#333,stroke-width:2px;',
161
+ ' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;',
162
+ ' classDef result fill:#f3f4f6,stroke:#4b5563,stroke-width:2px;',
163
+ ' classDef best_result fill:#bfdbfe,stroke:#3b82f6,stroke-width:2px;',
164
+ ' classDef final_answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
165
+ ' class Q question;',
166
+ ' linkStyle default stroke:#666,stroke-width:2px;'
167
+ ])
168
+
169
+ diagram.append('</div>')
170
+ return '\n'.join(diagram)
171
+
172
+ def wrap_text(text: str, config: VisualizationConfig) -> str:
173
+ """Wrap text to fit within box constraints"""
174
+ text = text.replace('\n', ' ').replace('"', "'")
175
+ wrapped_lines = textwrap.wrap(text, width=config.max_chars_per_line)
176
+
177
+ if len(wrapped_lines) > config.max_lines:
178
+ # Option 1: Simply truncate and add ellipsis to the last line
179
+ wrapped_lines = wrapped_lines[:config.max_lines]
180
+ wrapped_lines[-1] = wrapped_lines[-1][:config.max_chars_per_line-3] + "..."
181
+
182
+ # Option 2 (alternative): Include part of the next line to show continuity
183
+ # original_next_line = wrapped_lines[config.max_lines] if len(wrapped_lines) > config.max_lines else ""
184
+ # wrapped_lines = wrapped_lines[:config.max_lines-1]
185
+ # wrapped_lines.append(original_next_line[:config.max_chars_per_line-3] + "...")
186
+
187
+ return "<br>".join(wrapped_lines)
configs.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, Any, Optional, List
3
+ import os
4
+ import json
5
+ import logging
6
+
7
+ # Configure logging
8
+ logging.basicConfig(
9
+ level=logging.INFO,
10
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
11
+ )
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def load_api_keys_from_file(file_path: str = "api_keys.json") -> Dict[str, str]:
15
+ """Load API keys from a JSON file"""
16
+ try:
17
+ if os.path.exists(file_path):
18
+ with open(file_path, 'r') as f:
19
+ return json.load(f)
20
+ else:
21
+ logger.warning(f"API keys file {file_path} not found")
22
+ return {}
23
+ except Exception as e:
24
+ logger.error(f"Error loading API keys from {file_path}: {str(e)}")
25
+ return {}
26
+
27
+ @dataclass
28
+ class GeneralConfig:
29
+ """General configuration parameters that are method-independent"""
30
+ available_models: List[str] = field(default_factory=lambda: [
31
+ # Anthropic Models
32
+ "claude-3-haiku-20240307",
33
+ "claude-3-sonnet-20240229",
34
+ "claude-3-opus-20240229",
35
+ # OpenAI Models
36
+ "gpt-4-turbo-preview",
37
+ "gpt-4",
38
+ "gpt-3.5-turbo",
39
+ # Gemini Models
40
+ "gemini-2.0-flash",
41
+ # Together AI Models
42
+ #"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
43
+ "meta-llama/Llama-3.3-70B-Instruct-Turbo",
44
+ "meta-llama/Meta-Llama-3.1-405B-Instruct-Lite-Pro",
45
+ "deepseek-ai/DeepSeek-V3",
46
+ "mistralai/Mixtral-8x22B-Instruct-v0.1",
47
+ "Qwen/Qwen2.5-72B-Instruct-Turbo",
48
+ #"microsoft/WizardLM-2-8x22B",
49
+ #"databricks/dbrx-instruct",
50
+ #"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
51
+
52
+ ])
53
+ model_providers: Dict[str, str] = field(default_factory=lambda: {
54
+ "claude-3-haiku-20240307": "anthropic",
55
+ "claude-3-sonnet-20240229": "anthropic",
56
+ "claude-3-opus-20240229": "anthropic",
57
+ "gpt-4-turbo-preview": "openai",
58
+ "gpt-4": "openai",
59
+ "gpt-3.5-turbo": "openai",
60
+ "gemini-2.0-flash": "google",
61
+ #"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": "together",
62
+ "meta-llama/Llama-3.3-70B-Instruct-Turbo": "together",
63
+ "meta-llama/Meta-Llama-3.1-405B-Instruct-Lite-Pro": "together",
64
+ "deepseek-ai/DeepSeek-V3": "together",
65
+ "mistralai/Mixtral-8x22B-Instruct-v0.1": "together",
66
+ "Qwen/Qwen2.5-72B-Instruct-Turbo": "together",
67
+ #"microsoft/WizardLM-2-8x22B": "together",
68
+ #"databricks/dbrx-instruct": "together",
69
+ #"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": "together",
70
+
71
+ })
72
+ providers: List[str] = field(default_factory=lambda: ["anthropic", "openai", "google", "together"])
73
+ max_tokens: int = 2048
74
+ chars_per_line: int = 40
75
+ max_lines: int = 8
76
+
77
+ def __post_init__(self):
78
+ """Load API keys after initialization"""
79
+ self.provider_api_keys = load_api_keys_from_file()
80
+
81
+ def get_default_api_key(self, provider: str) -> str:
82
+ """Get default API key for specific provider"""
83
+ return self.provider_api_keys.get(provider, "")
84
+
85
+ @dataclass
86
+ class ChainOfThoughtsConfig:
87
+ """Configuration specific to Chain of Thoughts method"""
88
+ name: str = "Chain of Thoughts"
89
+ prompt_format: str = '''Please answer the question using the following format by Chain-of-Thoughts, with each step clearly marked:
90
+
91
+ Question: {question}
92
+
93
+ Let's solve this step by step:
94
+ <step number="1">
95
+ [First step of reasoning]
96
+ </step>
97
+ ... (add more steps as needed)
98
+ <answer>
99
+ [Final answer]
100
+ </answer>'''
101
+ example_question: str = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
102
+
103
+ @dataclass
104
+ class TreeOfThoughtsConfig:
105
+ """Configuration specific to Tree of Thoughts method"""
106
+ name: str = "Tree of Thoughts"
107
+ prompt_format: str = '''Please answer the question using Tree of Thoughts reasoning. Consider multiple possible approaches and explore their consequences. Feel free to create as many branches and sub-branches as needed for thorough exploration. Use the following format:
108
+
109
+ Question: {question}
110
+
111
+ Let's explore different paths of reasoning:
112
+ <node id="root">
113
+ [Initial analysis of the problem]
114
+ </node>
115
+
116
+ [Add main approaches with unique IDs (approach1, approach2, etc.)]
117
+ <node id="approach1" parent="root">
118
+ [First main approach to solve the problem]
119
+ </node>
120
+
121
+ [For each approach, add as many sub-branches as needed using parent references]
122
+ <node id="approach1.1" parent="approach1">
123
+ [Exploration of a sub-path]
124
+ </node>
125
+
126
+ [Continue adding nodes and exploring paths as needed. You can create deeper levels by extending the ID pattern (e.g., approach1.1.1)]
127
+
128
+ <answer>
129
+ Based on exploring all paths:
130
+ - [Explain which path(s) led to the best solution and why]
131
+ - [State the final answer]
132
+ </answer>'''
133
+ example_question: str = "Using the numbers 3, 3, 8, and 8, find a way to make exactly 24 using basic arithmetic operations (addition, subtraction, multiplication, division). Each number must be used exactly once, and you can use parentheses to control the order of operations."
134
+
135
+ @dataclass
136
+ class LeastToMostConfig:
137
+ """Configuration specific to Least-to-Most method"""
138
+ name: str = "Least to Most"
139
+ prompt_format: str = '''Please solve this question using the Least-to-Most approach. First break down the complex question into simpler sub-questions, then solve them in order from simplest to most complex.
140
+
141
+ Question: {question}
142
+
143
+ Let's solve this step by step:
144
+ <step number="1">
145
+ <question>[First sub-question - should be the simplest]</question>
146
+ <reasoning>[Reasoning process for this sub-question]</reasoning>
147
+ <answer>[Answer to this sub-question]</answer>
148
+ </step>
149
+ ... (add more steps as needed)
150
+ <final_answer>
151
+ [Final answer that combines the insights from all steps]
152
+ </final_answer>'''
153
+ example_question: str = "How to create a personal website?"
154
+
155
+ @dataclass
156
+ class SelfRefineConfig:
157
+ """Configuration specific to Self-Refine method"""
158
+ name: str = "Self-Refine"
159
+ prompt_format: str = '''Please solve this question step by step, then check your work and revise any mistakes. Use the following format:
160
+
161
+ Question: {question}
162
+
163
+ Let's solve this step by step:
164
+ <step number="1">
165
+ [First step of reasoning]
166
+ </step>
167
+ ... (add more steps as needed)
168
+ <answer>
169
+ [Initial answer]
170
+ </answer>
171
+
172
+ Now, let's check our work:
173
+ <revision_check>
174
+ [Examine each step for errors or improvements]
175
+ </revision_check>
176
+
177
+ [If any revisions are needed, add revised steps:]
178
+ <revised_step number="[new_step_number]" revises="[original_step_number]">
179
+ [Corrected reasoning]
180
+ </revised_step>
181
+ ... (add more revised steps if needed)
182
+
183
+ [If the answer changes, add the revised answer:]
184
+ <revised_answer>
185
+ [Updated final answer]
186
+ </revised_answer>'''
187
+ example_question: str = "Write a one sentence fiction and then improve it after refine."
188
+
189
+ @dataclass
190
+ class SelfConsistencyConfig:
191
+ """Configuration specific to Self-consistency method"""
192
+ name: str = "Self-consistency"
193
+ prompt_format: str = '''Please solve the question using multiple independent reasoning paths. Generate 3 different Chain-of-Thought solutions and provide the final answer based on majority voting.
194
+
195
+ Question: {question}
196
+
197
+ Path 1:
198
+ <step number="1">
199
+ [First step of reasoning]
200
+ </step>
201
+ ... (add more steps as needed)
202
+ <answer>
203
+ [Path 1's answer]
204
+ </answer>
205
+
206
+ Path 2:
207
+ ... (repeat the same format for all 3 paths)
208
+
209
+ Note: Each path should be independent and may arrive at different answers. The final answer will be determined by majority voting.'''
210
+ example_question: str = "How many r are there in strawberrrrrrrrry?"
211
+
212
+ @dataclass
213
+ class BeamSearchConfig:
214
+ """Configuration specific to Beam Search method"""
215
+ name: str = "Beam Search"
216
+ prompt_format: str = '''Please solve this question using Beam Search reasoning. For each step:
217
+ 1. Explore multiple paths fully regardless of intermediate scores
218
+ 2. Assign a score between 0 and 1 to each node based on how promising that step is
219
+ 3. Calculate path_score for each result by summing scores along the path from root to result
220
+ 4. The final choice will be based on the highest cumulative path score
221
+
222
+ Question: {question}
223
+
224
+ <node id="root" score="[score]">
225
+ [Initial analysis - Break down the key aspects of the problem]
226
+ </node>
227
+
228
+ # First approach branch
229
+ <node id="approach1" parent="root" score="[score]">
230
+ [First approach - Outline the general strategy]
231
+ </node>
232
+
233
+ <node id="impl1.1" parent="approach1" score="[score]">
234
+ [Implementation 1.1 - Detail the specific steps and methods]
235
+ </node>
236
+
237
+ <node id="result1.1" parent="impl1.1" score="[score]" path_score="[sum of scores from root to here]">
238
+ [Result 1.1 - Describe concrete outcome and effectiveness]
239
+ </node>
240
+
241
+ <node id="impl1.2" parent="approach1" score="[score]">
242
+ [Implementation 1.2 - Detail alternative steps and methods]
243
+ </node>
244
+
245
+ <node id="result1.2" parent="impl1.2" score="[score]" path_score="[sum of scores from root to here]">
246
+ [Result 1.2 - Describe concrete outcome and effectiveness]
247
+ </node>
248
+
249
+ # Second approach branch
250
+ <node id="approach2" parent="root" score="[score]">
251
+ [Second approach - Outline an alternative general strategy]
252
+ </node>
253
+
254
+ <node id="impl2.1" parent="approach2" score="[score]">
255
+ [Implementation 2.1 - Detail the specific steps and methods]
256
+ </node>
257
+
258
+ <node id="result2.1" parent="impl2.1" score="[score]" path_score="[sum of scores from root to here]">
259
+ [Result 2.1 - Describe concrete outcome and effectiveness]
260
+ </node>
261
+
262
+ <node id="impl2.2" parent="approach2" score="[score]">
263
+ [Implementation 2.2 - Detail alternative steps and methods]
264
+ </node>
265
+
266
+ <node id="result2.2" parent="impl2.2" score="[score]" path_score="[sum of scores from root to here]">
267
+ [Result 2.2 - Describe concrete outcome and effectiveness]
268
+ </node>
269
+
270
+ <answer>
271
+ Best path (path_score: [highest_path_score]):
272
+ [Identify the path with the highest cumulative score]
273
+ [Explain why this path is most effective]
274
+ [Provide the final synthesized solution]
275
+ </answer>'''
276
+ example_question: str = "Give me two suggestions for transitioning from a journalist to a book editor?"
277
+
278
+ class ReasoningConfig:
279
+ """Main configuration class that manages both general and method-specific configs"""
280
+ def __init__(self):
281
+ self.general = GeneralConfig()
282
+ self.methods = {
283
+ "cot": ChainOfThoughtsConfig(),
284
+ "tot": TreeOfThoughtsConfig(),
285
+ "scr": SelfConsistencyConfig(),
286
+ "srf": SelfRefineConfig(),
287
+ "l2m": LeastToMostConfig(),
288
+ "bs": BeamSearchConfig(),
289
+ }
290
+
291
+ def get_method_config(self, method_id: str) -> Optional[dict]:
292
+ """Get configuration for specific method"""
293
+ method = self.methods.get(method_id)
294
+ if method:
295
+ return {
296
+ "name": method.name,
297
+ "prompt_format": method.prompt_format,
298
+ "example_question": method.example_question
299
+ }
300
+ return None
301
+
302
+ def get_initial_values(self) -> dict:
303
+ """Get initial values for UI"""
304
+ return {
305
+ "general": {
306
+ "available_models": self.general.available_models,
307
+ "model_providers": self.general.model_providers,
308
+ "providers": self.general.providers,
309
+ "max_tokens": self.general.max_tokens,
310
+ "default_api_key": self.general.get_default_api_key(self.general.providers[0]),
311
+ "visualization": {
312
+ "chars_per_line": self.general.chars_per_line,
313
+ "max_lines": self.general.max_lines
314
+ }
315
+ },
316
+ "methods": {
317
+ method_id: {
318
+ "name": config.name,
319
+ "prompt_format": config.prompt_format,
320
+ "example_question": config.example_question
321
+ }
322
+ for method_id, config in self.methods.items()
323
+ }
324
+ }
325
+
326
+ def add_method(self, method_id: str, config: Any) -> None:
327
+ """Add a new reasoning method configuration"""
328
+ if method_id not in self.methods:
329
+ self.methods[method_id] = config
330
+ else:
331
+ raise ValueError(f"Method {method_id} already exists")
332
+
333
+ # Create global config instance
334
+ config = ReasoningConfig()
cot_reasoning.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import requests
3
+ import textwrap
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional
6
+
7
+ @dataclass
8
+ class CoTStep:
9
+ """Data class representing a single CoT step"""
10
+ number: int
11
+ content: str
12
+
13
+ @dataclass
14
+ class CoTResponse:
15
+ """Data class representing a complete CoT response"""
16
+ question: str
17
+ steps: List[CoTStep]
18
+ answer: Optional[str] = None
19
+
20
+ @dataclass
21
+ class VisualizationConfig:
22
+ """Configuration for CoT visualization"""
23
+ max_chars_per_line: int = 40
24
+ max_lines: int = 4
25
+ truncation_suffix: str = "..."
26
+
27
+ class AnthropicAPI:
28
+ """Class to handle interactions with the Anthropic API"""
29
+ def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"):
30
+ self.api_key = api_key
31
+ self.model = model
32
+ self.base_url = "https://api.anthropic.com/v1/messages"
33
+ self.headers = {
34
+ "x-api-key": api_key,
35
+ "anthropic-version": "2023-06-01",
36
+ "content-type": "application/json"
37
+ }
38
+
39
+ def generate_response(self, prompt: str, max_tokens: int = 1024, prompt_format: str = None) -> str:
40
+ """Generate a response using the Anthropic API"""
41
+ formatted_prompt = self._format_prompt(prompt, prompt_format) if prompt_format else prompt
42
+ data = {
43
+ "model": self.model,
44
+ "messages": [{"role": "user", "content": formatted_prompt}],
45
+ "max_tokens": max_tokens
46
+ }
47
+
48
+ try:
49
+ response = requests.post(self.base_url, headers=self.headers, json=data)
50
+ response.raise_for_status()
51
+ return response.json()["content"][0]["text"]
52
+ except Exception as e:
53
+ raise Exception(f"API call failed: {str(e)}")
54
+
55
+ def _format_prompt(self, question: str, prompt_format: str = None) -> str:
56
+ """Format the prompt using custom format if provided"""
57
+ if prompt_format:
58
+ return prompt_format.format(question=question)
59
+
60
+ # Default format if none provided
61
+ return f"""Please answer the question using the following format, with each step clearly marked:
62
+
63
+ Question: {question}
64
+
65
+ Let's solve this step by step:
66
+ <step number="1">
67
+ [First step of reasoning]
68
+ </step>
69
+ <step number="2">
70
+ [Second step of reasoning]
71
+ </step>
72
+ <step number="3">
73
+ [Third step of reasoning]
74
+ </step>
75
+ ... (add more steps as needed)
76
+ <answer>
77
+ [Final answer]
78
+ </answer>
79
+
80
+ Note:
81
+ 1. Each step must be wrapped in XML tags <step>
82
+ 2. Each step must have a number attribute
83
+ 3. The final answer must be wrapped in <answer> tags
84
+ """
85
+
86
+ def wrap_text(text: str, config: VisualizationConfig) -> str:
87
+ """
88
+ Wrap text to fit within box constraints with proper line breaks.
89
+
90
+ Args:
91
+ text: The text to wrap
92
+ config: VisualizationConfig containing formatting parameters
93
+
94
+ Returns:
95
+ Wrapped text with line breaks
96
+ """
97
+ # Clean the text first
98
+ text = text.replace('\n', ' ').replace('"', "'")
99
+
100
+ # Wrap the text into lines
101
+ wrapped_lines = textwrap.wrap(text, width=config.max_chars_per_line)
102
+
103
+ # Limit number of lines and add truncation if necessary
104
+ if len(wrapped_lines) > config.max_lines:
105
+ wrapped_lines = wrapped_lines[:config.max_lines-1]
106
+ wrapped_lines.append(wrapped_lines[-1][:config.max_chars_per_line-3] + config.truncation_suffix)
107
+
108
+ # Join with <br> for HTML line breaks in Mermaid
109
+ return "<br>".join(wrapped_lines)
110
+
111
+ def parse_cot_response(response_text: str, question: str) -> CoTResponse:
112
+ """
113
+ Parse CoT response text to extract steps and final answer.
114
+
115
+ Args:
116
+ response_text: The raw response from the API
117
+ question: The original question
118
+
119
+ Returns:
120
+ CoTResponse object containing question, steps, and answer
121
+ """
122
+ # Extract all steps
123
+ step_pattern = r'<step number="(\d+)">\s*(.*?)\s*</step>'
124
+ steps = []
125
+ for match in re.finditer(step_pattern, response_text, re.DOTALL):
126
+ number = int(match.group(1))
127
+ content = match.group(2).strip()
128
+ steps.append(CoTStep(number=number, content=content))
129
+
130
+ # Extract answer
131
+ answer_pattern = r'<answer>\s*(.*?)\s*</answer>'
132
+ answer_match = re.search(answer_pattern, response_text, re.DOTALL)
133
+ answer = answer_match.group(1).strip() if answer_match else None
134
+
135
+ # Sort steps by number
136
+ steps.sort(key=lambda x: x.number)
137
+
138
+ return CoTResponse(question=question, steps=steps, answer=answer)
139
+
140
+ def create_mermaid_diagram(cot_response: CoTResponse, config: VisualizationConfig) -> str:
141
+ """
142
+ Convert CoT steps to Mermaid diagram with improved text wrapping.
143
+
144
+ Args:
145
+ cot_response: CoTResponse object containing the reasoning steps
146
+ config: VisualizationConfig for text formatting
147
+
148
+ Returns:
149
+ Mermaid diagram markup as a string
150
+ """
151
+ diagram = ['<div class="mermaid">', 'graph TD']
152
+
153
+ # Add question node
154
+ question_content = wrap_text(cot_response.question, config)
155
+ diagram.append(f' Q["{question_content}"]')
156
+
157
+ # Add steps with wrapped text and connect them
158
+ if cot_response.steps:
159
+ # Connect question to first step
160
+ diagram.append(f' Q --> S{cot_response.steps[0].number}')
161
+
162
+ # Add all steps
163
+ for i, step in enumerate(cot_response.steps):
164
+ content = wrap_text(step.content, config)
165
+ node_id = f'S{step.number}'
166
+ diagram.append(f' {node_id}["{content}"]')
167
+
168
+ # Connect steps sequentially
169
+ if i < len(cot_response.steps) - 1:
170
+ next_id = f'S{cot_response.steps[i + 1].number}'
171
+ diagram.append(f' {node_id} --> {next_id}')
172
+
173
+ # Add final answer node
174
+ if cot_response.answer:
175
+ answer = wrap_text(cot_response.answer, config)
176
+ diagram.append(f' A["{answer}"]')
177
+ if cot_response.steps:
178
+ diagram.append(f' S{cot_response.steps[-1].number} --> A')
179
+ else:
180
+ diagram.append(' Q --> A')
181
+
182
+ # Add styles for better visualization
183
+ diagram.extend([
184
+ ' classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;',
185
+ ' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;',
186
+ ' classDef answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
187
+ ' class Q question;',
188
+ ' class A answer;',
189
+ ' linkStyle default stroke:#666,stroke-width:2px;'
190
+ ])
191
+
192
+ diagram.append('</div>')
193
+ return '\n'.join(diagram)
l2m_reasoning.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional
4
+ import textwrap
5
+
6
+ @dataclass
7
+ class L2MStep:
8
+ """Data class representing a single L2M step"""
9
+ number: int
10
+ question: str # The sub-question for this step
11
+ reasoning: str # The reasoning process
12
+ answer: str # The answer to this sub-question
13
+
14
+ @dataclass
15
+ class L2MResponse:
16
+ """Data class representing a complete L2M response"""
17
+ main_question: str
18
+ steps: List[L2MStep]
19
+ final_answer: Optional[str] = None
20
+
21
+ def parse_l2m_response(response_text: str, question: str) -> L2MResponse:
22
+ """
23
+ Parse L2M response text to extract steps and final answer.
24
+
25
+ Args:
26
+ response_text: The raw response from the API
27
+ question: The original question
28
+
29
+ Returns:
30
+ L2MResponse object containing main question, steps, and final answer
31
+ """
32
+ # Extract all steps
33
+ step_pattern = r'<step number="(\d+)">\s*<question>(.*?)</question>\s*<reasoning>(.*?)</reasoning>\s*<answer>(.*?)</answer>\s*</step>'
34
+ steps = []
35
+
36
+ for match in re.finditer(step_pattern, response_text, re.DOTALL):
37
+ number = int(match.group(1))
38
+ sub_question = match.group(2).strip()
39
+ reasoning = match.group(3).strip()
40
+ answer = match.group(4).strip()
41
+ steps.append(L2MStep(
42
+ number=number,
43
+ question=sub_question,
44
+ reasoning=reasoning,
45
+ answer=answer
46
+ ))
47
+
48
+ # Extract final answer
49
+ final_answer_pattern = r'<final_answer>\s*(.*?)\s*</final_answer>'
50
+ final_answer_match = re.search(final_answer_pattern, response_text, re.DOTALL)
51
+ final_answer = final_answer_match.group(1).strip() if final_answer_match else None
52
+
53
+ # Sort steps by number
54
+ steps.sort(key=lambda x: x.number)
55
+
56
+ return L2MResponse(main_question=question, steps=steps, final_answer=final_answer)
57
+
58
+ def wrap_text(text: str, max_chars: int = 40, max_lines: int = 4) -> str:
59
+ """Wrap text to fit within box constraints with proper line breaks."""
60
+ text = text.replace('\n', ' ').replace('"', "'")
61
+ wrapped_lines = textwrap.wrap(text, width=max_chars)
62
+
63
+ if len(wrapped_lines) > max_lines:
64
+ wrapped_lines = wrapped_lines[:max_lines-1]
65
+ wrapped_lines.append(wrapped_lines[-1][:max_chars-3] + "...")
66
+
67
+ return "<br>".join(wrapped_lines)
68
+
69
+ def create_mermaid_diagram(l2m_response: L2MResponse, config: 'VisualizationConfig') -> str:
70
+ """
71
+ Convert L2M steps to Mermaid diagram.
72
+
73
+ Args:
74
+ l2m_response: L2MResponse object containing the reasoning steps
75
+ config: VisualizationConfig for text formatting
76
+
77
+ Returns:
78
+ Mermaid diagram markup as a string
79
+ """
80
+ diagram = ['<div class="mermaid">', 'graph TD']
81
+
82
+ # Add main question node
83
+ question_content = wrap_text(l2m_response.main_question, config.max_chars_per_line, config.max_lines)
84
+ diagram.append(f' Q["{question_content}"]')
85
+
86
+ # Add decomposition node
87
+ diagram.append(f' D["Problem Decomposition"]')
88
+ diagram.append(f' Q --> D')
89
+
90
+ # Add all step nodes with sub-questions, reasoning, and answers
91
+ if l2m_response.steps:
92
+ # Connect decomposition to first step
93
+ diagram.append(f' D --> S{l2m_response.steps[0].number}')
94
+
95
+ for i, step in enumerate(l2m_response.steps):
96
+ # Create sub-question node
97
+ sq_content = wrap_text(f"Q{step.number}: {step.question}", config.max_chars_per_line, config.max_lines)
98
+ sq_id = f'S{step.number}'
99
+ diagram.append(f' {sq_id}["{sq_content}"]')
100
+
101
+ # Create reasoning node
102
+ r_content = wrap_text(step.reasoning, config.max_chars_per_line, config.max_lines)
103
+ r_id = f'R{step.number}'
104
+ diagram.append(f' {r_id}["{r_content}"]')
105
+
106
+ # Create answer node
107
+ a_content = wrap_text(f"A{step.number}: {step.answer}", config.max_chars_per_line, config.max_lines)
108
+ a_id = f'A{step.number}'
109
+ diagram.append(f' {a_id}["{a_content}"]')
110
+
111
+ # Connect the nodes
112
+ diagram.append(f' {sq_id} --> {r_id}')
113
+ diagram.append(f' {r_id} --> {a_id}')
114
+
115
+ # Connect to next step if exists
116
+ if i < len(l2m_response.steps) - 1:
117
+ next_id = f'S{l2m_response.steps[i + 1].number}'
118
+ diagram.append(f' {a_id} --> {next_id}')
119
+
120
+ # Add final answer node if exists
121
+ if l2m_response.final_answer:
122
+ final_content = wrap_text(f"Final: {l2m_response.final_answer}", config.max_chars_per_line, config.max_lines)
123
+ diagram.append(f' F["{final_content}"]')
124
+ if l2m_response.steps:
125
+ diagram.append(f' A{l2m_response.steps[-1].number} --> F')
126
+ else:
127
+ diagram.append(' D --> F')
128
+
129
+ # Add styles
130
+ diagram.extend([
131
+ ' classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;',
132
+ ' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;',
133
+ ' classDef reasoning fill:#f9f9f9,stroke:#333,stroke-width:2px;',
134
+ ' classDef answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
135
+ ' classDef decomp fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px;',
136
+ ' class Q,S1,S2,S3,S4,S5 question;',
137
+ ' class R1,R2,R3,R4,R5 reasoning;',
138
+ ' class A1,A2,A3,A4,A5,F answer;',
139
+ ' class D decomp;',
140
+ ' linkStyle default stroke:#666,stroke-width:2px;'
141
+ ])
142
+
143
+ diagram.append('</div>')
144
+ return '\n'.join(diagram)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ requests==2.31.0
2
+ openai==1.63.2
3
+ together==1.4.1
4
+ flask==3.1.0
5
+ google==3.0.0
6
+ google-genai==1.2.0
7
+ google-generativeai==0.8.4
selfconsistency_reasoning.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Dict
4
+ from collections import Counter
5
+ from cot_reasoning import CoTStep, CoTResponse, VisualizationConfig, wrap_text
6
+
7
+ @dataclass
8
+ class SCRPath:
9
+ """Data class representing a single self-consistency reasoning path"""
10
+ path_id: int
11
+ steps: List[CoTStep]
12
+ answer: Optional[str] = None
13
+
14
+ @dataclass
15
+ class SCRResponse:
16
+ """Data class representing a complete self-consistency response"""
17
+ question: str
18
+ paths: List[SCRPath]
19
+ final_answer: Optional[str] = None
20
+ vote_counts: Optional[Dict[str, int]] = None
21
+
22
+ def parse_scr_response(response_text: str, question: str, num_paths: int = 5) -> SCRResponse:
23
+ """
24
+ Parse self-consistency response text to extract multiple reasoning paths and answers.
25
+
26
+ Args:
27
+ response_text: The raw response from the API containing multiple paths
28
+ question: The original question
29
+ num_paths: Expected number of reasoning paths
30
+
31
+ Returns:
32
+ SCRResponse object containing all paths and aggregated answer
33
+ """
34
+ # Split the response into individual paths
35
+ path_pattern = r'Path\s+(\d+):(.*?)(?=Path\s+\d+:|$)'
36
+ path_matches = re.finditer(path_pattern, response_text, re.DOTALL)
37
+
38
+ paths = []
39
+ answers = []
40
+
41
+ for match in path_matches:
42
+ path_id = int(match.group(1))
43
+ path_content = match.group(2).strip()
44
+
45
+ # Extract steps for this path
46
+ step_pattern = r'<step number="(\d+)">\s*(.*?)\s*</step>'
47
+ steps = []
48
+ for step_match in re.finditer(step_pattern, path_content, re.DOTALL):
49
+ number = int(step_match.group(1))
50
+ content = step_match.group(2).strip()
51
+ steps.append(CoTStep(number=number, content=content))
52
+
53
+ # Extract answer for this path
54
+ answer_pattern = r'<answer>\s*(.*?)\s*</answer>'
55
+ answer_match = re.search(answer_pattern, path_content, re.DOTALL)
56
+ answer = answer_match.group(1).strip() if answer_match else None
57
+
58
+ if answer:
59
+ answers.append(answer)
60
+
61
+ # Sort steps by number
62
+ steps.sort(key=lambda x: x.number)
63
+
64
+ paths.append(SCRPath(path_id=path_id, steps=steps, answer=answer))
65
+
66
+ # Determine final answer through voting
67
+ vote_counts = Counter(answers)
68
+ final_answer = vote_counts.most_common(1)[0][0] if vote_counts else None
69
+
70
+ return SCRResponse(
71
+ question=question,
72
+ paths=paths,
73
+ final_answer=final_answer,
74
+ vote_counts=dict(vote_counts)
75
+ )
76
+
77
+ def create_mermaid_diagram(scr_response: SCRResponse, config: VisualizationConfig) -> str:
78
+ """
79
+ Convert self-consistency paths to Mermaid diagram.
80
+
81
+ Args:
82
+ scr_response: SCRResponse object containing multiple reasoning paths
83
+ config: VisualizationConfig for text formatting
84
+
85
+ Returns:
86
+ Mermaid diagram markup as a string
87
+ """
88
+ diagram = ['<div class="mermaid">', 'graph TD']
89
+
90
+ # Add question node
91
+ question_content = wrap_text(scr_response.question, config)
92
+ diagram.append(f' Q["{question_content}"]')
93
+
94
+ # Process each path
95
+ for path in scr_response.paths:
96
+ path_id = f'P{path.path_id}'
97
+
98
+ # Add path label
99
+ diagram.append(f' {path_id}["Path {path.path_id}"]')
100
+ diagram.append(f' Q --> {path_id}')
101
+
102
+ # Add steps for this path
103
+ prev_node = path_id
104
+ for step in path.steps:
105
+ content = wrap_text(step.content, config)
106
+ node_id = f'P{path.path_id}S{step.number}'
107
+ diagram.append(f' {node_id}["{content}"]')
108
+ diagram.append(f' {prev_node} --> {node_id}')
109
+ prev_node = node_id
110
+
111
+ # Add path answer
112
+ if path.answer:
113
+ answer_content = wrap_text(path.answer, config)
114
+ answer_id = f'A{path.path_id}'
115
+ diagram.append(f' {answer_id}["{answer_content}"]')
116
+ diagram.append(f' {prev_node} --> {answer_id}')
117
+
118
+ # Add final answer with vote counts
119
+ if scr_response.final_answer and scr_response.vote_counts:
120
+ vote_info = [f"{ans}: {count} votes" for ans, count in scr_response.vote_counts.items()]
121
+ final_content = wrap_text(
122
+ f"Final Answer (by voting):\\n{scr_response.final_answer}\\n\\n" +
123
+ "Vote Distribution:\\n" + "\\n".join(vote_info),
124
+ config
125
+ )
126
+ diagram.append(f' F["{final_content}"]')
127
+
128
+ # Connect all path answers to final answer
129
+ for path in scr_response.paths:
130
+ if path.answer:
131
+ diagram.append(f' A{path.path_id} --> F')
132
+
133
+ # Add styles
134
+ diagram.extend([
135
+ ' classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;',
136
+ ' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;',
137
+ ' classDef path fill:#fff3e0,stroke:#f57c00,stroke-width:2px;',
138
+ ' classDef answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
139
+ ' classDef final fill:#d4edda,stroke:#28a745,stroke-width:2px;',
140
+ ' class Q question;',
141
+ ' class F final;'
142
+ ])
143
+
144
+ # Apply path style to all path nodes
145
+ for path in scr_response.paths:
146
+ diagram.append(f' class P{path.path_id} path;')
147
+
148
+ # Apply answer style to all answer nodes
149
+ for path in scr_response.paths:
150
+ if path.answer:
151
+ diagram.append(f' class A{path.path_id} answer;')
152
+
153
+ diagram.append('</div>')
154
+ return '\n'.join(diagram)
selfrefine_reasoning.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional
4
+ from cot_reasoning import VisualizationConfig, wrap_text
5
+
6
+ @dataclass
7
+ class SelfRefineStep:
8
+ """Data class representing a single step in self-refine reasoning"""
9
+ number: int
10
+ content: str
11
+ is_revised: bool = False
12
+ revision_of: Optional[int] = None
13
+
14
+ @dataclass
15
+ class SelfRefineResponse:
16
+ """Data class representing a complete self-refine response"""
17
+ question: str
18
+ steps: List[SelfRefineStep]
19
+ answer: Optional[str] = None
20
+ revision_check: Optional[str] = None
21
+ revised_answer: Optional[str] = None
22
+
23
+ def parse_selfrefine_response(response_text: str, question: str) -> SelfRefineResponse:
24
+ """
25
+ Parse self-refine response text to extract steps, answers, and revisions.
26
+
27
+ Args:
28
+ response_text: The raw response from the API
29
+ question: The original question
30
+
31
+ Returns:
32
+ SelfRefineResponse object containing all components
33
+ """
34
+ # Extract initial steps
35
+ step_pattern = r'<step number="(\d+)">\s*(.*?)\s*</step>'
36
+ steps = []
37
+ for match in re.finditer(step_pattern, response_text, re.DOTALL):
38
+ number = int(match.group(1))
39
+ content = match.group(2).strip()
40
+ steps.append(SelfRefineStep(number=number, content=content))
41
+
42
+ # Extract initial answer
43
+ answer_pattern = r'<answer>\s*(.*?)\s*</answer>'
44
+ answer_match = re.search(answer_pattern, response_text, re.DOTALL)
45
+ answer = answer_match.group(1).strip() if answer_match else None
46
+
47
+ # Extract revision check
48
+ check_pattern = r'<revision_check>\s*(.*?)\s*</revision_check>'
49
+ check_match = re.search(check_pattern, response_text, re.DOTALL)
50
+ revision_check = check_match.group(1).strip() if check_match else None
51
+
52
+ # Extract revised steps
53
+ revised_step_pattern = r'<revised_step number="(\d+)" revises="(\d+)">\s*(.*?)\s*</revised_step>'
54
+ for match in re.finditer(revised_step_pattern, response_text, re.DOTALL):
55
+ number = int(match.group(1))
56
+ revises = int(match.group(2))
57
+ content = match.group(3).strip()
58
+ steps.append(SelfRefineStep(
59
+ number=number,
60
+ content=content,
61
+ is_revised=True,
62
+ revision_of=revises
63
+ ))
64
+
65
+ # Extract revised answer
66
+ revised_answer_pattern = r'<revised_answer>\s*(.*?)\s*</revised_answer>'
67
+ revised_answer_match = re.search(revised_answer_pattern, response_text, re.DOTALL)
68
+ revised_answer = revised_answer_match.group(1).strip() if revised_answer_match else None
69
+
70
+ return SelfRefineResponse(
71
+ question=question,
72
+ steps=steps,
73
+ answer=answer,
74
+ revision_check=revision_check,
75
+ revised_answer=revised_answer
76
+ )
77
+
78
+ def create_mermaid_diagram(sr_response: SelfRefineResponse, config: VisualizationConfig) -> str:
79
+ """
80
+ Create a Mermaid diagram for self-refine reasoning.
81
+
82
+ Args:
83
+ sr_response: SelfRefineResponse object containing the reasoning steps
84
+ config: VisualizationConfig for text formatting
85
+
86
+ Returns:
87
+ Mermaid diagram markup as a string
88
+ """
89
+ diagram = ['<div class="mermaid">', 'graph TD']
90
+
91
+ # Add question node
92
+ question_content = wrap_text(sr_response.question, config)
93
+ diagram.append(f' Q["{question_content}"]')
94
+
95
+ # Track original and revised steps
96
+ original_steps = [s for s in sr_response.steps if not s.is_revised]
97
+ revised_steps = [s for s in sr_response.steps if s.is_revised]
98
+
99
+ # Add original steps and connect them
100
+ prev_node = 'Q'
101
+ for step in original_steps:
102
+ node_id = f'S{step.number}'
103
+ content = wrap_text(step.content, config)
104
+ diagram.append(f' {node_id}["{content}"]')
105
+ diagram.append(f' {prev_node} --> {node_id}')
106
+ prev_node = node_id
107
+
108
+ # Add initial answer if present
109
+ if sr_response.answer:
110
+ answer_content = wrap_text(sr_response.answer, config)
111
+ diagram.append(f' A["{answer_content}"]')
112
+ diagram.append(f' {prev_node} --> A')
113
+ prev_node = 'A'
114
+
115
+ # Add revision check if present
116
+ if sr_response.revision_check:
117
+ check_content = wrap_text(sr_response.revision_check, config)
118
+ diagram.append(f' RC["{check_content}"]')
119
+ diagram.append(f' {prev_node} --> RC')
120
+
121
+ # Add revised steps if any
122
+ if revised_steps:
123
+ # Process each revision step
124
+ for i, step in enumerate(revised_steps):
125
+ rev_node_id = f'R{step.number}'
126
+ content = wrap_text(step.content, config)
127
+ diagram.append(f' {rev_node_id}["{content}"]')
128
+
129
+ # Connect from the revision check to problematic step, then to revision
130
+ if step.revision_of:
131
+ orig_node = f'S{step.revision_of}'
132
+ # Add connection from revision check to problematic step
133
+ diagram.append(f' RC --> {orig_node}')
134
+ # Add connection from problematic step to its revision
135
+ diagram.append(f' {orig_node} --> {rev_node_id}')
136
+
137
+ # Connect subsequent revised steps
138
+ if i < len(revised_steps) - 1:
139
+ next_node = f'R{revised_steps[i + 1].number}'
140
+ diagram.append(f' {rev_node_id} --> {next_node}')
141
+
142
+ # Add revised answer if present
143
+ if sr_response.revised_answer:
144
+ revised_content = wrap_text(sr_response.revised_answer, config)
145
+ diagram.append(f' RA["{revised_content}"]')
146
+ last_node = f'R{revised_steps[-1].number}' if revised_steps else 'RC'
147
+ diagram.append(f' {last_node} --> RA')
148
+
149
+ # Add styles
150
+ diagram.extend([
151
+ ' classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;',
152
+ ' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;',
153
+ ' classDef answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
154
+ ' classDef revision fill:#fff3cd,stroke:#ffc107,stroke-width:2px;',
155
+ ' class Q question;',
156
+ ' class A,RA answer;',
157
+ ' class RC revision;'
158
+ ])
159
+
160
+ # Style revision nodes
161
+ for step in revised_steps:
162
+ diagram.append(f' class R{step.number} revision;')
163
+
164
+ diagram.append('</div>')
165
+ return '\n'.join(diagram)
static/assets/banner-bg.jpg ADDED

Git LFS Details

  • SHA256: 8ce7b7253ba11a169681a8d6786b898e504992f3a4e17eeeee2c99fe69261ad1
  • Pointer size: 132 Bytes
  • Size of remote file: 4.19 MB
static/assets/idea.png ADDED
templates/index.html ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>ReasonGraph</title>
5
+ <link rel="icon" href="static/assets/idea.png" type="image/x-icon">
6
+ <link rel="shortcut icon" href="favicon.ico" type="image/x-icon">
7
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/mermaid.min.js"></script>
8
+ <style>
9
+ body {
10
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
11
+ margin: 0;
12
+ padding: 20px;
13
+ background-color: #f5f5f5;
14
+ }
15
+
16
+ .banner {
17
+ margin: 3px 20px;
18
+ border-radius: 8px;
19
+ background-image: url("{{ url_for('static', filename='assets/banner-bg.jpg') }}");
20
+ background-size: cover;
21
+ background-position: center;
22
+ box-shadow: 0 1px 3px rgba(0,0,0,0.1);
23
+ height: 280px;
24
+ display: flex;
25
+ align-items: center;
26
+ justify-content: center;
27
+ position: relative;
28
+ }
29
+
30
+ .banner::before {
31
+ content: '';
32
+ position: absolute;
33
+ top: 0;
34
+ left: 0;
35
+ right: 0;
36
+ bottom: 0;
37
+ background: rgba(0, 0, 0, 0.4);
38
+ border-radius: 8px;
39
+ }
40
+
41
+ .banner-content {
42
+ position: relative;
43
+ z-index: 1;
44
+ text-align: center;
45
+ display: flex;
46
+ flex-direction: column;
47
+ gap: 24px;
48
+ width: 100%;
49
+ max-width: 800px;
50
+ padding: 0 20px;
51
+ }
52
+
53
+ .banner h1 {
54
+ color: white;
55
+ font-size: 32px;
56
+ font-weight: 600;
57
+ margin: 0;
58
+ text-shadow: 0 2px 4px rgba(0,0,0,0.2);
59
+ }
60
+
61
+ .search-area {
62
+ display: flex;
63
+ flex-direction: column;
64
+ gap: 16px;
65
+ width: 100%;
66
+ }
67
+
68
+ .search-input-container {
69
+ position: relative;
70
+ width: 100%;
71
+ }
72
+
73
+ .search-input {
74
+ width: 100%;
75
+ padding: 12px;
76
+ border: 2px solid rgba(255, 255, 255, 0.2);
77
+ border-radius: 8px;
78
+ background: rgba(255, 255, 255, 0.9);
79
+ font-size: 16px;
80
+ resize: none;
81
+ outline: none;
82
+ transition: border-color 0.2s;
83
+ }
84
+
85
+ .search-input:focus {
86
+ border-color: rgba(255, 255, 255, 0.5);
87
+ }
88
+
89
+ .search-buttons {
90
+ display: flex;
91
+ gap: 10px;
92
+ justify-content: center;
93
+ align-items: center;
94
+ margin: 0 auto;
95
+ max-width: 800px; /* Match search input max-width */
96
+ width: 100%;
97
+ padding: 0 20px;
98
+ }
99
+
100
+ .search-buttons .param-input {
101
+ width: 200px !important; /* Override the default param-input width */
102
+ padding: 8px 16px;
103
+ background-color: rgba(255, 255, 255, 0.9);
104
+ border: none;
105
+ border-radius: 6px;
106
+ font-size: 14px;
107
+ font-weight: 500;
108
+ height: 35px; /* Match button height */
109
+ flex: none; /* Override flex property */
110
+ }
111
+
112
+ .search-buttons button {
113
+ width: auto;
114
+ min-width: 120px;
115
+ padding: 8px 16px;
116
+ background-color: rgba(37, 99, 235, 0.8); /* Semi-transparent gray */
117
+ color: white; /* White text */
118
+ border: none;
119
+ border-radius: 6px;
120
+ font-size: 14px;
121
+ font-weight: 500;
122
+ cursor: pointer;
123
+ transition: all 0.2s;
124
+ height: 35px; /* Fixed height */
125
+ line-height: 1; /* Ensure text vertical centering */
126
+ }
127
+
128
+ .search-buttons button:hover {
129
+ background-color: rgba(37, 99, 235, 0.9); /* 鼠标悬停时变为深一点的蓝色 */
130
+ transform: translateY(-1px);
131
+ }
132
+
133
+ .links {
134
+ display: flex;
135
+ justify-content: center;
136
+ gap: 0;
137
+ white-space: nowrap;
138
+ }
139
+
140
+ .links a {
141
+ color: white;
142
+ text-decoration: none;
143
+ font-size: 16px;
144
+ opacity: 0.9;
145
+ transition: opacity 0.2s;
146
+ }
147
+
148
+ .links a:hover {
149
+ opacity: 1;
150
+ text-decoration: underline;
151
+ }
152
+
153
+ .container {
154
+ display: flex;
155
+ min-height: 100vh;
156
+ gap: 20px;
157
+ padding: 20px;
158
+ }
159
+
160
+ .column {
161
+ flex: 1;
162
+ background: white;
163
+ border-radius: 8px;
164
+ box-shadow: 0 1px 3px rgba(0,0,0,0.1);
165
+ padding: 20px;
166
+ }
167
+
168
+ h2 {
169
+ margin-top: 0;
170
+ margin-bottom: 20px;
171
+ color: #1f2937;
172
+ font-size: 18px;
173
+ font-weight: 600;
174
+ }
175
+
176
+ .param-group {
177
+ display: flex;
178
+ margin-bottom: 15px;
179
+ border: 1px solid #e5e7eb;
180
+ border-radius: 4px;
181
+ overflow: hidden;
182
+ }
183
+
184
+ .param-label {
185
+ width: 180px;
186
+ padding: 10px 15px;
187
+ background-color: #f8f9fa;
188
+ border-right: 1px solid #e5e7eb;
189
+ font-size: 14px;
190
+ line-height: 1.5;
191
+ color: #374151;
192
+ }
193
+
194
+ .param-input {
195
+ flex: 1;
196
+ padding: 10px 15px;
197
+ border: none;
198
+ font-size: 14px;
199
+ line-height: 1.5;
200
+ outline: none;
201
+ background: white;
202
+ }
203
+
204
+ select.param-input {
205
+ cursor: pointer;
206
+ padding-right: 30px;
207
+ appearance: none;
208
+ background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' fill='%236b7280' viewBox='0 0 16 16'%3E%3Cpath d='M8 10l4-4H4l4 4z'/%3E%3C/svg%3E");
209
+ background-repeat: no-repeat;
210
+ background-position: right 10px center;
211
+ }
212
+
213
+ textarea.param-input {
214
+ resize: vertical;
215
+ min-height: 80px;
216
+ }
217
+
218
+ .error-message {
219
+ color: #dc2626;
220
+ font-size: 14px;
221
+ margin-top: 5px;
222
+ display: none;
223
+ position: absolute;
224
+ bottom: -20px;
225
+ left: 0;
226
+ background: white;
227
+ padding: 2px 8px;
228
+ border-radius: 4px;
229
+ box-shadow: 0 1px 2px rgba(0,0,0,0.1);
230
+ }
231
+
232
+ .output-section {
233
+ margin-top: 20px;
234
+ padding: 0;
235
+ background: white;
236
+ }
237
+
238
+ .output-section h3 {
239
+ margin: 0 0 15px 0;
240
+ color: #1f2937;
241
+ font-size: 18px;
242
+ font-weight: 600;
243
+ }
244
+
245
+ .output-wrapper {
246
+ overflow: auto;
247
+ height: 100px;
248
+ min-height: 100px;
249
+ max-height: 1000px;
250
+ border: 1px solid #e5e7eb;
251
+ border-radius: 4px;
252
+ padding: 15px;
253
+ resize: vertical;
254
+ background-color: #f8f9fa;
255
+ }
256
+
257
+ #raw-output {
258
+ white-space: pre-wrap;
259
+ word-break: break-word;
260
+ margin: 0;
261
+ font-family: 'Menlo', 'Monaco', 'Courier New', monospace;
262
+ font-size: 13px;
263
+ line-height: 1.5;
264
+ color: #1f2937;
265
+ }
266
+
267
+ button:disabled {
268
+ background-color: #9ca3af;
269
+ cursor: not-allowed;
270
+ }
271
+
272
+ .zoom-controls {
273
+ display: flex;
274
+ align-items: center;
275
+ gap: 10px;
276
+ margin-bottom: 10px;
277
+ }
278
+
279
+ .zoom-button {
280
+ padding: 5px 10px;
281
+ background-color: #f3f4f6;
282
+ border: 1px solid #e5e7eb;
283
+ border-radius: 4px;
284
+ cursor: pointer;
285
+ font-size: 14px;
286
+ color: #374151;
287
+ width: auto;
288
+ }
289
+
290
+ .zoom-button:hover {
291
+ background-color: #e5e7eb;
292
+ }
293
+
294
+ .zoom-level {
295
+ font-size: 14px;
296
+ color: #374151;
297
+ min-width: 50px;
298
+ text-align: center;
299
+ }
300
+
301
+ #mermaid-container {
302
+ transform-origin: top left; /* Changed from 'top center' */
303
+ transition: transform 0.2s ease;
304
+ width: 100%;
305
+ display: block; /* Changed from 'flex' */
306
+ justify-content: flex-start; /* Changed from 'center' */
307
+ }
308
+
309
+
310
+ .visualization-wrapper {
311
+ overflow: auto;
312
+ height: 370px;
313
+ min-height: 100px;
314
+ max-height: 1000px;
315
+ border: 1px solid #e5e7eb;
316
+ border-radius: 4px;
317
+ padding: 0;
318
+ resize: vertical;
319
+ background-color: #f8f9fa;
320
+ }
321
+
322
+ .mermaid {
323
+ padding: 0;
324
+ border-radius: 4px;
325
+ }
326
+
327
+ .has-visualization .placeholder-visualization {
328
+ display: none;
329
+ }
330
+ </style>
331
+ </head>
332
+ <body>
333
+ <div class="banner">
334
+ <div class="banner-content">
335
+ <h1>ReasonGraph: Visualisation of Reasoning Paths</h1>
336
+ <div class="search-area">
337
+ <div class="search-input-container">
338
+ <textarea id="question" class="search-input" placeholder="Enter your question here..." rows="1"></textarea>
339
+ <div class="error-message" id="question-error">Please enter a question</div>
340
+ </div>
341
+ <div class="search-buttons">
342
+ <select class="param-input" id="reasoning-method">
343
+ <!-- Populated dynamically -->
344
+ </select>
345
+ <button onclick="metaReasoning()" id="meta-btn">Meta Reasoning</button>
346
+ <button onclick="processQuestion()" id="process-btn">Start Reasoning</button>
347
+ </div>
348
+ </div>
349
+ <div class="links">
350
+ <a href="https://github.com/ZongqianLi/ReasonGraph"><u>Github</u> |&nbsp</a><a href="https://arxiv.org/abs/2503.03979"><u>Paper</u> |&nbsp</a><a href="mailto:[email protected]"><u>Email</u></a>
351
+ </div>
352
+ </div>
353
+ </div>
354
+
355
+ <div class="container">
356
+ <div class="column">
357
+ <h2>Reasoning Settings</h2>
358
+
359
+ <div class="param-group">
360
+ <div class="param-label">API Provider</div>
361
+ <select class="param-input" id="api-provider" onchange="handleProviderChange(this.value)">
362
+ <!-- Populated dynamically -->
363
+ </select>
364
+ </div>
365
+
366
+ <div class="param-group">
367
+ <div class="param-label">Model</div>
368
+ <select class="param-input" id="model">
369
+ <!-- Populated dynamically -->
370
+ </select>
371
+ </div>
372
+
373
+ <div class="param-group">
374
+ <div class="param-label">Max Tokens</div>
375
+ <input type="number" class="param-input" id="max-tokens">
376
+ </div>
377
+
378
+ <div class="param-group">
379
+ <div class="param-label">API Key</div>
380
+ <input type="password" class="param-input" id="api-key">
381
+ <div class="error-message" id="api-key-error">Please enter a valid API key</div>
382
+ </div>
383
+
384
+ <div class="param-group">
385
+ <div class="param-label">Custom Prompt Format</div>
386
+ <textarea class="param-input" id="prompt-format" rows="6"></textarea>
387
+ </div>
388
+
389
+ <div class="output-section">
390
+ <h3>Raw Model Output</h3>
391
+ <div class="output-wrapper">
392
+ <pre id="raw-output">Output will appear here...</pre>
393
+ </div>
394
+ </div>
395
+ </div>
396
+
397
+ <div class="column">
398
+ <h2>Visualization Settings</h2>
399
+
400
+ <div class="param-group">
401
+ <div class="param-label">Characters Per Line</div>
402
+ <input type="number" class="param-input" id="chars-per-line">
403
+ </div>
404
+
405
+ <div class="param-group">
406
+ <div class="param-label">Maximum Lines</div>
407
+ <input type="number" class="param-input" id="max-lines">
408
+ </div>
409
+
410
+ <div class="output-section">
411
+ <h3>Visualization Results</h3>
412
+ <div class="zoom-controls">
413
+ <button class="zoom-button" onclick="adjustZoom(-0.1)">-</button>
414
+ <div class="zoom-level" id="zoom-level">100%</div>
415
+ <button class="zoom-button" onclick="adjustZoom(0.1)">+</button>
416
+ <button class="zoom-button" onclick="resetZoom()">Reset</button>
417
+ <button class="zoom-button" onclick="downloadDiagram()">Download</button>
418
+ </div>
419
+ <div class="visualization-wrapper">
420
+ <div id="mermaid-container">
421
+ <div id="mermaid-diagram"></div>
422
+ </div>
423
+ </div>
424
+ </div>
425
+ </div>
426
+ </div>
427
+
428
+ <script>
429
+ // Initialize Mermaid
430
+ mermaid.initialize({
431
+ startOnLoad: true,
432
+ theme: 'default',
433
+ securityLevel: 'loose',
434
+ flowchart: {
435
+ curve: 'basis',
436
+ padding: 15
437
+ }
438
+ });
439
+
440
+ // Store current configuration
441
+ let currentConfig = null;
442
+
443
+ // Zoom control variables
444
+ let currentZoom = 1;
445
+ const MIN_ZOOM = 0.1;
446
+ const MAX_ZOOM = 5;
447
+
448
+ // Initialize zoom lock flag
449
+ window.isZoomLocked = false;
450
+
451
+ // Handle API Provider change
452
+ async function handleProviderChange(provider) {
453
+ try {
454
+ // Update model list
455
+ updateModelList();
456
+
457
+ // Get and update API key
458
+ const response = await fetch(`/provider-api-key/${provider}`);
459
+ const result = await response.json();
460
+
461
+ if (result.success) {
462
+ document.getElementById('api-key').value = result.api_key;
463
+ } else {
464
+ console.error('Failed to get API key:', result.error);
465
+ }
466
+ } catch (error) {
467
+ console.error('Error updating provider settings:', error);
468
+ }
469
+ }
470
+
471
+ // Load initial configuration
472
+ async function loadConfig() {
473
+ try {
474
+ const response = await fetch('/config');
475
+ currentConfig = await response.json();
476
+
477
+ // Populate API providers
478
+ const providerSelect = document.getElementById('api-provider');
479
+ currentConfig.general.providers.forEach(provider => {
480
+ const option = document.createElement('option');
481
+ option.value = provider;
482
+ option.textContent = provider.charAt(0).toUpperCase() + provider.slice(1);
483
+ providerSelect.appendChild(option);
484
+ });
485
+
486
+ // Populate reasoning methods
487
+ const methodSelect = document.getElementById('reasoning-method');
488
+ Object.entries(currentConfig.methods).forEach(([id, methodConfig]) => {
489
+ const option = document.createElement('option');
490
+ option.value = id;
491
+ option.textContent = methodConfig.name;
492
+ methodSelect.appendChild(option);
493
+ });
494
+
495
+ // Initial provider setup
496
+ await handleProviderChange(currentConfig.general.providers[0]);
497
+
498
+ // Set other initial values
499
+ document.getElementById('max-tokens').value = currentConfig.general.max_tokens;
500
+ document.getElementById('chars-per-line').value = currentConfig.general.visualization.chars_per_line;
501
+ document.getElementById('max-lines').value = currentConfig.general.visualization.max_lines;
502
+
503
+ // Set initial prompt format and example question
504
+ const defaultMethod = methodSelect.value;
505
+ const methodConfig = currentConfig.methods[defaultMethod];
506
+ updatePromptFormat(methodConfig.prompt_format);
507
+ updateExampleQuestion(methodConfig.example_question);
508
+ } catch (error) {
509
+ console.error('Failed to load configuration:', error);
510
+ showError('Failed to load configuration. Please refresh the page.');
511
+ }
512
+ }
513
+
514
+ // Update model list based on selected provider
515
+ function updateModelList() {
516
+ const provider = document.getElementById('api-provider').value;
517
+ const modelSelect = document.getElementById('model');
518
+ modelSelect.innerHTML = ''; // Clear current options
519
+
520
+ const models = currentConfig.general.available_models;
521
+ const providers = currentConfig.general.model_providers;
522
+
523
+ models.forEach(model => {
524
+ if (providers[model] === provider) {
525
+ const option = document.createElement('option');
526
+ option.value = model;
527
+ option.textContent = model;
528
+ modelSelect.appendChild(option);
529
+ }
530
+ });
531
+ }
532
+
533
+ // Update prompt format when method changes
534
+ document.getElementById('reasoning-method').addEventListener('change', async (event) => {
535
+ try {
536
+ const response = await fetch(`/method-config/${event.target.value}`);
537
+ const methodConfig = await response.json();
538
+ updatePromptFormat(methodConfig.prompt_format);
539
+ updateExampleQuestion(methodConfig.example_question);
540
+ } catch (error) {
541
+ console.error('Failed to load method configuration:', error);
542
+ showError('Failed to update method configuration.');
543
+ }
544
+ });
545
+
546
+ function updatePromptFormat(format) {
547
+ document.getElementById('prompt-format').value = format;
548
+ }
549
+
550
+ function updateExampleQuestion(question) {
551
+ document.getElementById('question').value = question;
552
+ }
553
+
554
+ function adjustZoom(delta) {
555
+ // Do nothing if zooming is locked
556
+ if (window.isZoomLocked) return;
557
+
558
+ const newZoom = Math.min(Math.max(currentZoom + delta, MIN_ZOOM), MAX_ZOOM);
559
+ if (newZoom !== currentZoom) {
560
+ currentZoom = newZoom;
561
+ applyZoom();
562
+ }
563
+ }
564
+
565
+ function resetZoom() {
566
+ // Do nothing if zooming is locked
567
+ if (window.isZoomLocked) return;
568
+
569
+ currentZoom = 1;
570
+ applyZoom();
571
+ }
572
+
573
+ function applyZoom() {
574
+ const container = document.getElementById('mermaid-container');
575
+ container.style.transform = `scale(${currentZoom})`;
576
+
577
+ // Update zoom level display
578
+ const percentage = Math.round(currentZoom * 100);
579
+ document.getElementById('zoom-level').textContent = `${percentage}%`;
580
+ }
581
+
582
+ function lockVisualization() {
583
+ window.isZoomLocked = true;
584
+ const zoomButtons = document.querySelectorAll('.zoom-button');
585
+ zoomButtons.forEach(button => button.disabled = true);
586
+ document.querySelector('.visualization-wrapper').style.pointerEvents = 'none';
587
+ }
588
+
589
+ function unlockVisualization() {
590
+ window.isZoomLocked = false;
591
+ const zoomButtons = document.querySelectorAll('.zoom-button');
592
+ zoomButtons.forEach(button => button.disabled = false);
593
+ document.querySelector('.visualization-wrapper').style.pointerEvents = 'auto';
594
+ }
595
+
596
+ async function downloadDiagram() {
597
+ // Do nothing if zooming is locked
598
+ if (window.isZoomLocked) return;
599
+
600
+ const diagramContainer = document.getElementById('mermaid-diagram');
601
+ if (!diagramContainer || !diagramContainer.querySelector('svg')) {
602
+ alert('No diagram available to download');
603
+ return;
604
+ }
605
+
606
+ try {
607
+ // Get the SVG element
608
+ const svg = diagramContainer.querySelector('svg');
609
+
610
+ // Create a copy of the SVG to modify
611
+ const svgCopy = svg.cloneNode(true);
612
+
613
+ // Ensure the SVG has proper dimensions
614
+ const bbox = svg.getBBox();
615
+ svgCopy.setAttribute('width', bbox.width);
616
+ svgCopy.setAttribute('height', bbox.height);
617
+ svgCopy.setAttribute('viewBox', `${bbox.x} ${bbox.y} ${bbox.width} ${bbox.height}`);
618
+
619
+ // Convert SVG to string
620
+ const serializer = new XMLSerializer();
621
+ const svgString = serializer.serializeToString(svgCopy);
622
+
623
+ // Create blob and download link
624
+ const blob = new Blob([svgString], {type: 'image/svg+xml'});
625
+ const url = URL.createObjectURL(blob);
626
+
627
+ // Create temporary link and trigger download
628
+ const link = document.createElement('a');
629
+ link.href = url;
630
+ link.download = 'reasoning_diagram.svg';
631
+ document.body.appendChild(link);
632
+ link.click();
633
+
634
+ // Cleanup
635
+ document.body.removeChild(link);
636
+ URL.revokeObjectURL(url);
637
+ } catch (error) {
638
+ console.error('Error downloading diagram:', error);
639
+ alert('Failed to download diagram');
640
+ }
641
+ }
642
+
643
+ function validateInputs() {
644
+ const apiKey = document.getElementById('api-key').value.trim();
645
+ const question = document.getElementById('question').value.trim();
646
+
647
+ let isValid = true;
648
+
649
+ // Validate API Key
650
+ if (!apiKey) {
651
+ document.getElementById('api-key-error').style.display = 'block';
652
+ isValid = false;
653
+ } else {
654
+ document.getElementById('api-key-error').style.display = 'none';
655
+ }
656
+
657
+ // Validate Question
658
+ if (!question) {
659
+ document.getElementById('question-error').style.display = 'block';
660
+ isValid = false;
661
+ } else {
662
+ document.getElementById('question-error').style.display = 'none';
663
+ }
664
+
665
+ return isValid;
666
+ }
667
+
668
+ function showError(message) {
669
+ const rawOutput = document.getElementById('raw-output');
670
+ rawOutput.textContent = `Error: ${message}`;
671
+ rawOutput.style.color = '#dc2626';
672
+ }
673
+
674
+ // Process question
675
+ async function processQuestion(isMetaReasoning = false) {
676
+ if (!validateInputs()) {
677
+ return;
678
+ }
679
+
680
+ // Reset Zoom before processing question
681
+ resetZoom();
682
+ const processButton = document.getElementById('process-btn');
683
+ const metaButton = document.getElementById('meta-btn');
684
+ const rawOutput = document.getElementById('raw-output');
685
+
686
+ processButton.disabled = true;
687
+ metaButton.disabled = true;
688
+ processButton.textContent = 'Processing...';
689
+ rawOutput.textContent = 'Loading...';
690
+ rawOutput.style.color = '#1f2937';
691
+
692
+ // Lock visualization
693
+ lockVisualization();
694
+
695
+ const data = {
696
+ provider: document.getElementById('api-provider').value,
697
+ api_key: document.getElementById('api-key').value,
698
+ model: document.getElementById('model').value,
699
+ max_tokens: parseInt(document.getElementById('max-tokens').value),
700
+ question: document.getElementById('question').value,
701
+ prompt_format: document.getElementById('prompt-format').value,
702
+ reasoning_method: document.getElementById('reasoning-method').value,
703
+ chars_per_line: parseInt(document.getElementById('chars-per-line').value),
704
+ max_lines: parseInt(document.getElementById('max-lines').value)
705
+ };
706
+
707
+ try {
708
+ const response = await fetch('/process', {
709
+ method: 'POST',
710
+ headers: {
711
+ 'Content-Type': 'application/json'
712
+ },
713
+ body: JSON.stringify(data)
714
+ });
715
+
716
+ const result = await response.json();
717
+
718
+ if (result.success) {
719
+ rawOutput.textContent = result.raw_output;
720
+ rawOutput.style.color = '#1f2937';
721
+
722
+ if (result.visualization) {
723
+ const container = document.getElementById('mermaid-diagram');
724
+ container.innerHTML = result.visualization;
725
+ document.getElementById('mermaid-container').classList.add('has-visualization');
726
+ resetZoom();
727
+ mermaid.init();
728
+ }
729
+ } else {
730
+ showError(result.error || 'Unknown error occurred');
731
+ }
732
+ } catch (error) {
733
+ showError('Failed to process request: ' + error.message);
734
+ } finally {
735
+ // Unlock visualization
736
+ unlockVisualization();
737
+
738
+ processButton.disabled = false;
739
+ metaButton.disabled = false;
740
+ processButton.textContent = 'Start Reasoning';
741
+ if (isMetaReasoning) {
742
+ metaButton.textContent = 'Meta Reasoning';
743
+ }
744
+ }
745
+ }
746
+
747
+ // Meta Reasoning function
748
+ async function metaReasoning() {
749
+ const metaButton = document.getElementById('meta-btn');
750
+ const rawOutput = document.getElementById('raw-output');
751
+
752
+ try {
753
+ metaButton.disabled = true;
754
+ metaButton.textContent = 'Selecting Method...';
755
+ rawOutput.textContent = 'Analyzing question to select best method...';
756
+
757
+ // Get current parameters
758
+ const data = {
759
+ provider: document.getElementById('api-provider').value,
760
+ api_key: document.getElementById('api-key').value,
761
+ model: document.getElementById('model').value,
762
+ question: document.getElementById('question').value
763
+ };
764
+
765
+ // Call the method selection endpoint
766
+ const response = await fetch('/select-method', {
767
+ method: 'POST',
768
+ headers: {
769
+ 'Content-Type': 'application/json'
770
+ },
771
+ body: JSON.stringify(data)
772
+ });
773
+
774
+ const result = await response.json();
775
+
776
+ if (result.success) {
777
+ // Set the selected method
778
+ const methodSelect = document.getElementById('reasoning-method');
779
+ methodSelect.value = result.selected_method;
780
+
781
+ // Fetch and update the corresponding method configuration
782
+ const methodResponse = await fetch(`/method-config/${result.selected_method}`);
783
+ const methodConfig = await methodResponse.json();
784
+
785
+ if (methodConfig) {
786
+ // Update the prompt format
787
+ updatePromptFormat(methodConfig.prompt_format);
788
+
789
+ // Update example question if needed
790
+ if (document.getElementById('question').value === '') {
791
+ updateExampleQuestion(methodConfig.example_question);
792
+ }
793
+
794
+ console.log(`Selected reasoning method: ${methodConfig.name}`);
795
+
796
+ // Update button to show method was selected
797
+ metaButton.textContent = 'Method Selected';
798
+ // Process the question with the selected method
799
+ await processQuestion(true);
800
+ } else {
801
+ showError('Failed to load method configuration');
802
+ metaButton.textContent = 'Meta Reasoning';
803
+ }
804
+ } else {
805
+ showError(result.error || 'Failed to select method');
806
+ metaButton.textContent = 'Meta Reasoning';
807
+ }
808
+ } catch (error) {
809
+ console.error('Meta reasoning error:', error);
810
+ showError('Failed to execute meta reasoning');
811
+ metaButton.disabled = false;
812
+ metaButton.textContent = 'Meta Reasoning';
813
+ }
814
+ }
815
+
816
+ // Add event listener for mouse wheel zoom
817
+ document.querySelector('.visualization-wrapper').addEventListener('wheel', function(e) {
818
+ // Do nothing if zooming is locked
819
+ if (window.isZoomLocked) {
820
+ e.preventDefault();
821
+ return;
822
+ }
823
+
824
+ if (e.ctrlKey) {
825
+ e.preventDefault(); // Prevent default zoom
826
+ const delta = e.deltaY > 0 ? -0.1 : 0.1;
827
+ adjustZoom(delta);
828
+ }
829
+ });
830
+
831
+ // Load configuration when page loads
832
+ document.addEventListener('DOMContentLoaded', loadConfig);
833
+ </script>
834
+ </body>
835
+ </html>
tot_reasoning.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+ import re
4
+ import textwrap
5
+ from cot_reasoning import VisualizationConfig, AnthropicAPI
6
+
7
+ @dataclass
8
+ class ToTNode:
9
+ """Data class representing a node in the Tree of Thoughts"""
10
+ id: str
11
+ content: str
12
+ parent_id: Optional[str] = None
13
+ children: List['ToTNode'] = None
14
+ is_answer: bool = False
15
+
16
+ def __post_init__(self):
17
+ if self.children is None:
18
+ self.children = []
19
+
20
+ @dataclass
21
+ class ToTResponse:
22
+ """Data class representing a complete ToT response"""
23
+ question: str
24
+ root: ToTNode
25
+ answer: Optional[str] = None
26
+
27
+ def parse_tot_response(response_text: str, question: str) -> ToTResponse:
28
+ """Parse ToT response text to extract nodes and build the tree"""
29
+ # Parse nodes
30
+ node_pattern = r'<node id="([^"]+)"(?:\s+parent="([^"]+)")?\s*>\s*(.*?)\s*</node>'
31
+ nodes_dict = {}
32
+
33
+ # First pass: create all nodes
34
+ for match in re.finditer(node_pattern, response_text, re.DOTALL):
35
+ node_id = match.group(1)
36
+ parent_id = match.group(2)
37
+ content = match.group(3).strip()
38
+
39
+ node = ToTNode(id=node_id, content=content, parent_id=parent_id)
40
+ nodes_dict[node_id] = node
41
+
42
+ # Second pass: build tree relationships
43
+ root = None
44
+ for node in nodes_dict.values():
45
+ if node.parent_id is None:
46
+ root = node
47
+ else:
48
+ parent = nodes_dict.get(node.parent_id)
49
+ if parent:
50
+ parent.children.append(node)
51
+
52
+ # Parse answer if present
53
+ answer_pattern = r'<answer>\s*(.*?)\s*</answer>'
54
+ answer_match = re.search(answer_pattern, response_text, re.DOTALL)
55
+ answer = answer_match.group(1).strip() if answer_match else None
56
+
57
+ if answer:
58
+ # Mark the node leading to the answer
59
+ for node in nodes_dict.values():
60
+ if node.content.strip() in answer.strip():
61
+ node.is_answer = True
62
+
63
+ return ToTResponse(question=question, root=root, answer=answer)
64
+
65
+ def create_mermaid_diagram(tot_response: ToTResponse, config: VisualizationConfig) -> str:
66
+ """Convert ToT response to Mermaid diagram"""
67
+ diagram = ['<div class="mermaid">', 'graph TD']
68
+
69
+ # Add question node
70
+ question_content = wrap_text(tot_response.question, config)
71
+ diagram.append(f' Q["{question_content}"]')
72
+
73
+ # Track leaf nodes for connecting to answer
74
+ leaf_nodes = []
75
+
76
+ def add_node_and_children(node: ToTNode, parent_id: Optional[str] = None):
77
+ content = wrap_text(node.content, config)
78
+ node_style = 'answer' if node.is_answer else 'default'
79
+
80
+ # Add node
81
+ diagram.append(f' {node.id}["{content}"]')
82
+
83
+ # Add connection from parent
84
+ if parent_id:
85
+ diagram.append(f' {parent_id} --> {node.id}')
86
+
87
+ # Process children
88
+ if node.children:
89
+ for child in node.children:
90
+ add_node_and_children(child, node.id)
91
+ else:
92
+ # This is a leaf node
93
+ leaf_nodes.append(node.id)
94
+
95
+ # Build tree structure
96
+ if tot_response.root:
97
+ diagram.append(f' Q --> {tot_response.root.id}')
98
+ add_node_and_children(tot_response.root)
99
+
100
+ # Add final answer node if answer exists
101
+ if tot_response.answer:
102
+ answer_content = wrap_text(tot_response.answer, config)
103
+ diagram.append(f' Answer["{answer_content}"]')
104
+ # Connect all leaf nodes to the answer
105
+ for leaf_id in leaf_nodes:
106
+ diagram.append(f' {leaf_id} --> Answer')
107
+ diagram.append(' class Answer final_answer;')
108
+
109
+ # Add styles
110
+ diagram.extend([
111
+ ' classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;',
112
+ ' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;',
113
+ ' classDef answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
114
+ ' classDef final_answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
115
+ ' class Q question;',
116
+ ' linkStyle default stroke:#666,stroke-width:2px;'
117
+ ])
118
+
119
+ diagram.append('</div>')
120
+ return '\n'.join(diagram)
121
+
122
+ def wrap_text(text: str, config: VisualizationConfig) -> str:
123
+ """Wrap text to fit within box constraints"""
124
+ text = text.replace('\n', ' ').replace('"', "'")
125
+ wrapped_lines = textwrap.wrap(text, width=config.max_chars_per_line)
126
+
127
+ if len(wrapped_lines) > config.max_lines:
128
+ # Option 1: Simply truncate and add ellipsis to the last line
129
+ wrapped_lines = wrapped_lines[:config.max_lines]
130
+ wrapped_lines[-1] = wrapped_lines[-1][:config.max_chars_per_line-3] + "..."
131
+
132
+ # Option 2 (alternative): Include part of the next line to show continuity
133
+ # original_next_line = wrapped_lines[config.max_lines] if len(wrapped_lines) > config.max_lines else ""
134
+ # wrapped_lines = wrapped_lines[:config.max_lines-1]
135
+ # wrapped_lines.append(original_next_line[:config.max_chars_per_line-3] + "...")
136
+
137
+ return "<br>".join(wrapped_lines)