Spaces:
Running
Running
| import os | |
| import sys | |
| import logging | |
| from functools import wraps | |
| from flask import Flask, request, jsonify | |
| import torch | |
| import pandas as pd | |
| from huggingface_hub import hf_hub_download | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Add parent directory to sys.path to allow imports from 'model' | |
| try: | |
| from model.kronos import Kronos, KronosTokenizer, KronosPredictor | |
| except ImportError as e: | |
| logging.error(f"Could not import from model.kronos: {e}") | |
| sys.exit(1) | |
| # --- Globals --- | |
| app = Flask(__name__) | |
| predictor = None | |
| model_name_global = "kronos-base" # Use key now | |
| API_KEY = os.environ.get("KRONOS_API_KEY") | |
| AVAILABLE_MODELS = { | |
| 'kronos-mini': { | |
| 'name': 'Kronos-mini', | |
| 'model_id': 'NeoQuasar/Kronos-mini', | |
| 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k', | |
| 'context_length': 2048, | |
| 'params': '4.1M', | |
| 'description': 'Lightweight model, suitable for fast prediction' | |
| }, | |
| 'kronos-small': { | |
| 'name': 'Kronos-small', | |
| 'model_id': 'NeoQuasar/Kronos-small', | |
| 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base', | |
| 'context_length': 512, | |
| 'params': '24.7M', | |
| 'description': 'Small model, balanced performance and speed' | |
| }, | |
| 'kronos-base': { | |
| 'name': 'Kronos-base', | |
| 'model_id': 'NeoQuasar/Kronos-base', | |
| 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base', | |
| 'context_length': 512, | |
| 'params': '102.3M', | |
| 'description': 'Base model, provides better prediction quality' | |
| } | |
| } | |
| # --- Helper Functions --- | |
| def download_model_from_hf(model_name, local_dir="."): | |
| """Downloads a model from Hugging Face Hub.""" | |
| logging.info(f"Downloading model '{model_name}' from Hugging Face Hub...") | |
| try: | |
| hf_hub_download(repo_id=model_name, filename="config.json", local_dir=local_dir) | |
| hf_hub_download(repo_id=model_name, filename="pytorch_model.bin", local_dir=local_dir) | |
| logging.info("Model downloaded successfully.") | |
| return True | |
| except Exception as e: | |
| logging.error(f"Failed to download model: {e}") | |
| return False | |
| # --- API Authentication --- | |
| def require_api_key(f): | |
| """Decorator to protect routes with an API key.""" | |
| def decorated_function(*args, **kwargs): | |
| # If KRONOS_API_KEY is not set on the server, skip authentication (for local/dev) | |
| if not API_KEY: | |
| logging.warning("API key not set. Skipping authentication.") | |
| return f(*args, **kwargs) | |
| auth_header = request.headers.get('Authorization') | |
| if not auth_header or not auth_header.startswith('Bearer '): | |
| return jsonify({'error': 'Authorization header is missing or invalid. Use Bearer token.'}), 401 | |
| token = auth_header.split(' ')[1] | |
| if token != API_KEY: | |
| return jsonify({'error': 'Invalid API Key.'}), 401 | |
| return f(*args, **kwargs) | |
| return decorated_function | |
| # --- API Endpoints --- | |
| def load_model_endpoint(): | |
| """Loads the prediction model into memory.""" | |
| global predictor, model_name_global | |
| json_data = request.get_json() | |
| model_key = json_data.get('model_key', model_name_global) # Changed to model_key | |
| force_reload = json_data.get('force_reload', False) | |
| # Validate if the requested model is in the allowed list | |
| if model_key not in AVAILABLE_MODELS: | |
| return jsonify({ | |
| 'error': f"Invalid model_key. Please choose from the allowed models.", | |
| 'allowed_models': list(AVAILABLE_MODELS.keys()) | |
| }), 400 | |
| if predictor and not force_reload and model_name_global == model_key: | |
| return jsonify({'status': 'Model already loaded.'}) | |
| try: | |
| model_config = AVAILABLE_MODELS[model_key] | |
| model_id = model_config['model_id'] | |
| tokenizer_id = model_config['tokenizer_id'] | |
| logging.info(f"Attempting to load model: {model_id}") | |
| logging.info(f"Attempting to load tokenizer: {tokenizer_id}") | |
| # Determine device | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| logging.info(f"Using device: {device}") | |
| # --- Proxy Setup --- | |
| # Check for proxy settings in environment variables, similar to the webui fix | |
| proxies = { | |
| "http": os.environ.get("HTTP_PROXY"), | |
| "https": os.environ.get("HTTPS_PROXY"), | |
| } | |
| # Filter out None values | |
| proxies = {k: v for k, v in proxies.items() if v} | |
| if proxies: | |
| logging.info(f"Using proxies: {proxies}") | |
| # Load model and tokenizer with proxy support | |
| model = Kronos.from_pretrained(model_id, proxies=proxies if proxies else None) | |
| tokenizer = KronosTokenizer.from_pretrained(tokenizer_id, proxies=proxies if proxies else None) | |
| # Create the predictor wrapper | |
| predictor = KronosPredictor(model, tokenizer, device=device) | |
| model_name_global = model_key | |
| logging.info(f"Model '{model_config['name']}' loaded successfully.") | |
| return jsonify({ | |
| 'status': f"Model '{model_config['name']}' loaded successfully.", | |
| 'model_info': model_config | |
| }) | |
| except Exception as e: | |
| logging.error(f"Error loading model: {e}") | |
| return jsonify({'error': str(e)}), 500 | |
| def model_status(): | |
| """Checks if the model is loaded.""" | |
| if predictor: | |
| return jsonify({ | |
| 'status': 'loaded', | |
| 'model_key': model_name_global, | |
| 'model_info': AVAILABLE_MODELS.get(model_name_global) | |
| }) | |
| else: | |
| return jsonify({'status': 'not_loaded'}) | |
| def get_available_models(): | |
| """Returns the list of available models and their details.""" | |
| return jsonify(AVAILABLE_MODELS) | |
| def predict_from_data(): | |
| """ | |
| Receives raw K-line data in the request body, makes a prediction, | |
| and returns the results. | |
| """ | |
| if not predictor: | |
| return jsonify({'error': 'Model not loaded. Please call /api/load-model first.'}), 400 | |
| data = request.get_json() | |
| if not data or 'k_lines' not in data: | |
| return jsonify({'error': 'Missing "k_lines" in request body.'}), 400 | |
| k_lines = data['k_lines'] | |
| params = data.get('prediction_params', {}) | |
| pred_len = params.get('pred_len', 120) | |
| try: | |
| # Define column names based on standard Binance API format | |
| # We only need the first 6 columns for the model | |
| columns = [ | |
| 'timestamp', 'open', 'high', 'low', 'close', 'volume', | |
| 'close_time', 'quote_asset_volume', 'number_of_trades', | |
| 'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'ignore' | |
| ] | |
| # Ensure we only use the first 12 columns if more are provided | |
| k_lines_standardized = [line[:12] for line in k_lines] | |
| df = pd.DataFrame(k_lines_standardized, columns=columns) | |
| # --- Data Type Conversion --- | |
| df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') | |
| numeric_cols = ['open', 'high', 'low', 'close', 'volume'] | |
| for col in numeric_cols: | |
| df[col] = pd.to_numeric(df[col]) | |
| # Keep only the necessary columns for the model | |
| df_model_input = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']] | |
| logging.info(f"Making prediction with pred_len={pred_len} on data with shape {df_model_input.shape}") | |
| # Make prediction | |
| # --- Timestamp Generation for Predictor --- | |
| # The predictor requires historical and future timestamps | |
| x_timestamp = df_model_input['timestamp'] | |
| # Assuming the K-line interval is consistent, calculate the interval | |
| # from the last two points to generate future timestamps. | |
| if len(x_timestamp) > 1: | |
| interval = x_timestamp.iloc[-1] - x_timestamp.iloc[-2] | |
| else: | |
| # If only one data point, assume a 1-minute interval as a fallback | |
| interval = pd.Timedelta(minutes=1) | |
| y_timestamp = pd.date_range( | |
| start=x_timestamp.iloc[-1] + interval, | |
| periods=pred_len, | |
| freq=interval | |
| ) | |
| # Convert DatetimeIndex to Series to prevent '.dt' accessor error inside the model | |
| y_timestamp = pd.Series(y_timestamp, name='timestamp') | |
| # Make prediction using the predictor wrapper | |
| pred_df = predictor.predict( | |
| df=df_model_input, | |
| x_timestamp=x_timestamp, | |
| y_timestamp=y_timestamp, | |
| pred_len=pred_len, | |
| verbose=False # Keep logs clean | |
| ) | |
| # Format results for JSON response | |
| prediction_results = pred_df.to_dict(orient='records') | |
| return jsonify({ | |
| 'success': True, | |
| 'prediction_params': {'pred_len': pred_len}, | |
| 'prediction_results': prediction_results | |
| }) | |
| except Exception as e: | |
| logging.error(f"Prediction failed: {e}") | |
| return jsonify({'error': f'An error occurred during prediction: {str(e)}'}), 500 | |