Spaces:
Running
Running
import gc | |
import os | |
import re | |
import subprocess | |
import time | |
from datetime import datetime, timedelta, date | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import torch | |
import baostock as bs | |
from pytz import timezone # 处理中国时区(Asia/Shanghai) | |
from model import KronosTokenizer, Kronos, KronosPredictor | |
# --- Configuration --- | |
Config = { | |
"REPO_PATH": Path(__file__).parent.resolve(), | |
"LOCAL_MODEL_PATH": os.path.join(Path(__file__).parent.resolve(), "models"), | |
"STOCK_CODE": "sh.000001", | |
"FREQUENCY": "d", | |
"START_DATE": "2022-01-01", | |
"PRED_HORIZON": 24, | |
"N_PREDICTIONS": 10, | |
"VOL_WINDOW": 24, | |
"PREDICTION_CACHE": os.path.join("/tmp", "predictions_cache"), | |
"CHART_PATH": os.path.join("/tmp", "prediction_chart.png"), | |
"HTML_PATH": os.path.join("/tmp", "index.html"), | |
# 核心修改:用“最后推理业务日”替代原IS_TODAY_INFERENCED(记录具体日期而非布尔值) | |
"LAST_INFERENCED_BUSINESS_DATE": None, | |
"CACHED_RESULTS": { | |
"close_preds": None, | |
"volume_preds": None, | |
"v_close_preds": None, | |
"upside_prob": None, | |
"vol_amp_prob": None, | |
"hist_df_for_plot": None | |
} | |
} | |
# 补充定义中文字体路径(此时Config已完全定义) | |
Config["CHINESE_FONT_PATH"] = os.path.join(Config["REPO_PATH"], "fonts", "wqy-microhei.ttf") | |
# 创建必要目录 | |
os.makedirs(Config["PREDICTION_CACHE"], exist_ok=True) | |
os.makedirs(Config["LOCAL_MODEL_PATH"], exist_ok=True) | |
def get_china_time(): | |
"""获取当前中国时间(Asia/Shanghai时区),返回datetime对象""" | |
china_tz = timezone("Asia/Shanghai") | |
return datetime.now(china_tz) | |
# -------------------------- 新增:业务日判断函数(核心修改) -------------------------- | |
def get_business_info(): | |
""" | |
基于北京时间20点分界,返回当前业务信息 | |
返回: | |
current_business_date: date对象 - 当前业务日(20点前=昨天,20点后=今天) | |
is_after_20h: bool - 是否已过当天20点(北京时间) | |
""" | |
china_now = get_china_time() | |
is_after_20h = china_now.hour >= 20 # 判断是否过20点 | |
if is_after_20h: | |
current_business_date = china_now.date() # 20点后:业务日=今天 | |
else: | |
current_business_date = (china_now - timedelta(days=1)).date() # 20点前:业务日=昨天 | |
return current_business_date, is_after_20h | |
def load_local_model(): | |
"""加载本地Kronos模型,添加字体加载日志""" | |
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 开始加载本地Kronos模型...") | |
tokenizer_path = os.path.join(Config["LOCAL_MODEL_PATH"], "tokenizer") | |
model_path = os.path.join(Config["LOCAL_MODEL_PATH"], "model") | |
# 检查模型文件是否存在 | |
if not os.path.exists(tokenizer_path): | |
raise FileNotFoundError(f"分词器路径不存在:{tokenizer_path}") | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"模型路径不存在:{model_path}") | |
# 加载模型和分词器 | |
tokenizer = KronosTokenizer.from_pretrained(tokenizer_path, local_files_only=True) | |
model = Kronos.from_pretrained(model_path, local_files_only=True) | |
tokenizer.eval() | |
model.eval() | |
predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512) | |
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 本地模型加载成功") | |
return predictor | |
# -------------------------- 修改:数据获取日期(基于业务日) -------------------------- | |
def fetch_stock_data(): | |
"""获取股票数据(基于业务日更新,中国时间),添加数据获取日志""" | |
china_now = get_china_time() | |
current_business_date, _ = get_business_info() # 核心修改:用业务日作为数据结束日期 | |
end_date = current_business_date.strftime("%Y-%m-%d") | |
need_points = Config["VOL_WINDOW"] + Config["VOL_WINDOW"] # 历史数据+波动率计算窗口 | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始获取{Config['STOCK_CODE']}日线数据(业务日结束日期:{end_date})") | |
lg = bs.login() | |
if lg.error_code != '0': | |
raise ConnectionError(f"Baostock登录失败:{lg.error_msg}") | |
try: | |
# 调用baostock获取K线数据 | |
fields = "date,open,high,low,close,volume" | |
rs = bs.query_history_k_data_plus( | |
code=Config["STOCK_CODE"], | |
fields=fields, | |
start_date=Config["START_DATE"], | |
end_date=end_date, | |
frequency=Config["FREQUENCY"], | |
adjustflag="2" # 后复权 | |
) | |
if rs.error_code != '0': | |
raise ValueError(f"获取K线数据失败:{rs.error_msg}") | |
# 处理数据 | |
data_list = [] | |
while rs.next(): | |
data_list.append(rs.get_row_data()) | |
df = pd.DataFrame(data_list, columns=rs.fields) | |
# 数值列转换 | |
numeric_cols = ['open', 'high', 'low', 'close', 'volume'] | |
for col in numeric_cols: | |
df[col] = pd.to_numeric(df[col], errors='coerce') | |
df = df.dropna(subset=numeric_cols) | |
# 添加时间戳和成交额列 | |
df['timestamps'] = pd.to_datetime(df['date'], format='%Y-%m-%d') | |
df['amount'] = (df['open'] + df['high'] + df['low'] + df['close']) / 4 * df['volume'] | |
df = df[['timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']] | |
# 检查数据量 | |
if len(df) < need_points: | |
raise ValueError(f"数据量不足(仅{len(df)}个交易日),请提前START_DATE") | |
df = df.tail(need_points).reset_index(drop=True) | |
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 股票数据获取成功,共{len(df)}个交易日") | |
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 最新5条数据:\n{df[['timestamps', 'open', 'close', 'volume']].tail()}") | |
return df | |
finally: | |
bs.logout() | |
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] Baostock已登出") | |
def make_prediction(df, predictor): | |
"""执行模型推理,仅当前业务日首次调用时运行,添加推理日志""" | |
china_now = get_china_time() | |
current_business_date, _ = get_business_info() | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始执行模型推理(业务日:{current_business_date},预测未来{Config['PRED_HORIZON']}个交易日)") | |
# 准备时间戳 | |
last_timestamp = df['timestamps'].max() | |
start_new_range = last_timestamp + pd.Timedelta(days=1) | |
new_timestamps_index = pd.date_range( | |
start=start_new_range, | |
periods=Config["PRED_HORIZON"], | |
freq='D' | |
) | |
y_timestamp = pd.Series(new_timestamps_index, name='y_timestamp') | |
x_timestamp = df['timestamps'] | |
x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']] | |
# 推理(禁用梯度计算,节省资源) | |
with torch.no_grad(): | |
begin_time = time.time() | |
close_preds_main, volume_preds_main = predictor.predict( | |
df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, | |
pred_len=Config["PRED_HORIZON"], T=1.0, top_p=0.95, | |
sample_count=Config["N_PREDICTIONS"], verbose=True | |
) | |
infer_time = time.time() - begin_time | |
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 推理完成,耗时{infer_time:.2f}秒") | |
# 波动率预测复用收盘价预测结果(保持原逻辑) | |
close_preds_volatility = close_preds_main | |
return close_preds_main, volume_preds_main, close_preds_volatility | |
def calculate_metrics(hist_df, close_preds_df, v_close_preds_df): | |
"""计算上涨概率和波动率放大概率,添加指标计算日志""" | |
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 开始计算预测指标...") | |
# 上涨概率(最后一个预测日相对于最新收盘价) | |
last_close = hist_df['close'].iloc[-1] | |
final_day_preds = close_preds_df.iloc[-1] | |
upside_prob = (final_day_preds > last_close).mean() | |
# 波动率放大概率(预测波动率vs历史波动率) | |
hist_log_returns = np.log(hist_df['close'] / hist_df['close'].shift(1)) | |
historical_vol = hist_log_returns.iloc[-Config["VOL_WINDOW"]:].std() | |
amplification_count = 0 | |
for col in v_close_preds_df.columns: | |
full_sequence = pd.concat([pd.Series([last_close]), v_close_preds_df[col]]).reset_index(drop=True) | |
pred_log_returns = np.log(full_sequence / full_sequence.shift(1)) | |
predicted_vol = pred_log_returns.std() | |
if predicted_vol > historical_vol: | |
amplification_count += 1 | |
vol_amp_prob = amplification_count / len(v_close_preds_df.columns) | |
# 打印指标日志 | |
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 指标计算完成:") | |
print(f" - 24个交易日上涨概率:{upside_prob:.2%}") | |
print(f" - 24个交易日波动率放大概率:{vol_amp_prob:.2%}") | |
return upside_prob, vol_amp_prob | |
def create_plot(): | |
china_now = get_china_time() | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始生成预测图表(适配低版本matplotlib字体)") | |
# 从缓存获取数据(原有逻辑不变) | |
hist_df_for_plot = Config["CACHED_RESULTS"]["hist_df_for_plot"] | |
close_preds = Config["CACHED_RESULTS"]["close_preds"] | |
volume_preds = Config["CACHED_RESULTS"]["volume_preds"] | |
# -------------------------- 新增:创建画布和子图 -------------------------- | |
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True) | |
# ----------------------------------------------------------------------------- | |
# -------------------------- 修正:低版本matplotlib字体处理 -------------------------- | |
from matplotlib.font_manager import FontProperties | |
font_path = Config["CHINESE_FONT_PATH"] | |
# 检查字体文件是否存在 | |
if os.path.exists(font_path): | |
# 直接通过FontProperties指定字体文件路径(兼容低版本matplotlib) | |
chinese_font = FontProperties(fname=font_path) | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 成功加载.ttf字体:{font_path}") | |
else: | |
# 字体文件不存在时的 fallback 逻辑 | |
chinese_font = FontProperties(family='SimHei', size=10) | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 字体文件不存在,使用系统默认字体:SimHei") | |
# 全局设置字体(确保坐标轴刻度等默认文本也能显示中文) | |
plt.rcParams["font.family"] = ["sans-serif"] | |
plt.rcParams["font.sans-serif"] = ["WenQuanYi Micro Hei", "SimHei", "Heiti TC"] | |
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 | |
# ----------------------------------------------------------------------------- | |
# 绘图时,为所有中文文本显式指定字体(关键) | |
# 1. 价格子图 | |
hist_time = hist_df_for_plot['timestamps'] | |
ax1.plot(hist_time, hist_df_for_plot['close'], color='#00274C', linewidth=1.5) | |
mean_preds = close_preds.mean(axis=1) | |
# 生成预测时间序列(假设预测是在历史最后一个时间之后的24个交易日) | |
last_hist_time = hist_time.max() | |
pred_time = pd.date_range(start=last_hist_time + pd.Timedelta(days=1), periods=Config["PRED_HORIZON"], freq='B') | |
ax1.plot(pred_time, mean_preds, color='#FF6B00', linestyle='-') | |
ax1.fill_between(pred_time, close_preds.min(axis=1), close_preds.max(axis=1), | |
color='#FF6B00', alpha=0.2) | |
# 中文标题/标签指定字体 | |
ax1.set_title(f'{Config["STOCK_CODE"]} 上证指数概率预测(未来{Config["PRED_HORIZON"]}个交易日)', | |
fontsize=16, weight='bold', fontproperties=chinese_font) | |
ax1.set_ylabel('价格(元)', fontsize=12, fontproperties=chinese_font) | |
# 图例指定字体 | |
ax1.legend(['上证指数(后复权)', '预测均价', '预测区间(最小-最大)'], | |
fontsize=10, prop=chinese_font) | |
ax1.grid(True, which='both', linestyle='--', linewidth=0.5) | |
# 2. 成交量子图(同理指定字体) | |
ax2.bar(hist_time, hist_df_for_plot['volume']/1e8, color='#00A86B', width=0.6) | |
ax2.bar(pred_time, volume_preds.mean(axis=1)/1e8, color='#FF6B00', width=0.6) | |
ax2.set_ylabel('成交量(亿手)', fontsize=12, fontproperties=chinese_font) | |
ax2.set_xlabel('日期', fontsize=12, fontproperties=chinese_font) | |
ax2.legend(['历史成交量(亿手)', '预测成交量(亿手)'], | |
fontsize=10, prop=chinese_font) | |
ax2.grid(True, which='both', linestyle='--', linewidth=0.5) | |
# 添加分割线(区分历史和预测数据) | |
separator_time = last_hist_time + pd.Timedelta(hours=12) | |
for ax in [ax1, ax2]: | |
ax.axvline(x=separator_time, color='red', linestyle='--', linewidth=1.5, label='_nolegend_') | |
ax.tick_params(axis='x', rotation=45) | |
# 保存图表 | |
fig.tight_layout() | |
chart_path = Path(Config["CHART_PATH"]) | |
if chart_path.exists(): | |
chart_path.chmod(0o666) # 确保可写权限 | |
fig.savefig(chart_path, dpi=120, bbox_inches='tight') | |
plt.close(fig) | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 图表生成完成,保存路径:{chart_path}") | |
def update_html(): | |
"""更新HTML页面,复用当前业务日缓存的指标,添加HTML更新日志""" | |
china_now = get_china_time() | |
current_business_date, _ = get_business_info() | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始更新HTML页面(业务日:{current_business_date})...") | |
# 1. 从缓存获取指标(增加空值判断,避免报错) | |
upside_prob = Config["CACHED_RESULTS"].get("upside_prob") | |
vol_amp_prob = Config["CACHED_RESULTS"].get("vol_amp_prob") | |
# 处理缓存为空的情况 | |
if upside_prob is None or vol_amp_prob is None: | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 警告:缓存中未找到指标数据,无法更新HTML") | |
return | |
# 格式化指标(保留1位小数百分比) | |
upside_prob_str = f'{upside_prob:.1%}' | |
vol_amp_prob_str = f'{vol_amp_prob:.1%}' | |
now_cn_str = china_now.strftime('%Y-%m-%d %H:%M:%S') | |
# 2. 初始化HTML(不存在则创建基础模板) | |
html_path = Path(Config["HTML_PATH"]) | |
src_html_path = Config["REPO_PATH"] / "templates" / "index.html" | |
if not html_path.exists(): | |
html_path.parent.mkdir(parents=True, exist_ok=True) | |
if src_html_path.exists(): | |
# 复制项目模板 | |
import shutil | |
shutil.copy2(src_html_path, html_path) | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 从项目模板复制HTML:{src_html_path} -> {html_path}") | |
else: | |
# 创建基础中文HTML(确保指标对应的id与正则匹配) | |
base_html = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>清华大模型Kronos上证指数预测</title> | |
<style> | |
body { max-width: 1200px; margin: 0 auto; padding: 20px; font-family: "WenQuanYi Micro Hei", Arial; } | |
.metric { margin: 20px 0; padding: 10px; background: #f5f5f5; border-radius: 5px; } | |
.metric-value { font-size: 1.2em; color: #0066cc; } | |
img { max-width: 100%; height: auto; } | |
h1 { color: #333; } | |
</style> | |
</head> | |
<body> | |
<h1>清华大学K线大模型Kronos上证指数(sh.000001)概率预测</h1> | |
<p>最后更新时间(中国时间):<strong id="update-time">未更新</strong></p> | |
<p>同 步 网 站:<strong><a href="http://15115656.top" target="_blank">火狼工具站</a></strong></p> | |
<div class="metric"> | |
<p>24个交易日上涨概率:<span class="metric-value" id="upside-prob">--%</span></p> | |
</div> | |
<div class="metric"> | |
<p>波动率放大概率:<span class="metric-value" id="vol-amp-prob">--%</span></p> | |
</div> | |
<div><img src="/prediction_chart.png" alt="上证指数预测图表"></div> | |
</body> | |
</html> | |
""" | |
with open(html_path, 'w', encoding='utf-8') as f: | |
f.write(base_html) | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 在/tmp创建基础HTML:{html_path}") | |
# 3. 读取HTML内容(确保读取成功) | |
try: | |
with open(html_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
except Exception as e: | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 读取HTML失败:{str(e)}") | |
return | |
# 4. 正则替换(关键:确保re.sub()参数完整) | |
# 替换更新时间 | |
content = re.sub( | |
pattern=r'(<strong id="update-time">).*?(</strong>)', | |
repl=lambda m: f'{m.group(1)}{now_cn_str}{m.group(2)}', | |
string=content | |
) | |
# 替换上涨概率(id="upside-prob",与HTML模板对应) | |
content = re.sub( | |
pattern=r'(<span class="metric-value" id="upside-prob">).*?(</span>)', | |
repl=lambda m: f'{m.group(1)}{upside_prob_str}{m.group(2)}', | |
string=content | |
) | |
# 替换波动率放大概率(id="vol-amp-prob",与HTML模板对应) | |
content = re.sub( | |
pattern=r'(<span class="metric-value" id="vol-amp-prob">).*?(</span>)', | |
repl=lambda m: f'{m.group(1)}{vol_amp_prob_str}{m.group(2)}', | |
string=content | |
) | |
# 5. 写入更新后的HTML | |
try: | |
with open(html_path, 'w', encoding='utf-8') as f: | |
f.write(content) | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] HTML更新完成,路径:{html_path}") | |
# 验证替换结果(调试用) | |
print(f"[DEBUG] 上涨概率更新为:{upside_prob_str}") | |
print(f"[DEBUG] 波动率概率更新为:{vol_amp_prob_str}") | |
except Exception as e: | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 写入HTML失败:{str(e)}") | |
def git_commit_and_push(): | |
"""Git提交(仅当Git存在时执行),添加Git操作日志""" | |
china_now = get_china_time() | |
current_business_date, _ = get_business_info() | |
commit_message = f"Auto-update: 上证指数预测(业务日{current_business_date} 中国时间)" | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始执行Git提交操作,提交信息:{commit_message}") | |
# 检查Git是否安装 | |
try: | |
subprocess.run(['git', '--version'], check=True, capture_output=True, text=True) | |
except (subprocess.CalledProcessError, FileNotFoundError): | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git未安装或未在PATH中,跳过Git操作") | |
return | |
# 执行Git操作 | |
try: | |
os.chdir(Config["REPO_PATH"]) | |
# 复制图表和HTML到Git跟踪目录(若需要) | |
chart_src = Config["CHART_PATH"] | |
chart_dst = Config["REPO_PATH"] / "prediction_chart.png" | |
html_src = Config["HTML_PATH"] | |
html_dst = Config["REPO_PATH"] / "index.html" | |
if os.path.exists(chart_src): | |
import shutil | |
shutil.copy2(chart_src, chart_dst) | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 图表复制到Git目录:{chart_dst}") | |
if os.path.exists(html_src): | |
shutil.copy2(html_src, html_dst) | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] HTML复制到Git目录:{html_dst}") | |
# Git add | |
subprocess.run(['git', 'add', 'prediction_chart.png', 'index.html'], check=True, capture_output=True, text=True) | |
# Git commit | |
commit_result = subprocess.run(['git', 'commit', '-m', commit_message], check=True, capture_output=True, text=True) | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git提交输出:\n{commit_result.stdout}") | |
# Git push | |
push_result = subprocess.run(['git', 'push'], check=True, capture_output=True, text=True) | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git推送输出:\n{push_result.stdout}") | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git操作完成") | |
except subprocess.CalledProcessError as e: | |
output = e.stdout if e.stdout else e.stderr | |
if "nothing to commit" in output or "Your branch is up to date" in output: | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 无新内容需要提交或推送") | |
else: | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git错误:\nSTDOUT: {e.stdout}\nSTDERR: {e.stderr}") | |
except PermissionError as e: | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git权限错误:{str(e)},跳过Git操作") | |
# -------------------------- 修改:主任务逻辑(基于业务日判断) -------------------------- | |
def main_task(model): | |
"""主任务:控制基于20点分界的业务日推理逻辑,同业务日复用缓存""" | |
china_now = get_china_time() | |
current_business_date, is_after_20h = get_business_info() # 获取当前业务信息 | |
print(f"\n[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60) | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始执行主任务") | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 当前业务日:{current_business_date}(北京时间{'20点后' if is_after_20h else '20点前'})") | |
# 核心修改:判断当前业务日是否已推理(而非自然日) | |
if Config["LAST_INFERENCED_BUSINESS_DATE"] == current_business_date: | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 当前业务日({current_business_date})已完成推理,直接复用缓存结果") | |
# 复用缓存生成图表和HTML | |
create_plot() | |
update_html() | |
git_commit_and_push() | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 主任务完成(复用缓存)") | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60 + "\n") | |
return | |
# 当前业务日未推理:执行完整流程 | |
try: | |
# 1. 获取股票数据 | |
df_full = fetch_stock_data() | |
df_for_model = df_full.iloc[:-1] # 排除最后一行避免数据泄漏 | |
# 2. 执行推理 | |
close_preds, volume_preds, v_close_preds = make_prediction(df_for_model, model) | |
# 3. 计算指标 | |
hist_df_for_metrics = df_for_model.tail(Config["VOL_WINDOW"]) | |
upside_prob, vol_amp_prob = calculate_metrics(hist_df_for_metrics, close_preds, v_close_preds) | |
# 4. 缓存结果(当前业务日复用) | |
hist_df_for_plot = df_for_model.tail(Config["VOL_WINDOW"]) # 用于绘图的历史数据 | |
Config["CACHED_RESULTS"] = { | |
"close_preds": close_preds, | |
"volume_preds": volume_preds, | |
"v_close_preds": v_close_preds, | |
"upside_prob": upside_prob, | |
"vol_amp_prob": vol_amp_prob, | |
"hist_df_for_plot": hist_df_for_plot | |
} | |
# 核心修改:标记当前业务日已推理(而非布尔值) | |
Config["LAST_INFERENCED_BUSINESS_DATE"] = current_business_date | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 业务日({current_business_date})推理结果已缓存,同业务日后续调用将复用") | |
# 5. 生成图表 | |
create_plot() | |
# 6. 更新HTML | |
update_html() | |
# 7. Git提交 | |
git_commit_and_push() | |
# 8. 内存回收 | |
del df_full, df_for_model, hist_df_for_metrics | |
gc.collect() | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 主任务完成(首次推理)") | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60 + "\n") | |
except Exception as e: | |
# 异常时不更新业务日标记,下次调用重试 | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 主任务执行失败,业务日({current_business_date})推理标记为未完成") | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 错误信息:{str(e)}") | |
import traceback | |
traceback.print_exc() | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60 + "\n") | |
# -------------------------- 修改:定时器逻辑(从0点改为20点触发) -------------------------- | |
def run_scheduler(model): | |
"""定时器:基于北京时间20点分界触发任务,其他时间5分钟检查一次""" | |
china_tz = timezone("Asia/Shanghai") | |
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 定时器启动(中国时间),每天20点执行推理") | |
while True: | |
china_now = get_china_time() | |
current_business_date, is_after_20h = get_business_info() | |
# 核心修改:计算下次执行时间(20点触发) | |
if is_after_20h: | |
# 已过当天20点 → 下次执行时间为次日20点 | |
next_exec_date = (china_now + timedelta(days=1)).date() | |
else: | |
# 未过当天20点 → 下次执行时间为当天20点 | |
next_exec_date = china_now.date() | |
# 构造下次执行时间(20:00:05,留5秒缓冲避免毫秒级误差) | |
next_exec_time = datetime.combine( | |
next_exec_date, | |
datetime.strptime("20:00:05", "%H:%M:%S").time(), | |
tzinfo=china_tz | |
) | |
# 计算等待时间(秒),最小等待5分钟(防止时间计算错误导致负数) | |
sleep_seconds = (next_exec_time - china_now).total_seconds() | |
sleep_seconds = max(sleep_seconds, 300) | |
# 打印等待日志 | |
print(f"\n[{china_now:%Y-%m-%d %H:%M:%S}] 定时器状态:") | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 当前时间:{china_now:%Y-%m-%d %H:%M:%S}(中国时间)") | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 当前业务日:{current_business_date}({'20点后' if is_after_20h else '20点前'})") | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 下次执行时间:{next_exec_time:%Y-%m-%d %H:%M:%S}(中国时间)") | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 等待时间:{sleep_seconds:.0f}秒(约{sleep_seconds/3600:.1f}小时)") | |
# 等待到下次执行时间 | |
time.sleep(sleep_seconds) | |
# 到达执行时间,触发主任务 | |
try: | |
main_task(model) | |
# 无需重置业务日标记(下次判断基于新业务日) | |
except Exception as e: | |
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 定时器触发任务失败:{str(e)}") | |
import traceback | |
traceback.print_exc() | |
print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 5分钟后重试...") | |
time.sleep(300) # 重试间隔5分钟 | |
if __name__ == '__main__': | |
# 初始化:加载模型→执行一次主任务→启动定时器 | |
china_now = get_china_time() | |
print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 程序启动(中国时间)") | |
# 加载模型 | |
loaded_model = load_local_model() | |
# 首次执行主任务(若当前业务日未执行) | |
main_task(loaded_model) | |
# 启动定时器(中国时间每天20点执行) | |
run_scheduler(loaded_model) |