youneslaaroussi commited on
Commit
82268ea
·
verified ·
1 Parent(s): 4251cab

Upload modeling_timellm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_timellm.py +101 -0
modeling_timellm.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TimeLLM Model for Supply Chain Demand Forecasting
3
+ Based on the TimeLLM framework: https://github.com/KimMeen/Time-LLM
4
+ """
5
+
6
+ import torch
7
+ import json
8
+ import numpy as np
9
+ from typing import Dict, List, Any
10
+
11
+ class TimeLLMForecaster:
12
+ """
13
+ TimeLLM model for supply chain demand forecasting.
14
+
15
+ This model was trained on AWS SageMaker using the TimeLLM framework
16
+ to forecast demand patterns in supply chain data.
17
+ """
18
+
19
+ def __init__(self, model_path: str = "model.pth", config_path: str = "config.json"):
20
+ """
21
+ Initialize the TimeLLM forecaster.
22
+
23
+ Args:
24
+ model_path: Path to the trained model weights
25
+ config_path: Path to the model configuration
26
+ """
27
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ # Load configuration
30
+ with open(config_path, 'r') as f:
31
+ self.config = json.load(f)
32
+
33
+ # Load model weights
34
+ self.model_state = torch.load(model_path, map_location=self.device)
35
+
36
+ print(f"TimeLLM model loaded successfully on {self.device}")
37
+ print(f"Model configuration: {self.config['model_name']}")
38
+ print(f"Sequence length: {self.config['seq_len']}")
39
+ print(f"Prediction length: {self.config['pred_len']}")
40
+ print(f"Features: {self.config['enc_in']}")
41
+
42
+ def forecast(self,
43
+ historical_data: np.ndarray,
44
+ time_features: np.ndarray) -> np.ndarray:
45
+ """
46
+ Generate demand forecasts.
47
+
48
+ Args:
49
+ historical_data: Historical time series data (seq_len, n_features)
50
+ time_features: Time-based features (seq_len, time_features)
51
+
52
+ Returns:
53
+ Forecasted values (pred_len, n_features)
54
+ """
55
+ # This is a placeholder - actual inference would require
56
+ # the full TimeLLM model implementation
57
+ print("Forecasting with TimeLLM...")
58
+ print(f"Input shape: {historical_data.shape}")
59
+ print(f"Time features shape: {time_features.shape}")
60
+
61
+ # Return dummy forecast for demonstration
62
+ pred_len = self.config['pred_len']
63
+ n_features = self.config['c_out']
64
+
65
+ return np.random.randn(pred_len, n_features)
66
+
67
+ def get_model_info(self) -> Dict[str, Any]:
68
+ """Get model information and training details."""
69
+ return {
70
+ "model_name": self.config['model_name'],
71
+ "base_model": self.config['base_model'],
72
+ "training_platform": self.config['trained_on'],
73
+ "training_job": self.config['training_job'],
74
+ "training_time": self.config['training_time'],
75
+ "instance_type": self.config['instance_type'],
76
+ "seq_len": self.config['seq_len'],
77
+ "pred_len": self.config['pred_len'],
78
+ "features": self.config['enc_in']
79
+ }
80
+
81
+ # Example usage
82
+ if __name__ == "__main__":
83
+ # Initialize the forecaster
84
+ forecaster = TimeLLMForecaster()
85
+
86
+ # Print model information
87
+ info = forecaster.get_model_info()
88
+ print("\nModel Information:")
89
+ for key, value in info.items():
90
+ print(f" {key}: {value}")
91
+
92
+ # Example forecast (with dummy data)
93
+ seq_len = 96
94
+ n_features = 14
95
+ time_features = 3
96
+
97
+ historical_data = np.random.randn(seq_len, n_features)
98
+ time_features_data = np.random.randn(seq_len, time_features)
99
+
100
+ forecast = forecaster.forecast(historical_data, time_features_data)
101
+ print(f"\nForecast shape: {forecast.shape}")