Spaces:
Build error
Build error
| import os | |
| import gradio as gr | |
| import torch | |
| from darts import TimeSeries, concatenate | |
| from darts.dataprocessing.transformers import Scaler | |
| from darts.utils.timeseries_generation import datetime_attribute_timeseries | |
| from darts.models.forecasting.tft_model import TFTModel | |
| from darts.metrics import mape | |
| from dateutil.relativedelta import relativedelta | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import logging | |
| logging.disable(logging.CRITICAL) | |
| import pandas as pd | |
| import numpy as np | |
| from typing import Any, List, Optional | |
| import plotly.graph_objects as go | |
| df_final = pd.read_csv('data/all_afghan.csv',parse_dates=['Date']) | |
| df_comtrade_flour = pd.read_csv('data/comtrade_flour.csv',parse_dates=['Date']) | |
| df_comtrade_grain = pd.read_csv('data/comtrade_grain.csv',parse_dates=['Date']) | |
| series = TimeSeries.from_dataframe(df_final, | |
| time_col='Date', | |
| value_cols=['price', 'usdprice', 'wheat_grain', 'exchange_rate','common_unit_price','black_sea'] | |
| ) | |
| six_months = df_final['Date'].max() + relativedelta(months=-6) | |
| data_series = series['common_unit_price'] | |
| train, val = data_series.split_after(six_months) | |
| transformer = Scaler() | |
| train_transformed = transformer.fit_transform(train) | |
| val_transformed = transformer.transform(val) | |
| series_transformed = transformer.transform(data_series) | |
| # create year, month and integer index covariate series | |
| covariates = datetime_attribute_timeseries(series_transformed, attribute="year", one_hot=False) | |
| covariates = covariates.stack( | |
| datetime_attribute_timeseries(series_transformed, attribute="month", one_hot=True) | |
| ) | |
| covariates = covariates.stack( | |
| TimeSeries.from_times_and_values( | |
| times=series_transformed.time_index, | |
| values=np.arange(len(series_transformed)), | |
| ) | |
| ) | |
| covariates = covariates.add_holidays(country_code="ES") | |
| covariates = covariates.astype(np.float32) | |
| scaler_covs = Scaler() | |
| cov_train, cov_val = covariates.split_after(six_months) | |
| cov_train = scaler_covs.fit_transform(cov_train) | |
| cov_val = scaler_covs.transform(cov_val) | |
| covariates_transformed = scaler_covs.transform(covariates) | |
| grain_series = series['wheat_grain'] | |
| grain_scaler = Scaler() | |
| grain_train, grain_val = grain_series.split_after(six_months) | |
| grain_train = grain_scaler.fit_transform(grain_train) | |
| grain_val = grain_scaler.transform(grain_val) | |
| grain_series_scaled = grain_scaler.transform(grain_series) | |
| pakistan_series = series["price"] | |
| pakistan_scaler = Scaler() | |
| pakistan_train, pakistan_val = pakistan_series.split_after(six_months) | |
| pakistan_train = pakistan_scaler.fit_transform(pakistan_train) | |
| pakistan_val = pakistan_scaler.transform(pakistan_val) | |
| pakistan_series_scaled = pakistan_scaler.transform(pakistan_series) | |
| usd_series = series['usdprice'] | |
| usd_scaler = Scaler() | |
| usd_train, usd_val = usd_series.split_after(six_months) | |
| usd_train = usd_scaler.fit_transform(usd_train) | |
| usd_val = usd_scaler.transform(usd_val) | |
| usd_series_scaled = usd_scaler.transform(usd_series) | |
| erate_series = series['exchange_rate'] | |
| erate_scaler = Scaler() | |
| erate_train, erate_val = erate_series.split_after(six_months) | |
| erate_train_transformed = erate_scaler.fit_transform(erate_train) | |
| erate_val_transformed = erate_scaler.transform(erate_val) | |
| erate_series_scaled = erate_scaler.transform(erate_series) | |
| black_sea = series['black_sea'] | |
| black_sea_scaler = Scaler() | |
| black_train,black_val = black_sea.split_after(six_months) | |
| black_train_transformed = black_sea_scaler.fit_transform(black_train) | |
| black_val_transformed = black_sea_scaler.transform(black_val) | |
| black_sea_series = black_sea_scaler.transform(black_sea) | |
| comtrade_flour_series = TimeSeries.from_dataframe(df_comtrade_flour, | |
| time_col="Date") | |
| comtrade_grain_series = TimeSeries.from_dataframe(df_comtrade_grain, | |
| time_col="Date") | |
| from darts import concatenate | |
| my_multivariate_series = concatenate( | |
| [ | |
| grain_series_scaled, | |
| pakistan_series_scaled, | |
| # usd_series_scaled, | |
| erate_series_scaled, | |
| black_sea_series, | |
| comtrade_flour_series, | |
| comtrade_grain_series, | |
| covariates_transformed, | |
| ], | |
| axis=1) | |
| multivariate_series_train = concatenate( | |
| [ | |
| grain_train, | |
| pakistan_train, | |
| # usd_train, | |
| erate_train, | |
| #russian_train_transformed, | |
| # black_train_transformed, | |
| cov_train, | |
| ], | |
| axis=1) | |
| class FlaggingHandler(gr.FlaggingCallback): | |
| def __init__(self): | |
| self._csv_logger = gr.CSVLogger() | |
| def setup(self, components: List[gr.components.Component], flagging_dir: str): | |
| """Called by Gradio at the beginning of the `Interface.launch()` method. | |
| Parameters: | |
| components: Set of components that will provide flagged data. | |
| flagging_dir: A string, typically containing the path to the directory where | |
| the flagging file should be storied (provided as an argument to Interface.__init__()). | |
| """ | |
| self.components = components | |
| self._csv_logger.setup(components=components, flagging_dir=flagging_dir) | |
| def flag( | |
| self, | |
| flag_data: List[Any], | |
| flag_option: Optional[str] = None, | |
| # flag_index: Optional[int] = None, | |
| username: Optional[str] = None, | |
| ) -> int: | |
| """Called by Gradio whenver one of the <flag> buttons is clicked. | |
| Parameters: | |
| interface: The Interface object that is being used to launch the flagging interface. | |
| flag_data: The data to be flagged. | |
| flag_option (optional): In the case that flagging_options are provided, the flag option that is being used. | |
| flag_index (optional): The index of the sample that is being flagged. | |
| username (optional): The username of the user that is flagging the data, if logged in. | |
| Returns: | |
| (int) The total number of samples that have been flagged. | |
| """ | |
| for item in flag_data: | |
| print(f"Flagging: {item}") | |
| if flag_option: | |
| print(f"Flag option: {flag_option}") | |
| # if flag_index: | |
| # print(f"Flag index: {flag_index}") | |
| flagged_count = self._csv_logger.flag( | |
| flag_data=flag_data, | |
| flag_option=flag_option, | |
| # flag_index=flag_index, | |
| # username=username, | |
| ) | |
| return flagged_count | |
| def get_forecast(period_: str, pred_model: str): | |
| # Let the prediction service do its magic. | |
| period = int(period_[0]) | |
| afgh_model = TFTModel.load("Afghan_w_blacksea_allcomtrade_aug31.pt",map_location=torch.device('cpu')) | |
| ### afgh model### | |
| pred_series = afgh_model.predict(n=period,num_samples=1) | |
| preds = transformer.inverse_transform(pred_series) | |
| # creating a Dataframe | |
| df_= preds.pd_dataframe() | |
| df_.rename(columns={'common_unit_price': 'Wheat_Forecast'},inplace=True) | |
| # error intervals: | |
| # Calculate the 90% and 110% forecast values | |
| forecast_90 = preds * 0.9 | |
| forecast_110 = preds * 1.1 | |
| df_90 = forecast_90.pd_dataframe() | |
| df_90.rename(columns={'common_unit_price': 'Lower_Limit'},inplace=True) | |
| df_110 = forecast_110.pd_dataframe() | |
| df_110.rename(columns={'common_unit_price': 'Upper_Limit'},inplace=True) | |
| merged_df = pd.merge(df_90,df_, on=['Date']).merge(df_110, on=['Date']) | |
| merged_df = merged_df.reset_index() | |
| merged_df.to_csv('data/afghan_wheatfcasts.csv',index=False) | |
| start=pd.Timestamp("20180131") | |
| backtest_series_ = afgh_model.historical_forecasts( | |
| series_transformed, | |
| past_covariates=my_multivariate_series, | |
| start=start, | |
| forecast_horizon=period, | |
| retrain=False, | |
| verbose=False, | |
| ) | |
| series_time = series_transformed[-len(backtest_series_):].time_index | |
| series_vals = (transformer.inverse_transform(series_transformed[-len(backtest_series_):])).values() | |
| df_series = pd.DataFrame(data={'Date': series_time, 'actual_prices': series_vals.ravel() }) | |
| vals = (transformer.inverse_transform(backtest_series_)).values() | |
| df_backtest = pd.DataFrame(data={'Date': backtest_series_.time_index, 'historical_forecasts': vals.ravel() }) | |
| # df_backtest_wheat = pd.DataFrame(data={'Date': backtest_series_.time_index, 'historical_wheat_forecasts': vals.ravel() }) | |
| df_wheat_output = pd.merge(df_series,df_backtest[['Date',"historical_forecasts"]],on=['Date'],how='left') | |
| df_wheat_output.to_csv('data/aghanwheat_allhistorical.csv',index=False) | |
| # Create figure | |
| fig = go.Figure() | |
| fig.add_trace( | |
| go.Scatter( | |
| x=list(df_backtest.Date), | |
| y=list(df_backtest.historical_forecasts), | |
| name='historical forecasts' | |
| # x=list(df.Date), y=list(df.High) | |
| )) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=list(df_series.Date), | |
| y=list(df_series.actual_prices), | |
| name="actual prices", | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x = list(merged_df.Date), | |
| y=list(merged_df.Upper_Limit), | |
| name="Upper limit" | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x = list(merged_df.Date), | |
| y=list(merged_df.Lower_Limit), | |
| name="Lower limit" | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x = list(merged_df.Date), | |
| y=list(merged_df.Wheat_Forecast), | |
| name=" Wheat Forecast" | |
| )) | |
| # Set title | |
| fig.update_layout( | |
| title_text=f"\n Mean Absolute Percentage Error {mape(transformer.inverse_transform(series_transformed), transformer.inverse_transform(backtest_series_)):.2f}%" | |
| ) | |
| # Add range slider | |
| fig.update_layout( | |
| xaxis=dict( | |
| rangeselector=dict( | |
| buttons=list([ | |
| dict(count=1, | |
| label="1m", | |
| step="month", | |
| stepmode="backward"), | |
| dict(count=6, | |
| label="6m", | |
| step="month", | |
| stepmode="todate"), | |
| dict(count=1, | |
| label="YTD", | |
| step="year", | |
| stepmode="todate"), | |
| # dict(count=1, | |
| # label="1y", | |
| # step="year", | |
| # stepmode="backward"), | |
| # dict(step="all") | |
| ]) | |
| ), | |
| rangeslider=dict( | |
| visible=True | |
| ), | |
| type="date" | |
| ) | |
| ) | |
| return merged_df,fig | |
| def main(): | |
| flagging_handler = FlaggingHandler() | |
| # example_url = "" # noqa: E501 | |
| with gr.Blocks() as iface: | |
| gr.Markdown( | |
| """ | |
| **Timeseries Forecasting model Temporal Fusion Transformer(TFT) built on Darts library**. | |
| """) | |
| commodity = gr.Radio(["Wheat Price Forecasting"],label="Commodity to Forecast") | |
| period = gr.Radio(['3 months',"6 months"],label="Forecast horizon") | |
| # with gr.Row(): | |
| # lib = gr.Dropdown(["pandas", "scikit-learn", "torch", "prophet"], label="Library", value="torch") | |
| # time = gr.Dropdown(["3 months", "6 months",], label="Downloads over the last...", value="6 months") | |
| with gr.Row(): | |
| btn = gr.Button("Forecast.") | |
| feedback = gr.Textbox(label="Give feedback") | |
| gr.CSVLogger() | |
| data_points = gr.Textbox(label=f"Forecast values. Lower and upper values include a 10% error rate") | |
| plt = gr.Plot(label="Backtesting plot, from 2018").style() | |
| btn.click( | |
| get_forecast, | |
| inputs=[period,commodity], | |
| outputs = [data_points,plt] | |
| ) | |
| with gr.Row(): | |
| btn_incorrect = gr.Button("Flag as incorrect") | |
| btn_other = gr.Button("Flag as other") | |
| flagging_handler.setup( | |
| components=[commodity, period], | |
| flagging_dir="data/flagged", | |
| ) | |
| btn_incorrect.click( | |
| lambda *args: flagging_handler.flag( | |
| flag_data=args, flag_option="Incorrect" | |
| ), | |
| [commodity, data_points, period,feedback], | |
| None, | |
| preprocess=False, | |
| ) | |
| btn_other.click( | |
| lambda *args: flagging_handler.flag(flag_data=args, flag_option="Other"), | |
| [commodity, data_points, period,feedback], | |
| None, | |
| preprocess=False, | |
| ) | |
| iface.launch(debug=True, inline=False) | |
| main() |