Akshit Chaturvedi commited on
Commit
40c93e1
Β·
1 Parent(s): a1f72cd

Initial code commit

Browse files
Files changed (2) hide show
  1. app.py +308 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from prophet import Prophet
3
+ import pandas as pd
4
+ import numpy as np
5
+ import os
6
+ import json
7
+ from datetime import datetime, timedelta
8
+ from alpha_vantage.timeseries import TimeSeries # Alpha Vantage library
9
+
10
+ # --- Configuration ---
11
+ # Directory where your .json model files are (for hyperparameters)
12
+ MODEL_PARAMS_DIR = "./trained_models"
13
+ MODEL_PARAMS_PREFIX = "prophet_model_"
14
+ DATA_CACHE_FILE = "data_cache.json" # File to cache Alpha Vantage data
15
+
16
+ # Fetch Alpha Vantage API Key from Hugging Face Space Secrets
17
+ ALPHAVANTAGE_API_KEY = os.environ.get("ALPHAVANTAGE_API_KEY")
18
+
19
+ if not ALPHAVANTAGE_API_KEY:
20
+ print("CRITICAL WARNING: ALPHAVANTAGE_API_KEY secret not found in Space settings!")
21
+ # The app might still run but data fetching will fail.
22
+ # Gradio UI can show an error message if this happens during data fetch.
23
+
24
+ # Default Prophet parameters (can be overridden by those in JSON files)
25
+ # Based on your training script
26
+ DEFAULT_PROPHET_PARAMS = {
27
+ 'yearly_seasonality': True,
28
+ 'weekly_seasonality': False, # You had this as False
29
+ 'daily_seasonality': False, # You had this as False
30
+ 'changepoint_prior_scale': 0.05,
31
+ 'seasonality_prior_scale': 10.0,
32
+ 'growth': 'linear' # Common default
33
+ }
34
+
35
+ # --- Load Model Hyperparameters ---
36
+ # These JSONs now primarily serve to list available tickers and potentially
37
+ # override default hyperparameters if specific ones were saved.
38
+ model_hyperparams_catalogue = {}
39
+ print("Loading model hyperparameter configurations...")
40
+ if os.path.exists(MODEL_PARAMS_DIR):
41
+ for filename in os.listdir(MODEL_PARAMS_DIR):
42
+ if filename.startswith(MODEL_PARAMS_PREFIX) and filename.endswith(".json"):
43
+ model_name_key = filename.replace(MODEL_PARAMS_PREFIX, "").replace(".json", "")
44
+ file_path = os.path.join(MODEL_PARAMS_DIR, filename)
45
+ try:
46
+ with open(file_path, 'r') as f:
47
+ # The JSONs from model_to_json are full models.
48
+ # We primarily need the ticker name. Hyperparameters can be
49
+ # extracted if they are top-level, or we use defaults.
50
+ # For simplicity, we'll use DEFAULT_PROPHET_PARAMS for now
51
+ # but confirm they exist.
52
+ # json_data = json.load(f)
53
+ # specific_params = {
54
+ # 'yearly_seasonality': json_data.get('yearly_seasonality', DEFAULT_PROPHET_PARAMS['yearly_seasonality']),
55
+ # # ... extract other relevant params ...
56
+ # }
57
+ model_hyperparams_catalogue[model_name_key] = DEFAULT_PROPHET_PARAMS.copy() # Use defaults
58
+ print(f"Registered model config for: {model_name_key}")
59
+ except Exception as e:
60
+ print(f"Error reading or parsing model JSON {filename}: {e}")
61
+ else:
62
+ print(f"WARNING: Model parameters directory '{MODEL_PARAMS_DIR}' not found.")
63
+
64
+ available_model_names = sorted(list(model_hyperparams_catalogue.keys()))
65
+ if not available_model_names:
66
+ print("WARNING: No model configurations loaded. The application might not function correctly.")
67
+
68
+ # --- Data Fetching and Caching Logic ---
69
+ def load_data_cache():
70
+ if os.path.exists(DATA_CACHE_FILE):
71
+ try:
72
+ with open(DATA_CACHE_FILE, 'r') as f:
73
+ return json.load(f)
74
+ except json.JSONDecodeError:
75
+ print(f"Cache file {DATA_CACHE_FILE} is corrupted. Starting with an empty cache.")
76
+ return {}
77
+ except Exception as e:
78
+ print(f"Error loading cache file {DATA_CACHE_FILE}: {e}. Starting with an empty cache.")
79
+ return {}
80
+ return {}
81
+
82
+ def save_data_cache(cache):
83
+ try:
84
+ with open(DATA_CACHE_FILE, 'w') as f:
85
+ json.dump(cache, f, indent=4)
86
+ except Exception as e:
87
+ print(f"Error saving cache file {DATA_CACHE_FILE}: {e}")
88
+
89
+ def get_timeseries_data_from_alphavantage(ticker_symbol):
90
+ """
91
+ Fetches time series data from Alpha Vantage for a given ticker symbol.
92
+ Returns a pandas DataFrame with 'ds' and 'y' columns, or None on error.
93
+ """
94
+ if not ALPHAVANTAGE_API_KEY:
95
+ raise ValueError("Alpha Vantage API key is not configured.")
96
+
97
+ print(f"Attempting to fetch new data for {ticker_symbol} from Alpha Vantage...")
98
+ ts = TimeSeries(key=ALPHAVANTAGE_API_KEY, output_format='pandas')
99
+ try:
100
+ # Get daily adjusted prices for robustness (handles splits/dividends)
101
+ # 'full' gives up to 20+ years of data. 'compact' gives last 100 days.
102
+ # Prophet generally benefits from at least 1-2 years of data.
103
+ data_av, meta_data = ts.get_daily_adjusted(symbol=ticker_symbol, outputsize='full')
104
+
105
+ # Process Alpha Vantage DataFrame:
106
+ # 1. Sort by date (Alpha Vantage usually returns newest first)
107
+ data_av = data_av.sort_index(ascending=True)
108
+ # 2. Rename date index to 'ds' and the chosen price column to 'y'
109
+ # '5. adjusted close' is generally preferred.
110
+ df_prophet = data_av[['5. adjusted close']].reset_index()
111
+ df_prophet.rename(columns={'date': 'ds', '5. adjusted close': 'y'}, inplace=True)
112
+ # 3. Ensure 'ds' is datetime
113
+ df_prophet['ds'] = pd.to_datetime(df_prophet['ds'])
114
+ # 4. Ensure 'y' is numeric
115
+ df_prophet['y'] = pd.to_numeric(df_prophet['y'], errors='coerce')
116
+ df_prophet.dropna(subset=['y'], inplace=True) # Remove rows where y could not be coerced
117
+
118
+ if df_prophet.empty:
119
+ print(f"No valid data returned from Alpha Vantage for {ticker_symbol} after processing.")
120
+ return None
121
+
122
+ print(f"Successfully fetched and processed {len(df_prophet)} data points for {ticker_symbol}.")
123
+ return df_prophet[['ds', 'y']]
124
+
125
+ except Exception as e:
126
+ print(f"Alpha Vantage API Error for {ticker_symbol}: {type(e).__name__} - {e}")
127
+ # Common issues: Invalid API key, rate limit exceeded, ticker not found.
128
+ if "Invalid API call" in str(e) or "does not exist" in str(e):
129
+ print(f"Ticker {ticker_symbol} might not be valid or API call issue.")
130
+ elif "rate limit" in str(e).lower():
131
+ print("Alpha Vantage rate limit likely exceeded.")
132
+ return None
133
+
134
+ def get_and_cache_data(ticker_symbol, min_history_days=730): # Need enough history for Prophet
135
+ cache = load_data_cache()
136
+ today_str = datetime.now().strftime("%Y-%m-%d")
137
+
138
+ # Check cache first
139
+ if ticker_symbol in cache and cache[ticker_symbol].get("date_fetched") == today_str:
140
+ print(f"Using cached data for {ticker_symbol} from {today_str}")
141
+ try:
142
+ # Data stored as list of dicts; convert back to DataFrame
143
+ df_data = pd.DataFrame(cache[ticker_symbol]["data"])
144
+ df_data['ds'] = pd.to_datetime(df_data['ds'])
145
+ return df_data
146
+ except Exception as e:
147
+ print(f"Error loading data from cache for {ticker_symbol}: {e}. Will try fetching.")
148
+
149
+ # If not in cache or stale, fetch from Alpha Vantage
150
+ df_new_data = get_timeseries_data_from_alphavantage(ticker_symbol)
151
+
152
+ if df_new_data is not None and not df_new_data.empty:
153
+ if len(df_new_data) < min_history_days / 2: # Arbitrary check for too little history
154
+ print(f"WARNING: Fetched data for {ticker_symbol} is very short ({len(df_new_data)} days). Forecast quality may be poor.")
155
+
156
+ # Update cache
157
+ # Convert 'ds' to string for JSON serialization
158
+ data_to_cache = df_new_data.copy()
159
+ data_to_cache['ds'] = data_to_cache['ds'].dt.strftime('%Y-%m-%d')
160
+ cache[ticker_symbol] = {
161
+ "date_fetched": today_str,
162
+ "data": data_to_cache.to_dict(orient='records')
163
+ }
164
+ save_data_cache(cache)
165
+ return df_new_data # Return with 'ds' as datetime
166
+ else:
167
+ # If fetching failed, check if older cache data exists and return that with a warning
168
+ if ticker_symbol in cache and "data" in cache[ticker_symbol]:
169
+ print(f"Fetching new data failed for {ticker_symbol}. Using older cached data.")
170
+ df_data = pd.DataFrame(cache[ticker_symbol]["data"])
171
+ df_data['ds'] = pd.to_datetime(df_data['ds'])
172
+ return df_data # This data might be stale
173
+ print(f"Failed to fetch or find any cached data for {ticker_symbol}.")
174
+ return None
175
+
176
+
177
+ # --- Main Prediction Function (Re-fitting Prophet on demand) ---
178
+ def predict_dynamic_forecast(ticker_selection, forecast_periods_str):
179
+ status_message = ""
180
+ if not ALPHAVANTAGE_API_KEY:
181
+ return "ERROR: Alpha Vantage API Key not configured in Space Secrets.", None
182
+ if not ticker_selection:
183
+ return "Please select a ticker.", None
184
+
185
+ try:
186
+ forecast_periods = int(forecast_periods_str)
187
+ if forecast_periods <= 0:
188
+ return "Forecast periods must be a positive number.", None
189
+ except ValueError:
190
+ return "Invalid number for forecast periods.", None
191
+
192
+ hyperparams = model_hyperparams_catalogue.get(ticker_selection)
193
+ if not hyperparams: # Should not happen if dropdown is populated correctly
194
+ return f"Configuration for '{ticker_selection}' not found.", None
195
+
196
+ try:
197
+ status_message += f"Fetching/loading data for {ticker_selection}...\n"
198
+ # Prophet generally needs at least a year or two of data.
199
+ # Alpha Vantage 'full' outputsize should provide this.
200
+ historical_df = get_and_cache_data(ticker_selection, min_history_days=365 * 2)
201
+
202
+ if historical_df is None or historical_df.empty:
203
+ status_message += f"Failed to retrieve sufficient historical data for {ticker_selection}."
204
+ return status_message, None
205
+
206
+ if len(historical_df) < 30: # Prophet needs some minimal data
207
+ status_message += f"Historical data for {ticker_selection} is too short ({len(historical_df)} points) to make a forecast."
208
+ return status_message, None
209
+
210
+ status_message += f"Data loaded. Preprocessing for Prophet (log transform 'y')...\n"
211
+ fit_df = historical_df.copy()
212
+ # IMPORTANT: Log transform 'y' as done during original training
213
+ fit_df['y'] = np.log(fit_df['y'])
214
+ # Handle potential -inf/inf/NaN from log(0) or log(negative_value)
215
+ fit_df.replace([np.inf, -np.inf], np.nan, inplace=True)
216
+ fit_df['y'] = fit_df['y'].ffill().bfill() # Forward fill then backward fill
217
+
218
+ if fit_df['y'].isnull().any():
219
+ status_message += f"NaNs remain in log-transformed 'y' for {ticker_selection} after fill. Cannot fit model."
220
+ return status_message, None
221
+
222
+ status_message += f"Fitting Prophet model for {ticker_selection} with latest data...\n"
223
+ model = Prophet(**hyperparams)
224
+ model.fit(fit_df[['ds', 'y']]) # Fit on 'ds' and log-transformed 'y'
225
+
226
+ status_message += f"Generating forecast for {forecast_periods} periods...\n"
227
+ future_df = model.make_future_dataframe(periods=forecast_periods, freq='D') # Daily frequency
228
+ forecast_log_scale = model.predict(future_df)
229
+
230
+ # IMPORTANT: Inverse transform (exponentiate) predictions
231
+ output_df = forecast_log_scale[['ds']].copy()
232
+ output_df['Predicted Price (yhat)'] = np.exp(forecast_log_scale['yhat'])
233
+ output_df['Lower Bound (yhat_lower)'] = np.exp(forecast_log_scale['yhat_lower'])
234
+ output_df['Upper Bound (yhat_upper)'] = np.exp(forecast_log_scale['yhat_upper'])
235
+
236
+ # Return only the future forecast part
237
+ final_forecast_df = output_df.tail(forecast_periods).reset_index(drop=True)
238
+
239
+ status_message += "Forecast generated successfully."
240
+ return status_message, final_forecast_df
241
+
242
+ except Exception as e:
243
+ error_full_message = f"ERROR during prediction for {ticker_selection}: {type(e).__name__} - {str(e)}"
244
+ print(error_full_message) # Log to Hugging Face Space console for debugging
245
+ # Also include parts of it in status_message for the user
246
+ status_message += f"\nAn error occurred: {type(e).__name__}. Check Space logs for details."
247
+ return status_message, None
248
+
249
+ # --- Gradio Interface Definition ---
250
+ with gr.Blocks(css="footer {visibility: hidden}", title="Stock/Commodity Forecaster") as iface:
251
+ gr.Markdown("# Stock & Commodity Price Forecaster")
252
+ gr.Markdown(
253
+ "This tool fetches the latest market data using the Alpha Vantage API, "
254
+ "re-fits a Prophet time series model on-the-fly using pre-defined hyperparameters, "
255
+ "and generates a future price forecast."
256
+ "\n\n**Note:** Forecasts are for informational purposes only and not financial advice. "
257
+ "Data fetching may be slow on the first request for a ticker each day."
258
+ )
259
+ if not ALPHAVANTAGE_API_KEY:
260
+ gr.Markdown("<h3 style='color:red;'>WARNING: Alpha Vantage API Key is not configured in Space Secrets. Data fetching will fail.</h3>")
261
+
262
+ with gr.Row():
263
+ with gr.Column(scale=1):
264
+ ticker_dropdown = gr.Dropdown(
265
+ choices=available_model_names,
266
+ label="Select Ticker Symbol",
267
+ info="Choose the stock/commodity to forecast."
268
+ )
269
+ periods_input = gr.Number(
270
+ value=30,
271
+ label="Forecast Periods (Days)",
272
+ minimum=1,
273
+ maximum=365 * 2, # Max 2 years forecast
274
+ step=1,
275
+ info="Number of future days to predict."
276
+ )
277
+ predict_button = gr.Button("Generate Forecast", variant="primary")
278
+
279
+ with gr.Column(scale=3):
280
+ status_textbox = gr.Textbox(
281
+ label="Process Status & Logs",
282
+ lines=6,
283
+ interactive=False,
284
+ placeholder="Status messages will appear here..."
285
+ )
286
+
287
+ gr.Markdown("## Forecast Results")
288
+ forecast_output_table = gr.DataFrame(
289
+ label="Price Forecast Data",
290
+ headers=['Date (ds)', 'Predicted Price (yhat)', 'Lower Bound', 'Upper Bound'] # Match output_df columns
291
+ )
292
+
293
+ predict_button.click(
294
+ fn=predict_dynamic_forecast,
295
+ inputs=[ticker_dropdown, periods_input],
296
+ outputs=[status_textbox, forecast_output_table]
297
+ )
298
+
299
+ gr.Markdown("---")
300
+ gr.Markdown(
301
+ "**How it works:** Models are based on Facebook's Prophet. Hyperparameters are pre-set. "
302
+ "Historical data is log-transformed before fitting. Predictions are exponentiated back. "
303
+ "Data is cached daily to minimize Alpha Vantage API calls."
304
+ )
305
+
306
+ # --- Launch the Gradio App ---
307
+ if __name__ == "__main__":
308
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ prophet
3
+ pandas
4
+ numpy
5
+ alpha-vantage