import os import sys import logging from functools import wraps from flask import Flask, request, jsonify, send_from_directory from flask_cors import CORS import torch import pandas as pd from huggingface_hub import hf_hub_download # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Add parent directory to sys.path to allow imports from 'model' try: from model.kronos import Kronos, KronosTokenizer, KronosPredictor except ImportError as e: logging.error(f"Could not import from model.kronos: {e}") sys.exit(1) # --- Globals --- app = Flask(__name__) CORS(app) # Enable CORS for all routes predictor = None model_name_global = "kronos-base" # Use key now API_KEY = os.environ.get("KRONOS_API_KEY") AVAILABLE_MODELS = { 'kronos-mini': { 'name': 'Kronos-mini', 'model_id': 'NeoQuasar/Kronos-mini', 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k', 'context_length': 2048, 'params': '4.1M', 'description': 'Lightweight model, suitable for fast prediction' }, 'kronos-small': { 'name': 'Kronos-small', 'model_id': 'NeoQuasar/Kronos-small', 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base', 'context_length': 512, 'params': '24.7M', 'description': 'Small model, balanced performance and speed' }, 'kronos-base': { 'name': 'Kronos-base', 'model_id': 'NeoQuasar/Kronos-base', 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base', 'context_length': 512, 'params': '102.3M', 'description': 'Base model, provides better prediction quality' } } # --- Helper Functions --- def download_model_from_hf(model_name, local_dir="."): """Downloads a model from Hugging Face Hub.""" logging.info(f"Downloading model '{model_name}' from Hugging Face Hub...") try: hf_hub_download(repo_id=model_name, filename="config.json", local_dir=local_dir) hf_hub_download(repo_id=model_name, filename="pytorch_model.bin", local_dir=local_dir) logging.info("Model downloaded successfully.") return True except Exception as e: logging.error(f"Failed to download model: {e}") return False # --- API Authentication --- def require_api_key(f): """Decorator to protect routes with an API key.""" @wraps(f) def decorated_function(*args, **kwargs): # If KRONOS_API_KEY is not set on the server, skip authentication (for local/dev) if not API_KEY: logging.warning("API key not set. Skipping authentication.") return f(*args, **kwargs) auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({'error': 'Authorization header is missing or invalid. Use Bearer token.'}), 401 token = auth_header.split(' ')[1] if token != API_KEY: return jsonify({'error': 'Invalid API Key.'}), 401 return f(*args, **kwargs) return decorated_function # --- API Endpoints --- @app.route('/') def index(): """Serves the index.html file for the visualizer.""" return send_from_directory('.', 'index.html') @app.route('/api/load-model', methods=['POST']) @require_api_key def load_model_endpoint(): """Loads the prediction model into memory.""" global predictor, model_name_global json_data = request.get_json() model_key = json_data.get('model_key', model_name_global) # Changed to model_key force_reload = json_data.get('force_reload', False) # Validate if the requested model is in the allowed list if model_key not in AVAILABLE_MODELS: return jsonify({ 'error': f"Invalid model_key. Please choose from the allowed models.", 'allowed_models': list(AVAILABLE_MODELS.keys()) }), 400 if predictor and not force_reload and model_name_global == model_key: return jsonify({'status': 'Model already loaded.'}) try: model_config = AVAILABLE_MODELS[model_key] model_id = model_config['model_id'] tokenizer_id = model_config['tokenizer_id'] logging.info(f"Attempting to load model: {model_id}") logging.info(f"Attempting to load tokenizer: {tokenizer_id}") # Determine device device = "cuda:0" if torch.cuda.is_available() else "cpu" logging.info(f"Using device: {device}") # --- Proxy Setup --- # Check for proxy settings in environment variables, similar to the webui fix proxies = { "http": os.environ.get("HTTP_PROXY"), "https": os.environ.get("HTTPS_PROXY"), } # Filter out None values proxies = {k: v for k, v in proxies.items() if v} if proxies: logging.info(f"Using proxies: {proxies}") # Load model and tokenizer with proxy support model = Kronos.from_pretrained(model_id, proxies=proxies if proxies else None) tokenizer = KronosTokenizer.from_pretrained(tokenizer_id, proxies=proxies if proxies else None) # Create the predictor wrapper predictor = KronosPredictor(model, tokenizer, device=device) model_name_global = model_key logging.info(f"Model '{model_config['name']}' loaded successfully.") return jsonify({ 'status': f"Model '{model_config['name']}' loaded successfully.", 'model_info': model_config }) except Exception as e: logging.error(f"Error loading model: {e}") return jsonify({'error': str(e)}), 500 @app.route('/api/model-status', methods=['GET']) def model_status(): """Checks if the model is loaded.""" if predictor: return jsonify({ 'status': 'loaded', 'model_key': model_name_global, 'model_info': AVAILABLE_MODELS.get(model_name_global) }) else: return jsonify({'status': 'not_loaded'}) @app.route('/api/available-models', methods=['GET']) def get_available_models(): """Returns the list of available models and their details.""" return jsonify(AVAILABLE_MODELS) @app.route('/api/predict', methods=['POST']) @require_api_key def predict(): """ Receives raw K-line data in the request body, makes a prediction, and returns the results. """ if not predictor: return jsonify({'error': 'Model not loaded. Please call /api/load-model first.'}), 400 data = request.get_json(force=True) if not data or 'k_lines' not in data: return jsonify({'error': 'Missing "k_lines" in request body.'}), 400 k_lines = data['k_lines'] params = data.get('prediction_params', {}) pred_len = params.get('pred_len', 120) try: # Define column names based on standard Binance API format # We only need the first 6 columns for the model columns = [ 'timestamp', 'open', 'high', 'low', 'close', 'volume', 'close_time', 'quote_asset_volume', 'number_of_trades', 'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'ignore' ] # Ensure we only use the first 12 columns if more are provided k_lines_standardized = [line[:12] for line in k_lines] df = pd.DataFrame(k_lines_standardized, columns=columns) # --- Data Type Conversion --- df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') numeric_cols = ['open', 'high', 'low', 'close', 'volume'] for col in numeric_cols: df[col] = pd.to_numeric(df[col]) # Keep only the necessary columns for the model df_model_input = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']] logging.info(f"Making prediction with pred_len={pred_len} on data with shape {df_model_input.shape}") # Make prediction # --- Timestamp Generation for Predictor --- # The predictor requires historical and future timestamps x_timestamp = df_model_input['timestamp'] # Assuming the K-line interval is consistent, calculate the interval # from the last two points to generate future timestamps. if len(x_timestamp) > 1: interval = x_timestamp.iloc[-1] - x_timestamp.iloc[-2] else: # If only one data point, assume a 1-minute interval as a fallback interval = pd.Timedelta(minutes=1) y_timestamp = pd.date_range( start=x_timestamp.iloc[-1] + interval, periods=pred_len, freq=interval ) # Convert DatetimeIndex to Series to prevent '.dt' accessor error inside the model y_timestamp = pd.Series(y_timestamp, name='timestamp') # Make prediction using the predictor wrapper pred_df = predictor.predict( df=df_model_input, x_timestamp=x_timestamp, y_timestamp=y_timestamp, pred_len=pred_len, verbose=False # Keep logs clean ) # Format results for JSON response # --- Format results to match input format --- pred_df_reset = pred_df.reset_index() # Convert timestamp to Unix milliseconds integer pred_df_reset['timestamp'] = (pred_df_reset['timestamp'].astype('int64') / 10**6).astype('int64') # Reorder columns to match the desired output format: [timestamp, open, high, low, close, volume] output_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] pred_df_formatted = pred_df_reset[output_columns] # Convert to list of lists prediction_results = pred_df_formatted.values.tolist() return jsonify({ 'success': True, 'prediction_params': {'pred_len': pred_len}, 'prediction_results': prediction_results }) except Exception as e: logging.error(f"Prediction failed: {e}") return jsonify({'error': f'An error occurred during prediction: {str(e)}'}), 500 if __name__ == '__main__': # This block is for local debugging purposes. # The production server will use a WSGI server like Gunicorn. app.run(host='0.0.0.0', port=7860, debug=True)