Spaces:
Running
Running
from flask import Flask, render_template, request, jsonify | |
from api_base import create_api # New import for API factory | |
from plain_text_reasoning import ( | |
create_mermaid_diagram as create_plain_diagram, | |
parse_plain_text_response | |
) | |
from cot_reasoning import ( | |
VisualizationConfig, | |
create_mermaid_diagram as create_cot_diagram, | |
parse_cot_response | |
) | |
from tot_reasoning import ( | |
create_mermaid_diagram as create_tot_diagram, | |
parse_tot_response | |
) | |
from l2m_reasoning import ( | |
create_mermaid_diagram as create_l2m_diagram, | |
parse_l2m_response | |
) | |
from selfconsistency_reasoning import ( | |
create_mermaid_diagram as create_scr_diagram, | |
parse_scr_response | |
) | |
from selfrefine_reasoning import ( | |
create_mermaid_diagram as create_srf_diagram, | |
parse_selfrefine_response | |
) | |
from bs_reasoning import ( | |
create_mermaid_diagram as create_bs_diagram, | |
parse_bs_response | |
) | |
from configs import config | |
import logging | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Initialize Flask app | |
app = Flask(__name__) | |
def index(): | |
"""Render the main page""" | |
return render_template('index.html') | |
def index_direct(): | |
"""Directly render the main page when accessed via index.html""" | |
return render_template('index.html') | |
def index_cn(): | |
"""Render the Chinese version of the main page""" | |
return render_template('index_cn.html') | |
def get_config(): | |
"""Get initial configuration""" | |
return jsonify(config.get_initial_values()) | |
def get_method_config(method_id): | |
"""Get configuration for specific method""" | |
method_config = config.get_method_config(method_id) | |
if method_config: | |
return jsonify(method_config) | |
return jsonify({"error": "Method not found"}), 404 | |
def get_provider_api_key(provider): | |
"""Get default API key for specific provider""" | |
try: | |
api_key = config.general.get_default_api_key(provider) | |
return jsonify({ | |
'success': True, | |
'api_key': api_key | |
}) | |
except Exception as e: | |
logger.error(f"Error getting API key for provider {provider}: {str(e)}") | |
return jsonify({ | |
'success': False, | |
'error': str(e) | |
}), 500 | |
def save_api_key(): | |
"""Save API key for a provider in memory only (no file storage)""" | |
try: | |
data = request.json | |
if not data: | |
return jsonify({ | |
'success': False, | |
'error': 'No data provided' | |
}), 400 | |
provider = data.get('provider') | |
api_key = data.get('api_key') | |
if not provider or not api_key: | |
return jsonify({ | |
'success': False, | |
'error': 'Provider and API key are required' | |
}), 400 | |
# Update API key in config (this updates the in-memory API keys only) | |
config.general.provider_api_keys[provider] = api_key | |
logger.info(f"Saved API key for provider: {provider} (in memory only)") | |
return jsonify({ | |
'success': True | |
}) | |
except Exception as e: | |
logger.error(f"Error saving API key: {str(e)}") | |
return jsonify({ | |
'success': False, | |
'error': str(e) | |
}), 500 | |
def select_method(): | |
"""Let the model select the most appropriate reasoning method""" | |
try: | |
data = request.json | |
if not data: | |
return jsonify({'success': False, 'error': 'No data provided'}), 400 | |
# Extract parameters | |
api_key = data.get('api_key') | |
provider = data.get('provider', 'anthropic') | |
model = data.get('model') | |
question = data.get('question') | |
if not all([api_key, model, question]): | |
return jsonify({'success': False, 'error': 'Missing required parameters'}), 400 | |
# Create the selection prompt | |
methods = config.methods | |
prompt = f"""Given this question: "{question}" | |
Please select the most appropriate reasoning method from the following options to solve it: | |
{chr(10).join(f'- {method_id}: {config.name}' for method_id, config in methods.items())} | |
Consider the characteristics of each method and the nature of the question. | |
Output your selection in exactly this format: | |
<selected_method>method_id</selected_method> | |
where method_id is strictly one of: {', '.join(methods.keys())}. | |
Do not use the method or words that are not in {', '.join(methods.keys())}.""" | |
# Get model's selection | |
try: | |
api = create_api(provider, api_key, model) | |
response = api.generate_response(prompt, max_tokens=100) | |
# Extract method ID using basic string parsing | |
import re | |
match = re.search(r'<selected_method>(\w+)</selected_method>', response) | |
if match and match.group(1) in methods: | |
selected_method = match.group(1) | |
return jsonify({ | |
'success': True, | |
'selected_method': selected_method, | |
'raw_response': response | |
}) | |
else: | |
return jsonify({ | |
'success': False, | |
'error': 'Invalid method selection in response' | |
}), 400 | |
except Exception as e: | |
return jsonify({ | |
'success': False, | |
'error': f'API call failed: {str(e)}' | |
}), 500 | |
except Exception as e: | |
logger.error(f"Error in method selection: {str(e)}") | |
return jsonify({ | |
'success': False, | |
'error': str(e) | |
}), 500 | |
def process(): | |
"""Process the reasoning request""" | |
try: | |
# Get request data | |
data = request.json | |
if not data: | |
return jsonify({ | |
'success': False, | |
'error': 'No data provided' | |
}), 400 | |
# Extract parameters | |
api_key = data.get('api_key') | |
if not api_key: | |
return jsonify({ | |
'success': False, | |
'error': 'API key is required' | |
}), 400 | |
question = data.get('question') | |
if not question: | |
return jsonify({ | |
'success': False, | |
'error': 'Question is required' | |
}), 400 | |
# Get optional parameters with defaults | |
provider = data.get('provider', 'anthropic') # New parameter for provider | |
model = data.get('model', config.general.available_models[0]) | |
max_tokens = int(data.get('max_tokens', config.general.max_tokens)) | |
prompt_format = data.get('prompt_format') | |
chars_per_line = int(data.get('chars_per_line', config.general.chars_per_line)) | |
max_lines = int(data.get('max_lines', config.general.max_lines)) | |
reasoning_method = data.get('reasoning_method', 'cot') | |
# Initialize API with factory function | |
try: | |
api = create_api(provider, api_key, model) | |
except Exception as e: | |
return jsonify({ | |
'success': False, | |
'error': f'Failed to initialize API: {str(e)}' | |
}), 400 | |
# Get model response | |
logger.info(f"Generating response for question using {provider} {model}") | |
try: | |
raw_response = api.generate_response( | |
question, | |
max_tokens=max_tokens, | |
prompt_format=prompt_format | |
) | |
except Exception as e: | |
return jsonify({ | |
'success': False, | |
'error': f'API call failed: {str(e)}' | |
}), 500 | |
# Create visualization config | |
viz_config = VisualizationConfig( | |
max_chars_per_line=chars_per_line, | |
max_lines=max_lines | |
) | |
# Generate visualization based on reasoning method | |
visualization = None | |
try: | |
if reasoning_method == 'cot': | |
result = parse_cot_response(raw_response, question) | |
visualization = create_cot_diagram(result, viz_config) | |
elif reasoning_method == 'tot': | |
result = parse_tot_response(raw_response, question) | |
visualization = create_tot_diagram(result, viz_config) | |
elif reasoning_method == 'l2m': | |
result = parse_l2m_response(raw_response, question) | |
visualization = create_l2m_diagram(result, viz_config) | |
elif reasoning_method == 'scr': | |
result = parse_scr_response(raw_response, question) | |
visualization = create_scr_diagram(result, viz_config) | |
elif reasoning_method == 'srf': | |
result = parse_selfrefine_response(raw_response, question) | |
visualization = create_srf_diagram(result, viz_config) | |
elif reasoning_method == 'bs': | |
result = parse_bs_response(raw_response, question) | |
visualization = create_bs_diagram(result, viz_config) | |
elif reasoning_method == 'plain': | |
parse_plain_text_response(raw_response, question) | |
visualization = None | |
logger.info("Successfully generated visualization") | |
except Exception as viz_error: | |
logger.error(f"Visualization generation failed: {str(viz_error)}") | |
# Continue without visualization | |
# Return successful response | |
return jsonify({ | |
'success': True, | |
'raw_output': raw_response, | |
'visualization': visualization | |
}) | |
except Exception as e: | |
# Log the error and return error response | |
logger.error(f"Error processing request: {str(e)}") | |
return jsonify({ | |
'success': False, | |
'error': str(e) | |
}), 500 | |
def not_found_error(error): | |
"""Handle 404 errors""" | |
return jsonify({ | |
'success': False, | |
'error': 'Resource not found' | |
}), 404 | |
def internal_error(error): | |
"""Handle 500 errors""" | |
return jsonify({ | |
'success': False, | |
'error': 'Internal server error' | |
}), 500 | |
if __name__ == '__main__': | |
try: | |
# Run the application | |
app.run( | |
host='0.0.0.0', | |
port=5001, | |
debug=False # Disable debug mode in production | |
) | |
except Exception as e: | |
logger.error(f"Failed to start application: {str(e)}") |