Astock / update_predictions-2.py
fiewolf1000's picture
Rename update_predictions.py to update_predictions-2.py
976d99d verified
raw
history blame
13.7 kB
import gc
import os
import re
import subprocess
import time
from datetime import datetime, timezone, timedelta
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import baostock as bs
from model import KronosTokenizer, Kronos, KronosPredictor
# --- Configuration ---
Config = {
"REPO_PATH": Path(__file__).parent.resolve(),
"LOCAL_MODEL_PATH": os.path.join(Path(__file__).parent.resolve(), "models"),
"STOCK_CODE": "sh.000001",
"FREQUENCY": "d",
"START_DATE": "2022-01-01",
"END_DATE": datetime.now().strftime("%Y-%m-%d"),
"HIST_POINTS": 360,
"PRED_HORIZON": 24,
"N_PREDICTIONS": 10,
"VOL_WINDOW": 24,
"PREDICTION_CACHE": os.path.join("/tmp", "predictions_cache"),
# 关键修改1:图表保存到 /tmp 目录(Hugging Face Spaces 默认可写)
"CHART_PATH": os.path.join("/tmp", "prediction_chart.png")
}
os.makedirs(Config["PREDICTION_CACHE"], exist_ok=True)
os.makedirs(Config["LOCAL_MODEL_PATH"], exist_ok=True)
# --- 其他函数(load_local_model、fetch_stock_data、make_prediction、calculate_metrics)保持不变 ---
def load_local_model():
print("Loading local Kronos model...")
tokenizer_path = os.path.join(Config["LOCAL_MODEL_PATH"], "tokenizer")
model_path = os.path.join(Config["LOCAL_MODEL_PATH"], "model")
if not os.path.exists(tokenizer_path):
raise FileNotFoundError(f"分词器路径不存在:{tokenizer_path}")
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型路径不存在:{model_path}")
tokenizer = KronosTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
model = Kronos.from_pretrained(model_path, local_files_only=True)
tokenizer.eval()
model.eval()
predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512)
print("Local model loaded successfully.")
return predictor
def fetch_stock_data():
stock_code = Config["STOCK_CODE"]
frequency = Config["FREQUENCY"]
start_date = Config["START_DATE"]
end_date = Config["END_DATE"]
need_points = Config["HIST_POINTS"] + Config["VOL_WINDOW"]
print(f"Fetching {need_points}个交易日的{stock_code}日线数据...")
lg = bs.login()
if lg.error_code != '0':
raise ConnectionError(f"Baostock登录失败:{lg.error_msg}")
try:
fields = "date,open,high,low,close,volume"
rs = bs.query_history_k_data_plus(
code=stock_code,
fields=fields,
start_date=start_date,
end_date=end_date,
frequency=frequency,
adjustflag="2"
)
if rs.error_code != '0':
raise ValueError(f"获取K线数据失败:{rs.error_msg}")
data_list = []
while rs.next():
data_list.append(rs.get_row_data())
df = pd.DataFrame(data_list, columns=rs.fields)
numeric_cols = ['open', 'high', 'low', 'close', 'volume']
for col in numeric_cols:
df[col] = pd.to_numeric(df[col], errors='coerce')
df = df.dropna(subset=numeric_cols)
df['timestamps'] = pd.to_datetime(df['date'], format='%Y-%m-%d')
df['amount'] = (df['open'] + df['high'] + df['low'] + df['close']) / 4 * df['volume']
df = df[['timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']]
if len(df) < need_points:
raise ValueError(f"数据量不足(仅{len(df)}个交易日),请提前start_date")
df = df.tail(need_points).reset_index(drop=True)
print(f"A股数据获取成功,共{len(df)}个交易日数据")
print(df[['timestamps', 'open', 'close', 'volume', 'amount']].tail())
return df
finally:
bs.logout()
def make_prediction(df, predictor):
last_timestamp = df['timestamps'].max()
start_new_range = last_timestamp + pd.Timedelta(days=1)
new_timestamps_index = pd.date_range(
start=start_new_range,
periods=Config["PRED_HORIZON"],
freq='D'
)
y_timestamp = pd.Series(new_timestamps_index, name='y_timestamp')
x_timestamp = df['timestamps']
x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']]
with torch.no_grad():
print("Making main prediction (T=1.0)...")
begin_time = time.time()
close_preds_main, volume_preds_main = predictor.predict(
df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp,
pred_len=Config["PRED_HORIZON"], T=1.0, top_p=0.95,
sample_count=Config["N_PREDICTIONS"], verbose=True
)
print(f"Main prediction completed in {time.time() - begin_time:.2f} seconds.")
close_preds_volatility = close_preds_main
return close_preds_main, volume_preds_main, close_preds_volatility
def calculate_metrics(hist_df, close_preds_df, v_close_preds_df):
last_close = hist_df['close'].iloc[-1]
final_day_preds = close_preds_df.iloc[-1]
upside_prob = (final_day_preds > last_close).mean()
hist_log_returns = np.log(hist_df['close'] / hist_df['close'].shift(1))
historical_vol = hist_log_returns.iloc[-Config["VOL_WINDOW"]:].std()
amplification_count = 0
for col in v_close_preds_df.columns:
full_sequence = pd.concat([pd.Series([last_close]), v_close_preds_df[col]]).reset_index(drop=True)
pred_log_returns = np.log(full_sequence / full_sequence.shift(1))
predicted_vol = pred_log_returns.std()
if predicted_vol > historical_vol:
amplification_count += 1
vol_amp_prob = amplification_count / len(v_close_preds_df.columns)
print(f"上证指数上涨概率(24个交易日):{upside_prob:.2%}")
print(f"波动率放大概率(24个交易日):{vol_amp_prob:.2%}")
return upside_prob, vol_amp_prob
# --- 关键修改:create_plot 函数(解决字体和权限问题)---
def create_plot(hist_df, close_preds_df, volume_preds_df):
print("Generating comprehensive forecast chart...")
# 关键修改2:设置支持中文的字体(使用系统自带的中文字体,无需额外安装)
plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei', 'SimHei', 'DejaVu Sans'] # 优先级:文泉驿→黑体→备用字体
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示异常问题
fig, (ax1, ax2) = plt.subplots(
2, 1, figsize=(15, 10), sharex=True,
gridspec_kw={'height_ratios': [3, 1]}
)
# 数据准备(逻辑不变)
hist_time = hist_df['timestamps']
last_hist_time = hist_time.iloc[-1]
pred_time = pd.to_datetime([last_hist_time + timedelta(days=i + 1) for i in range(len(close_preds_df))])
# 价格子图
ax1.plot(hist_time, hist_df['close'], color='#00274C', label='上证指数(后复权)', linewidth=1.5)
mean_preds = close_preds_df.mean(axis=1)
ax1.plot(pred_time, mean_preds, color='#FF6B00', linestyle='-', label='预测均价')
ax1.fill_between(pred_time, close_preds_df.min(axis=1), close_preds_df.max(axis=1),
color='#FF6B00', alpha=0.2, label='预测区间(最小-最大)')
ax1.set_title(f'{Config["STOCK_CODE"]} 上证指数 概率预测(未来{Config["PRED_HORIZON"]}个交易日)',
fontsize=16, weight='bold')
ax1.set_ylabel('价格(元)')
ax1.legend()
ax1.grid(True, which='both', linestyle='--', linewidth=0.5)
# 成交量子图
ax2.bar(hist_time, hist_df['volume']/1e8, color='#00A86B', label='历史成交量(亿手)', width=0.6)
ax2.bar(pred_time, volume_preds_df.mean(axis=1)/1e8, color='#FF6B00', label='预测成交量(亿手)', width=0.6)
ax2.set_ylabel('成交量(亿手)')
ax2.set_xlabel('日期')
ax2.legend()
ax2.grid(True, which='both', linestyle='--', linewidth=0.5)
# 分割线
separator_time = last_hist_time + timedelta(hours=12)
for ax in [ax1, ax2]:
ax.axvline(x=separator_time, color='red', linestyle='--', linewidth=1.5, label='_nolegend_')
ax.tick_params(axis='x', rotation=45)
fig.tight_layout()
# 关键修改3:保存到 /tmp 目录(权限充足),并添加权限检查
chart_path = Path(Config["CHART_PATH"])
# 确保父目录存在且有写入权限
if not chart_path.parent.exists():
chart_path.parent.mkdir(parents=True, exist_ok=True)
# 强制赋予文件写入权限(可选,进一步避免权限问题)
if chart_path.exists():
chart_path.chmod(0o666) # 所有用户可读写
fig.savefig(chart_path, dpi=120, bbox_inches='tight')
plt.close(fig)
print(f"图表已保存至:{chart_path}(权限:{oct(os.stat(chart_path).st_mode)[-3:]})")
# --- 其他函数(update_html、git_commit_and_push、main_task、run_scheduler)保持不变 ---
def update_html(upside_prob, vol_amp_prob):
print("Updating index.html...")
html_path = Config["REPO_PATH"] / 'index.html'
now_utc_str = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')
upside_prob_str = f'{upside_prob:.1%}'
vol_amp_prob_str = f'{vol_amp_prob:.1%}'
with open(html_path, 'r', encoding='utf-8') as f:
content = f.read()
content = re.sub(
r'(<strong id="update-time">).*?(</strong>)',
lambda m: f'{m.group(1)}{now_utc_str}{m.group(2)}',
content
)
content = re.sub(
r'(<p class="metric-value" id="upside-prob">).*?(</p>)',
lambda m: f'{m.group(1)}{upside_prob_str}{m.group(2)}',
content
)
content = re.sub(
r'(<p class="metric-value" id="vol-amp-prob">).*?(</p>)',
lambda m: f'{m.group(1)}{vol_amp_prob_str}{m.group(2)}',
content
)
content = content.replace("24h Upside Probability", "24个交易日上涨概率")
content = content.replace("Volatility Amplification Probability", "波动率放大概率")
content = content.replace("BTCUSDT", "上证指数(sh.000001)")
with open(html_path, 'w', encoding='utf-8') as f:
f.write(content)
print("HTML文件更新成功")
def git_commit_and_push(commit_message):
print("Performing Git operations...")
try:
os.chdir(Config["REPO_PATH"])
# 关键修改4:若需Git提交图表,需从 /tmp 复制到 /app 目录(因/tmp目录不纳入Git跟踪)
chart_src = Config["CHART_PATH"]
chart_dst = os.path.join(Config["REPO_PATH"], "prediction_chart.png")
if os.path.exists(chart_src):
import shutil
shutil.copy2(chart_src, chart_dst) # 复制图表到Git跟踪目录
print(f"图表已从 {chart_src} 复制到 {chart_dst}")
subprocess.run(['git', 'add', 'prediction_chart.png', 'index.html'], check=True, capture_output=True, text=True)
commit_result = subprocess.run(['git', 'commit', '-m', commit_message], check=True, capture_output=True, text=True)
print(commit_result.stdout)
push_result = subprocess.run(['git', 'push'], check=True, capture_output=True, text=True)
print(push_result.stdout)
print("Git push successful.")
except subprocess.CalledProcessError as e:
output = e.stdout if e.stdout else e.stderr
if "nothing to commit" in output or "Your branch is up to date" in output:
print("No new changes to commit or push.")
else:
print(f"A Git error occurred:\n--- STDOUT ---\n{e.stdout}\n--- STDERR ---\n{e.stderr}")
def main_task(model):
print("\n" + "=" * 60 + f"\nStarting update task at {datetime.now(timezone.utc)}\n" + "=" * 60)
df_full = fetch_stock_data()
df_for_model = df_full.iloc[:-1]
close_preds, volume_preds, v_close_preds = make_prediction(df_for_model, model)
hist_df_for_plot = df_for_model.tail(Config["HIST_POINTS"])
hist_df_for_metrics = df_for_model.tail(Config["VOL_WINDOW"])
upside_prob, vol_amp_prob = calculate_metrics(hist_df_for_metrics, close_preds, v_close_preds)
create_plot(hist_df_for_plot, close_preds, volume_preds) # 调用修改后的绘图函数
update_html(upside_prob, vol_amp_prob)
commit_message = f"Auto-update 上证指数预测 {datetime.now(timezone.utc):%Y-%m-%d %H:%M} UTC"
git_commit_and_push(commit_message)
del df_full, df_for_model, close_preds, volume_preds, v_close_preds
del hist_df_for_plot, hist_df_for_metrics
gc.collect()
print("-" * 60 + "\n--- Task completed successfully ---\n" + "-" * 60 + "\n")
def run_scheduler(model):
while True:
now = datetime.now(timezone.utc)
next_run_time = (now + timedelta(days=1)).replace(hour=0, minute=0, second=5, microsecond=0)
sleep_seconds = (next_run_time - now).total_seconds()
if sleep_seconds > 0:
print(f"Current time: {now:%Y-%m-%d %H:%M:%S UTC}.")
print(f"Next run at: {next_run_time:%Y-%m-%d %H:%M:%S UTC}. Waiting for {sleep_seconds:.0f} seconds...")
time.sleep(sleep_seconds)
try:
main_task(model)
except Exception as e:
print(f"\n!!!!!! A critical error occurred in the main task !!!!!!!")
print(f"Error: {e}")
import traceback
traceback.print_exc()
print("Retrying in 5 minutes...")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n")
time.sleep(300)
if __name__ == '__main__':
local_model_path = Path(Config["LOCAL_MODEL_PATH"])
local_model_path.mkdir(parents=True, exist_ok=True)
loaded_model = load_local_model()
main_task(loaded_model)
run_scheduler(loaded_model)