Spaces:
Running
Running
| import gc | |
| import os | |
| import re | |
| import subprocess | |
| import time | |
| from datetime import datetime, timedelta, date | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import baostock as bs | |
| from pytz import timezone # 处理中国时区(Asia/Shanghai) | |
| from model import KronosTokenizer, Kronos, KronosPredictor | |
| # --- Configuration --- | |
| Config = { | |
| "REPO_PATH": Path(__file__).parent.resolve(), | |
| "LOCAL_MODEL_PATH": os.path.join(Path(__file__).parent.resolve(), "models"), | |
| "STOCK_CODE": "sh.000001", | |
| "FREQUENCY": "d", | |
| "START_DATE": "2022-01-01", | |
| "PRED_HORIZON": 24, | |
| "N_PREDICTIONS": 10, | |
| "VOL_WINDOW": 24, | |
| "PREDICTION_CACHE": os.path.join("/tmp", "predictions_cache"), | |
| "CHART_PATH": os.path.join("/tmp", "prediction_chart.png"), | |
| "HTML_PATH": os.path.join("/tmp", "index.html"), | |
| # 核心修改:用“最后推理业务日”替代原IS_TODAY_INFERENCED(记录具体日期而非布尔值) | |
| "LAST_INFERENCED_BUSINESS_DATE": None, | |
| "CACHED_RESULTS": { | |
| "close_preds": None, | |
| "volume_preds": None, | |
| "v_close_preds": None, | |
| "upside_prob": None, | |
| "vol_amp_prob": None, | |
| "hist_df_for_plot": None | |
| } | |
| } | |
| # 补充定义中文字体路径(此时Config已完全定义) | |
| Config["CHINESE_FONT_PATH"] = os.path.join(Config["REPO_PATH"], "fonts", "wqy-microhei.ttf") | |
| # 创建必要目录 | |
| os.makedirs(Config["PREDICTION_CACHE"], exist_ok=True) | |
| os.makedirs(Config["LOCAL_MODEL_PATH"], exist_ok=True) | |
| def get_china_time(): | |
| """获取当前中国时间(Asia/Shanghai时区),返回datetime对象""" | |
| china_tz = timezone("Asia/Shanghai") | |
| return datetime.now(china_tz) | |
| # -------------------------- 新增:业务日判断函数(核心修改) -------------------------- | |
| def get_business_info(): | |
| """ | |
| 基于北京时间20点分界,返回当前业务信息 | |
| 返回: | |
| current_business_date: date对象 - 当前业务日(20点前=昨天,20点后=今天) | |
| is_after_20h: bool - 是否已过当天20点(北京时间) | |
| """ | |
| china_now = get_china_time() | |
| is_after_20h = china_now.hour >= 20 # 判断是否过20点 | |
| if is_after_20h: | |
| current_business_date = china_now.date() # 20点后:业务日=今天 | |
| else: | |
| current_business_date = (china_now - timedelta(days=1)).date() # 20点前:业务日=昨天 | |
| return current_business_date, is_after_20h | |
| def load_local_model(): | |
| """加载本地Kronos模型,添加字体加载日志""" | |
| print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 开始加载本地Kronos模型...") | |
| tokenizer_path = os.path.join(Config["LOCAL_MODEL_PATH"], "tokenizer") | |
| model_path = os.path.join(Config["LOCAL_MODEL_PATH"], "model") | |
| # 检查模型文件是否存在 | |
| if not os.path.exists(tokenizer_path): | |
| raise FileNotFoundError(f"分词器路径不存在:{tokenizer_path}") | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"模型路径不存在:{model_path}") | |
| # 加载模型和分词器 | |
| tokenizer = KronosTokenizer.from_pretrained(tokenizer_path, local_files_only=True) | |
| model = Kronos.from_pretrained(model_path, local_files_only=True) | |
| tokenizer.eval() | |
| model.eval() | |
| predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512) | |
| print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 本地模型加载成功") | |
| return predictor | |
| # -------------------------- 修改:数据获取日期(基于业务日) -------------------------- | |
| def fetch_stock_data(): | |
| """获取股票数据(基于业务日更新,中国时间),添加数据获取日志""" | |
| china_now = get_china_time() | |
| current_business_date, _ = get_business_info() # 核心修改:用业务日作为数据结束日期 | |
| end_date = current_business_date.strftime("%Y-%m-%d") | |
| need_points = Config["VOL_WINDOW"] + Config["VOL_WINDOW"] # 历史数据+波动率计算窗口 | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始获取{Config['STOCK_CODE']}日线数据(业务日结束日期:{end_date})") | |
| lg = bs.login() | |
| if lg.error_code != '0': | |
| raise ConnectionError(f"Baostock登录失败:{lg.error_msg}") | |
| try: | |
| # 调用baostock获取K线数据 | |
| fields = "date,open,high,low,close,volume" | |
| rs = bs.query_history_k_data_plus( | |
| code=Config["STOCK_CODE"], | |
| fields=fields, | |
| start_date=Config["START_DATE"], | |
| end_date=end_date, | |
| frequency=Config["FREQUENCY"], | |
| adjustflag="2" # 后复权 | |
| ) | |
| if rs.error_code != '0': | |
| raise ValueError(f"获取K线数据失败:{rs.error_msg}") | |
| # 处理数据 | |
| data_list = [] | |
| while rs.next(): | |
| data_list.append(rs.get_row_data()) | |
| df = pd.DataFrame(data_list, columns=rs.fields) | |
| # 数值列转换 | |
| numeric_cols = ['open', 'high', 'low', 'close', 'volume'] | |
| for col in numeric_cols: | |
| df[col] = pd.to_numeric(df[col], errors='coerce') | |
| df = df.dropna(subset=numeric_cols) | |
| # 添加时间戳和成交额列 | |
| df['timestamps'] = pd.to_datetime(df['date'], format='%Y-%m-%d') | |
| df['amount'] = (df['open'] + df['high'] + df['low'] + df['close']) / 4 * df['volume'] | |
| df = df[['timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']] | |
| # 检查数据量 | |
| if len(df) < need_points: | |
| raise ValueError(f"数据量不足(仅{len(df)}个交易日),请提前START_DATE") | |
| df = df.tail(need_points).reset_index(drop=True) | |
| print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 股票数据获取成功,共{len(df)}个交易日") | |
| print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 最新5条数据:\n{df[['timestamps', 'open', 'close', 'volume']].tail()}") | |
| return df | |
| finally: | |
| bs.logout() | |
| print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] Baostock已登出") | |
| def make_prediction(df, predictor): | |
| """执行模型推理,仅当前业务日首次调用时运行,添加推理日志""" | |
| china_now = get_china_time() | |
| current_business_date, _ = get_business_info() | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始执行模型推理(业务日:{current_business_date},预测未来{Config['PRED_HORIZON']}个交易日)") | |
| # 准备时间戳 | |
| last_timestamp = df['timestamps'].max() | |
| start_new_range = last_timestamp + pd.Timedelta(days=1) | |
| new_timestamps_index = pd.date_range( | |
| start=start_new_range, | |
| periods=Config["PRED_HORIZON"], | |
| freq='D' | |
| ) | |
| y_timestamp = pd.Series(new_timestamps_index, name='y_timestamp') | |
| x_timestamp = df['timestamps'] | |
| x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']] | |
| # 推理(禁用梯度计算,节省资源) | |
| with torch.no_grad(): | |
| begin_time = time.time() | |
| close_preds_main, volume_preds_main = predictor.predict( | |
| df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, | |
| pred_len=Config["PRED_HORIZON"], T=1.0, top_p=0.95, | |
| sample_count=Config["N_PREDICTIONS"], verbose=True | |
| ) | |
| infer_time = time.time() - begin_time | |
| print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 推理完成,耗时{infer_time:.2f}秒") | |
| # 波动率预测复用收盘价预测结果(保持原逻辑) | |
| close_preds_volatility = close_preds_main | |
| return close_preds_main, volume_preds_main, close_preds_volatility | |
| def calculate_metrics(hist_df, close_preds_df, v_close_preds_df): | |
| """计算上涨概率和波动率放大概率,添加指标计算日志""" | |
| print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 开始计算预测指标...") | |
| # 上涨概率(最后一个预测日相对于最新收盘价) | |
| last_close = hist_df['close'].iloc[-1] | |
| final_day_preds = close_preds_df.iloc[-1] | |
| upside_prob = (final_day_preds > last_close).mean() | |
| # 波动率放大概率(预测波动率vs历史波动率) | |
| hist_log_returns = np.log(hist_df['close'] / hist_df['close'].shift(1)) | |
| historical_vol = hist_log_returns.iloc[-Config["VOL_WINDOW"]:].std() | |
| amplification_count = 0 | |
| for col in v_close_preds_df.columns: | |
| full_sequence = pd.concat([pd.Series([last_close]), v_close_preds_df[col]]).reset_index(drop=True) | |
| pred_log_returns = np.log(full_sequence / full_sequence.shift(1)) | |
| predicted_vol = pred_log_returns.std() | |
| if predicted_vol > historical_vol: | |
| amplification_count += 1 | |
| vol_amp_prob = amplification_count / len(v_close_preds_df.columns) | |
| # 打印指标日志 | |
| print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 指标计算完成:") | |
| print(f" - 24个交易日上涨概率:{upside_prob:.2%}") | |
| print(f" - 24个交易日波动率放大概率:{vol_amp_prob:.2%}") | |
| return upside_prob, vol_amp_prob | |
| def create_plot(): | |
| china_now = get_china_time() | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始生成预测图表(适配低版本matplotlib字体)") | |
| # 从缓存获取数据(原有逻辑不变) | |
| hist_df_for_plot = Config["CACHED_RESULTS"]["hist_df_for_plot"] | |
| close_preds = Config["CACHED_RESULTS"]["close_preds"] | |
| volume_preds = Config["CACHED_RESULTS"]["volume_preds"] | |
| # -------------------------- 新增:创建画布和子图 -------------------------- | |
| fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True) | |
| # ----------------------------------------------------------------------------- | |
| # -------------------------- 修正:低版本matplotlib字体处理 -------------------------- | |
| from matplotlib.font_manager import FontProperties | |
| font_path = Config["CHINESE_FONT_PATH"] | |
| # 检查字体文件是否存在 | |
| if os.path.exists(font_path): | |
| # 直接通过FontProperties指定字体文件路径(兼容低版本matplotlib) | |
| chinese_font = FontProperties(fname=font_path) | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 成功加载.ttf字体:{font_path}") | |
| else: | |
| # 字体文件不存在时的 fallback 逻辑 | |
| chinese_font = FontProperties(family='SimHei', size=10) | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 字体文件不存在,使用系统默认字体:SimHei") | |
| # 全局设置字体(确保坐标轴刻度等默认文本也能显示中文) | |
| plt.rcParams["font.family"] = ["sans-serif"] | |
| plt.rcParams["font.sans-serif"] = ["WenQuanYi Micro Hei", "SimHei", "Heiti TC"] | |
| plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 | |
| # ----------------------------------------------------------------------------- | |
| # 绘图时,为所有中文文本显式指定字体(关键) | |
| # 1. 价格子图 | |
| hist_time = hist_df_for_plot['timestamps'] | |
| ax1.plot(hist_time, hist_df_for_plot['close'], color='#00274C', linewidth=1.5) | |
| mean_preds = close_preds.mean(axis=1) | |
| # 生成预测时间序列(假设预测是在历史最后一个时间之后的24个交易日) | |
| last_hist_time = hist_time.max() | |
| pred_time = pd.date_range(start=last_hist_time + pd.Timedelta(days=1), periods=Config["PRED_HORIZON"], freq='B') | |
| ax1.plot(pred_time, mean_preds, color='#FF6B00', linestyle='-') | |
| ax1.fill_between(pred_time, close_preds.min(axis=1), close_preds.max(axis=1), | |
| color='#FF6B00', alpha=0.2) | |
| # 中文标题/标签指定字体 | |
| ax1.set_title(f'{Config["STOCK_CODE"]} 上证指数概率预测(未来{Config["PRED_HORIZON"]}个交易日)', | |
| fontsize=16, weight='bold', fontproperties=chinese_font) | |
| ax1.set_ylabel('价格(元)', fontsize=12, fontproperties=chinese_font) | |
| # 图例指定字体 | |
| ax1.legend(['上证指数(后复权)', '预测均价', '预测区间(最小-最大)'], | |
| fontsize=10, prop=chinese_font) | |
| ax1.grid(True, which='both', linestyle='--', linewidth=0.5) | |
| # 2. 成交量子图(同理指定字体) | |
| ax2.bar(hist_time, hist_df_for_plot['volume']/1e8, color='#00A86B', width=0.6) | |
| ax2.bar(pred_time, volume_preds.mean(axis=1)/1e8, color='#FF6B00', width=0.6) | |
| ax2.set_ylabel('成交量(亿手)', fontsize=12, fontproperties=chinese_font) | |
| ax2.set_xlabel('日期', fontsize=12, fontproperties=chinese_font) | |
| ax2.legend(['历史成交量(亿手)', '预测成交量(亿手)'], | |
| fontsize=10, prop=chinese_font) | |
| ax2.grid(True, which='both', linestyle='--', linewidth=0.5) | |
| # 添加分割线(区分历史和预测数据) | |
| separator_time = last_hist_time + pd.Timedelta(hours=12) | |
| for ax in [ax1, ax2]: | |
| ax.axvline(x=separator_time, color='red', linestyle='--', linewidth=1.5, label='_nolegend_') | |
| ax.tick_params(axis='x', rotation=45) | |
| # 保存图表 | |
| fig.tight_layout() | |
| chart_path = Path(Config["CHART_PATH"]) | |
| if chart_path.exists(): | |
| chart_path.chmod(0o666) # 确保可写权限 | |
| fig.savefig(chart_path, dpi=120, bbox_inches='tight') | |
| plt.close(fig) | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 图表生成完成,保存路径:{chart_path}") | |
| def update_html(): | |
| """更新HTML页面,复用当前业务日缓存的指标,添加HTML更新日志""" | |
| china_now = get_china_time() | |
| current_business_date, _ = get_business_info() | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始更新HTML页面(业务日:{current_business_date})...") | |
| # 1. 从缓存获取指标(增加空值判断,避免报错) | |
| upside_prob = Config["CACHED_RESULTS"].get("upside_prob") | |
| vol_amp_prob = Config["CACHED_RESULTS"].get("vol_amp_prob") | |
| # 处理缓存为空的情况 | |
| if upside_prob is None or vol_amp_prob is None: | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 警告:缓存中未找到指标数据,无法更新HTML") | |
| return | |
| # 格式化指标(保留1位小数百分比) | |
| upside_prob_str = f'{upside_prob:.1%}' | |
| vol_amp_prob_str = f'{vol_amp_prob:.1%}' | |
| now_cn_str = china_now.strftime('%Y-%m-%d %H:%M:%S') | |
| # 2. 初始化HTML(不存在则创建基础模板) | |
| html_path = Path(Config["HTML_PATH"]) | |
| src_html_path = Config["REPO_PATH"] / "templates" / "index.html" | |
| if not html_path.exists(): | |
| html_path.parent.mkdir(parents=True, exist_ok=True) | |
| if src_html_path.exists(): | |
| # 复制项目模板 | |
| import shutil | |
| shutil.copy2(src_html_path, html_path) | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 从项目模板复制HTML:{src_html_path} -> {html_path}") | |
| else: | |
| # 创建基础中文HTML(确保指标对应的id与正则匹配) | |
| base_html = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>清华大模型Kronos上证指数预测</title> | |
| <style> | |
| body { max-width: 1200px; margin: 0 auto; padding: 20px; font-family: "WenQuanYi Micro Hei", Arial; } | |
| .metric { margin: 20px 0; padding: 10px; background: #f5f5f5; border-radius: 5px; } | |
| .metric-value { font-size: 1.2em; color: #0066cc; } | |
| img { max-width: 100%; height: auto; } | |
| h1 { color: #333; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>清华大学K线大模型Kronos上证指数(sh.000001)概率预测</h1> | |
| <p>最后更新时间(中国时间):<strong id="update-time">未更新</strong></p> | |
| <p>同 步 网 站:<strong><a href="http://15115656.top" target="_blank">火狼工具站</a></strong></p> | |
| <div class="metric"> | |
| <p>24个交易日上涨概率:<span class="metric-value" id="upside-prob">--%</span></p> | |
| </div> | |
| <div class="metric"> | |
| <p>波动率放大概率:<span class="metric-value" id="vol-amp-prob">--%</span></p> | |
| </div> | |
| <div><img src="/prediction_chart.png" alt="上证指数预测图表"></div> | |
| </body> | |
| </html> | |
| """ | |
| with open(html_path, 'w', encoding='utf-8') as f: | |
| f.write(base_html) | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 在/tmp创建基础HTML:{html_path}") | |
| # 3. 读取HTML内容(确保读取成功) | |
| try: | |
| with open(html_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| except Exception as e: | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 读取HTML失败:{str(e)}") | |
| return | |
| # 4. 正则替换(关键:确保re.sub()参数完整) | |
| # 替换更新时间 | |
| content = re.sub( | |
| pattern=r'(<strong id="update-time">).*?(</strong>)', | |
| repl=lambda m: f'{m.group(1)}{now_cn_str}{m.group(2)}', | |
| string=content | |
| ) | |
| # 替换上涨概率(id="upside-prob",与HTML模板对应) | |
| content = re.sub( | |
| pattern=r'(<span class="metric-value" id="upside-prob">).*?(</span>)', | |
| repl=lambda m: f'{m.group(1)}{upside_prob_str}{m.group(2)}', | |
| string=content | |
| ) | |
| # 替换波动率放大概率(id="vol-amp-prob",与HTML模板对应) | |
| content = re.sub( | |
| pattern=r'(<span class="metric-value" id="vol-amp-prob">).*?(</span>)', | |
| repl=lambda m: f'{m.group(1)}{vol_amp_prob_str}{m.group(2)}', | |
| string=content | |
| ) | |
| # 5. 写入更新后的HTML | |
| try: | |
| with open(html_path, 'w', encoding='utf-8') as f: | |
| f.write(content) | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] HTML更新完成,路径:{html_path}") | |
| # 验证替换结果(调试用) | |
| print(f"[DEBUG] 上涨概率更新为:{upside_prob_str}") | |
| print(f"[DEBUG] 波动率概率更新为:{vol_amp_prob_str}") | |
| except Exception as e: | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 写入HTML失败:{str(e)}") | |
| def git_commit_and_push(): | |
| """Git提交(仅当Git存在时执行),添加Git操作日志""" | |
| china_now = get_china_time() | |
| current_business_date, _ = get_business_info() | |
| commit_message = f"Auto-update: 上证指数预测(业务日{current_business_date} 中国时间)" | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始执行Git提交操作,提交信息:{commit_message}") | |
| # 检查Git是否安装 | |
| try: | |
| subprocess.run(['git', '--version'], check=True, capture_output=True, text=True) | |
| except (subprocess.CalledProcessError, FileNotFoundError): | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git未安装或未在PATH中,跳过Git操作") | |
| return | |
| # 执行Git操作 | |
| try: | |
| os.chdir(Config["REPO_PATH"]) | |
| # 复制图表和HTML到Git跟踪目录(若需要) | |
| chart_src = Config["CHART_PATH"] | |
| chart_dst = Config["REPO_PATH"] / "prediction_chart.png" | |
| html_src = Config["HTML_PATH"] | |
| html_dst = Config["REPO_PATH"] / "index.html" | |
| if os.path.exists(chart_src): | |
| import shutil | |
| shutil.copy2(chart_src, chart_dst) | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 图表复制到Git目录:{chart_dst}") | |
| if os.path.exists(html_src): | |
| shutil.copy2(html_src, html_dst) | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] HTML复制到Git目录:{html_dst}") | |
| # Git add | |
| subprocess.run(['git', 'add', 'prediction_chart.png', 'index.html'], check=True, capture_output=True, text=True) | |
| # Git commit | |
| commit_result = subprocess.run(['git', 'commit', '-m', commit_message], check=True, capture_output=True, text=True) | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git提交输出:\n{commit_result.stdout}") | |
| # Git push | |
| push_result = subprocess.run(['git', 'push'], check=True, capture_output=True, text=True) | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git推送输出:\n{push_result.stdout}") | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git操作完成") | |
| except subprocess.CalledProcessError as e: | |
| output = e.stdout if e.stdout else e.stderr | |
| if "nothing to commit" in output or "Your branch is up to date" in output: | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 无新内容需要提交或推送") | |
| else: | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git错误:\nSTDOUT: {e.stdout}\nSTDERR: {e.stderr}") | |
| except PermissionError as e: | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git权限错误:{str(e)},跳过Git操作") | |
| # -------------------------- 修改:主任务逻辑(基于业务日判断) -------------------------- | |
| def main_task(model): | |
| """主任务:控制基于20点分界的业务日推理逻辑,同业务日复用缓存""" | |
| china_now = get_china_time() | |
| current_business_date, is_after_20h = get_business_info() # 获取当前业务信息 | |
| print(f"\n[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60) | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始执行主任务") | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 当前业务日:{current_business_date}(北京时间{'20点后' if is_after_20h else '20点前'})") | |
| # 核心修改:判断当前业务日是否已推理(而非自然日) | |
| if Config["LAST_INFERENCED_BUSINESS_DATE"] == current_business_date: | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 当前业务日({current_business_date})已完成推理,直接复用缓存结果") | |
| # 复用缓存生成图表和HTML | |
| create_plot() | |
| update_html() | |
| git_commit_and_push() | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 主任务完成(复用缓存)") | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60 + "\n") | |
| return | |
| # 当前业务日未推理:执行完整流程 | |
| try: | |
| # 1. 获取股票数据 | |
| df_full = fetch_stock_data() | |
| df_for_model = df_full.iloc[:-1] # 排除最后一行避免数据泄漏 | |
| # 2. 执行推理 | |
| close_preds, volume_preds, v_close_preds = make_prediction(df_for_model, model) | |
| # 3. 计算指标 | |
| hist_df_for_metrics = df_for_model.tail(Config["VOL_WINDOW"]) | |
| upside_prob, vol_amp_prob = calculate_metrics(hist_df_for_metrics, close_preds, v_close_preds) | |
| # 4. 缓存结果(当前业务日复用) | |
| hist_df_for_plot = df_for_model.tail(Config["VOL_WINDOW"]) # 用于绘图的历史数据 | |
| Config["CACHED_RESULTS"] = { | |
| "close_preds": close_preds, | |
| "volume_preds": volume_preds, | |
| "v_close_preds": v_close_preds, | |
| "upside_prob": upside_prob, | |
| "vol_amp_prob": vol_amp_prob, | |
| "hist_df_for_plot": hist_df_for_plot | |
| } | |
| # 核心修改:标记当前业务日已推理(而非布尔值) | |
| Config["LAST_INFERENCED_BUSINESS_DATE"] = current_business_date | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 业务日({current_business_date})推理结果已缓存,同业务日后续调用将复用") | |
| # 5. 生成图表 | |
| create_plot() | |
| # 6. 更新HTML | |
| update_html() | |
| # 7. Git提交 | |
| git_commit_and_push() | |
| # 8. 内存回收 | |
| del df_full, df_for_model, hist_df_for_metrics | |
| gc.collect() | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 主任务完成(首次推理)") | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60 + "\n") | |
| except Exception as e: | |
| # 异常时不更新业务日标记,下次调用重试 | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 主任务执行失败,业务日({current_business_date})推理标记为未完成") | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 错误信息:{str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60 + "\n") | |
| # -------------------------- 修改:定时器逻辑(从0点改为20点触发) -------------------------- | |
| def run_scheduler(model): | |
| """定时器:基于北京时间20点分界触发任务,其他时间5分钟检查一次""" | |
| china_tz = timezone("Asia/Shanghai") | |
| print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 定时器启动(中国时间),每天20点执行推理") | |
| while True: | |
| china_now = get_china_time() | |
| current_business_date, is_after_20h = get_business_info() | |
| # 核心修改:计算下次执行时间(20点触发) | |
| if is_after_20h: | |
| # 已过当天20点 → 下次执行时间为次日20点 | |
| next_exec_date = (china_now + timedelta(days=1)).date() | |
| else: | |
| # 未过当天20点 → 下次执行时间为当天20点 | |
| next_exec_date = china_now.date() | |
| # 构造下次执行时间(20:00:05,留5秒缓冲避免毫秒级误差) | |
| next_exec_time = datetime.combine( | |
| next_exec_date, | |
| datetime.strptime("20:00:05", "%H:%M:%S").time(), | |
| tzinfo=china_tz | |
| ) | |
| # 计算等待时间(秒),最小等待5分钟(防止时间计算错误导致负数) | |
| sleep_seconds = (next_exec_time - china_now).total_seconds() | |
| sleep_seconds = max(sleep_seconds, 300) | |
| # 打印等待日志 | |
| print(f"\n[{china_now:%Y-%m-%d %H:%M:%S}] 定时器状态:") | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 当前时间:{china_now:%Y-%m-%d %H:%M:%S}(中国时间)") | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 当前业务日:{current_business_date}({'20点后' if is_after_20h else '20点前'})") | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 下次执行时间:{next_exec_time:%Y-%m-%d %H:%M:%S}(中国时间)") | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 等待时间:{sleep_seconds:.0f}秒(约{sleep_seconds/3600:.1f}小时)") | |
| # 等待到下次执行时间 | |
| time.sleep(sleep_seconds) | |
| # 到达执行时间,触发主任务 | |
| try: | |
| main_task(model) | |
| # 无需重置业务日标记(下次判断基于新业务日) | |
| except Exception as e: | |
| print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 定时器触发任务失败:{str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 5分钟后重试...") | |
| time.sleep(300) # 重试间隔5分钟 | |
| if __name__ == '__main__': | |
| # 初始化:加载模型→执行一次主任务→启动定时器 | |
| china_now = get_china_time() | |
| print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 程序启动(中国时间)") | |
| # 加载模型 | |
| loaded_model = load_local_model() | |
| # 首次执行主任务(若当前业务日未执行) | |
| main_task(loaded_model) | |
| # 启动定时器(中国时间每天20点执行) | |
| run_scheduler(loaded_model) |