Akshit Chaturvedi
commited on
Commit
Β·
40c93e1
1
Parent(s):
a1f72cd
Initial code commit
Browse files- app.py +308 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from prophet import Prophet
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
from datetime import datetime, timedelta
|
8 |
+
from alpha_vantage.timeseries import TimeSeries # Alpha Vantage library
|
9 |
+
|
10 |
+
# --- Configuration ---
|
11 |
+
# Directory where your .json model files are (for hyperparameters)
|
12 |
+
MODEL_PARAMS_DIR = "./trained_models"
|
13 |
+
MODEL_PARAMS_PREFIX = "prophet_model_"
|
14 |
+
DATA_CACHE_FILE = "data_cache.json" # File to cache Alpha Vantage data
|
15 |
+
|
16 |
+
# Fetch Alpha Vantage API Key from Hugging Face Space Secrets
|
17 |
+
ALPHAVANTAGE_API_KEY = os.environ.get("ALPHAVANTAGE_API_KEY")
|
18 |
+
|
19 |
+
if not ALPHAVANTAGE_API_KEY:
|
20 |
+
print("CRITICAL WARNING: ALPHAVANTAGE_API_KEY secret not found in Space settings!")
|
21 |
+
# The app might still run but data fetching will fail.
|
22 |
+
# Gradio UI can show an error message if this happens during data fetch.
|
23 |
+
|
24 |
+
# Default Prophet parameters (can be overridden by those in JSON files)
|
25 |
+
# Based on your training script
|
26 |
+
DEFAULT_PROPHET_PARAMS = {
|
27 |
+
'yearly_seasonality': True,
|
28 |
+
'weekly_seasonality': False, # You had this as False
|
29 |
+
'daily_seasonality': False, # You had this as False
|
30 |
+
'changepoint_prior_scale': 0.05,
|
31 |
+
'seasonality_prior_scale': 10.0,
|
32 |
+
'growth': 'linear' # Common default
|
33 |
+
}
|
34 |
+
|
35 |
+
# --- Load Model Hyperparameters ---
|
36 |
+
# These JSONs now primarily serve to list available tickers and potentially
|
37 |
+
# override default hyperparameters if specific ones were saved.
|
38 |
+
model_hyperparams_catalogue = {}
|
39 |
+
print("Loading model hyperparameter configurations...")
|
40 |
+
if os.path.exists(MODEL_PARAMS_DIR):
|
41 |
+
for filename in os.listdir(MODEL_PARAMS_DIR):
|
42 |
+
if filename.startswith(MODEL_PARAMS_PREFIX) and filename.endswith(".json"):
|
43 |
+
model_name_key = filename.replace(MODEL_PARAMS_PREFIX, "").replace(".json", "")
|
44 |
+
file_path = os.path.join(MODEL_PARAMS_DIR, filename)
|
45 |
+
try:
|
46 |
+
with open(file_path, 'r') as f:
|
47 |
+
# The JSONs from model_to_json are full models.
|
48 |
+
# We primarily need the ticker name. Hyperparameters can be
|
49 |
+
# extracted if they are top-level, or we use defaults.
|
50 |
+
# For simplicity, we'll use DEFAULT_PROPHET_PARAMS for now
|
51 |
+
# but confirm they exist.
|
52 |
+
# json_data = json.load(f)
|
53 |
+
# specific_params = {
|
54 |
+
# 'yearly_seasonality': json_data.get('yearly_seasonality', DEFAULT_PROPHET_PARAMS['yearly_seasonality']),
|
55 |
+
# # ... extract other relevant params ...
|
56 |
+
# }
|
57 |
+
model_hyperparams_catalogue[model_name_key] = DEFAULT_PROPHET_PARAMS.copy() # Use defaults
|
58 |
+
print(f"Registered model config for: {model_name_key}")
|
59 |
+
except Exception as e:
|
60 |
+
print(f"Error reading or parsing model JSON {filename}: {e}")
|
61 |
+
else:
|
62 |
+
print(f"WARNING: Model parameters directory '{MODEL_PARAMS_DIR}' not found.")
|
63 |
+
|
64 |
+
available_model_names = sorted(list(model_hyperparams_catalogue.keys()))
|
65 |
+
if not available_model_names:
|
66 |
+
print("WARNING: No model configurations loaded. The application might not function correctly.")
|
67 |
+
|
68 |
+
# --- Data Fetching and Caching Logic ---
|
69 |
+
def load_data_cache():
|
70 |
+
if os.path.exists(DATA_CACHE_FILE):
|
71 |
+
try:
|
72 |
+
with open(DATA_CACHE_FILE, 'r') as f:
|
73 |
+
return json.load(f)
|
74 |
+
except json.JSONDecodeError:
|
75 |
+
print(f"Cache file {DATA_CACHE_FILE} is corrupted. Starting with an empty cache.")
|
76 |
+
return {}
|
77 |
+
except Exception as e:
|
78 |
+
print(f"Error loading cache file {DATA_CACHE_FILE}: {e}. Starting with an empty cache.")
|
79 |
+
return {}
|
80 |
+
return {}
|
81 |
+
|
82 |
+
def save_data_cache(cache):
|
83 |
+
try:
|
84 |
+
with open(DATA_CACHE_FILE, 'w') as f:
|
85 |
+
json.dump(cache, f, indent=4)
|
86 |
+
except Exception as e:
|
87 |
+
print(f"Error saving cache file {DATA_CACHE_FILE}: {e}")
|
88 |
+
|
89 |
+
def get_timeseries_data_from_alphavantage(ticker_symbol):
|
90 |
+
"""
|
91 |
+
Fetches time series data from Alpha Vantage for a given ticker symbol.
|
92 |
+
Returns a pandas DataFrame with 'ds' and 'y' columns, or None on error.
|
93 |
+
"""
|
94 |
+
if not ALPHAVANTAGE_API_KEY:
|
95 |
+
raise ValueError("Alpha Vantage API key is not configured.")
|
96 |
+
|
97 |
+
print(f"Attempting to fetch new data for {ticker_symbol} from Alpha Vantage...")
|
98 |
+
ts = TimeSeries(key=ALPHAVANTAGE_API_KEY, output_format='pandas')
|
99 |
+
try:
|
100 |
+
# Get daily adjusted prices for robustness (handles splits/dividends)
|
101 |
+
# 'full' gives up to 20+ years of data. 'compact' gives last 100 days.
|
102 |
+
# Prophet generally benefits from at least 1-2 years of data.
|
103 |
+
data_av, meta_data = ts.get_daily_adjusted(symbol=ticker_symbol, outputsize='full')
|
104 |
+
|
105 |
+
# Process Alpha Vantage DataFrame:
|
106 |
+
# 1. Sort by date (Alpha Vantage usually returns newest first)
|
107 |
+
data_av = data_av.sort_index(ascending=True)
|
108 |
+
# 2. Rename date index to 'ds' and the chosen price column to 'y'
|
109 |
+
# '5. adjusted close' is generally preferred.
|
110 |
+
df_prophet = data_av[['5. adjusted close']].reset_index()
|
111 |
+
df_prophet.rename(columns={'date': 'ds', '5. adjusted close': 'y'}, inplace=True)
|
112 |
+
# 3. Ensure 'ds' is datetime
|
113 |
+
df_prophet['ds'] = pd.to_datetime(df_prophet['ds'])
|
114 |
+
# 4. Ensure 'y' is numeric
|
115 |
+
df_prophet['y'] = pd.to_numeric(df_prophet['y'], errors='coerce')
|
116 |
+
df_prophet.dropna(subset=['y'], inplace=True) # Remove rows where y could not be coerced
|
117 |
+
|
118 |
+
if df_prophet.empty:
|
119 |
+
print(f"No valid data returned from Alpha Vantage for {ticker_symbol} after processing.")
|
120 |
+
return None
|
121 |
+
|
122 |
+
print(f"Successfully fetched and processed {len(df_prophet)} data points for {ticker_symbol}.")
|
123 |
+
return df_prophet[['ds', 'y']]
|
124 |
+
|
125 |
+
except Exception as e:
|
126 |
+
print(f"Alpha Vantage API Error for {ticker_symbol}: {type(e).__name__} - {e}")
|
127 |
+
# Common issues: Invalid API key, rate limit exceeded, ticker not found.
|
128 |
+
if "Invalid API call" in str(e) or "does not exist" in str(e):
|
129 |
+
print(f"Ticker {ticker_symbol} might not be valid or API call issue.")
|
130 |
+
elif "rate limit" in str(e).lower():
|
131 |
+
print("Alpha Vantage rate limit likely exceeded.")
|
132 |
+
return None
|
133 |
+
|
134 |
+
def get_and_cache_data(ticker_symbol, min_history_days=730): # Need enough history for Prophet
|
135 |
+
cache = load_data_cache()
|
136 |
+
today_str = datetime.now().strftime("%Y-%m-%d")
|
137 |
+
|
138 |
+
# Check cache first
|
139 |
+
if ticker_symbol in cache and cache[ticker_symbol].get("date_fetched") == today_str:
|
140 |
+
print(f"Using cached data for {ticker_symbol} from {today_str}")
|
141 |
+
try:
|
142 |
+
# Data stored as list of dicts; convert back to DataFrame
|
143 |
+
df_data = pd.DataFrame(cache[ticker_symbol]["data"])
|
144 |
+
df_data['ds'] = pd.to_datetime(df_data['ds'])
|
145 |
+
return df_data
|
146 |
+
except Exception as e:
|
147 |
+
print(f"Error loading data from cache for {ticker_symbol}: {e}. Will try fetching.")
|
148 |
+
|
149 |
+
# If not in cache or stale, fetch from Alpha Vantage
|
150 |
+
df_new_data = get_timeseries_data_from_alphavantage(ticker_symbol)
|
151 |
+
|
152 |
+
if df_new_data is not None and not df_new_data.empty:
|
153 |
+
if len(df_new_data) < min_history_days / 2: # Arbitrary check for too little history
|
154 |
+
print(f"WARNING: Fetched data for {ticker_symbol} is very short ({len(df_new_data)} days). Forecast quality may be poor.")
|
155 |
+
|
156 |
+
# Update cache
|
157 |
+
# Convert 'ds' to string for JSON serialization
|
158 |
+
data_to_cache = df_new_data.copy()
|
159 |
+
data_to_cache['ds'] = data_to_cache['ds'].dt.strftime('%Y-%m-%d')
|
160 |
+
cache[ticker_symbol] = {
|
161 |
+
"date_fetched": today_str,
|
162 |
+
"data": data_to_cache.to_dict(orient='records')
|
163 |
+
}
|
164 |
+
save_data_cache(cache)
|
165 |
+
return df_new_data # Return with 'ds' as datetime
|
166 |
+
else:
|
167 |
+
# If fetching failed, check if older cache data exists and return that with a warning
|
168 |
+
if ticker_symbol in cache and "data" in cache[ticker_symbol]:
|
169 |
+
print(f"Fetching new data failed for {ticker_symbol}. Using older cached data.")
|
170 |
+
df_data = pd.DataFrame(cache[ticker_symbol]["data"])
|
171 |
+
df_data['ds'] = pd.to_datetime(df_data['ds'])
|
172 |
+
return df_data # This data might be stale
|
173 |
+
print(f"Failed to fetch or find any cached data for {ticker_symbol}.")
|
174 |
+
return None
|
175 |
+
|
176 |
+
|
177 |
+
# --- Main Prediction Function (Re-fitting Prophet on demand) ---
|
178 |
+
def predict_dynamic_forecast(ticker_selection, forecast_periods_str):
|
179 |
+
status_message = ""
|
180 |
+
if not ALPHAVANTAGE_API_KEY:
|
181 |
+
return "ERROR: Alpha Vantage API Key not configured in Space Secrets.", None
|
182 |
+
if not ticker_selection:
|
183 |
+
return "Please select a ticker.", None
|
184 |
+
|
185 |
+
try:
|
186 |
+
forecast_periods = int(forecast_periods_str)
|
187 |
+
if forecast_periods <= 0:
|
188 |
+
return "Forecast periods must be a positive number.", None
|
189 |
+
except ValueError:
|
190 |
+
return "Invalid number for forecast periods.", None
|
191 |
+
|
192 |
+
hyperparams = model_hyperparams_catalogue.get(ticker_selection)
|
193 |
+
if not hyperparams: # Should not happen if dropdown is populated correctly
|
194 |
+
return f"Configuration for '{ticker_selection}' not found.", None
|
195 |
+
|
196 |
+
try:
|
197 |
+
status_message += f"Fetching/loading data for {ticker_selection}...\n"
|
198 |
+
# Prophet generally needs at least a year or two of data.
|
199 |
+
# Alpha Vantage 'full' outputsize should provide this.
|
200 |
+
historical_df = get_and_cache_data(ticker_selection, min_history_days=365 * 2)
|
201 |
+
|
202 |
+
if historical_df is None or historical_df.empty:
|
203 |
+
status_message += f"Failed to retrieve sufficient historical data for {ticker_selection}."
|
204 |
+
return status_message, None
|
205 |
+
|
206 |
+
if len(historical_df) < 30: # Prophet needs some minimal data
|
207 |
+
status_message += f"Historical data for {ticker_selection} is too short ({len(historical_df)} points) to make a forecast."
|
208 |
+
return status_message, None
|
209 |
+
|
210 |
+
status_message += f"Data loaded. Preprocessing for Prophet (log transform 'y')...\n"
|
211 |
+
fit_df = historical_df.copy()
|
212 |
+
# IMPORTANT: Log transform 'y' as done during original training
|
213 |
+
fit_df['y'] = np.log(fit_df['y'])
|
214 |
+
# Handle potential -inf/inf/NaN from log(0) or log(negative_value)
|
215 |
+
fit_df.replace([np.inf, -np.inf], np.nan, inplace=True)
|
216 |
+
fit_df['y'] = fit_df['y'].ffill().bfill() # Forward fill then backward fill
|
217 |
+
|
218 |
+
if fit_df['y'].isnull().any():
|
219 |
+
status_message += f"NaNs remain in log-transformed 'y' for {ticker_selection} after fill. Cannot fit model."
|
220 |
+
return status_message, None
|
221 |
+
|
222 |
+
status_message += f"Fitting Prophet model for {ticker_selection} with latest data...\n"
|
223 |
+
model = Prophet(**hyperparams)
|
224 |
+
model.fit(fit_df[['ds', 'y']]) # Fit on 'ds' and log-transformed 'y'
|
225 |
+
|
226 |
+
status_message += f"Generating forecast for {forecast_periods} periods...\n"
|
227 |
+
future_df = model.make_future_dataframe(periods=forecast_periods, freq='D') # Daily frequency
|
228 |
+
forecast_log_scale = model.predict(future_df)
|
229 |
+
|
230 |
+
# IMPORTANT: Inverse transform (exponentiate) predictions
|
231 |
+
output_df = forecast_log_scale[['ds']].copy()
|
232 |
+
output_df['Predicted Price (yhat)'] = np.exp(forecast_log_scale['yhat'])
|
233 |
+
output_df['Lower Bound (yhat_lower)'] = np.exp(forecast_log_scale['yhat_lower'])
|
234 |
+
output_df['Upper Bound (yhat_upper)'] = np.exp(forecast_log_scale['yhat_upper'])
|
235 |
+
|
236 |
+
# Return only the future forecast part
|
237 |
+
final_forecast_df = output_df.tail(forecast_periods).reset_index(drop=True)
|
238 |
+
|
239 |
+
status_message += "Forecast generated successfully."
|
240 |
+
return status_message, final_forecast_df
|
241 |
+
|
242 |
+
except Exception as e:
|
243 |
+
error_full_message = f"ERROR during prediction for {ticker_selection}: {type(e).__name__} - {str(e)}"
|
244 |
+
print(error_full_message) # Log to Hugging Face Space console for debugging
|
245 |
+
# Also include parts of it in status_message for the user
|
246 |
+
status_message += f"\nAn error occurred: {type(e).__name__}. Check Space logs for details."
|
247 |
+
return status_message, None
|
248 |
+
|
249 |
+
# --- Gradio Interface Definition ---
|
250 |
+
with gr.Blocks(css="footer {visibility: hidden}", title="Stock/Commodity Forecaster") as iface:
|
251 |
+
gr.Markdown("# Stock & Commodity Price Forecaster")
|
252 |
+
gr.Markdown(
|
253 |
+
"This tool fetches the latest market data using the Alpha Vantage API, "
|
254 |
+
"re-fits a Prophet time series model on-the-fly using pre-defined hyperparameters, "
|
255 |
+
"and generates a future price forecast."
|
256 |
+
"\n\n**Note:** Forecasts are for informational purposes only and not financial advice. "
|
257 |
+
"Data fetching may be slow on the first request for a ticker each day."
|
258 |
+
)
|
259 |
+
if not ALPHAVANTAGE_API_KEY:
|
260 |
+
gr.Markdown("<h3 style='color:red;'>WARNING: Alpha Vantage API Key is not configured in Space Secrets. Data fetching will fail.</h3>")
|
261 |
+
|
262 |
+
with gr.Row():
|
263 |
+
with gr.Column(scale=1):
|
264 |
+
ticker_dropdown = gr.Dropdown(
|
265 |
+
choices=available_model_names,
|
266 |
+
label="Select Ticker Symbol",
|
267 |
+
info="Choose the stock/commodity to forecast."
|
268 |
+
)
|
269 |
+
periods_input = gr.Number(
|
270 |
+
value=30,
|
271 |
+
label="Forecast Periods (Days)",
|
272 |
+
minimum=1,
|
273 |
+
maximum=365 * 2, # Max 2 years forecast
|
274 |
+
step=1,
|
275 |
+
info="Number of future days to predict."
|
276 |
+
)
|
277 |
+
predict_button = gr.Button("Generate Forecast", variant="primary")
|
278 |
+
|
279 |
+
with gr.Column(scale=3):
|
280 |
+
status_textbox = gr.Textbox(
|
281 |
+
label="Process Status & Logs",
|
282 |
+
lines=6,
|
283 |
+
interactive=False,
|
284 |
+
placeholder="Status messages will appear here..."
|
285 |
+
)
|
286 |
+
|
287 |
+
gr.Markdown("## Forecast Results")
|
288 |
+
forecast_output_table = gr.DataFrame(
|
289 |
+
label="Price Forecast Data",
|
290 |
+
headers=['Date (ds)', 'Predicted Price (yhat)', 'Lower Bound', 'Upper Bound'] # Match output_df columns
|
291 |
+
)
|
292 |
+
|
293 |
+
predict_button.click(
|
294 |
+
fn=predict_dynamic_forecast,
|
295 |
+
inputs=[ticker_dropdown, periods_input],
|
296 |
+
outputs=[status_textbox, forecast_output_table]
|
297 |
+
)
|
298 |
+
|
299 |
+
gr.Markdown("---")
|
300 |
+
gr.Markdown(
|
301 |
+
"**How it works:** Models are based on Facebook's Prophet. Hyperparameters are pre-set. "
|
302 |
+
"Historical data is log-transformed before fitting. Predictions are exponentiated back. "
|
303 |
+
"Data is cached daily to minimize Alpha Vantage API calls."
|
304 |
+
)
|
305 |
+
|
306 |
+
# --- Launch the Gradio App ---
|
307 |
+
if __name__ == "__main__":
|
308 |
+
iface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
prophet
|
3 |
+
pandas
|
4 |
+
numpy
|
5 |
+
alpha-vantage
|