Spaces:
Running
Running
import gc | |
import os | |
import re | |
import subprocess | |
import time | |
from datetime import datetime, timezone, timedelta | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import torch | |
import baostock as bs | |
from model import KronosTokenizer, Kronos, KronosPredictor | |
# --- Configuration --- | |
Config = { | |
"REPO_PATH": Path(__file__).parent.resolve(), | |
"LOCAL_MODEL_PATH": os.path.join(Path(__file__).parent.resolve(), "models"), | |
"STOCK_CODE": "sh.000001", | |
"FREQUENCY": "d", | |
"START_DATE": "2022-01-01", | |
"END_DATE": datetime.now().strftime("%Y-%m-%d"), | |
"HIST_POINTS": 360, | |
"PRED_HORIZON": 24, | |
"N_PREDICTIONS": 10, | |
"VOL_WINDOW": 24, | |
"PREDICTION_CACHE": os.path.join("/tmp", "predictions_cache"), | |
# 关键修改1:图表保存到 /tmp 目录(Hugging Face Spaces 默认可写) | |
"CHART_PATH": os.path.join("/tmp", "prediction_chart.png") | |
} | |
os.makedirs(Config["PREDICTION_CACHE"], exist_ok=True) | |
os.makedirs(Config["LOCAL_MODEL_PATH"], exist_ok=True) | |
# --- 其他函数(load_local_model、fetch_stock_data、make_prediction、calculate_metrics)保持不变 --- | |
def load_local_model(): | |
print("Loading local Kronos model...") | |
tokenizer_path = os.path.join(Config["LOCAL_MODEL_PATH"], "tokenizer") | |
model_path = os.path.join(Config["LOCAL_MODEL_PATH"], "model") | |
if not os.path.exists(tokenizer_path): | |
raise FileNotFoundError(f"分词器路径不存在:{tokenizer_path}") | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"模型路径不存在:{model_path}") | |
tokenizer = KronosTokenizer.from_pretrained(tokenizer_path, local_files_only=True) | |
model = Kronos.from_pretrained(model_path, local_files_only=True) | |
tokenizer.eval() | |
model.eval() | |
predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512) | |
print("Local model loaded successfully.") | |
return predictor | |
def fetch_stock_data(): | |
stock_code = Config["STOCK_CODE"] | |
frequency = Config["FREQUENCY"] | |
start_date = Config["START_DATE"] | |
end_date = Config["END_DATE"] | |
need_points = Config["HIST_POINTS"] + Config["VOL_WINDOW"] | |
print(f"Fetching {need_points}个交易日的{stock_code}日线数据...") | |
lg = bs.login() | |
if lg.error_code != '0': | |
raise ConnectionError(f"Baostock登录失败:{lg.error_msg}") | |
try: | |
fields = "date,open,high,low,close,volume" | |
rs = bs.query_history_k_data_plus( | |
code=stock_code, | |
fields=fields, | |
start_date=start_date, | |
end_date=end_date, | |
frequency=frequency, | |
adjustflag="2" | |
) | |
if rs.error_code != '0': | |
raise ValueError(f"获取K线数据失败:{rs.error_msg}") | |
data_list = [] | |
while rs.next(): | |
data_list.append(rs.get_row_data()) | |
df = pd.DataFrame(data_list, columns=rs.fields) | |
numeric_cols = ['open', 'high', 'low', 'close', 'volume'] | |
for col in numeric_cols: | |
df[col] = pd.to_numeric(df[col], errors='coerce') | |
df = df.dropna(subset=numeric_cols) | |
df['timestamps'] = pd.to_datetime(df['date'], format='%Y-%m-%d') | |
df['amount'] = (df['open'] + df['high'] + df['low'] + df['close']) / 4 * df['volume'] | |
df = df[['timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']] | |
if len(df) < need_points: | |
raise ValueError(f"数据量不足(仅{len(df)}个交易日),请提前start_date") | |
df = df.tail(need_points).reset_index(drop=True) | |
print(f"A股数据获取成功,共{len(df)}个交易日数据") | |
print(df[['timestamps', 'open', 'close', 'volume', 'amount']].tail()) | |
return df | |
finally: | |
bs.logout() | |
def make_prediction(df, predictor): | |
last_timestamp = df['timestamps'].max() | |
start_new_range = last_timestamp + pd.Timedelta(days=1) | |
new_timestamps_index = pd.date_range( | |
start=start_new_range, | |
periods=Config["PRED_HORIZON"], | |
freq='D' | |
) | |
y_timestamp = pd.Series(new_timestamps_index, name='y_timestamp') | |
x_timestamp = df['timestamps'] | |
x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']] | |
with torch.no_grad(): | |
print("Making main prediction (T=1.0)...") | |
begin_time = time.time() | |
close_preds_main, volume_preds_main = predictor.predict( | |
df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, | |
pred_len=Config["PRED_HORIZON"], T=1.0, top_p=0.95, | |
sample_count=Config["N_PREDICTIONS"], verbose=True | |
) | |
print(f"Main prediction completed in {time.time() - begin_time:.2f} seconds.") | |
close_preds_volatility = close_preds_main | |
return close_preds_main, volume_preds_main, close_preds_volatility | |
def calculate_metrics(hist_df, close_preds_df, v_close_preds_df): | |
last_close = hist_df['close'].iloc[-1] | |
final_day_preds = close_preds_df.iloc[-1] | |
upside_prob = (final_day_preds > last_close).mean() | |
hist_log_returns = np.log(hist_df['close'] / hist_df['close'].shift(1)) | |
historical_vol = hist_log_returns.iloc[-Config["VOL_WINDOW"]:].std() | |
amplification_count = 0 | |
for col in v_close_preds_df.columns: | |
full_sequence = pd.concat([pd.Series([last_close]), v_close_preds_df[col]]).reset_index(drop=True) | |
pred_log_returns = np.log(full_sequence / full_sequence.shift(1)) | |
predicted_vol = pred_log_returns.std() | |
if predicted_vol > historical_vol: | |
amplification_count += 1 | |
vol_amp_prob = amplification_count / len(v_close_preds_df.columns) | |
print(f"上证指数上涨概率(24个交易日):{upside_prob:.2%}") | |
print(f"波动率放大概率(24个交易日):{vol_amp_prob:.2%}") | |
return upside_prob, vol_amp_prob | |
# --- 关键修改:create_plot 函数(解决字体和权限问题)--- | |
def create_plot(hist_df, close_preds_df, volume_preds_df): | |
print("Generating comprehensive forecast chart...") | |
# 关键修改2:设置支持中文的字体(使用系统自带的中文字体,无需额外安装) | |
plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei', 'SimHei', 'DejaVu Sans'] # 优先级:文泉驿→黑体→备用字体 | |
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示异常问题 | |
fig, (ax1, ax2) = plt.subplots( | |
2, 1, figsize=(15, 10), sharex=True, | |
gridspec_kw={'height_ratios': [3, 1]} | |
) | |
# 数据准备(逻辑不变) | |
hist_time = hist_df['timestamps'] | |
last_hist_time = hist_time.iloc[-1] | |
pred_time = pd.to_datetime([last_hist_time + timedelta(days=i + 1) for i in range(len(close_preds_df))]) | |
# 价格子图 | |
ax1.plot(hist_time, hist_df['close'], color='#00274C', label='上证指数(后复权)', linewidth=1.5) | |
mean_preds = close_preds_df.mean(axis=1) | |
ax1.plot(pred_time, mean_preds, color='#FF6B00', linestyle='-', label='预测均价') | |
ax1.fill_between(pred_time, close_preds_df.min(axis=1), close_preds_df.max(axis=1), | |
color='#FF6B00', alpha=0.2, label='预测区间(最小-最大)') | |
ax1.set_title(f'{Config["STOCK_CODE"]} 上证指数 概率预测(未来{Config["PRED_HORIZON"]}个交易日)', | |
fontsize=16, weight='bold') | |
ax1.set_ylabel('价格(元)') | |
ax1.legend() | |
ax1.grid(True, which='both', linestyle='--', linewidth=0.5) | |
# 成交量子图 | |
ax2.bar(hist_time, hist_df['volume']/1e8, color='#00A86B', label='历史成交量(亿手)', width=0.6) | |
ax2.bar(pred_time, volume_preds_df.mean(axis=1)/1e8, color='#FF6B00', label='预测成交量(亿手)', width=0.6) | |
ax2.set_ylabel('成交量(亿手)') | |
ax2.set_xlabel('日期') | |
ax2.legend() | |
ax2.grid(True, which='both', linestyle='--', linewidth=0.5) | |
# 分割线 | |
separator_time = last_hist_time + timedelta(hours=12) | |
for ax in [ax1, ax2]: | |
ax.axvline(x=separator_time, color='red', linestyle='--', linewidth=1.5, label='_nolegend_') | |
ax.tick_params(axis='x', rotation=45) | |
fig.tight_layout() | |
# 关键修改3:保存到 /tmp 目录(权限充足),并添加权限检查 | |
chart_path = Path(Config["CHART_PATH"]) | |
# 确保父目录存在且有写入权限 | |
if not chart_path.parent.exists(): | |
chart_path.parent.mkdir(parents=True, exist_ok=True) | |
# 强制赋予文件写入权限(可选,进一步避免权限问题) | |
if chart_path.exists(): | |
chart_path.chmod(0o666) # 所有用户可读写 | |
fig.savefig(chart_path, dpi=120, bbox_inches='tight') | |
plt.close(fig) | |
print(f"图表已保存至:{chart_path}(权限:{oct(os.stat(chart_path).st_mode)[-3:]})") | |
# --- 其他函数(update_html、git_commit_and_push、main_task、run_scheduler)保持不变 --- | |
def update_html(upside_prob, vol_amp_prob): | |
print("Updating index.html...") | |
html_path = Config["REPO_PATH"] / 'index.html' | |
now_utc_str = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S') | |
upside_prob_str = f'{upside_prob:.1%}' | |
vol_amp_prob_str = f'{vol_amp_prob:.1%}' | |
with open(html_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
content = re.sub( | |
r'(<strong id="update-time">).*?(</strong>)', | |
lambda m: f'{m.group(1)}{now_utc_str}{m.group(2)}', | |
content | |
) | |
content = re.sub( | |
r'(<p class="metric-value" id="upside-prob">).*?(</p>)', | |
lambda m: f'{m.group(1)}{upside_prob_str}{m.group(2)}', | |
content | |
) | |
content = re.sub( | |
r'(<p class="metric-value" id="vol-amp-prob">).*?(</p>)', | |
lambda m: f'{m.group(1)}{vol_amp_prob_str}{m.group(2)}', | |
content | |
) | |
content = content.replace("24h Upside Probability", "24个交易日上涨概率") | |
content = content.replace("Volatility Amplification Probability", "波动率放大概率") | |
content = content.replace("BTCUSDT", "上证指数(sh.000001)") | |
with open(html_path, 'w', encoding='utf-8') as f: | |
f.write(content) | |
print("HTML文件更新成功") | |
def git_commit_and_push(commit_message): | |
print("Performing Git operations...") | |
try: | |
os.chdir(Config["REPO_PATH"]) | |
# 关键修改4:若需Git提交图表,需从 /tmp 复制到 /app 目录(因/tmp目录不纳入Git跟踪) | |
chart_src = Config["CHART_PATH"] | |
chart_dst = os.path.join(Config["REPO_PATH"], "prediction_chart.png") | |
if os.path.exists(chart_src): | |
import shutil | |
shutil.copy2(chart_src, chart_dst) # 复制图表到Git跟踪目录 | |
print(f"图表已从 {chart_src} 复制到 {chart_dst}") | |
subprocess.run(['git', 'add', 'prediction_chart.png', 'index.html'], check=True, capture_output=True, text=True) | |
commit_result = subprocess.run(['git', 'commit', '-m', commit_message], check=True, capture_output=True, text=True) | |
print(commit_result.stdout) | |
push_result = subprocess.run(['git', 'push'], check=True, capture_output=True, text=True) | |
print(push_result.stdout) | |
print("Git push successful.") | |
except subprocess.CalledProcessError as e: | |
output = e.stdout if e.stdout else e.stderr | |
if "nothing to commit" in output or "Your branch is up to date" in output: | |
print("No new changes to commit or push.") | |
else: | |
print(f"A Git error occurred:\n--- STDOUT ---\n{e.stdout}\n--- STDERR ---\n{e.stderr}") | |
def main_task(model): | |
print("\n" + "=" * 60 + f"\nStarting update task at {datetime.now(timezone.utc)}\n" + "=" * 60) | |
df_full = fetch_stock_data() | |
df_for_model = df_full.iloc[:-1] | |
close_preds, volume_preds, v_close_preds = make_prediction(df_for_model, model) | |
hist_df_for_plot = df_for_model.tail(Config["HIST_POINTS"]) | |
hist_df_for_metrics = df_for_model.tail(Config["VOL_WINDOW"]) | |
upside_prob, vol_amp_prob = calculate_metrics(hist_df_for_metrics, close_preds, v_close_preds) | |
create_plot(hist_df_for_plot, close_preds, volume_preds) # 调用修改后的绘图函数 | |
update_html(upside_prob, vol_amp_prob) | |
commit_message = f"Auto-update 上证指数预测 {datetime.now(timezone.utc):%Y-%m-%d %H:%M} UTC" | |
git_commit_and_push(commit_message) | |
del df_full, df_for_model, close_preds, volume_preds, v_close_preds | |
del hist_df_for_plot, hist_df_for_metrics | |
gc.collect() | |
print("-" * 60 + "\n--- Task completed successfully ---\n" + "-" * 60 + "\n") | |
def run_scheduler(model): | |
while True: | |
now = datetime.now(timezone.utc) | |
next_run_time = (now + timedelta(days=1)).replace(hour=0, minute=0, second=5, microsecond=0) | |
sleep_seconds = (next_run_time - now).total_seconds() | |
if sleep_seconds > 0: | |
print(f"Current time: {now:%Y-%m-%d %H:%M:%S UTC}.") | |
print(f"Next run at: {next_run_time:%Y-%m-%d %H:%M:%S UTC}. Waiting for {sleep_seconds:.0f} seconds...") | |
time.sleep(sleep_seconds) | |
try: | |
main_task(model) | |
except Exception as e: | |
print(f"\n!!!!!! A critical error occurred in the main task !!!!!!!") | |
print(f"Error: {e}") | |
import traceback | |
traceback.print_exc() | |
print("Retrying in 5 minutes...") | |
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n") | |
time.sleep(300) | |
if __name__ == '__main__': | |
local_model_path = Path(Config["LOCAL_MODEL_PATH"]) | |
local_model_path.mkdir(parents=True, exist_ok=True) | |
loaded_model = load_local_model() | |
main_task(loaded_model) | |
run_scheduler(loaded_model) |