import gc import os import re import subprocess import time from datetime import datetime, timezone, timedelta from pathlib import Path import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch import baostock as bs from model import KronosTokenizer, Kronos, KronosPredictor # --- Configuration --- Config = { "REPO_PATH": Path(__file__).parent.resolve(), "LOCAL_MODEL_PATH": os.path.join(Path(__file__).parent.resolve(), "models"), "STOCK_CODE": "sh.000001", "FREQUENCY": "d", "START_DATE": "2022-01-01", "END_DATE": datetime.now().strftime("%Y-%m-%d"), "HIST_POINTS": 360, "PRED_HORIZON": 24, "N_PREDICTIONS": 10, "VOL_WINDOW": 24, "PREDICTION_CACHE": os.path.join("/tmp", "predictions_cache"), "CHART_PATH": os.path.join("/tmp", "prediction_chart.png"), # 新增1:HTML文件保存到/tmp目录(解决权限问题) "HTML_PATH": os.path.join("/tmp", "index.html") } os.makedirs(Config["PREDICTION_CACHE"], exist_ok=True) os.makedirs(Config["LOCAL_MODEL_PATH"], exist_ok=True) def load_local_model(): print("Loading local Kronos model...") tokenizer_path = os.path.join(Config["LOCAL_MODEL_PATH"], "tokenizer") model_path = os.path.join(Config["LOCAL_MODEL_PATH"], "model") if not os.path.exists(tokenizer_path): raise FileNotFoundError(f"分词器路径不存在:{tokenizer_path}") if not os.path.exists(model_path): raise FileNotFoundError(f"模型路径不存在:{model_path}") tokenizer = KronosTokenizer.from_pretrained(tokenizer_path, local_files_only=True) model = Kronos.from_pretrained(model_path, local_files_only=True) tokenizer.eval() model.eval() predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512) print("Local model loaded successfully.") return predictor def fetch_stock_data(): stock_code = Config["STOCK_CODE"] frequency = Config["FREQUENCY"] start_date = Config["START_DATE"] end_date = Config["END_DATE"] need_points = Config["HIST_POINTS"] + Config["VOL_WINDOW"] print(f"Fetching {need_points}个交易日的{stock_code}日线数据...") lg = bs.login() if lg.error_code != '0': raise ConnectionError(f"Baostock登录失败:{lg.error_msg}") try: fields = "date,open,high,low,close,volume" rs = bs.query_history_k_data_plus( code=stock_code, fields=fields, start_date=start_date, end_date=end_date, frequency=frequency, adjustflag="2" ) if rs.error_code != '0': raise ValueError(f"获取K线数据失败:{rs.error_msg}") data_list = [] while rs.next(): data_list.append(rs.get_row_data()) df = pd.DataFrame(data_list, columns=rs.fields) numeric_cols = ['open', 'high', 'low', 'close', 'volume'] for col in numeric_cols: df[col] = pd.to_numeric(df[col], errors='coerce') df = df.dropna(subset=numeric_cols) df['timestamps'] = pd.to_datetime(df['date'], format='%Y-%m-%d') df['amount'] = (df['open'] + df['high'] + df['low'] + df['close']) / 4 * df['volume'] df = df[['timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']] if len(df) < need_points: raise ValueError(f"数据量不足(仅{len(df)}个交易日),请提前start_date") df = df.tail(need_points).reset_index(drop=True) print(f"A股数据获取成功,共{len(df)}个交易日数据") print(df[['timestamps', 'open', 'close', 'volume', 'amount']].tail()) return df finally: bs.logout() def make_prediction(df, predictor): last_timestamp = df['timestamps'].max() start_new_range = last_timestamp + pd.Timedelta(days=1) new_timestamps_index = pd.date_range( start=start_new_range, periods=Config["PRED_HORIZON"], freq='D' ) y_timestamp = pd.Series(new_timestamps_index, name='y_timestamp') x_timestamp = df['timestamps'] x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']] with torch.no_grad(): print("Making main prediction (T=1.0)...") begin_time = time.time() close_preds_main, volume_preds_main = predictor.predict( df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, pred_len=Config["PRED_HORIZON"], T=1.0, top_p=0.95, sample_count=Config["N_PREDICTIONS"], verbose=True ) print(f"Main prediction completed in {time.time() - begin_time:.2f} seconds.") close_preds_volatility = close_preds_main return close_preds_main, volume_preds_main, close_preds_volatility def calculate_metrics(hist_df, close_preds_df, v_close_preds_df): last_close = hist_df['close'].iloc[-1] final_day_preds = close_preds_df.iloc[-1] upside_prob = (final_day_preds > last_close).mean() hist_log_returns = np.log(hist_df['close'] / hist_df['close'].shift(1)) historical_vol = hist_log_returns.iloc[-Config["VOL_WINDOW"]:].std() amplification_count = 0 for col in v_close_preds_df.columns: full_sequence = pd.concat([pd.Series([last_close]), v_close_preds_df[col]]).reset_index(drop=True) pred_log_returns = np.log(full_sequence / full_sequence.shift(1)) predicted_vol = pred_log_returns.std() if predicted_vol > historical_vol: amplification_count += 1 vol_amp_prob = amplification_count / len(v_close_preds_df.columns) print(f"上证指数上涨概率(24个交易日):{upside_prob:.2%}") print(f"波动率放大概率(24个交易日):{vol_amp_prob:.2%}") return upside_prob, vol_amp_prob # --- 修改1:增强create_plot函数(彻底解决中文字体警告)--- def create_plot(hist_df, close_preds_df, volume_preds_df): print("Generating comprehensive forecast chart...") # 新增:尝试加载本地中文字体(若项目有fonts目录) font_path = os.path.join(Config["REPO_PATH"], "fonts", "wqy-microhei.ttc") if os.path.exists(font_path): # 加载项目自带的中文字体(优先使用,彻底解决警告) plt.rcParams['font.family'] = 'sans-serif' plt.rcParams['font.sans-serif'] = [font_path] plt.rcParams['axes.unicode_minus'] = False print(f"使用项目自带中文字体:{font_path}") else: # 备用:使用系统可能存在的中文字体 plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei', 'SimHei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False print("项目中文字体未找到,使用系统默认字体(可能仍有警告)") fig, (ax1, ax2) = plt.subplots( 2, 1, figsize=(15, 10), sharex=True, gridspec_kw={'height_ratios': [3, 1]} ) # 数据准备(逻辑不变) hist_time = hist_df['timestamps'] last_hist_time = hist_time.iloc[-1] pred_time = pd.to_datetime([last_hist_time + timedelta(days=i + 1) for i in range(len(close_preds_df))]) # 价格子图(逻辑不变) ax1.plot(hist_time, hist_df['close'], color='#00274C', label='上证指数(后复权)', linewidth=1.5) mean_preds = close_preds_df.mean(axis=1) ax1.plot(pred_time, mean_preds, color='#FF6B00', linestyle='-', label='预测均价') ax1.fill_between(pred_time, close_preds_df.min(axis=1), close_preds_df.max(axis=1), color='#FF6B00', alpha=0.2, label='预测区间(最小-最大)') ax1.set_title(f'{Config["STOCK_CODE"]} 上证指数 概率预测(未来{Config["PRED_HORIZON"]}个交易日)', fontsize=16, weight='bold') ax1.set_ylabel('价格(元)') ax1.legend() ax1.grid(True, which='both', linestyle='--', linewidth=0.5) # 成交量子图(逻辑不变) ax2.bar(hist_time, hist_df['volume']/1e8, color='#00A86B', label='历史成交量(亿手)', width=0.6) ax2.bar(pred_time, volume_preds_df.mean(axis=1)/1e8, color='#FF6B00', label='预测成交量(亿手)', width=0.6) ax2.set_ylabel('成交量(亿手)') ax2.set_xlabel('日期') ax2.legend() ax2.grid(True, which='both', linestyle='--', linewidth=0.5) # 分割线(逻辑不变) separator_time = last_hist_time + timedelta(hours=12) for ax in [ax1, ax2]: ax.axvline(x=separator_time, color='red', linestyle='--', linewidth=1.5, label='_nolegend_') ax.tick_params(axis='x', rotation=45) fig.tight_layout() # 保存图表(逻辑不变) chart_path = Path(Config["CHART_PATH"]) if not chart_path.parent.exists(): chart_path.parent.mkdir(parents=True, exist_ok=True) if chart_path.exists(): chart_path.chmod(0o666) fig.savefig(chart_path, dpi=120, bbox_inches='tight') plt.close(fig) print(f"图表已保存至:{chart_path}(权限:{oct(os.stat(chart_path).st_mode)[-3:]})") # --- 修改2:重写update_html函数(解决index.html权限问题)--- def update_html(upside_prob, vol_amp_prob): print("Updating index.html...") # 1. 定义路径:目标HTML(/tmp可写)、源HTML(项目模板) html_path = Path(Config["HTML_PATH"]) src_html_path = Config["REPO_PATH"] / "templates" / "index.html" # 项目原模板路径 now_utc_str = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S') upside_prob_str = f'{upside_prob:.1%}' vol_amp_prob_str = f'{vol_amp_prob:.1%}' # 2. 初始化HTML:若/tmp无HTML,从项目模板复制或创建基础版 if not html_path.exists(): html_path.parent.mkdir(parents=True, exist_ok=True) if src_html_path.exists(): # 从项目templates目录复制(保留原页面样式) import shutil shutil.copy2(src_html_path, html_path) print(f"从项目模板复制HTML到:{html_path}") else: # 项目无模板时,创建基础HTML(避免报错) base_html = """ 上证指数预测

上证指数(sh.000001)概率预测

最后更新时间:未更新(UTC)

24个交易日上涨概率:--%

波动率放大概率:--%

预测图表
""" with open(html_path, 'w', encoding='utf-8') as f: f.write(base_html) print(f"在/tmp创建基础HTML:{html_path}") # 3. 读取并更新HTML内容 with open(html_path, 'r', encoding='utf-8') as f: content = f.read() # 替换关键数据(更新时间、上涨概率、波动率概率) content = re.sub( r'().*?()', lambda m: f'{m.group(1)}{now_utc_str}{m.group(2)}', content ) content = re.sub( r'().*?()', lambda m: f'{m.group(1)}{upside_prob_str}{m.group(2)}', content ) content = re.sub( r'().*?()', lambda m: f'{m.group(1)}{vol_amp_prob_str}{m.group(2)}', content ) # 4. 写入更新后的内容(/tmp目录有写入权限) with open(html_path, 'w', encoding='utf-8') as f: f.write(content) print(f"HTML更新成功:{html_path}") def git_commit_and_push(commit_message): print("Performing Git operations...") try: os.chdir(Config["REPO_PATH"]) # 复制图表到Git跟踪目录(/app) chart_src = Config["CHART_PATH"] chart_dst = os.path.join(Config["REPO_PATH"], "prediction_chart.png") if os.path.exists(chart_src): import shutil shutil.copy2(chart_src, chart_dst) print(f"图表已从 {chart_src} 复制到 {chart_dst}") # 若需Git提交HTML,需从/tmp复制到/app(可选,根据需求决定) html_src = Config["HTML_PATH"] html_dst = os.path.join(Config["REPO_PATH"], "index.html") if os.path.exists(html_src): shutil.copy2(html_src, html_dst) print(f"HTML已从 {html_src} 复制到 {html_dst}") subprocess.run(['git', 'add', 'prediction_chart.png', 'index.html'], check=True, capture_output=True, text=True) commit_result = subprocess.run(['git', 'commit', '-m', commit_message], check=True, capture_output=True, text=True) print(commit_result.stdout) push_result = subprocess.run(['git', 'push'], check=True, capture_output=True, text=True) print(push_result.stdout) print("Git push successful.") except subprocess.CalledProcessError as e: output = e.stdout if e.stdout else e.stderr if "nothing to commit" in output or "Your branch is up to date" in output: print("No new changes to commit or push.") else: print(f"A Git error occurred:\n--- STDOUT ---\n{e.stdout}\n--- STDERR ---\n{e.stderr}") def main_task(model): print("\n" + "=" * 60 + f"\nStarting update task at {datetime.now(timezone.utc)}\n" + "=" * 60) df_full = fetch_stock_data() df_for_model = df_full.iloc[:-1] close_preds, volume_preds, v_close_preds = make_prediction(df_for_model, model) hist_df_for_plot = df_for_model.tail(Config["HIST_POINTS"]) hist_df_for_metrics = df_for_model.tail(Config["VOL_WINDOW"]) upside_prob, vol_amp_prob = calculate_metrics(hist_df_for_metrics, close_preds, v_close_preds) create_plot(hist_df_for_plot, close_preds, volume_preds) update_html(upside_prob, vol_amp_prob) # 调用修改后的HTML更新函数 commit_message = f"Auto-update 上证指数预测 {datetime.now(timezone.utc):%Y-%m-%d %H:%M} UTC" git_commit_and_push(commit_message) del df_full, df_for_model, close_preds, volume_preds, v_close_preds del hist_df_for_plot, hist_df_for_metrics gc.collect() print("-" * 60 + "\n--- Task completed successfully ---\n" + "-" * 60 + "\n") def run_scheduler(model): while True: now = datetime.now(timezone.utc) next_run_time = (now + timedelta(days=1)).replace(hour=0, minute=0, second=5, microsecond=0) sleep_seconds = (next_run_time - now).total_seconds() if sleep_seconds > 0: print(f"Current time: {now:%Y-%m-%d %H:%M:%S UTC}.") print(f"Next run at: {next_run_time:%Y-%m-%d %H:%M:%S UTC}. Waiting for {sleep_seconds:.0f} seconds...") time.sleep(sleep_seconds) try: main_task(model) except Exception as e: print(f"\n!!!!!! A critical error occurred in the main task !!!!!!!") print(f"Error: {e}") import traceback traceback.print_exc() print("Retrying in 5 minutes...") print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n") time.sleep(300) if __name__ == '__main__': local_model_path = Path(Config["LOCAL_MODEL_PATH"]) local_model_path.mkdir(parents=True, exist_ok=True) loaded_model = load_local_model() main_task(loaded_model) run_scheduler(loaded_model)