Spaces:
Running
Running
import gc | |
import os | |
import re | |
import subprocess | |
import time | |
from datetime import datetime, timezone, timedelta | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import torch | |
from binance.client import Client | |
from model import KronosTokenizer, Kronos, KronosPredictor | |
# --- Configuration --- | |
# 适配使用本地模型的配置 | |
Config = { | |
# 项目根目录(保持不变,用于定位其他资源) | |
"REPO_PATH": Path(__file__).parent.resolve(), | |
# 本地模型存储路径(需将模型文件放在该目录下) | |
"LOCAL_MODEL_PATH": os.path.join(Path(__file__).parent.resolve(), "models"), | |
# 交易对和时间间隔(保持原配置) | |
"SYMBOL": "BTCUSDT", | |
"INTERVAL": "1h", | |
# 数据和预测参数(根据 Spaces 资源调整) | |
"HIST_POINTS": 360, | |
"PRED_HORIZON": 24, | |
"N_PREDICTIONS": 10, | |
"VOL_WINDOW": 24, | |
# 新增:缓存预测结果路径(使用 /tmp 目录) | |
"PREDICTION_CACHE": os.path.join("/tmp", "predictions_cache"), | |
# 新增:图表保存路径(确保在可写目录) | |
"CHART_PATH": os.path.join(Path(__file__).parent.resolve(), "prediction_chart.png") | |
} | |
# 确保缓存目录存在 | |
os.makedirs(Config["PREDICTION_CACHE"], exist_ok=True) | |
# 确保本地模型目录存在 | |
os.makedirs(Config["LOCAL_MODEL_PATH"], exist_ok=True) | |
def load_local_model(): | |
"""从本地加载Kronos模型和分词器""" | |
print("Loading local Kronos model...") | |
# 本地分词器路径 | |
tokenizer_path = os.path.join(Config["LOCAL_MODEL_PATH"], "tokenizer") | |
# 本地模型路径 | |
model_path = os.path.join(Config["LOCAL_MODEL_PATH"], "model") | |
# 检查路径是否存在 | |
if not os.path.exists(tokenizer_path): | |
raise FileNotFoundError(f"Tokenizer path {tokenizer_path} does not exist. Please place tokenizer files here.") | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"Model path {model_path} does not exist. Please place model files here.") | |
tokenizer = KronosTokenizer.from_pretrained(tokenizer_path) | |
model = Kronos.from_pretrained(model_path) | |
tokenizer.eval() | |
model.eval() | |
predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512) | |
print("Local model loaded successfully.") | |
return predictor | |
def make_prediction(df, predictor): | |
"""Generates probabilistic forecasts using the Kronos model.""" | |
last_timestamp = df['timestamps'].max() | |
start_new_range = last_timestamp + pd.Timedelta(hours=1) | |
new_timestamps_index = pd.date_range( | |
start=start_new_range, | |
periods=Config["PRED_HORIZON"], | |
freq='H' | |
) | |
y_timestamp = pd.Series(new_timestamps_index, name='y_timestamp') | |
x_timestamp = df['timestamps'] | |
x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']] | |
with torch.no_grad(): | |
print("Making main prediction (T=1.0)...") | |
begin_time = time.time() | |
close_preds_main, volume_preds_main = predictor.predict( | |
df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, | |
pred_len=Config["PRED_HORIZON"], T=1.0, top_p=0.95, | |
sample_count=Config["N_PREDICTIONS"], verbose=True | |
) | |
print(f"Main prediction completed in {time.time() - begin_time:.2f} seconds.") | |
close_preds_volatility = close_preds_main | |
return close_preds_main, volume_preds_main, close_preds_volatility | |
def fetch_binance_data(): | |
"""Fetches K-line data from the Binance public API.""" | |
symbol, interval = Config["SYMBOL"], Config["INTERVAL"] | |
limit = Config["HIST_POINTS"] + Config["VOL_WINDOW"] | |
print(f"Fetching {limit} bars of {symbol} {interval} data from Binance...") | |
# 关键修复:为当前线程创建并设置asyncio事件循环 | |
import asyncio | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
# 初始化Binance客户端 | |
client = Client() | |
klines = client.get_klines(symbol=symbol, interval=interval, limit=limit) | |
# 后续数据处理逻辑保持不变... | |
cols = ['open_time', 'open', 'high', 'low', 'close', 'volume', 'close_time', | |
'quote_asset_volume', 'number_of_trades', 'taker_buy_base_asset_volume', | |
'taker_buy_quote_asset_volume', 'ignore'] | |
df = pd.DataFrame(klines, columns=cols) | |
df = df[['open_time', 'open', 'high', 'low', 'close', 'volume', 'quote_asset_volume']] | |
df.rename(columns={'quote_asset_volume': 'amount', 'open_time': 'timestamps'}, inplace=True) | |
df['timestamps'] = pd.to_datetime(df['timestamps'], unit='ms') | |
for col in ['open', 'high', 'low', 'close', 'volume', 'amount']: | |
df[col] = pd.to_numeric(df[col]) | |
print("Data fetched successfully.") | |
return df | |
def calculate_metrics(hist_df, close_preds_df, v_close_preds_df): | |
""" | |
Calculates upside and volatility amplification probabilities for the 24h horizon. | |
""" | |
last_close = hist_df['close'].iloc[-1] | |
# 1. Upside Probability (for the 24-hour horizon) | |
final_hour_preds = close_preds_df.iloc[-1] | |
upside_prob = (final_hour_preds > last_close).mean() | |
# 2. Volatility Amplification Probability (over the 24-hour horizon) | |
hist_log_returns = np.log(hist_df['close'] / hist_df['close'].shift(1)) | |
historical_vol = hist_log_returns.iloc[-Config["VOL_WINDOW"]:].std() | |
amplification_count = 0 | |
for col in v_close_preds_df.columns: | |
full_sequence = pd.concat([pd.Series([last_close]), v_close_preds_df[col]]).reset_index(drop=True) | |
pred_log_returns = np.log(full_sequence / full_sequence.shift(1)) | |
predicted_vol = pred_log_returns.std() | |
if predicted_vol > historical_vol: | |
amplification_count += 1 | |
vol_amp_prob = amplification_count / len(v_close_preds_df.columns) | |
print(f"Upside Probability (24h): {upside_prob:.2%}, Volatility Amplification Probability: {vol_amp_prob:.2%}") | |
return upside_prob, vol_amp_prob | |
def create_plot(hist_df, close_preds_df, volume_preds_df): | |
"""Generates and saves a comprehensive forecast chart.""" | |
print("Generating comprehensive forecast chart...") | |
fig, (ax1, ax2) = plt.subplots( | |
2, 1, figsize=(15, 10), sharex=True, | |
gridspec_kw={'height_ratios': [3, 1]} | |
) | |
hist_time = hist_df['timestamps'] | |
last_hist_time = hist_time.iloc[-1] | |
pred_time = pd.to_datetime([last_hist_time + timedelta(hours=i + 1) for i in range(len(close_preds_df))]) | |
ax1.plot(hist_time, hist_df['close'], color='royalblue', label='Historical Price', linewidth=1.5) | |
mean_preds = close_preds_df.mean(axis=1) | |
ax1.plot(pred_time, mean_preds, color='darkorange', linestyle='-', label='Mean Forecast') | |
ax1.fill_between(pred_time, close_preds_df.min(axis=1), close_preds_df.max(axis=1), color='darkorange', alpha=0.2, label='Forecast Range (Min-Max)') | |
ax1.set_title(f'{Config["SYMBOL"]} Probabilistic Price & Volume Forecast (Next {Config["PRED_HORIZON"]} Hours)', fontsize=16, weight='bold') | |
ax1.set_ylabel('Price (USDT)') | |
ax1.legend() | |
ax1.grid(True, which='both', linestyle='--', linewidth=0.5) | |
ax2.bar(hist_time, hist_df['volume'], color='skyblue', label='Historical Volume', width=0.03) | |
ax2.bar(pred_time, volume_preds_df.mean(axis=1), color='sandybrown', label='Mean Forecasted Volume', width=0.03) | |
ax2.set_ylabel('Volume') | |
ax2.set_xlabel('Time (UTC)') | |
ax2.legend() | |
ax2.grid(True, which='both', linestyle='--', linewidth=0.5) | |
separator_time = hist_time.iloc[-1] + timedelta(minutes=30) | |
for ax in [ax1, ax2]: | |
ax.axvline(x=separator_time, color='red', linestyle='--', linewidth=1.5, label='_nolegend_') | |
ax.tick_params(axis='x', rotation=30) | |
fig.tight_layout() | |
chart_path = Config["REPO_PATH"] / 'prediction_chart.png' | |
fig.savefig(chart_path, dpi=120) | |
plt.close(fig) | |
print(f"Chart saved to: {chart_path}") | |
def update_html(upside_prob, vol_amp_prob): | |
""" | |
Updates the index.html file with the latest metrics and timestamp. | |
""" | |
print("Updating index.html...") | |
html_path = Config["REPO_PATH"] / 'index.html' | |
now_utc_str = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S') | |
upside_prob_str = f'{upside_prob:.1%}' | |
vol_amp_prob_str = f'{vol_amp_prob:.1%}' | |
with open(html_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
# Robustly replace content using lambda functions | |
content = re.sub( | |
r'(<strong id="update-time">).*?(</strong>)', | |
lambda m: f'{m.group(1)}{now_utc_str}{m.group(2)}', | |
content | |
) | |
content = re.sub( | |
r'(<p class="metric-value" id="upside-prob">).*?(</p>)', | |
lambda m: f'{m.group(1)}{upside_prob_str}{m.group(2)}', | |
content | |
) | |
content = re.sub( | |
r'(<p class="metric-value" id="vol-amp-prob">).*?(</p>)', | |
lambda m: f'{m.group(1)}{vol_amp_prob_str}{m.group(2)}', | |
content | |
) | |
with open(html_path, 'w', encoding='utf-8') as f: | |
f.write(content) | |
print("HTML file updated successfully.") | |
def git_commit_and_push(commit_message): | |
"""Adds, commits, and pushes specified files to the Git repository.""" | |
print("Performing Git operations...") | |
try: | |
os.chdir(Config["REPO_PATH"]) | |
subprocess.run(['git', 'add', 'prediction_chart.png', 'index.html'], check=True, capture_output=True, text=True) | |
commit_result = subprocess.run(['git', 'commit', '-m', commit_message], check=True, capture_output=True, text=True) | |
print(commit_result.stdout) | |
push_result = subprocess.run(['git', 'push'], check=True, capture_output=True, text=True) | |
print(push_result.stdout) | |
print("Git push successful.") | |
except subprocess.CalledProcessError as e: | |
output = e.stdout if e.stdout else e.stderr | |
if "nothing to commit" in output or "Your branch is up to date" in output: | |
print("No new changes to commit or push.") | |
else: | |
print(f"A Git error occurred:\n--- STDOUT ---\n{e.stdout}\n--- STDERR ---\n{e.stderr}") | |
def main_task(model): | |
"""Executes one full update cycle.""" | |
print("\n" + "=" * 60 + f"\nStarting update task at {datetime.now(timezone.utc)}\n" + "=" * 60) | |
df_full = fetch_binance_data() | |
df_for_model = df_full.iloc[:-1] | |
close_preds, volume_preds, v_close_preds = make_prediction(df_for_model, model) | |
hist_df_for_plot = df_for_model.tail(Config["HIST_POINTS"]) | |
hist_df_for_metrics = df_for_model.tail(Config["VOL_WINDOW"]) | |
upside_prob, vol_amp_prob = calculate_metrics(hist_df_for_metrics, close_preds, v_close_preds) | |
create_plot(hist_df_for_plot, close_preds, volume_preds) | |
update_html(upside_prob, vol_amp_prob) | |
commit_message = f"Auto-update forecast for {datetime.now(timezone.utc):%Y-%m-%d %H:%M} UTC" | |
git_commit_and_push(commit_message) | |
# --- 新增的内存清理步骤 --- | |
# 显式删除大的DataFrame对象,帮助垃圾回收器 | |
del df_full, df_for_model, close_preds, volume_preds, v_close_preds | |
del hist_df_for_plot, hist_df_for_metrics | |
# 强制执行垃圾回收 | |
gc.collect() | |
# --- 内存清理结束 --- | |
print("-" * 60 + "\n--- Task completed successfully ---\n" + "-" * 60 + "\n") | |
def run_scheduler(model): | |
"""A continuous scheduler that runs the main task hourly.""" | |
while True: | |
now = datetime.now(timezone.utc) | |
next_run_time = (now + timedelta(hours=1)).replace(minute=0, second=5, microsecond=0) | |
sleep_seconds = (next_run_time - now).total_seconds() | |
if sleep_seconds > 0: | |
print(f"Current time: {now:%Y-%m-%d %H:%M:%S UTC}.") | |
print(f"Next run at: {next_run_time:%Y-%m-%d %H:%M:%S UTC}. Waiting for {sleep_seconds:.0f} seconds...") | |
time.sleep(sleep_seconds) | |
try: | |
main_task(model) | |
except Exception as e: | |
print(f"\n!!!!!! A critical error occurred in the main task !!!!!!!") | |
print(f"Error: {e}") | |
import traceback | |
traceback.print_exc() | |
print("Retrying in 5 minutes...") | |
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n") | |
time.sleep(300) | |
if __name__ == '__main__': | |
# 确保本地模型目录存在 | |
local_model_path = Path(Config["LOCAL_MODEL_PATH"]) | |
local_model_path.mkdir(parents=True, exist_ok=True) | |
loaded_model = load_local_model() | |
main_task(loaded_model) # Run once on startup | |
run_scheduler(loaded_model) # Start the schedule |