import gc import os import re import subprocess import time from datetime import datetime, timezone, timedelta from pathlib import Path import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch import baostock as bs from model import KronosTokenizer, Kronos, KronosPredictor # --- Configuration --- Config = { "REPO_PATH": Path(__file__).parent.resolve(), "LOCAL_MODEL_PATH": os.path.join(Path(__file__).parent.resolve(), "models"), "STOCK_CODE": "sh.000001", "FREQUENCY": "d", "START_DATE": "2022-01-01", "END_DATE": datetime.now().strftime("%Y-%m-%d"), "HIST_POINTS": 360, "PRED_HORIZON": 24, "N_PREDICTIONS": 10, "VOL_WINDOW": 24, "PREDICTION_CACHE": os.path.join("/tmp", "predictions_cache"), "CHART_PATH": os.path.join("/tmp", "prediction_chart.png"), # 新增1:HTML文件保存到/tmp目录(解决权限问题) "HTML_PATH": os.path.join("/tmp", "index.html") } os.makedirs(Config["PREDICTION_CACHE"], exist_ok=True) os.makedirs(Config["LOCAL_MODEL_PATH"], exist_ok=True) def load_local_model(): print("Loading local Kronos model...") tokenizer_path = os.path.join(Config["LOCAL_MODEL_PATH"], "tokenizer") model_path = os.path.join(Config["LOCAL_MODEL_PATH"], "model") if not os.path.exists(tokenizer_path): raise FileNotFoundError(f"分词器路径不存在:{tokenizer_path}") if not os.path.exists(model_path): raise FileNotFoundError(f"模型路径不存在:{model_path}") tokenizer = KronosTokenizer.from_pretrained(tokenizer_path, local_files_only=True) model = Kronos.from_pretrained(model_path, local_files_only=True) tokenizer.eval() model.eval() predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512) print("Local model loaded successfully.") return predictor def fetch_stock_data(): stock_code = Config["STOCK_CODE"] frequency = Config["FREQUENCY"] start_date = Config["START_DATE"] end_date = Config["END_DATE"] need_points = Config["HIST_POINTS"] + Config["VOL_WINDOW"] print(f"Fetching {need_points}个交易日的{stock_code}日线数据...") lg = bs.login() if lg.error_code != '0': raise ConnectionError(f"Baostock登录失败:{lg.error_msg}") try: fields = "date,open,high,low,close,volume" rs = bs.query_history_k_data_plus( code=stock_code, fields=fields, start_date=start_date, end_date=end_date, frequency=frequency, adjustflag="2" ) if rs.error_code != '0': raise ValueError(f"获取K线数据失败:{rs.error_msg}") data_list = [] while rs.next(): data_list.append(rs.get_row_data()) df = pd.DataFrame(data_list, columns=rs.fields) numeric_cols = ['open', 'high', 'low', 'close', 'volume'] for col in numeric_cols: df[col] = pd.to_numeric(df[col], errors='coerce') df = df.dropna(subset=numeric_cols) df['timestamps'] = pd.to_datetime(df['date'], format='%Y-%m-%d') df['amount'] = (df['open'] + df['high'] + df['low'] + df['close']) / 4 * df['volume'] df = df[['timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']] if len(df) < need_points: raise ValueError(f"数据量不足(仅{len(df)}个交易日),请提前start_date") df = df.tail(need_points).reset_index(drop=True) print(f"A股数据获取成功,共{len(df)}个交易日数据") print(df[['timestamps', 'open', 'close', 'volume', 'amount']].tail()) return df finally: bs.logout() def make_prediction(df, predictor): last_timestamp = df['timestamps'].max() start_new_range = last_timestamp + pd.Timedelta(days=1) new_timestamps_index = pd.date_range( start=start_new_range, periods=Config["PRED_HORIZON"], freq='D' ) y_timestamp = pd.Series(new_timestamps_index, name='y_timestamp') x_timestamp = df['timestamps'] x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']] with torch.no_grad(): print("Making main prediction (T=1.0)...") begin_time = time.time() close_preds_main, volume_preds_main = predictor.predict( df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, pred_len=Config["PRED_HORIZON"], T=1.0, top_p=0.95, sample_count=Config["N_PREDICTIONS"], verbose=True ) print(f"Main prediction completed in {time.time() - begin_time:.2f} seconds.") close_preds_volatility = close_preds_main return close_preds_main, volume_preds_main, close_preds_volatility def calculate_metrics(hist_df, close_preds_df, v_close_preds_df): last_close = hist_df['close'].iloc[-1] final_day_preds = close_preds_df.iloc[-1] upside_prob = (final_day_preds > last_close).mean() hist_log_returns = np.log(hist_df['close'] / hist_df['close'].shift(1)) historical_vol = hist_log_returns.iloc[-Config["VOL_WINDOW"]:].std() amplification_count = 0 for col in v_close_preds_df.columns: full_sequence = pd.concat([pd.Series([last_close]), v_close_preds_df[col]]).reset_index(drop=True) pred_log_returns = np.log(full_sequence / full_sequence.shift(1)) predicted_vol = pred_log_returns.std() if predicted_vol > historical_vol: amplification_count += 1 vol_amp_prob = amplification_count / len(v_close_preds_df.columns) print(f"上证指数上涨概率(24个交易日):{upside_prob:.2%}") print(f"波动率放大概率(24个交易日):{vol_amp_prob:.2%}") return upside_prob, vol_amp_prob # --- 修改1:增强create_plot函数(彻底解决中文字体警告)--- def create_plot(hist_df, close_preds_df, volume_preds_df): print("Generating comprehensive forecast chart...") # 新增:尝试加载本地中文字体(若项目有fonts目录) font_path = os.path.join(Config["REPO_PATH"], "fonts", "wqy-microhei.ttc") if os.path.exists(font_path): # 加载项目自带的中文字体(优先使用,彻底解决警告) plt.rcParams['font.family'] = 'sans-serif' plt.rcParams['font.sans-serif'] = [font_path] plt.rcParams['axes.unicode_minus'] = False print(f"使用项目自带中文字体:{font_path}") else: # 备用:使用系统可能存在的中文字体 plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei', 'SimHei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False print("项目中文字体未找到,使用系统默认字体(可能仍有警告)") fig, (ax1, ax2) = plt.subplots( 2, 1, figsize=(15, 10), sharex=True, gridspec_kw={'height_ratios': [3, 1]} ) # 数据准备(逻辑不变) hist_time = hist_df['timestamps'] last_hist_time = hist_time.iloc[-1] pred_time = pd.to_datetime([last_hist_time + timedelta(days=i + 1) for i in range(len(close_preds_df))]) # 价格子图(逻辑不变) ax1.plot(hist_time, hist_df['close'], color='#00274C', label='上证指数(后复权)', linewidth=1.5) mean_preds = close_preds_df.mean(axis=1) ax1.plot(pred_time, mean_preds, color='#FF6B00', linestyle='-', label='预测均价') ax1.fill_between(pred_time, close_preds_df.min(axis=1), close_preds_df.max(axis=1), color='#FF6B00', alpha=0.2, label='预测区间(最小-最大)') ax1.set_title(f'{Config["STOCK_CODE"]} 上证指数 概率预测(未来{Config["PRED_HORIZON"]}个交易日)', fontsize=16, weight='bold') ax1.set_ylabel('价格(元)') ax1.legend() ax1.grid(True, which='both', linestyle='--', linewidth=0.5) # 成交量子图(逻辑不变) ax2.bar(hist_time, hist_df['volume']/1e8, color='#00A86B', label='历史成交量(亿手)', width=0.6) ax2.bar(pred_time, volume_preds_df.mean(axis=1)/1e8, color='#FF6B00', label='预测成交量(亿手)', width=0.6) ax2.set_ylabel('成交量(亿手)') ax2.set_xlabel('日期') ax2.legend() ax2.grid(True, which='both', linestyle='--', linewidth=0.5) # 分割线(逻辑不变) separator_time = last_hist_time + timedelta(hours=12) for ax in [ax1, ax2]: ax.axvline(x=separator_time, color='red', linestyle='--', linewidth=1.5, label='_nolegend_') ax.tick_params(axis='x', rotation=45) fig.tight_layout() # 保存图表(逻辑不变) chart_path = Path(Config["CHART_PATH"]) if not chart_path.parent.exists(): chart_path.parent.mkdir(parents=True, exist_ok=True) if chart_path.exists(): chart_path.chmod(0o666) fig.savefig(chart_path, dpi=120, bbox_inches='tight') plt.close(fig) print(f"图表已保存至:{chart_path}(权限:{oct(os.stat(chart_path).st_mode)[-3:]})") # --- 修改2:重写update_html函数(解决index.html权限问题)--- def update_html(upside_prob, vol_amp_prob): print("Updating index.html...") # 1. 定义路径:目标HTML(/tmp可写)、源HTML(项目模板) html_path = Path(Config["HTML_PATH"]) src_html_path = Config["REPO_PATH"] / "templates" / "index.html" # 项目原模板路径 now_utc_str = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S') upside_prob_str = f'{upside_prob:.1%}' vol_amp_prob_str = f'{vol_amp_prob:.1%}' # 2. 初始化HTML:若/tmp无HTML,从项目模板复制或创建基础版 if not html_path.exists(): html_path.parent.mkdir(parents=True, exist_ok=True) if src_html_path.exists(): # 从项目templates目录复制(保留原页面样式) import shutil shutil.copy2(src_html_path, html_path) print(f"从项目模板复制HTML到:{html_path}") else: # 项目无模板时,创建基础HTML(避免报错) base_html = """
最后更新时间:未更新(UTC)
24个交易日上涨概率:--%
波动率放大概率:--%
