Spaces:
Runtime error
Runtime error
change conversion
Browse files
app.py
CHANGED
|
@@ -26,10 +26,11 @@ def raw_preds_to_df(raw, quantiles = None):
|
|
| 26 |
in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
|
| 27 |
pred_idx the index of the predicted date i.e. time_idx + h - 1
|
| 28 |
"""
|
| 29 |
-
index = raw
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
| 33 |
preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
|
| 34 |
preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
|
| 35 |
preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
|
|
@@ -39,7 +40,7 @@ def raw_preds_to_df(raw, quantiles = None):
|
|
| 39 |
|
| 40 |
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
|
| 41 |
return preds_df
|
| 42 |
-
|
| 43 |
def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
|
| 44 |
if rain != "Default":
|
| 45 |
df["MTXWTH_Day_precip"] = mapping[rain]
|
|
@@ -96,7 +97,7 @@ def generate_plot(df):
|
|
| 96 |
|
| 97 |
@st.cache_data
|
| 98 |
def load_data():
|
| 99 |
-
with open('data/
|
| 100 |
parameters = pickle.load(f)
|
| 101 |
df = pd.read_pickle('data/test_data.pkl')
|
| 102 |
df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
|
|
@@ -104,7 +105,7 @@ def load_data():
|
|
| 104 |
|
| 105 |
@st.cache_resource
|
| 106 |
def init_model():
|
| 107 |
-
model = TemporalFusionTransformer.load_from_checkpoint('model/
|
| 108 |
return model
|
| 109 |
|
| 110 |
def main():
|
|
|
|
| 26 |
in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
|
| 27 |
pred_idx the index of the predicted date i.e. time_idx + h - 1
|
| 28 |
"""
|
| 29 |
+
index = raw[2]
|
| 30 |
+
output = raw[0]
|
| 31 |
+
preds = output.prediction
|
| 32 |
+
dec_len = output.prediction.shape[1]
|
| 33 |
+
n_quantiles = output.prediction.shape[-1]
|
| 34 |
preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
|
| 35 |
preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
|
| 36 |
preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
|
|
|
|
| 40 |
|
| 41 |
preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
|
| 42 |
return preds_df
|
| 43 |
+
|
| 44 |
def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
|
| 45 |
if rain != "Default":
|
| 46 |
df["MTXWTH_Day_precip"] = mapping[rain]
|
|
|
|
| 97 |
|
| 98 |
@st.cache_data
|
| 99 |
def load_data():
|
| 100 |
+
with open('data/parameters_q.pkl', 'rb') as f:
|
| 101 |
parameters = pickle.load(f)
|
| 102 |
df = pd.read_pickle('data/test_data.pkl')
|
| 103 |
df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
|
|
|
|
| 105 |
|
| 106 |
@st.cache_resource
|
| 107 |
def init_model():
|
| 108 |
+
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check_q.ckpt', map_location=torch.device('cpu'))
|
| 109 |
return model
|
| 110 |
|
| 111 |
def main():
|