Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	change conversion
Browse files
    	
        app.py
    CHANGED
    
    | @@ -26,10 +26,11 @@ def raw_preds_to_df(raw, quantiles = None): | |
| 26 | 
             
                in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
         | 
| 27 | 
             
                pred_idx the index of the predicted date i.e. time_idx + h - 1
         | 
| 28 | 
             
                """
         | 
| 29 | 
            -
                index = raw | 
| 30 | 
            -
                 | 
| 31 | 
            -
                 | 
| 32 | 
            -
                 | 
|  | |
| 33 | 
             
                preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
         | 
| 34 | 
             
                preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
         | 
| 35 | 
             
                preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
         | 
| @@ -39,7 +40,7 @@ def raw_preds_to_df(raw, quantiles = None): | |
| 39 |  | 
| 40 | 
             
                preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
         | 
| 41 | 
             
                return preds_df
         | 
| 42 | 
            -
             | 
| 43 | 
             
            def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
         | 
| 44 | 
             
                if rain != "Default":
         | 
| 45 | 
             
                    df["MTXWTH_Day_precip"] = mapping[rain]
         | 
| @@ -96,7 +97,7 @@ def generate_plot(df): | |
| 96 |  | 
| 97 | 
             
            @st.cache_data
         | 
| 98 | 
             
            def load_data():
         | 
| 99 | 
            -
                with open('data/ | 
| 100 | 
             
                    parameters = pickle.load(f)
         | 
| 101 | 
             
                df = pd.read_pickle('data/test_data.pkl')
         | 
| 102 | 
             
                df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
         | 
| @@ -104,7 +105,7 @@ def load_data(): | |
| 104 |  | 
| 105 | 
             
            @st.cache_resource
         | 
| 106 | 
             
            def init_model():
         | 
| 107 | 
            -
                model = TemporalFusionTransformer.load_from_checkpoint('model/ | 
| 108 | 
             
                return model
         | 
| 109 |  | 
| 110 | 
             
            def main():
         | 
|  | |
| 26 | 
             
                in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
         | 
| 27 | 
             
                pred_idx the index of the predicted date i.e. time_idx + h - 1
         | 
| 28 | 
             
                """
         | 
| 29 | 
            +
                index = raw[2]
         | 
| 30 | 
            +
                output = raw[0]
         | 
| 31 | 
            +
                preds = output.prediction
         | 
| 32 | 
            +
                dec_len = output.prediction.shape[1]
         | 
| 33 | 
            +
                n_quantiles = output.prediction.shape[-1]
         | 
| 34 | 
             
                preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
         | 
| 35 | 
             
                preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
         | 
| 36 | 
             
                preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
         | 
|  | |
| 40 |  | 
| 41 | 
             
                preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
         | 
| 42 | 
             
                return preds_df
         | 
| 43 | 
            +
                
         | 
| 44 | 
             
            def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping):
         | 
| 45 | 
             
                if rain != "Default":
         | 
| 46 | 
             
                    df["MTXWTH_Day_precip"] = mapping[rain]
         | 
|  | |
| 97 |  | 
| 98 | 
             
            @st.cache_data
         | 
| 99 | 
             
            def load_data():
         | 
| 100 | 
            +
                with open('data/parameters_q.pkl', 'rb') as f:
         | 
| 101 | 
             
                    parameters = pickle.load(f)
         | 
| 102 | 
             
                df = pd.read_pickle('data/test_data.pkl')
         | 
| 103 | 
             
                df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
         | 
|  | |
| 105 |  | 
| 106 | 
             
            @st.cache_resource
         | 
| 107 | 
             
            def init_model():
         | 
| 108 | 
            +
                model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check_q.ckpt', map_location=torch.device('cpu'))
         | 
| 109 | 
             
                return model
         | 
| 110 |  | 
| 111 | 
             
            def main():
         | 
