Spaces:
Running
Running
Update update_predictions.py
Browse files- update_predictions.py +548 -0
update_predictions.py
CHANGED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import subprocess
|
| 5 |
+
import time
|
| 6 |
+
from datetime import datetime, timedelta
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import torch
|
| 13 |
+
import baostock as bs
|
| 14 |
+
from pytz import timezone # 处理中国时区(Asia/Shanghai)
|
| 15 |
+
|
| 16 |
+
from model import KronosTokenizer, Kronos, KronosPredictor
|
| 17 |
+
|
| 18 |
+
# --- Configuration ---
|
| 19 |
+
Config = {
|
| 20 |
+
"REPO_PATH": Path(__file__).parent.resolve(),
|
| 21 |
+
"LOCAL_MODEL_PATH": os.path.join(Path(__file__).parent.resolve(), "models"),
|
| 22 |
+
"STOCK_CODE": "sh.000001",
|
| 23 |
+
"FREQUENCY": "d",
|
| 24 |
+
"START_DATE": "2022-01-01",
|
| 25 |
+
"PRED_HORIZON": 24,
|
| 26 |
+
"N_PREDICTIONS": 10,
|
| 27 |
+
"VOL_WINDOW": 24,
|
| 28 |
+
"PREDICTION_CACHE": os.path.join("/tmp", "predictions_cache"),
|
| 29 |
+
"CHART_PATH": os.path.join("/tmp", "prediction_chart.png"),
|
| 30 |
+
"HTML_PATH": os.path.join("/tmp", "index.html"),
|
| 31 |
+
# 先不定义CHINESE_FONT_PATH,避免引用未完成的Config
|
| 32 |
+
"IS_TODAY_INFERENCED": False,
|
| 33 |
+
"CACHED_RESULTS": {
|
| 34 |
+
"close_preds": None,
|
| 35 |
+
"volume_preds": None,
|
| 36 |
+
"v_close_preds": None,
|
| 37 |
+
"upside_prob": None,
|
| 38 |
+
"vol_amp_prob": None,
|
| 39 |
+
"hist_df_for_plot": None
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# 补充定义中文字体路径(此时Config已完全定义)
|
| 44 |
+
Config["CHINESE_FONT_PATH"] = os.path.join(Config["REPO_PATH"], "fonts", "wqy-microhei.ttf")
|
| 45 |
+
|
| 46 |
+
# 创建必要目录
|
| 47 |
+
os.makedirs(Config["PREDICTION_CACHE"], exist_ok=True)
|
| 48 |
+
os.makedirs(Config["LOCAL_MODEL_PATH"], exist_ok=True)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_china_time():
|
| 52 |
+
"""获取当前中国时间(Asia/Shanghai时区),返回datetime对象"""
|
| 53 |
+
china_tz = timezone("Asia/Shanghai")
|
| 54 |
+
return datetime.now(china_tz)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_local_model():
|
| 58 |
+
"""加载本地Kronos模型,添加字体加载日志"""
|
| 59 |
+
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 开始加载本地Kronos模型...")
|
| 60 |
+
tokenizer_path = os.path.join(Config["LOCAL_MODEL_PATH"], "tokenizer")
|
| 61 |
+
model_path = os.path.join(Config["LOCAL_MODEL_PATH"], "model")
|
| 62 |
+
|
| 63 |
+
# 检查模型文件是否存在
|
| 64 |
+
if not os.path.exists(tokenizer_path):
|
| 65 |
+
raise FileNotFoundError(f"分词器路径不存在:{tokenizer_path}")
|
| 66 |
+
if not os.path.exists(model_path):
|
| 67 |
+
raise FileNotFoundError(f"模型路径不存在:{model_path}")
|
| 68 |
+
|
| 69 |
+
# 加载模型和分词器
|
| 70 |
+
tokenizer = KronosTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
|
| 71 |
+
model = Kronos.from_pretrained(model_path, local_files_only=True)
|
| 72 |
+
tokenizer.eval()
|
| 73 |
+
model.eval()
|
| 74 |
+
predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512)
|
| 75 |
+
|
| 76 |
+
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 本地模型加载成功")
|
| 77 |
+
return predictor
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def fetch_stock_data():
|
| 81 |
+
"""获取股票数据(每日更新一次,中国时间),添加数据获取日志"""
|
| 82 |
+
china_now = get_china_time()
|
| 83 |
+
end_date = china_now.strftime("%Y-%m-%d") # 按中国时间取结束日期
|
| 84 |
+
need_points = Config["VOL_WINDOW"] + Config["VOL_WINDOW"] # 历史数据+波动率计算窗口
|
| 85 |
+
|
| 86 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始获取{Config['STOCK_CODE']}日线数据(结束日期:{end_date})")
|
| 87 |
+
lg = bs.login()
|
| 88 |
+
if lg.error_code != '0':
|
| 89 |
+
raise ConnectionError(f"Baostock登录失败:{lg.error_msg}")
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# 调用baostock获取K线数据
|
| 93 |
+
fields = "date,open,high,low,close,volume"
|
| 94 |
+
rs = bs.query_history_k_data_plus(
|
| 95 |
+
code=Config["STOCK_CODE"],
|
| 96 |
+
fields=fields,
|
| 97 |
+
start_date=Config["START_DATE"],
|
| 98 |
+
end_date=end_date,
|
| 99 |
+
frequency=Config["FREQUENCY"],
|
| 100 |
+
adjustflag="2" # 后复权
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if rs.error_code != '0':
|
| 104 |
+
raise ValueError(f"获取K线数据失败:{rs.error_msg}")
|
| 105 |
+
|
| 106 |
+
# 处理数据
|
| 107 |
+
data_list = []
|
| 108 |
+
while rs.next():
|
| 109 |
+
data_list.append(rs.get_row_data())
|
| 110 |
+
df = pd.DataFrame(data_list, columns=rs.fields)
|
| 111 |
+
|
| 112 |
+
# 数值列转换
|
| 113 |
+
numeric_cols = ['open', 'high', 'low', 'close', 'volume']
|
| 114 |
+
for col in numeric_cols:
|
| 115 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 116 |
+
df = df.dropna(subset=numeric_cols)
|
| 117 |
+
|
| 118 |
+
# 添加时间戳和成交额列
|
| 119 |
+
df['timestamps'] = pd.to_datetime(df['date'], format='%Y-%m-%d')
|
| 120 |
+
df['amount'] = (df['open'] + df['high'] + df['low'] + df['close']) / 4 * df['volume']
|
| 121 |
+
df = df[['timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']]
|
| 122 |
+
|
| 123 |
+
# 检查数据量
|
| 124 |
+
if len(df) < need_points:
|
| 125 |
+
raise ValueError(f"数据量不足(仅{len(df)}个交易日),请提前START_DATE")
|
| 126 |
+
df = df.tail(need_points).reset_index(drop=True)
|
| 127 |
+
|
| 128 |
+
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 股票数据获取成功,共{len(df)}个交易日")
|
| 129 |
+
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 最新5条数据:\n{df[['timestamps', 'open', 'close', 'volume']].tail()}")
|
| 130 |
+
return df
|
| 131 |
+
|
| 132 |
+
finally:
|
| 133 |
+
bs.logout()
|
| 134 |
+
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] Baostock已登出")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def make_prediction(df, predictor):
|
| 138 |
+
"""执行模型推理,仅当天首次调用时运行,添加推理日志"""
|
| 139 |
+
china_now = get_china_time()
|
| 140 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始执行模型推理(预测未来{Config['PRED_HORIZON']}个交易日)")
|
| 141 |
+
|
| 142 |
+
# 准备时间戳
|
| 143 |
+
last_timestamp = df['timestamps'].max()
|
| 144 |
+
start_new_range = last_timestamp + pd.Timedelta(days=1)
|
| 145 |
+
new_timestamps_index = pd.date_range(
|
| 146 |
+
start=start_new_range,
|
| 147 |
+
periods=Config["PRED_HORIZON"],
|
| 148 |
+
freq='D'
|
| 149 |
+
)
|
| 150 |
+
y_timestamp = pd.Series(new_timestamps_index, name='y_timestamp')
|
| 151 |
+
x_timestamp = df['timestamps']
|
| 152 |
+
x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']]
|
| 153 |
+
|
| 154 |
+
# 推理(禁用梯度计算,节省资源)
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
begin_time = time.time()
|
| 157 |
+
close_preds_main, volume_preds_main = predictor.predict(
|
| 158 |
+
df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp,
|
| 159 |
+
pred_len=Config["PRED_HORIZON"], T=1.0, top_p=0.95,
|
| 160 |
+
sample_count=Config["N_PREDICTIONS"], verbose=True
|
| 161 |
+
)
|
| 162 |
+
infer_time = time.time() - begin_time
|
| 163 |
+
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 推理完成,耗时{infer_time:.2f}秒")
|
| 164 |
+
|
| 165 |
+
# 波动率预测复用收盘价预测结果(保持原逻辑)
|
| 166 |
+
close_preds_volatility = close_preds_main
|
| 167 |
+
return close_preds_main, volume_preds_main, close_preds_volatility
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def calculate_metrics(hist_df, close_preds_df, v_close_preds_df):
|
| 171 |
+
"""计算上涨概率和波动率放大概率,添加指标计算日志"""
|
| 172 |
+
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 开始计算预测指标...")
|
| 173 |
+
|
| 174 |
+
# 上涨概率(最后一个预测日相对于最新收盘价)
|
| 175 |
+
last_close = hist_df['close'].iloc[-1]
|
| 176 |
+
final_day_preds = close_preds_df.iloc[-1]
|
| 177 |
+
upside_prob = (final_day_preds > last_close).mean()
|
| 178 |
+
|
| 179 |
+
# 波动率放大概率(预测波动率vs历史波动率)
|
| 180 |
+
hist_log_returns = np.log(hist_df['close'] / hist_df['close'].shift(1))
|
| 181 |
+
historical_vol = hist_log_returns.iloc[-Config["VOL_WINDOW"]:].std()
|
| 182 |
+
|
| 183 |
+
amplification_count = 0
|
| 184 |
+
for col in v_close_preds_df.columns:
|
| 185 |
+
full_sequence = pd.concat([pd.Series([last_close]), v_close_preds_df[col]]).reset_index(drop=True)
|
| 186 |
+
pred_log_returns = np.log(full_sequence / full_sequence.shift(1))
|
| 187 |
+
predicted_vol = pred_log_returns.std()
|
| 188 |
+
if predicted_vol > historical_vol:
|
| 189 |
+
amplification_count += 1
|
| 190 |
+
vol_amp_prob = amplification_count / len(v_close_preds_df.columns)
|
| 191 |
+
|
| 192 |
+
# 打印指标日志
|
| 193 |
+
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 指标计算完成:")
|
| 194 |
+
print(f" - 24个交易日上涨概率:{upside_prob:.2%}")
|
| 195 |
+
print(f" - 24个交易日波动率放大概率:{vol_amp_prob:.2%}")
|
| 196 |
+
return upside_prob, vol_amp_prob
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def create_plot():
|
| 201 |
+
china_now = get_china_time()
|
| 202 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始生成预测图表(适配低版本matplotlib字体)")
|
| 203 |
+
|
| 204 |
+
# 从缓存获取数据(原有逻辑不变)
|
| 205 |
+
hist_df_for_plot = Config["CACHED_RESULTS"]["hist_df_for_plot"]
|
| 206 |
+
close_preds = Config["CACHED_RESULTS"]["close_preds"]
|
| 207 |
+
volume_preds = Config["CACHED_RESULTS"]["volume_preds"]
|
| 208 |
+
|
| 209 |
+
# -------------------------- 新增:创建画布和子图 --------------------------
|
| 210 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
|
| 211 |
+
# -----------------------------------------------------------------------------
|
| 212 |
+
|
| 213 |
+
# -------------------------- 修正:低版本matplotlib字体处理 --------------------------
|
| 214 |
+
from matplotlib.font_manager import FontProperties
|
| 215 |
+
font_path = Config["CHINESE_FONT_PATH"]
|
| 216 |
+
|
| 217 |
+
# 检查字体文件是否存在
|
| 218 |
+
if os.path.exists(font_path):
|
| 219 |
+
# 直接通过FontProperties指定字体文件路径(兼容低版本matplotlib)
|
| 220 |
+
chinese_font = FontProperties(fname=font_path)
|
| 221 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 成功加载.ttf字体:{font_path}")
|
| 222 |
+
else:
|
| 223 |
+
# 字体文件不存在时的 fallback 逻辑
|
| 224 |
+
chinese_font = FontProperties(family='SimHei', size=10)
|
| 225 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 字体文件不存在,使用系统默认字体:SimHei")
|
| 226 |
+
|
| 227 |
+
# 全局设置字体(确保坐标轴刻度等默认文本也能显示中文)
|
| 228 |
+
plt.rcParams["font.family"] = ["sans-serif"]
|
| 229 |
+
plt.rcParams["font.sans-serif"] = ["WenQuanYi Micro Hei", "SimHei", "Heiti TC"]
|
| 230 |
+
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
| 231 |
+
# -----------------------------------------------------------------------------
|
| 232 |
+
|
| 233 |
+
# 绘图时,为所有中文文本显式指定字体(关键)
|
| 234 |
+
# 1. 价格子图
|
| 235 |
+
hist_time = hist_df_for_plot['timestamps']
|
| 236 |
+
ax1.plot(hist_time, hist_df_for_plot['close'], color='#00274C', linewidth=1.5)
|
| 237 |
+
mean_preds = close_preds.mean(axis=1)
|
| 238 |
+
# 生成预测时间序列(假设预测是在历史最后一个时间���后的24个交易日)
|
| 239 |
+
last_hist_time = hist_time.max()
|
| 240 |
+
pred_time = pd.date_range(start=last_hist_time + pd.Timedelta(days=1), periods=Config["PRED_HORIZON"], freq='B')
|
| 241 |
+
ax1.plot(pred_time, mean_preds, color='#FF6B00', linestyle='-')
|
| 242 |
+
ax1.fill_between(pred_time, close_preds.min(axis=1), close_preds.max(axis=1),
|
| 243 |
+
color='#FF6B00', alpha=0.2)
|
| 244 |
+
# 中文标题/标签指定字体
|
| 245 |
+
ax1.set_title(f'{Config["STOCK_CODE"]} 上证指数概率预测(未来{Config["PRED_HORIZON"]}个交易日)',
|
| 246 |
+
fontsize=16, weight='bold', fontproperties=chinese_font)
|
| 247 |
+
ax1.set_ylabel('价格(元)', fontsize=12, fontproperties=chinese_font)
|
| 248 |
+
# 图例指定字体
|
| 249 |
+
ax1.legend(['上证指数(后复权)', '预测均价', '预测区间(最小-最大)'],
|
| 250 |
+
fontsize=10, prop=chinese_font)
|
| 251 |
+
ax1.grid(True, which='both', linestyle='--', linewidth=0.5)
|
| 252 |
+
|
| 253 |
+
# 2. 成交量子图(同理指定字体)
|
| 254 |
+
ax2.bar(hist_time, hist_df_for_plot['volume']/1e8, color='#00A86B', width=0.6)
|
| 255 |
+
ax2.bar(pred_time, volume_preds.mean(axis=1)/1e8, color='#FF6B00', width=0.6)
|
| 256 |
+
ax2.set_ylabel('成交量(亿手)', fontsize=12, fontproperties=chinese_font)
|
| 257 |
+
ax2.set_xlabel('日期', fontsize=12, fontproperties=chinese_font)
|
| 258 |
+
ax2.legend(['历史成交量(亿手)', '预测成交量(亿手)'],
|
| 259 |
+
fontsize=10, prop=chinese_font)
|
| 260 |
+
ax2.grid(True, which='both', linestyle='--', linewidth=0.5)
|
| 261 |
+
|
| 262 |
+
# 添加分割线(区分历史和预测数据)
|
| 263 |
+
separator_time = last_hist_time + pd.Timedelta(hours=12)
|
| 264 |
+
for ax in [ax1, ax2]:
|
| 265 |
+
ax.axvline(x=separator_time, color='red', linestyle='--', linewidth=1.5, label='_nolegend_')
|
| 266 |
+
ax.tick_params(axis='x', rotation=45)
|
| 267 |
+
|
| 268 |
+
# 保存图表
|
| 269 |
+
fig.tight_layout()
|
| 270 |
+
chart_path = Path(Config["CHART_PATH"])
|
| 271 |
+
if chart_path.exists():
|
| 272 |
+
chart_path.chmod(0o666) # 确保可写权限
|
| 273 |
+
fig.savefig(chart_path, dpi=120, bbox_inches='tight')
|
| 274 |
+
plt.close(fig)
|
| 275 |
+
|
| 276 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 图表生成完成,保存路径:{chart_path}")
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def update_html():
|
| 280 |
+
"""更新HTML页面,复用当天缓存的指标,添加HTML更新日志"""
|
| 281 |
+
china_now = get_china_time()
|
| 282 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始更新HTML页面...")
|
| 283 |
+
|
| 284 |
+
# 1. 从缓存获取指标(增加空值判断,避免报错)
|
| 285 |
+
upside_prob = Config["CACHED_RESULTS"].get("upside_prob")
|
| 286 |
+
vol_amp_prob = Config["CACHED_RESULTS"].get("vol_amp_prob")
|
| 287 |
+
|
| 288 |
+
# 处理缓存为空的情况
|
| 289 |
+
if upside_prob is None or vol_amp_prob is None:
|
| 290 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 警告:缓存中未找到指标数据,无法更新HTML")
|
| 291 |
+
return
|
| 292 |
+
|
| 293 |
+
# 格式化指标(保留1位小数百分比)
|
| 294 |
+
upside_prob_str = f'{upside_prob:.1%}'
|
| 295 |
+
vol_amp_prob_str = f'{vol_amp_prob:.1%}'
|
| 296 |
+
now_cn_str = china_now.strftime('%Y-%m-%d %H:%M:%S')
|
| 297 |
+
|
| 298 |
+
# 2. 初始化HTML(不存在则创建基础模板)
|
| 299 |
+
html_path = Path(Config["HTML_PATH"])
|
| 300 |
+
src_html_path = Config["REPO_PATH"] / "templates" / "index.html"
|
| 301 |
+
|
| 302 |
+
if not html_path.exists():
|
| 303 |
+
html_path.parent.mkdir(parents=True, exist_ok=True)
|
| 304 |
+
if src_html_path.exists():
|
| 305 |
+
# 复制项目模板
|
| 306 |
+
import shutil
|
| 307 |
+
shutil.copy2(src_html_path, html_path)
|
| 308 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 从项目模板复制HTML:{src_html_path} -> {html_path}")
|
| 309 |
+
else:
|
| 310 |
+
# 创建基础中文HTML(确保指标对应的id与正则匹配)
|
| 311 |
+
base_html = """
|
| 312 |
+
<!DOCTYPE html>
|
| 313 |
+
<html>
|
| 314 |
+
<head>
|
| 315 |
+
<title>清华大模型Kronos上证指数预测</title>
|
| 316 |
+
<style>
|
| 317 |
+
body { max-width: 1200px; margin: 0 auto; padding: 20px; font-family: "WenQuanYi Micro Hei", Arial; }
|
| 318 |
+
.metric { margin: 20px 0; padding: 10px; background: #f5f5f5; border-radius: 5px; }
|
| 319 |
+
.metric-value { font-size: 1.2em; color: #0066cc; }
|
| 320 |
+
img { max-width: 100%; height: auto; }
|
| 321 |
+
h1 { color: #333; }
|
| 322 |
+
</style>
|
| 323 |
+
</head>
|
| 324 |
+
<body>
|
| 325 |
+
<h1>清华大学K线大模型Kronos上证指数(sh.000001)概率预测</h1>
|
| 326 |
+
<p>最后更新时间(中国时间):<strong id="update-time">未更新</strong></p>
|
| 327 |
+
<p>同 步 网 站:<strong><a href="http://15115656.top" target="_blank">火狼工具站</a></strong></p>
|
| 328 |
+
<div class="metric">
|
| 329 |
+
<p>24个交易日上涨概率:<span class="metric-value" id="upside-prob">--%</span></p>
|
| 330 |
+
</div>
|
| 331 |
+
<div class="metric">
|
| 332 |
+
<p>24个交易日波动率放大概率:<span class="metric-value" id="vol-amp-prob">--%</span></p>
|
| 333 |
+
</div>
|
| 334 |
+
<div><img src="/prediction_chart.png" alt="上证指数预测图表"></div>
|
| 335 |
+
</body>
|
| 336 |
+
</html>
|
| 337 |
+
"""
|
| 338 |
+
with open(html_path, 'w', encoding='utf-8') as f:
|
| 339 |
+
f.write(base_html)
|
| 340 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 在/tmp创建基础HTML:{html_path}")
|
| 341 |
+
|
| 342 |
+
# 3. 读取HTML内容(确保读取成功)
|
| 343 |
+
try:
|
| 344 |
+
with open(html_path, 'r', encoding='utf-8') as f:
|
| 345 |
+
content = f.read()
|
| 346 |
+
except Exception as e:
|
| 347 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 读取HTML失败:{str(e)}")
|
| 348 |
+
return
|
| 349 |
+
|
| 350 |
+
# 4. 正则替换(关键:确保re.sub()参数完整)
|
| 351 |
+
# 替换更新时间
|
| 352 |
+
content = re.sub(
|
| 353 |
+
pattern=r'(<strong id="update-time">).*?(</strong>)',
|
| 354 |
+
repl=lambda m: f'{m.group(1)}{now_cn_str}{m.group(2)}',
|
| 355 |
+
string=content
|
| 356 |
+
)
|
| 357 |
+
# 替换上涨概率(id="upside-prob",与HTML模板对应)
|
| 358 |
+
content = re.sub(
|
| 359 |
+
pattern=r'(<span class="metric-value" id="upside-prob">).*?(</span>)',
|
| 360 |
+
repl=lambda m: f'{m.group(1)}{upside_prob_str}{m.group(2)}',
|
| 361 |
+
string=content
|
| 362 |
+
)
|
| 363 |
+
# 替换波动率放大概率(id="vol-amp-prob",与HTML模板对应)
|
| 364 |
+
content = re.sub(
|
| 365 |
+
pattern=r'(<span class="metric-value" id="vol-amp-prob">).*?(</span>)',
|
| 366 |
+
repl=lambda m: f'{m.group(1)}{vol_amp_prob_str}{m.group(2)}',
|
| 367 |
+
string=content
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# 5. 写入更新后的HTML
|
| 371 |
+
try:
|
| 372 |
+
with open(html_path, 'w', encoding='utf-8') as f:
|
| 373 |
+
f.write(content)
|
| 374 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] HTML更新完成,路径:{html_path}")
|
| 375 |
+
# 验证替换结果(调试用)
|
| 376 |
+
print(f"[DEBUG] 上涨概率更新为:{upside_prob_str}")
|
| 377 |
+
print(f"[DEBUG] 波动率概率更新为:{vol_amp_prob_str}")
|
| 378 |
+
except Exception as e:
|
| 379 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 写入HTML失败:{str(e)}")
|
| 380 |
+
|
| 381 |
+
def git_commit_and_push():
|
| 382 |
+
"""Git提交(仅当Git存在时执行),添加Git操作日志"""
|
| 383 |
+
china_now = get_china_time()
|
| 384 |
+
commit_message = f"Auto-update: 上证指数预测({china_now:%Y-%m-%d 中国时间})"
|
| 385 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始执行Git提交操作,提交信息:{commit_message}")
|
| 386 |
+
|
| 387 |
+
# 检查Git是否安装
|
| 388 |
+
try:
|
| 389 |
+
subprocess.run(['git', '--version'], check=True, capture_output=True, text=True)
|
| 390 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
| 391 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git未安装或未在PATH中,跳过Git操作")
|
| 392 |
+
return
|
| 393 |
+
|
| 394 |
+
# 执行Git操作
|
| 395 |
+
try:
|
| 396 |
+
os.chdir(Config["REPO_PATH"])
|
| 397 |
+
# 复制图表和HTML到Git跟踪目录(若需要)
|
| 398 |
+
chart_src = Config["CHART_PATH"]
|
| 399 |
+
chart_dst = Config["REPO_PATH"] / "prediction_chart.png"
|
| 400 |
+
html_src = Config["HTML_PATH"]
|
| 401 |
+
html_dst = Config["REPO_PATH"] / "index.html"
|
| 402 |
+
|
| 403 |
+
if os.path.exists(chart_src):
|
| 404 |
+
import shutil
|
| 405 |
+
shutil.copy2(chart_src, chart_dst)
|
| 406 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 图表复制到Git目录:{chart_dst}")
|
| 407 |
+
if os.path.exists(html_src):
|
| 408 |
+
shutil.copy2(html_src, html_dst)
|
| 409 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] HTML复制到Git目录:{html_dst}")
|
| 410 |
+
|
| 411 |
+
# Git add
|
| 412 |
+
subprocess.run(['git', 'add', 'prediction_chart.png', 'index.html'], check=True, capture_output=True, text=True)
|
| 413 |
+
# Git commit
|
| 414 |
+
commit_result = subprocess.run(['git', 'commit', '-m', commit_message], check=True, capture_output=True, text=True)
|
| 415 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git提交输出:\n{commit_result.stdout}")
|
| 416 |
+
# Git push
|
| 417 |
+
push_result = subprocess.run(['git', 'push'], check=True, capture_output=True, text=True)
|
| 418 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git推送输出:\n{push_result.stdout}")
|
| 419 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git操作完成")
|
| 420 |
+
|
| 421 |
+
except subprocess.CalledProcessError as e:
|
| 422 |
+
output = e.stdout if e.stdout else e.stderr
|
| 423 |
+
if "nothing to commit" in output or "Your branch is up to date" in output:
|
| 424 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 无新内容需要提交或推送")
|
| 425 |
+
else:
|
| 426 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git错误:\nSTDOUT: {e.stdout}\nSTDERR: {e.stderr}")
|
| 427 |
+
except PermissionError as e:
|
| 428 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git权限错误:{str(e)},跳过Git操作")
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def main_task(model):
|
| 432 |
+
"""主任务:控制每日仅执行一次推理,当天复用缓存"""
|
| 433 |
+
china_now = get_china_time()
|
| 434 |
+
print(f"\n[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60)
|
| 435 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始执行主任务")
|
| 436 |
+
|
| 437 |
+
# 检查当天是否已完成推理(中国时间)
|
| 438 |
+
if Config["IS_TODAY_INFERENCED"]:
|
| 439 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 今日(中国时间)已完成推理,直接复用缓存结果")
|
| 440 |
+
# 复用缓存生成图表和HTML
|
| 441 |
+
create_plot()
|
| 442 |
+
update_html()
|
| 443 |
+
git_commit_and_push()
|
| 444 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 主任务完成(复用缓存)")
|
| 445 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60 + "\n")
|
| 446 |
+
return
|
| 447 |
+
|
| 448 |
+
# 当天首次执行:获取数据→推理→缓存结果→生成图表→更新HTML→Git提交
|
| 449 |
+
try:
|
| 450 |
+
# 1. 获取股票数据
|
| 451 |
+
df_full = fetch_stock_data()
|
| 452 |
+
df_for_model = df_full.iloc[:-1] # 排除最后一行避免数据泄漏
|
| 453 |
+
|
| 454 |
+
# 2. 执行推理
|
| 455 |
+
close_preds, volume_preds, v_close_preds = make_prediction(df_for_model, model)
|
| 456 |
+
|
| 457 |
+
# 3. 计算指标
|
| 458 |
+
hist_df_for_metrics = df_for_model.tail(Config["VOL_WINDOW"])
|
| 459 |
+
upside_prob, vol_amp_prob = calculate_metrics(hist_df_for_metrics, close_preds, v_close_preds)
|
| 460 |
+
|
| 461 |
+
# 4. 缓存结果(当天复用)
|
| 462 |
+
hist_df_for_plot = df_for_model.tail(Config["VOL_WINDOW"]) # 用于绘图的历史数据
|
| 463 |
+
Config["CACHED_RESULTS"] = {
|
| 464 |
+
"close_preds": close_preds,
|
| 465 |
+
"volume_preds": volume_preds,
|
| 466 |
+
"v_close_preds": v_close_preds,
|
| 467 |
+
"upside_prob": upside_prob,
|
| 468 |
+
"vol_amp_prob": vol_amp_prob,
|
| 469 |
+
"hist_df_for_plot": hist_df_for_plot
|
| 470 |
+
}
|
| 471 |
+
# 标记当天已完成推理
|
| 472 |
+
Config["IS_TODAY_INFERENCED"] = True
|
| 473 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 今日推理结果已缓存,后续调用将复用")
|
| 474 |
+
|
| 475 |
+
# 5. 生成图表
|
| 476 |
+
create_plot()
|
| 477 |
+
|
| 478 |
+
# 6. 更新HTML
|
| 479 |
+
update_html()
|
| 480 |
+
|
| 481 |
+
# 7. Git提交
|
| 482 |
+
git_commit_and_push()
|
| 483 |
+
|
| 484 |
+
# 8. 内存回收
|
| 485 |
+
del df_full, df_for_model, hist_df_for_metrics
|
| 486 |
+
gc.collect()
|
| 487 |
+
|
| 488 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 主任务完成(首次推理)")
|
| 489 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60 + "\n")
|
| 490 |
+
|
| 491 |
+
except Exception as e:
|
| 492 |
+
# 异常时不标记为“已推理”,下次调用重试
|
| 493 |
+
Config["IS_TODAY_INFERENCED"] = False
|
| 494 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 主任务执行失败,今日推理标记为未完成")
|
| 495 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 错误信息:{str(e)}")
|
| 496 |
+
import traceback
|
| 497 |
+
traceback.print_exc()
|
| 498 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60 + "\n")
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def run_scheduler(model):
|
| 502 |
+
"""定时器:中国时间每天0点触发主任务,其他时间5分钟检查一次"""
|
| 503 |
+
china_tz = timezone("Asia/Shanghai")
|
| 504 |
+
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 定时器启动(中国时间),每天0点执行推理")
|
| 505 |
+
|
| 506 |
+
while True:
|
| 507 |
+
china_now = get_china_time()
|
| 508 |
+
# 计算次日0点(中国时间)
|
| 509 |
+
next_midnight = (china_now + timedelta(days=1)).replace(
|
| 510 |
+
hour=0, minute=0, second=5, microsecond=0, tzinfo=china_tz
|
| 511 |
+
)
|
| 512 |
+
# 计算等待时间(秒)
|
| 513 |
+
sleep_seconds = (next_midnight - china_now).total_seconds()
|
| 514 |
+
|
| 515 |
+
# 打印等待日志
|
| 516 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 当前时间:{china_now:%Y-%m-%d %H:%M:%S}(中国时间)")
|
| 517 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 下次执行时间:{next_midnight:%Y-%m-%d %H:%M:%S}(中国时间)")
|
| 518 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 等待时间:{sleep_seconds:.0f}秒(约{sleep_seconds/3600:.1f}小时)")
|
| 519 |
+
|
| 520 |
+
# 等待到次日0点
|
| 521 |
+
time.sleep(sleep_seconds)
|
| 522 |
+
|
| 523 |
+
# 到达0点,执行主任务
|
| 524 |
+
try:
|
| 525 |
+
main_task(model)
|
| 526 |
+
# 任务完成后,重置“当天已推理”标记(避免跨天复用)
|
| 527 |
+
Config["IS_TODAY_INFERENCED"] = False
|
| 528 |
+
except Exception as e:
|
| 529 |
+
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 定时器触发任务失败:{str(e)}")
|
| 530 |
+
import traceback
|
| 531 |
+
traceback.print_exc()
|
| 532 |
+
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 5分钟后重试...")
|
| 533 |
+
time.sleep(300) # 重试间隔5分钟
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
if __name__ == '__main__':
|
| 537 |
+
# 初始化:加载模型→执行一次主任务→启动定时器
|
| 538 |
+
china_now = get_china_time()
|
| 539 |
+
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 程序启动(中国时间)")
|
| 540 |
+
|
| 541 |
+
# 加载模型
|
| 542 |
+
loaded_model = load_local_model()
|
| 543 |
+
|
| 544 |
+
# 首次执行主任务(若当天未执行)
|
| 545 |
+
main_task(loaded_model)
|
| 546 |
+
|
| 547 |
+
# 启动定时器(中国时间每天0点执行)
|
| 548 |
+
run_scheduler(loaded_model)
|