fiewolf1000 commited on
Commit
85c1b5f
·
verified ·
1 Parent(s): 3d9ed7f

Update update_predictions.py

Browse files
Files changed (1) hide show
  1. update_predictions.py +548 -0
update_predictions.py CHANGED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import re
4
+ import subprocess
5
+ import time
6
+ from datetime import datetime, timedelta
7
+ from pathlib import Path
8
+
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ import baostock as bs
14
+ from pytz import timezone # 处理中国时区(Asia/Shanghai)
15
+
16
+ from model import KronosTokenizer, Kronos, KronosPredictor
17
+
18
+ # --- Configuration ---
19
+ Config = {
20
+ "REPO_PATH": Path(__file__).parent.resolve(),
21
+ "LOCAL_MODEL_PATH": os.path.join(Path(__file__).parent.resolve(), "models"),
22
+ "STOCK_CODE": "sh.000001",
23
+ "FREQUENCY": "d",
24
+ "START_DATE": "2022-01-01",
25
+ "PRED_HORIZON": 24,
26
+ "N_PREDICTIONS": 10,
27
+ "VOL_WINDOW": 24,
28
+ "PREDICTION_CACHE": os.path.join("/tmp", "predictions_cache"),
29
+ "CHART_PATH": os.path.join("/tmp", "prediction_chart.png"),
30
+ "HTML_PATH": os.path.join("/tmp", "index.html"),
31
+ # 先不定义CHINESE_FONT_PATH,避免引用未完成的Config
32
+ "IS_TODAY_INFERENCED": False,
33
+ "CACHED_RESULTS": {
34
+ "close_preds": None,
35
+ "volume_preds": None,
36
+ "v_close_preds": None,
37
+ "upside_prob": None,
38
+ "vol_amp_prob": None,
39
+ "hist_df_for_plot": None
40
+ }
41
+ }
42
+
43
+ # 补充定义中文字体路径(此时Config已完全定义)
44
+ Config["CHINESE_FONT_PATH"] = os.path.join(Config["REPO_PATH"], "fonts", "wqy-microhei.ttf")
45
+
46
+ # 创建必要目录
47
+ os.makedirs(Config["PREDICTION_CACHE"], exist_ok=True)
48
+ os.makedirs(Config["LOCAL_MODEL_PATH"], exist_ok=True)
49
+
50
+
51
+ def get_china_time():
52
+ """获取当前中国时间(Asia/Shanghai时区),返回datetime对象"""
53
+ china_tz = timezone("Asia/Shanghai")
54
+ return datetime.now(china_tz)
55
+
56
+
57
+ def load_local_model():
58
+ """加载本地Kronos模型,添加字体加载日志"""
59
+ print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 开始加载本地Kronos模型...")
60
+ tokenizer_path = os.path.join(Config["LOCAL_MODEL_PATH"], "tokenizer")
61
+ model_path = os.path.join(Config["LOCAL_MODEL_PATH"], "model")
62
+
63
+ # 检查模型文件是否存在
64
+ if not os.path.exists(tokenizer_path):
65
+ raise FileNotFoundError(f"分词器路径不存在:{tokenizer_path}")
66
+ if not os.path.exists(model_path):
67
+ raise FileNotFoundError(f"模型路径不存在:{model_path}")
68
+
69
+ # 加载模型和分词器
70
+ tokenizer = KronosTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
71
+ model = Kronos.from_pretrained(model_path, local_files_only=True)
72
+ tokenizer.eval()
73
+ model.eval()
74
+ predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512)
75
+
76
+ print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 本地模型加载成功")
77
+ return predictor
78
+
79
+
80
+ def fetch_stock_data():
81
+ """获取股票数据(每日更新一次,中国时间),添加数据获取日志"""
82
+ china_now = get_china_time()
83
+ end_date = china_now.strftime("%Y-%m-%d") # 按中国时间取结束日期
84
+ need_points = Config["VOL_WINDOW"] + Config["VOL_WINDOW"] # 历史数据+波动率计算窗口
85
+
86
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始获取{Config['STOCK_CODE']}日线数据(结束日期:{end_date})")
87
+ lg = bs.login()
88
+ if lg.error_code != '0':
89
+ raise ConnectionError(f"Baostock登录失败:{lg.error_msg}")
90
+
91
+ try:
92
+ # 调用baostock获取K线数据
93
+ fields = "date,open,high,low,close,volume"
94
+ rs = bs.query_history_k_data_plus(
95
+ code=Config["STOCK_CODE"],
96
+ fields=fields,
97
+ start_date=Config["START_DATE"],
98
+ end_date=end_date,
99
+ frequency=Config["FREQUENCY"],
100
+ adjustflag="2" # 后复权
101
+ )
102
+
103
+ if rs.error_code != '0':
104
+ raise ValueError(f"获取K线数据失败:{rs.error_msg}")
105
+
106
+ # 处理数据
107
+ data_list = []
108
+ while rs.next():
109
+ data_list.append(rs.get_row_data())
110
+ df = pd.DataFrame(data_list, columns=rs.fields)
111
+
112
+ # 数值列转换
113
+ numeric_cols = ['open', 'high', 'low', 'close', 'volume']
114
+ for col in numeric_cols:
115
+ df[col] = pd.to_numeric(df[col], errors='coerce')
116
+ df = df.dropna(subset=numeric_cols)
117
+
118
+ # 添加时间戳和成交额列
119
+ df['timestamps'] = pd.to_datetime(df['date'], format='%Y-%m-%d')
120
+ df['amount'] = (df['open'] + df['high'] + df['low'] + df['close']) / 4 * df['volume']
121
+ df = df[['timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']]
122
+
123
+ # 检查数据量
124
+ if len(df) < need_points:
125
+ raise ValueError(f"数据量不足(仅{len(df)}个交易日),请提前START_DATE")
126
+ df = df.tail(need_points).reset_index(drop=True)
127
+
128
+ print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 股票数据获取成功,共{len(df)}个交易日")
129
+ print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 最新5条数据:\n{df[['timestamps', 'open', 'close', 'volume']].tail()}")
130
+ return df
131
+
132
+ finally:
133
+ bs.logout()
134
+ print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] Baostock已登出")
135
+
136
+
137
+ def make_prediction(df, predictor):
138
+ """执行模型推理,仅当天首次调用时运行,添加推理日志"""
139
+ china_now = get_china_time()
140
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始执行模型推理(预测未来{Config['PRED_HORIZON']}个交易日)")
141
+
142
+ # 准备时间戳
143
+ last_timestamp = df['timestamps'].max()
144
+ start_new_range = last_timestamp + pd.Timedelta(days=1)
145
+ new_timestamps_index = pd.date_range(
146
+ start=start_new_range,
147
+ periods=Config["PRED_HORIZON"],
148
+ freq='D'
149
+ )
150
+ y_timestamp = pd.Series(new_timestamps_index, name='y_timestamp')
151
+ x_timestamp = df['timestamps']
152
+ x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']]
153
+
154
+ # 推理(禁用梯度计算,节省资源)
155
+ with torch.no_grad():
156
+ begin_time = time.time()
157
+ close_preds_main, volume_preds_main = predictor.predict(
158
+ df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp,
159
+ pred_len=Config["PRED_HORIZON"], T=1.0, top_p=0.95,
160
+ sample_count=Config["N_PREDICTIONS"], verbose=True
161
+ )
162
+ infer_time = time.time() - begin_time
163
+ print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 推理完成,耗时{infer_time:.2f}秒")
164
+
165
+ # 波动率预测复用收盘价预测结果(保持原逻辑)
166
+ close_preds_volatility = close_preds_main
167
+ return close_preds_main, volume_preds_main, close_preds_volatility
168
+
169
+
170
+ def calculate_metrics(hist_df, close_preds_df, v_close_preds_df):
171
+ """计算上涨概率和波动率放大概率,添加指标计算日志"""
172
+ print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 开始计算预测指标...")
173
+
174
+ # 上涨概率(最后一个预测日相对于最新收盘价)
175
+ last_close = hist_df['close'].iloc[-1]
176
+ final_day_preds = close_preds_df.iloc[-1]
177
+ upside_prob = (final_day_preds > last_close).mean()
178
+
179
+ # 波动率放大概率(预测波动率vs历史波动率)
180
+ hist_log_returns = np.log(hist_df['close'] / hist_df['close'].shift(1))
181
+ historical_vol = hist_log_returns.iloc[-Config["VOL_WINDOW"]:].std()
182
+
183
+ amplification_count = 0
184
+ for col in v_close_preds_df.columns:
185
+ full_sequence = pd.concat([pd.Series([last_close]), v_close_preds_df[col]]).reset_index(drop=True)
186
+ pred_log_returns = np.log(full_sequence / full_sequence.shift(1))
187
+ predicted_vol = pred_log_returns.std()
188
+ if predicted_vol > historical_vol:
189
+ amplification_count += 1
190
+ vol_amp_prob = amplification_count / len(v_close_preds_df.columns)
191
+
192
+ # 打印指标日志
193
+ print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 指标计算完成:")
194
+ print(f" - 24个交易日上涨概率:{upside_prob:.2%}")
195
+ print(f" - 24个交易日波动率放大概率:{vol_amp_prob:.2%}")
196
+ return upside_prob, vol_amp_prob
197
+
198
+
199
+
200
+ def create_plot():
201
+ china_now = get_china_time()
202
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始生成预测图表(适配低版本matplotlib字体)")
203
+
204
+ # 从缓存获取数据(原有逻辑不变)
205
+ hist_df_for_plot = Config["CACHED_RESULTS"]["hist_df_for_plot"]
206
+ close_preds = Config["CACHED_RESULTS"]["close_preds"]
207
+ volume_preds = Config["CACHED_RESULTS"]["volume_preds"]
208
+
209
+ # -------------------------- 新增:创建画布和子图 --------------------------
210
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
211
+ # -----------------------------------------------------------------------------
212
+
213
+ # -------------------------- 修正:低版本matplotlib字体处理 --------------------------
214
+ from matplotlib.font_manager import FontProperties
215
+ font_path = Config["CHINESE_FONT_PATH"]
216
+
217
+ # 检查字体文件是否存在
218
+ if os.path.exists(font_path):
219
+ # 直接通过FontProperties指定字体文件路径(兼容低版本matplotlib)
220
+ chinese_font = FontProperties(fname=font_path)
221
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 成功加载.ttf字体:{font_path}")
222
+ else:
223
+ # 字体文件不存在时的 fallback 逻辑
224
+ chinese_font = FontProperties(family='SimHei', size=10)
225
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 字体文件不存在,使用系统默认字体:SimHei")
226
+
227
+ # 全局设置字体(确保坐标轴刻度等默认文本也能显示中文)
228
+ plt.rcParams["font.family"] = ["sans-serif"]
229
+ plt.rcParams["font.sans-serif"] = ["WenQuanYi Micro Hei", "SimHei", "Heiti TC"]
230
+ plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
231
+ # -----------------------------------------------------------------------------
232
+
233
+ # 绘图时,为所有中文文本显式指定字体(关键)
234
+ # 1. 价格子图
235
+ hist_time = hist_df_for_plot['timestamps']
236
+ ax1.plot(hist_time, hist_df_for_plot['close'], color='#00274C', linewidth=1.5)
237
+ mean_preds = close_preds.mean(axis=1)
238
+ # 生成预测时间序列(假设预测是在历史最后一个时间���后的24个交易日)
239
+ last_hist_time = hist_time.max()
240
+ pred_time = pd.date_range(start=last_hist_time + pd.Timedelta(days=1), periods=Config["PRED_HORIZON"], freq='B')
241
+ ax1.plot(pred_time, mean_preds, color='#FF6B00', linestyle='-')
242
+ ax1.fill_between(pred_time, close_preds.min(axis=1), close_preds.max(axis=1),
243
+ color='#FF6B00', alpha=0.2)
244
+ # 中文标题/标签指定字体
245
+ ax1.set_title(f'{Config["STOCK_CODE"]} 上证指数概率预测(未来{Config["PRED_HORIZON"]}个交易日)',
246
+ fontsize=16, weight='bold', fontproperties=chinese_font)
247
+ ax1.set_ylabel('价格(元)', fontsize=12, fontproperties=chinese_font)
248
+ # 图例指定字体
249
+ ax1.legend(['上证指数(后复权)', '预测均价', '预测区间(最小-最大)'],
250
+ fontsize=10, prop=chinese_font)
251
+ ax1.grid(True, which='both', linestyle='--', linewidth=0.5)
252
+
253
+ # 2. 成交量子图(同理指定字体)
254
+ ax2.bar(hist_time, hist_df_for_plot['volume']/1e8, color='#00A86B', width=0.6)
255
+ ax2.bar(pred_time, volume_preds.mean(axis=1)/1e8, color='#FF6B00', width=0.6)
256
+ ax2.set_ylabel('成交量(亿手)', fontsize=12, fontproperties=chinese_font)
257
+ ax2.set_xlabel('日期', fontsize=12, fontproperties=chinese_font)
258
+ ax2.legend(['历史成交量(亿手)', '预测成交量(亿手)'],
259
+ fontsize=10, prop=chinese_font)
260
+ ax2.grid(True, which='both', linestyle='--', linewidth=0.5)
261
+
262
+ # 添加分割线(区分历史和预测数据)
263
+ separator_time = last_hist_time + pd.Timedelta(hours=12)
264
+ for ax in [ax1, ax2]:
265
+ ax.axvline(x=separator_time, color='red', linestyle='--', linewidth=1.5, label='_nolegend_')
266
+ ax.tick_params(axis='x', rotation=45)
267
+
268
+ # 保存图表
269
+ fig.tight_layout()
270
+ chart_path = Path(Config["CHART_PATH"])
271
+ if chart_path.exists():
272
+ chart_path.chmod(0o666) # 确保可写权限
273
+ fig.savefig(chart_path, dpi=120, bbox_inches='tight')
274
+ plt.close(fig)
275
+
276
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 图表生成完成,保存路径:{chart_path}")
277
+
278
+
279
+ def update_html():
280
+ """更新HTML页面,复用当天缓存的指标,添加HTML更新日志"""
281
+ china_now = get_china_time()
282
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始更新HTML页面...")
283
+
284
+ # 1. 从缓存获取指标(增加空值判断,避免报错)
285
+ upside_prob = Config["CACHED_RESULTS"].get("upside_prob")
286
+ vol_amp_prob = Config["CACHED_RESULTS"].get("vol_amp_prob")
287
+
288
+ # 处理缓存为空的情况
289
+ if upside_prob is None or vol_amp_prob is None:
290
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 警告:缓存中未找到指标数据,无法更新HTML")
291
+ return
292
+
293
+ # 格式化指标(保留1位小数百分比)
294
+ upside_prob_str = f'{upside_prob:.1%}'
295
+ vol_amp_prob_str = f'{vol_amp_prob:.1%}'
296
+ now_cn_str = china_now.strftime('%Y-%m-%d %H:%M:%S')
297
+
298
+ # 2. 初始化HTML(不存在则创建基础模板)
299
+ html_path = Path(Config["HTML_PATH"])
300
+ src_html_path = Config["REPO_PATH"] / "templates" / "index.html"
301
+
302
+ if not html_path.exists():
303
+ html_path.parent.mkdir(parents=True, exist_ok=True)
304
+ if src_html_path.exists():
305
+ # 复制项目模板
306
+ import shutil
307
+ shutil.copy2(src_html_path, html_path)
308
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 从项目模板复制HTML:{src_html_path} -> {html_path}")
309
+ else:
310
+ # 创建基础中文HTML(确保指标对应的id与正则匹配)
311
+ base_html = """
312
+ <!DOCTYPE html>
313
+ <html>
314
+ <head>
315
+ <title>清华大模型Kronos上证指数预测</title>
316
+ <style>
317
+ body { max-width: 1200px; margin: 0 auto; padding: 20px; font-family: "WenQuanYi Micro Hei", Arial; }
318
+ .metric { margin: 20px 0; padding: 10px; background: #f5f5f5; border-radius: 5px; }
319
+ .metric-value { font-size: 1.2em; color: #0066cc; }
320
+ img { max-width: 100%; height: auto; }
321
+ h1 { color: #333; }
322
+ </style>
323
+ </head>
324
+ <body>
325
+ <h1>清华大学K线大模型Kronos上证指数(sh.000001)概率预测</h1>
326
+ <p>最后更新时间(中国时间):<strong id="update-time">未更新</strong></p>
327
+ <p>同 步 网 站:<strong><a href="http://15115656.top" target="_blank">火狼工具站</a></strong></p>
328
+ <div class="metric">
329
+ <p>24个交易日上涨概率:<span class="metric-value" id="upside-prob">--%</span></p>
330
+ </div>
331
+ <div class="metric">
332
+ <p>24个交易日波动率放大概率:<span class="metric-value" id="vol-amp-prob">--%</span></p>
333
+ </div>
334
+ <div><img src="/prediction_chart.png" alt="上证指数预测图表"></div>
335
+ </body>
336
+ </html>
337
+ """
338
+ with open(html_path, 'w', encoding='utf-8') as f:
339
+ f.write(base_html)
340
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 在/tmp创建基础HTML:{html_path}")
341
+
342
+ # 3. 读取HTML内容(确保读取成功)
343
+ try:
344
+ with open(html_path, 'r', encoding='utf-8') as f:
345
+ content = f.read()
346
+ except Exception as e:
347
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 读取HTML失败:{str(e)}")
348
+ return
349
+
350
+ # 4. 正则替换(关键:确保re.sub()参数完整)
351
+ # 替换更新时间
352
+ content = re.sub(
353
+ pattern=r'(<strong id="update-time">).*?(</strong>)',
354
+ repl=lambda m: f'{m.group(1)}{now_cn_str}{m.group(2)}',
355
+ string=content
356
+ )
357
+ # 替换上涨概率(id="upside-prob",与HTML模板对应)
358
+ content = re.sub(
359
+ pattern=r'(<span class="metric-value" id="upside-prob">).*?(</span>)',
360
+ repl=lambda m: f'{m.group(1)}{upside_prob_str}{m.group(2)}',
361
+ string=content
362
+ )
363
+ # 替换波动率放大概率(id="vol-amp-prob",与HTML模板对应)
364
+ content = re.sub(
365
+ pattern=r'(<span class="metric-value" id="vol-amp-prob">).*?(</span>)',
366
+ repl=lambda m: f'{m.group(1)}{vol_amp_prob_str}{m.group(2)}',
367
+ string=content
368
+ )
369
+
370
+ # 5. 写入更新后的HTML
371
+ try:
372
+ with open(html_path, 'w', encoding='utf-8') as f:
373
+ f.write(content)
374
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] HTML更新完成,路径:{html_path}")
375
+ # 验证替换结果(调试用)
376
+ print(f"[DEBUG] 上涨概率更新为:{upside_prob_str}")
377
+ print(f"[DEBUG] 波动率概率更新为:{vol_amp_prob_str}")
378
+ except Exception as e:
379
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 写入HTML失败:{str(e)}")
380
+
381
+ def git_commit_and_push():
382
+ """Git提交(仅当Git存在时执行),添加Git操作日志"""
383
+ china_now = get_china_time()
384
+ commit_message = f"Auto-update: 上证指数预测({china_now:%Y-%m-%d 中国时间})"
385
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始执行Git提交操作,提交信息:{commit_message}")
386
+
387
+ # 检查Git是否安装
388
+ try:
389
+ subprocess.run(['git', '--version'], check=True, capture_output=True, text=True)
390
+ except (subprocess.CalledProcessError, FileNotFoundError):
391
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git未安装或未在PATH中,跳过Git操作")
392
+ return
393
+
394
+ # 执行Git操作
395
+ try:
396
+ os.chdir(Config["REPO_PATH"])
397
+ # 复制图表和HTML到Git跟踪目录(若需要)
398
+ chart_src = Config["CHART_PATH"]
399
+ chart_dst = Config["REPO_PATH"] / "prediction_chart.png"
400
+ html_src = Config["HTML_PATH"]
401
+ html_dst = Config["REPO_PATH"] / "index.html"
402
+
403
+ if os.path.exists(chart_src):
404
+ import shutil
405
+ shutil.copy2(chart_src, chart_dst)
406
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 图表复制到Git目录:{chart_dst}")
407
+ if os.path.exists(html_src):
408
+ shutil.copy2(html_src, html_dst)
409
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] HTML复制到Git目录:{html_dst}")
410
+
411
+ # Git add
412
+ subprocess.run(['git', 'add', 'prediction_chart.png', 'index.html'], check=True, capture_output=True, text=True)
413
+ # Git commit
414
+ commit_result = subprocess.run(['git', 'commit', '-m', commit_message], check=True, capture_output=True, text=True)
415
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git提交输出:\n{commit_result.stdout}")
416
+ # Git push
417
+ push_result = subprocess.run(['git', 'push'], check=True, capture_output=True, text=True)
418
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git推送输出:\n{push_result.stdout}")
419
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git操作完成")
420
+
421
+ except subprocess.CalledProcessError as e:
422
+ output = e.stdout if e.stdout else e.stderr
423
+ if "nothing to commit" in output or "Your branch is up to date" in output:
424
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 无新内容需要提交或推送")
425
+ else:
426
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git错误:\nSTDOUT: {e.stdout}\nSTDERR: {e.stderr}")
427
+ except PermissionError as e:
428
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] Git权限错误:{str(e)},跳过Git操作")
429
+
430
+
431
+ def main_task(model):
432
+ """主任务:控制每日仅执行一次推理,当天复用缓存"""
433
+ china_now = get_china_time()
434
+ print(f"\n[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60)
435
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 开始执行主任务")
436
+
437
+ # 检查当天是否已完成推理(中国时间)
438
+ if Config["IS_TODAY_INFERENCED"]:
439
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 今日(中国时间)已完成推理,直接复用缓存结果")
440
+ # 复用缓存生成图表和HTML
441
+ create_plot()
442
+ update_html()
443
+ git_commit_and_push()
444
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 主任务完成(复用缓存)")
445
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60 + "\n")
446
+ return
447
+
448
+ # 当天首次执行:获取数据→推理→缓存结果→生成图表→更新HTML→Git提交
449
+ try:
450
+ # 1. 获取股票数据
451
+ df_full = fetch_stock_data()
452
+ df_for_model = df_full.iloc[:-1] # 排除最后一行避免数据泄漏
453
+
454
+ # 2. 执行推理
455
+ close_preds, volume_preds, v_close_preds = make_prediction(df_for_model, model)
456
+
457
+ # 3. 计算指标
458
+ hist_df_for_metrics = df_for_model.tail(Config["VOL_WINDOW"])
459
+ upside_prob, vol_amp_prob = calculate_metrics(hist_df_for_metrics, close_preds, v_close_preds)
460
+
461
+ # 4. 缓存结果(当天复用)
462
+ hist_df_for_plot = df_for_model.tail(Config["VOL_WINDOW"]) # 用于绘图的历史数据
463
+ Config["CACHED_RESULTS"] = {
464
+ "close_preds": close_preds,
465
+ "volume_preds": volume_preds,
466
+ "v_close_preds": v_close_preds,
467
+ "upside_prob": upside_prob,
468
+ "vol_amp_prob": vol_amp_prob,
469
+ "hist_df_for_plot": hist_df_for_plot
470
+ }
471
+ # 标记当天已完成推理
472
+ Config["IS_TODAY_INFERENCED"] = True
473
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 今日推理结果已缓存,后续调用将复用")
474
+
475
+ # 5. 生成图表
476
+ create_plot()
477
+
478
+ # 6. 更新HTML
479
+ update_html()
480
+
481
+ # 7. Git提交
482
+ git_commit_and_push()
483
+
484
+ # 8. 内存回收
485
+ del df_full, df_for_model, hist_df_for_metrics
486
+ gc.collect()
487
+
488
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 主任务完成(首次推理)")
489
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60 + "\n")
490
+
491
+ except Exception as e:
492
+ # 异常时不标记为“已推理”,下次调用重试
493
+ Config["IS_TODAY_INFERENCED"] = False
494
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 主任务执行失败,今日推理标记为未完成")
495
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 错误信息:{str(e)}")
496
+ import traceback
497
+ traceback.print_exc()
498
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] " + "="*60 + "\n")
499
+
500
+
501
+ def run_scheduler(model):
502
+ """定时器:中国时间每天0点触发主任务,其他时间5分钟检查一次"""
503
+ china_tz = timezone("Asia/Shanghai")
504
+ print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 定时器启动(中国时间),每天0点执行推理")
505
+
506
+ while True:
507
+ china_now = get_china_time()
508
+ # 计算次日0点(中国时间)
509
+ next_midnight = (china_now + timedelta(days=1)).replace(
510
+ hour=0, minute=0, second=5, microsecond=0, tzinfo=china_tz
511
+ )
512
+ # 计算等待时间(秒)
513
+ sleep_seconds = (next_midnight - china_now).total_seconds()
514
+
515
+ # 打印等待日志
516
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 当前时间:{china_now:%Y-%m-%d %H:%M:%S}(中国时间)")
517
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 下次执行时间:{next_midnight:%Y-%m-%d %H:%M:%S}(中国时间)")
518
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 等待时间:{sleep_seconds:.0f}秒(约{sleep_seconds/3600:.1f}小时)")
519
+
520
+ # 等待到次日0点
521
+ time.sleep(sleep_seconds)
522
+
523
+ # 到达0点,执行主任务
524
+ try:
525
+ main_task(model)
526
+ # 任务完成后,重置“当天已推理”标记(避免跨天复用)
527
+ Config["IS_TODAY_INFERENCED"] = False
528
+ except Exception as e:
529
+ print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 定时器触发任务失败:{str(e)}")
530
+ import traceback
531
+ traceback.print_exc()
532
+ print(f"[{get_china_time():%Y-%m-%d %H:%M:%S}] 5分钟后重试...")
533
+ time.sleep(300) # 重试间隔5分钟
534
+
535
+
536
+ if __name__ == '__main__':
537
+ # 初始化:加载模型→执行一次主任务→启动定时器
538
+ china_now = get_china_time()
539
+ print(f"[{china_now:%Y-%m-%d %H:%M:%S}] 程序启动(中国时间)")
540
+
541
+ # 加载模型
542
+ loaded_model = load_local_model()
543
+
544
+ # 首次执行主任务(若当天未执行)
545
+ main_task(loaded_model)
546
+
547
+ # 启动定时器(中国时间每天0点执行)
548
+ run_scheduler(loaded_model)