xue wang commited on
Commit
1373f4d
·
verified ·
1 Parent(s): f8e09fb

Upload 14 files

Browse files
Long_Term_Forecasting/data_loader.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ from torch.utils.data import Dataset
5
+ from sklearn.preprocessing import StandardScaler
6
+
7
+
8
+
9
+ import warnings
10
+
11
+ warnings.filterwarnings('ignore')
12
+
13
+
14
+ class Dataset_ETT_hour(Dataset):
15
+ def __init__(self, root_path, flag='train', size=None,
16
+ features='S', data_path='ETTh1.csv',
17
+ target='OT', scale=True, timeenc=0, freq='h', seasonal_patterns=None):
18
+ # size [seq_len, label_len, pred_len]
19
+ # info
20
+ if size == None:
21
+ self.seq_len = 24 * 4 * 4
22
+ self.label_len = 24 * 4
23
+ self.pred_len = 24 * 4
24
+ else:
25
+ self.seq_len = size[0]
26
+ self.label_len = size[1]
27
+ self.pred_len = size[2]
28
+ # init
29
+ assert flag in ['train', 'test', 'val']
30
+ type_map = {'train': 0, 'val': 1, 'test': 2}
31
+ self.set_type = type_map[flag]
32
+
33
+ self.features = features
34
+ self.target = target
35
+ self.scale = scale
36
+ self.timeenc = timeenc
37
+ self.freq = freq
38
+
39
+ self.root_path = root_path
40
+ self.data_path = data_path
41
+ self.raw_start = 96
42
+ self.raw_last= 720
43
+ self.seperate = [96, 192, 336, 720]
44
+ self.__read_data__()
45
+
46
+ def __read_data__(self):
47
+ self.scaler = StandardScaler()
48
+ df_raw = pd.read_csv(os.path.join(self.root_path,
49
+ self.data_path))
50
+
51
+ border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
52
+ border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
53
+ border1 = border1s[self.set_type]
54
+ border2 = border2s[self.set_type]
55
+
56
+ if self.features == 'M' or self.features == 'MS':
57
+ cols_data = df_raw.columns[1:]
58
+ df_data = df_raw[cols_data]
59
+ elif self.features == 'S':
60
+ df_data = df_raw[[self.target]]
61
+
62
+ if self.scale:
63
+ train_data = df_data[border1s[0]:border2s[0]]
64
+ self.scaler.fit(train_data.values)
65
+ data = self.scaler.transform(df_data.values)
66
+ else:
67
+ data = df_data.values
68
+
69
+ df_stamp = df_raw[['date']][border1:border2]
70
+ df_stamp['date'] = pd.to_datetime(df_stamp.date)
71
+ if self.timeenc == 0:
72
+ df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
73
+ df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
74
+ df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
75
+ df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
76
+
77
+
78
+
79
+ self.data_x = data[border1:border2]
80
+ self.data_y = data[border1:border2]
81
+
82
+
83
+ def __getitem__(self, index):
84
+ s_begin = index
85
+ s_end = s_begin + self.seq_len
86
+ r_begin = s_end - self.label_len
87
+ r_end = r_begin + self.label_len + self.pred_len
88
+
89
+ seq_x = self.data_x[s_begin:s_end]
90
+ seq_y = self.data_y[r_begin:r_end]
91
+ res = []
92
+ for end in self.seperate:
93
+ r_end = r_begin + self.label_len + end
94
+
95
+ if r_end <= self.data_y.shape[0]:
96
+ res.append(self.data_y[r_begin:r_end])
97
+ else:
98
+ res.append(np.full((r_end - r_begin,self.data_y.shape[-1]), np.nan))
99
+
100
+ return seq_x, seq_y, res[0], res[1], res[2], res[3]
101
+
102
+ def __len__(self):
103
+ return len(self.data_x) - self.seq_len - self.raw_start + 1
104
+
105
+ def inverse_transform(self, data):
106
+ return self.scaler.inverse_transform(data)
107
+
108
+
109
+ class Dataset_ETT_minute(Dataset):
110
+ def __init__(self, root_path, flag='train', size=None,
111
+ features='S', data_path='ETTm1.csv',
112
+ target='OT', scale=True, timeenc=0, freq='t', seasonal_patterns=None):
113
+ # size [seq_len, label_len, pred_len]
114
+ # info
115
+ if size == None:
116
+ self.seq_len = 24 * 4 * 4
117
+ self.label_len = 24 * 4
118
+ self.pred_len = 24 * 4
119
+ else:
120
+ self.seq_len = size[0]
121
+ self.label_len = size[1]
122
+ self.pred_len = size[2]
123
+ # init
124
+ assert flag in ['train', 'test', 'val']
125
+ type_map = {'train': 0, 'val': 1, 'test': 2}
126
+ self.set_type = type_map[flag]
127
+
128
+ self.features = features
129
+ self.target = target
130
+ self.scale = scale
131
+ self.timeenc = timeenc
132
+ self.freq = freq
133
+
134
+ self.root_path = root_path
135
+ self.data_path = data_path
136
+
137
+ self.raw_start = 96
138
+ self.raw_last= 720
139
+ self.seperate = [96, 192, 336, 720]
140
+ self.__read_data__()
141
+
142
+ def __read_data__(self):
143
+ self.scaler = StandardScaler()
144
+ df_raw = pd.read_csv(os.path.join(self.root_path,
145
+ self.data_path))
146
+
147
+ border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len]
148
+ border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]
149
+ border1 = border1s[self.set_type]
150
+ border2 = border2s[self.set_type]
151
+
152
+ if self.features == 'M' or self.features == 'MS':
153
+ cols_data = df_raw.columns[1:]
154
+ df_data = df_raw[cols_data]
155
+ elif self.features == 'S':
156
+ df_data = df_raw[[self.target]]
157
+
158
+ if self.scale:
159
+ train_data = df_data[border1s[0]:border2s[0]]
160
+ self.scaler.fit(train_data.values)
161
+ data = self.scaler.transform(df_data.values)
162
+ else:
163
+ data = df_data.values
164
+
165
+ df_stamp = df_raw[['date']][border1:border2]
166
+ df_stamp['date'] = pd.to_datetime(df_stamp.date)
167
+ if self.timeenc == 0:
168
+ df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
169
+ df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
170
+ df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
171
+ df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
172
+ df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
173
+ df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
174
+
175
+
176
+
177
+ self.data_x = data[border1:border2]
178
+ self.data_y = data[border1:border2]
179
+
180
+ def __getitem__(self, index):
181
+
182
+ s_begin = index
183
+ s_end = s_begin + self.seq_len
184
+ r_begin = s_end - self.label_len
185
+ r_end = r_begin + self.label_len + self.pred_len
186
+
187
+ seq_x = self.data_x[s_begin:s_end]
188
+ seq_y = self.data_y[r_begin:r_end]
189
+ res = []
190
+ for end in self.seperate:
191
+ r_end = r_begin + self.label_len + end
192
+
193
+ if r_end <= self.data_y.shape[0]:
194
+ res.append(self.data_y[r_begin:r_end])
195
+ else:
196
+ res.append(np.full((r_end - r_begin,self.data_y.shape[-1]), np.nan))
197
+
198
+ return seq_x, seq_y, res[0], res[1], res[2], res[3]
199
+
200
+ def __len__(self):
201
+ return len(self.data_x) - self.seq_len - self.raw_start + 1
202
+
203
+ def inverse_transform(self, data):
204
+ return self.scaler.inverse_transform(data)
205
+
206
+
207
+ class Dataset_Custom(Dataset):
208
+ def __init__(self, root_path, flag='train', size=None,
209
+ features='S', data_path='ETTh1.csv',
210
+ target='OT', scale=True, timeenc=0, freq='h', seasonal_patterns=None):
211
+ # size [seq_len, label_len, pred_len]
212
+ # info
213
+ if size == None:
214
+ self.seq_len = 24 * 4 * 4
215
+ self.label_len = 24 * 4
216
+ self.pred_len = 24 * 4
217
+ else:
218
+ self.seq_len = size[0]
219
+ self.label_len = size[1]
220
+ self.pred_len = size[2]
221
+ # init
222
+ assert flag in ['train', 'test', 'val']
223
+ type_map = {'train': 0, 'val': 1, 'test': 2}
224
+ self.set_type = type_map[flag]
225
+
226
+ self.features = features
227
+ self.target = target
228
+ self.scale = scale
229
+ self.timeenc = timeenc
230
+ self.freq = freq
231
+
232
+ self.root_path = root_path
233
+ self.data_path = data_path
234
+ self.raw_start = 96
235
+ self.raw_last= 720
236
+ self.seperate = [96, 192, 336, 720]
237
+ self.__read_data__()
238
+
239
+ def __read_data__(self):
240
+ self.scaler = StandardScaler()
241
+ df_raw = pd.read_csv(os.path.join(self.root_path,
242
+ self.data_path))
243
+
244
+ '''
245
+ df_raw.columns: ['date', ...(other features), target feature]
246
+ '''
247
+ cols = list(df_raw.columns)
248
+ cols.remove(self.target)
249
+ cols.remove('date')
250
+ df_raw = df_raw[['date'] + cols + [self.target]]
251
+ num_train = int(len(df_raw) * 0.7)
252
+ num_test = int(len(df_raw) * 0.2)
253
+ num_vali = len(df_raw) - num_train - num_test
254
+ border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
255
+ border2s = [num_train, num_train + num_vali, len(df_raw)]
256
+ border1 = border1s[self.set_type]
257
+ border2 = border2s[self.set_type]
258
+
259
+ if self.features == 'M' or self.features == 'MS':
260
+ cols_data = df_raw.columns[1:]
261
+ df_data = df_raw[cols_data]
262
+ elif self.features == 'S':
263
+ df_data = df_raw[[self.target]]
264
+
265
+ if self.scale:
266
+ train_data = df_data[border1s[0]:border2s[0]]
267
+ self.scaler.fit(train_data.values)
268
+ data = self.scaler.transform(df_data.values)
269
+ else:
270
+ data = df_data.values
271
+
272
+ df_stamp = df_raw[['date']][border1:border2]
273
+ df_stamp['date'] = pd.to_datetime(df_stamp.date)
274
+ if self.timeenc == 0:
275
+ df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
276
+ df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
277
+ df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
278
+ df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
279
+
280
+
281
+
282
+ self.data_x = data[border1:border2]
283
+ self.data_y = data[border1:border2]
284
+
285
+
286
+ def __getitem__(self, index):
287
+
288
+ s_begin = index
289
+ s_end = s_begin + self.seq_len
290
+ r_begin = s_end - self.label_len
291
+ r_end = r_begin + self.label_len + self.pred_len
292
+
293
+ seq_x = self.data_x[s_begin:s_end]
294
+ seq_y = self.data_y[r_begin:r_end]
295
+ res = []
296
+ for end in self.seperate:
297
+ r_end = r_begin + self.label_len + end
298
+
299
+ if r_end <= self.data_y.shape[0]:
300
+ res.append(self.data_y[r_begin:r_end])
301
+ else:
302
+ res.append(np.full((r_end - r_begin,self.data_y.shape[-1]), np.nan))
303
+
304
+ return seq_x, seq_y, res[0], res[1], res[2], res[3]
305
+
306
+ def __len__(self):
307
+ return len(self.data_x) - self.seq_len - self.raw_start + 1
308
+
309
+ def inverse_transform(self, data):
310
+ return self.scaler.inverse_transform(data)
311
+
312
+
313
+
314
+
Long_Term_Forecasting/dataset/ETT-small/ETTh1.csv ADDED
The diff for this file is too large to render. See raw diff
 
Long_Term_Forecasting/dataset/ETT-small/ETTh2.csv ADDED
The diff for this file is too large to render. See raw diff
 
Long_Term_Forecasting/dataset/ETT-small/ETTm1.csv ADDED
The diff for this file is too large to render. See raw diff
 
Long_Term_Forecasting/dataset/ETT-small/ETTm2.csv ADDED
The diff for this file is too large to render. See raw diff
 
Long_Term_Forecasting/dataset/weather/weather.csv ADDED
The diff for this file is too large to render. See raw diff
 
Long_Term_Forecasting/main.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ import datetime
4
+ import os
5
+ import sys
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.data import DataLoader
10
+
11
+ import random
12
+ import numpy as np
13
+
14
+ from einops import rearrange
15
+ import torch.distributed as dist
16
+ import torch.multiprocessing as mp
17
+
18
+
19
+ import lightning as L
20
+ from lightning.fabric.strategies import DDPStrategy
21
+ import glob
22
+ from data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom
23
+ from metrics import metric
24
+ from tqdm import tqdm
25
+
26
+ from transformers import AutoModelForCausalLM
27
+
28
+
29
+
30
+ data_dict = {
31
+ 'ETTh1': Dataset_ETT_hour,
32
+ 'ETTh2': Dataset_ETT_hour,
33
+ 'ETTm1': Dataset_ETT_minute,
34
+ 'ETTm2': Dataset_ETT_minute,
35
+ 'custom': Dataset_Custom,
36
+ }
37
+
38
+
39
+
40
+ def data_provider(data,
41
+ root_path,
42
+ data_path,
43
+ batch_size = 128,
44
+ seq_len = 96,
45
+ pred_len = 96,
46
+ flag= 'test',
47
+ dataset = None,
48
+ num_workers = 8,
49
+ seasonal_patterns = None,
50
+ target = 'OT',
51
+ features = 'M',
52
+ ):
53
+
54
+
55
+ Data = data_dict[data]
56
+ shuffle_flag = False
57
+ drop_last = False
58
+ batch_size = batch_size
59
+ data_set = Data(
60
+ root_path=root_path,
61
+ data_path=data_path,
62
+ flag=flag,
63
+ size=[seq_len, 0, pred_len],
64
+ features=features,
65
+ target=target,
66
+ timeenc=1,
67
+ freq='h',
68
+ seasonal_patterns=seasonal_patterns
69
+ )
70
+
71
+ data_loader = DataLoader(
72
+ data_set,
73
+ batch_size=batch_size,
74
+ shuffle=shuffle_flag,
75
+ num_workers=num_workers,
76
+ drop_last=drop_last,
77
+ pin_memory = True)
78
+ return data_set, data_loader
79
+
80
+ datasets_configs = {
81
+ 'ETTh1': {'data': 'ETTh1',
82
+ 'root_path': './Long_Term_Forecasting/dataset/ETT-small/',
83
+ 'data_path': 'ETTh1.csv',
84
+ 'batch_size':32,
85
+ 'best_other_mse':'0.390',
86
+ 'best_other_mae':'0.406',
87
+ },
88
+ 'ETTh2': {'data': 'ETTh2',
89
+ 'root_path': './Long_Term_Forecasting/dataset/ETT-small/',
90
+ 'data_path': 'ETTh2.csv',
91
+ 'batch_size':32,
92
+ 'best_other_mse':'0.330',
93
+ 'best_other_mae':'0.375',
94
+ },
95
+ 'ETTm1': {'data': 'ETTm1',
96
+ 'root_path': './Long_Term_Forecasting/dataset/ETT-small/',
97
+ 'data_path': 'ETTm1.csv',
98
+ 'batch_size':32,
99
+ 'best_other_mse':'0.351',
100
+ 'best_other_mae':'0.372',
101
+ },
102
+ 'ETTm2': {'data': 'ETTm2',
103
+ 'root_path': './Long_Term_Forecasting/dataset/ETT-small/',
104
+ 'data_path': 'ETTm2.csv',
105
+ 'batch_size':32,
106
+ 'best_other_mse':'0.255',
107
+ 'best_other_mae':'0.315',
108
+ },
109
+ 'Weather': {'data': 'custom',
110
+ 'root_path': './Long_Term_Forecasting/dataset/weather/',
111
+ 'data_path': 'weather.csv',
112
+ 'batch_size':32,
113
+ 'best_other_mse':'0.226',
114
+ 'best_other_mae':'0.261',
115
+ },
116
+ 'Electricity': {'data': 'custom',
117
+ 'root_path': './dataset/',
118
+ 'data_path': 'electricity/electricity.csv',
119
+ 'batch_size':1,
120
+ 'best_other_mse':'0.159',
121
+ 'best_other_mae':'0.253',
122
+ },
123
+ 'Traffic': {'data': 'custom',
124
+ 'root_path': './dataset/',
125
+ 'data_path': 'traffic.csv',
126
+ 'batch_size':32,
127
+ 'best_other_mse':'0.391',
128
+ 'best_other_mae':'0.264',
129
+ },
130
+
131
+ 'GlobalTemp': {'data': 'Global_Temp',
132
+ 'root_path': './dataset/',
133
+ 'data_path': 'solar_AL.csv',
134
+ 'batch_size':1,
135
+ 'best_other_mse':'0.322',
136
+ 'best_other_mae':'0.370',
137
+ },
138
+
139
+ }
140
+
141
+
142
+ if __name__ == '__main__':
143
+
144
+
145
+
146
+
147
+
148
+ parser = argparse.ArgumentParser()
149
+ parser.add_argument('--data', type=str, default='ETTh2', help='dataset type')
150
+ parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
151
+ parser.add_argument('--seq_len', type=int, default=96, help='input sequence length')
152
+ parser.add_argument('--num_gpus', type=int,default=1)
153
+ parser.add_argument('--future_token', type=int,default=3072)
154
+ parser.add_argument('-t', '--task_list', action='append')
155
+ parser.add_argument('--model_name',type=str)
156
+
157
+
158
+
159
+ args = parser.parse_args()
160
+
161
+
162
+ torch.set_float32_matmul_precision("high")
163
+
164
+
165
+
166
+ strategy = DDPStrategy(find_unused_parameters=True)
167
+
168
+
169
+ fabric = L.Fabric(devices=args.num_gpus, strategy=strategy)
170
+
171
+ local_rank = fabric.global_rank
172
+
173
+ if local_rank == 0:
174
+ print(args)
175
+
176
+
177
+ model = AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=True)
178
+ model = model.to(local_rank).bfloat16()
179
+
180
+ model = fabric.setup(model)
181
+ name = args.model_name
182
+
183
+ if local_rank == 0:
184
+ with open(f'results.txt', 'a') as f:
185
+ f.write(f"---------------------------------------------------------------------------------")
186
+
187
+
188
+
189
+ for task_name in args.task_list:
190
+
191
+ task = datasets_configs[task_name]
192
+ data = task['data']
193
+ root_path = task['root_path']
194
+ data_path = task['data_path']
195
+ best_mse = task['best_other_mse']
196
+ best_mae = task['best_other_mae']
197
+ batch_size = min(args.batch_size,task['batch_size'])
198
+ seq_len = args.seq_len
199
+ if local_rank == 0:
200
+ with open(f'results.txt', 'a') as f:
201
+ seconds_since_epoch = time.time()
202
+ human_readable_time = datetime.datetime.fromtimestamp(seconds_since_epoch).strftime('%Y-%m-%d %H:%M:%S')
203
+ f.write(f"{human_readable_time}-------------\n")
204
+
205
+
206
+
207
+
208
+
209
+
210
+ data_set, data_loader = data_provider(data,root_path,data_path,batch_size = batch_size,seq_len=seq_len)
211
+ data_loader = fabric.setup_dataloaders(data_loader)
212
+
213
+
214
+
215
+ model.eval()
216
+ preds = []
217
+ truths = []
218
+ preds_s = [[],[],[],[]]
219
+ truths_s = [[],[],[],[]]
220
+ intermediates = []
221
+ xs = []
222
+
223
+ seperate_s = [96, 192, 336,720]
224
+ remains = args.future_token
225
+ prevs = 0
226
+
227
+ with torch.no_grad():
228
+ for idx,(x_ori,y,y1,y2,y3,y4) in enumerate(tqdm(data_loader,disable = local_rank != 0)):
229
+
230
+ b,c = x_ori.shape[0],x_ori.shape[2]
231
+ x = rearrange(x_ori, 'b l c -> (b c) l').float().to(local_rank).bfloat16().contiguous()
232
+ y = rearrange(y1, 'b l c -> (b c) l').float()
233
+
234
+ y_s = [y1,y2,y3,y4]
235
+ res = []
236
+ res1 = []
237
+ res2 = []
238
+ res3 = []
239
+
240
+ logits = 0
241
+ used = 0
242
+
243
+
244
+ for history in [512,1024,2048,4096]:
245
+ if history > x.shape[1]:
246
+ continue
247
+ else:
248
+ used += 2
249
+
250
+ x_mean = x[:,-history:].mean(dim = -1,keepdims = True)
251
+ x_std = x[:,-history:].std(dim = -1,keepdims = True)
252
+
253
+ x_train = torch.cat((x[:,-history:],-x[:,-history:]),dim = 0)
254
+
255
+
256
+ logits_all = model(idx = x_train, future_token = args.future_token)
257
+
258
+
259
+ logits_all = rearrange(logits_all, '(t b) l c d -> b (l c) d t', t = 2)
260
+ logits += logits_all[...,0] -logits_all[...,1].flip(dims = [-1])
261
+
262
+ logits = logits / used
263
+
264
+
265
+
266
+ x = torch.cat([x,logits[:,:720,49]],dim = -1).float()
267
+
268
+ median = logits[:,:720,49].float()
269
+ median = median[:,:720]
270
+
271
+ median0 = rearrange(median, '(b c) l -> b l c',b = b).contiguous().cpu().detach().numpy()
272
+ y0 = rearrange(y, '(b c) l -> b l c',b = b).contiguous().cpu().detach().numpy()
273
+
274
+
275
+
276
+ for i, seperate in enumerate(seperate_s):
277
+ median_s = logits[:,:seperate,49].float()
278
+ median_s = rearrange(median_s, '(b c) l -> b l c',b = b).contiguous().cpu().detach().numpy()
279
+ preds_s[i].append(median_s)
280
+ truths_s[i].append(y_s[i].contiguous().cpu().detach().numpy())
281
+
282
+
283
+ xs.append(x_ori.contiguous().cpu().detach().numpy())
284
+
285
+
286
+
287
+ def gather_losses(loss):
288
+ """Gather loss values from all GPUs."""
289
+ if dist.is_initialized():
290
+ loss_tensor = torch.tensor([loss], device=local_rank)
291
+ gathered_losses = [torch.zeros_like(loss_tensor) for _ in range(args.num_gpus)]
292
+ dist.all_gather(gathered_losses, loss_tensor)
293
+ return torch.cat(gathered_losses).mean()
294
+ else:
295
+ return loss
296
+
297
+
298
+
299
+
300
+
301
+
302
+
303
+ if local_rank == 0:
304
+ print(f'Eval on {task_name}-{seq_len}...')
305
+ mses = []
306
+ maes = []
307
+
308
+
309
+ for i, seperate in enumerate(seperate_s):
310
+ if i == 4:
311
+ break
312
+ truths = truths_s[i]
313
+ preds = preds_s[i]
314
+ truths = np.concatenate(truths,axis = 0)
315
+ preds = np.concatenate(preds,axis = 0)
316
+
317
+ truths = rearrange(truths,'b l c -> b c l')
318
+ preds = rearrange(preds,'b l c -> b c l')
319
+ mask = np.isnan(truths).any(axis=2)
320
+
321
+ truths1 = truths[~mask]
322
+ preds1 = preds[~mask]
323
+
324
+ truths = rearrange(truths,'b c l-> b l c')
325
+ preds = rearrange(preds,'b c l-> b l c')
326
+
327
+ mae, mse, rmse, mape, mspe = metric(preds1[:,:seperate], truths1[:,:seperate])
328
+ mae,mse = gather_losses(mae), gather_losses(mse)
329
+
330
+
331
+ if local_rank == 0:
332
+ print(f'ours-{name}: mse {mse:.4f} mae {mae:.4f}')
333
+ mses.append(mse.cpu().numpy())
334
+ maes.append(mae.cpu().numpy())
335
+ with open(f'results.txt', 'a') as f:
336
+
337
+ f.write(f"ours-{name}, {data_path.split('.')[0]}-{args.seq_len}-{seperate}-{args.future_token}, mse, {mse:.5f}, mae, {mae:.5f}\n")
338
+ if local_rank == 0:
339
+
340
+ print(f'ours-{name}-avg: mse {np.mean(mses):.3f} mae {np.mean(maes):.3f}')
341
+
342
+
343
+ with open(f'results.txt', 'a') as f:
344
+ f.write(f"ours-{name}, {data_path.split('.')[0]}-avg, mse, {np.mean(mses):.5f}, mae, {np.mean(maes):.5f}\n")
345
+ print(f'best-avg: mse {best_mse} mae {best_mae}')
346
+
347
+
348
+
349
+ dist.destroy_process_group()
Long_Term_Forecasting/metrics.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def RSE(pred, true):
5
+ return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))
6
+
7
+
8
+ def CORR(pred, true):
9
+ u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
10
+ d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
11
+ return (u / d).mean(-1)
12
+
13
+
14
+ def MAE(pred, true):
15
+ return np.mean(np.abs(pred - true))
16
+
17
+
18
+ def MSE(pred, true):
19
+ return np.mean((pred - true) ** 2)
20
+
21
+
22
+ def RMSE(pred, true):
23
+ return np.sqrt(MSE(pred, true))
24
+
25
+
26
+ def MAPE(pred, true):
27
+ return np.mean(np.abs((pred - true) / true))
28
+
29
+
30
+ def MSPE(pred, true):
31
+ return np.mean(np.square((pred - true) / true))
32
+
33
+
34
+ def metric(pred, true):
35
+ mae = MAE(pred, true)
36
+ mse = MSE(pred, true)
37
+ rmse = RMSE(pred, true)
38
+ mape = MAPE(pred, true)
39
+ mspe = MSPE(pred, true)
40
+
41
+ return mae, mse, rmse, mape, mspe
README.md CHANGED
@@ -1,3 +1,9 @@
1
- ---
2
- license: cc-by-4.0
3
- ---
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Library: [More Information Needed]
9
+ - Docs: [More Information Needed]
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["YingLong"],
3
+
4
+ "auto_map": {
5
+ "AutoConfig": "model_config.YingLongConfig",
6
+ "AutoModelForCausalLM": "model.GPT"
7
+ },
8
+ "org": "Alibaba",
9
+ "_mlp_class": "LLaMAMLP",
10
+ "_norm_class": "FusedRMSNorm",
11
+ "bias": false,
12
+ "block_size": 8224,
13
+ "condense_ratio": 1,
14
+ "haar_trans": true,
15
+ "haar_trans_inv": true,
16
+ "haar_trans_norm": "backward",
17
+ "intermediate_size": 1024,
18
+ "n_embd": 256,
19
+ "n_head": 16,
20
+ "n_layer": 6,
21
+ "n_query_groups": 4,
22
+ "norm_eps": 1e-05,
23
+ "parallel_residual": false,
24
+ "patch_size": 32,
25
+ "quantitle": true,
26
+ "rope_base": 10000,
27
+ "rotary_percentage": 1.0,
28
+ "shared_attention_norm": false,
29
+ "unet": true,
30
+ "vocab_size": 1
31
+ }
model.py ADDED
@@ -0,0 +1,1526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ Based on the tinyllama implementation: https://github.com/jzhang38/TinyLlama
4
+
5
+ """
6
+
7
+
8
+ import math, random
9
+ import numpy as np
10
+ from typing import Any, List, Optional, Tuple
11
+ from typing_extensions import Self
12
+
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ from lightning_utilities.core.imports import RequirementCache
20
+ FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1")
21
+
22
+ from flash_attn import flash_attn_func
23
+ from xformers.ops import SwiGLU
24
+ from einops import rearrange
25
+
26
+
27
+ from transformers import PreTrainedModel
28
+ from .model_config import YingLongConfig
29
+
30
+
31
+
32
+
33
+ class Tokenizer(torch.nn.Module):
34
+ def __init__(self, config: YingLongConfig, *args,**kwargs) -> None:
35
+ super().__init__()
36
+
37
+ self.config = config
38
+ self.tokenizer = nn.Linear(config.patch_size,self.config.n_embd)
39
+
40
+ self.patch_size = config.patch_size
41
+ self.mask0 = nn.Linear(1,config.n_embd)
42
+
43
+ self.register_buffer('mask_token', torch.zeros(1000))
44
+ if self.config.haar_trans:
45
+
46
+ self.register_buffer('haar_transform',torch.Tensor(haarMatrix(self.config.patch_size,normalized = self.config.haar_trans_norm)))
47
+
48
+
49
+
50
+ def forward(self,x,
51
+ future_token = 0,
52
+ prev_token = 0,
53
+ factor = 0.2,
54
+ sequential = False,
55
+ *args, **kwargs):
56
+
57
+
58
+ b = x.shape[0]
59
+
60
+ x_raw = rearrange(x, "b (l c) -> b l c", c = self.patch_size)
61
+ x_raw_0 = rearrange(x, "b (l c) -> b l c", c = self.patch_size).detach().clone()
62
+
63
+ if future_token == 0:
64
+ if not sequential:
65
+ masks = torch.randperm(x_raw.shape[1])
66
+ unmasks,masks = masks[:int(x_raw.shape[1]*factor)],masks[int(x_raw.shape[1]*factor):]
67
+ else:
68
+ masks = [_ for _ in range(x_raw.shape[1])]
69
+ factor = np.random.rand()*0.6 + 0.2
70
+ unmasks,masks = masks[:int(x_raw.shape[1]*factor)],masks[int(x_raw.shape[1]*factor):]
71
+
72
+
73
+
74
+ x_raw_remains = x_raw[:,unmasks,:]
75
+
76
+ mean = x_raw_remains.mean(dim = (-2,-1),keepdims = True)
77
+ std = x_raw_remains.std(dim = (-2,-1),keepdims = True)
78
+ x_raw = (x_raw - mean)/ (std + 1e-4)
79
+
80
+
81
+ if self.config.haar_trans:
82
+ x_featured = torch.einsum('blc,ac->bla',x_raw,self.haar_transform)
83
+ x_featured = self.tokenizer(x_featured)
84
+ else:
85
+ x_featured = self.tokenizer(x_raw)
86
+
87
+
88
+ x_featured[:,masks,:] = self.mask0(self.mask_token[0].unsqueeze(0))
89
+
90
+
91
+
92
+ else:
93
+
94
+ factor = 1
95
+ more_rows = future_token // self.patch_size + 1
96
+ prev_more_rows = prev_token // self.patch_size + 1
97
+
98
+ mean = x_raw[:,prev_more_rows:-more_rows,:].mean(dim = (-2,-1),keepdims = True)
99
+ std = x_raw[:,prev_more_rows:-more_rows,:].std(dim = (-2,-1),keepdims = True)
100
+ x_raw = (x_raw - mean)/ (std + 1e-4)
101
+
102
+
103
+ if self.config.haar_trans:
104
+ x_featured = torch.einsum('blc,ac->bla',x_raw,self.haar_transform)
105
+ x_featured = self.tokenizer(x_featured)
106
+ else:
107
+ x_featured = self.tokenizer(x_raw)
108
+
109
+
110
+ masks = [jj for jj in range(x_featured.shape[1])]
111
+ masks = masks[-more_rows:]
112
+
113
+ x_featured[:,-more_rows:] = self.mask0(self.mask_token[:len(masks)].unsqueeze(-1)).repeat(x_featured.shape[0],1,1)
114
+ x_featured[:,:prev_more_rows] = self.mask0(self.mask_token[:prev_more_rows].unsqueeze(-1)).repeat(x_featured.shape[0],1,1)
115
+
116
+
117
+ return x_featured, x_raw_0, masks, mean, std, x_raw
118
+
119
+
120
+
121
+ class model_tmp(PreTrainedModel):
122
+ config_class = YingLongConfig
123
+ base_model_prefix = "model"
124
+
125
+
126
+
127
+ def _init_weights(self, module: nn.Module) -> None:
128
+ if isinstance(module, nn.Embedding):
129
+ torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
130
+ elif isinstance(module, nn.Linear):
131
+ torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
132
+ if module.bias is not None:
133
+ torch.nn.init.zeros_(module.bias)
134
+ for name, p in module.named_parameters():
135
+ if (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU) or (name=="proj.weight" and isinstance(module, BidirectedlSelfAttention))):
136
+ nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / self.config.n_layer)
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+ class GPT(model_tmp):
145
+ def __init__(self, config: YingLongConfig, *args,**kwargs) -> None:
146
+
147
+
148
+ super().__init__(config)
149
+
150
+ self.config = config
151
+ self.patch_size = config.patch_size
152
+ self.unet = config.unet
153
+
154
+
155
+ if self.config._norm_class == "RMSNorm":
156
+
157
+ self.config.norm_class = RMSNorm
158
+ elif self.config._norm_class == "FusedRMSNorm":
159
+ self.config.norm_class = FusedRMSNorm
160
+ elif self.config._norm_class == 'BatchNorm':
161
+ self.config.norm_class = iBatchNorm
162
+
163
+
164
+ if self.config._mlp_class == "GptNeoxMLP":
165
+ self.config.mlp_class = GptNeoxMLP
166
+ elif self.config._mlp_class == "LLaMAMLP":
167
+ self.config.mlp_class = LLaMAMLP
168
+
169
+
170
+
171
+
172
+ self.tokenizer = Tokenizer(config)
173
+
174
+
175
+ self.lm_head = nn.Linear(config.n_embd, 99*self.patch_size)
176
+
177
+
178
+ self.quantitleLoss = quantitleLoss(99,patch_size = self.patch_size)
179
+
180
+
181
+
182
+ if self.unet:
183
+ assert config.n_layer%2 == 0
184
+ self.unet_projection = nn.ModuleList(nn.Sequential(nn.Linear(config.n_embd*2,config.n_embd),
185
+ config.norm_class(config.n_embd, eps=config.norm_eps),
186
+ )
187
+ for _ in range(config.n_layer//2)
188
+ )
189
+ self.unet_merge = nn.ModuleList(nn.Sequential(nn.Linear(config.n_embd*2,config.n_embd),
190
+ config.norm_class(config.n_embd, eps=config.norm_eps),
191
+ )
192
+ for _ in range(config.n_layer//2)
193
+ )
194
+
195
+
196
+
197
+ self.transformer = nn.ModuleDict(dict(h = nn.ModuleList(Block(config)
198
+ for _ in range(config.n_layer))
199
+ )
200
+ )
201
+
202
+
203
+
204
+ self.rope_cache = None
205
+
206
+
207
+
208
+ def forward(
209
+ self, idx: torch.Tensor,
210
+ future_token: int = 0,
211
+ prev_token: int = 0,
212
+ *args,**kwargs,
213
+ ) -> torch.Tensor:
214
+
215
+ if future_token > 0:
216
+ more_rows = future_token // self.patch_size + 1
217
+ idx = torch.cat((idx,torch.zeros(idx.shape[0],more_rows*self.patch_size).to(idx.device)),dim = -1).bfloat16()
218
+ if prev_token > 0:
219
+ more_rows = prev_token // self.patch_size + 1
220
+ idx = torch.cat((torch.zeros(idx.shape[0],more_rows*self.patch_size).to(idx.device),idx),dim = -1).bfloat16()
221
+
222
+ B, T = idx.size()
223
+
224
+
225
+
226
+ block_size = self.config.block_size
227
+ max_seq_length = T
228
+
229
+ assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
230
+
231
+
232
+ self.rope_cache = self.build_rope_cache(idx)
233
+ cos, sin = self.rope_cache
234
+
235
+ cos = cos[:max(T,1024)]
236
+ sin = sin[:max(T,1024)]
237
+
238
+
239
+
240
+
241
+ x,x_raw,masks,mean,std,_ = self.tokenizer(idx, future_token =future_token,prev_token = prev_token)
242
+
243
+
244
+
245
+ if self.unet:
246
+ skips = []
247
+
248
+
249
+
250
+
251
+ for block_idx in range(len( self.transformer.h)):
252
+
253
+
254
+ block = self.transformer.h[block_idx]
255
+
256
+ if self.unet and block_idx >=len(self.transformer.h) //2:
257
+ x = self.unet_projection[block_idx - len(self.transformer.h) //2](torch.cat((skips.pop(),x),dim = -1))
258
+
259
+ x = block(x, (cos, sin), max_seq_length)
260
+
261
+ if self.unet and block_idx <len(self.transformer.h) //2:
262
+ skips.append(x)
263
+ x_delay = torch.cat((x[:,0,:].unsqueeze(1),x[:,:-1,:]),dim = 1)
264
+ x = self.unet_merge[block_idx](torch.cat((x_delay,x),dim = -1))
265
+
266
+
267
+
268
+
269
+ res = self.lm_head(x)
270
+
271
+
272
+
273
+ res = rearrange(res,'b c (l1 l2) -> b c l1 l2', l2 = 99)
274
+
275
+
276
+
277
+ if self.config.haar_trans_inv:
278
+ res = torch.einsum('bcal,ad->bcdl',res,self.tokenizer.haar_transform)
279
+ if self.config.haar_trans_norm == "backward":
280
+ res = res / np.sqrt(res.shape[-2])
281
+ elif self.config.haar_trans_norm == "forward":
282
+ res = res * np.sqrt(res.shape[-2])
283
+
284
+
285
+
286
+
287
+
288
+ res = res * (std.unsqueeze(-1) + 1e-4) + mean.unsqueeze(-1)
289
+
290
+
291
+
292
+
293
+ if future_token == 0:
294
+ return res[:,masks,:,:], x_raw[:,masks,:]
295
+ else:
296
+ return res[:,masks,:,:]
297
+
298
+ def generate(self,*args,**kwargs):
299
+ res = self.forward(*args,**kwargs)
300
+ res = rearrange(res, 'b l c d -> b (l c) d')
301
+ return res[:,:kwargs['future_token'],:]
302
+
303
+
304
+
305
+ @classmethod
306
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
307
+ return cls(Config.from_name(name, **kwargs))
308
+
309
+ def build_rope_cache(self, idx: torch.Tensor) :
310
+ return build_rope_cache(
311
+ seq_len=self.config.block_size,
312
+ n_elem=int(self.config.rotary_percentage * self.config.head_size),
313
+ dtype=torch.bfloat16,
314
+ device=idx.device,
315
+ base = self.config.rope_base,
316
+ condense_ratio=self.config.condense_ratio,
317
+ )
318
+
319
+
320
+ class Block(nn.Module):
321
+ def __init__(self, config:YingLongConfig) -> None:
322
+ super().__init__()
323
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
324
+ self.attn = BidirectedlSelfAttention(config)
325
+ if not config.shared_attention_norm:
326
+ self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
327
+ self.mlp = config.mlp_class(config)
328
+ self.config = config
329
+ def forward(
330
+ self,
331
+ x: torch.Tensor,
332
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]],
333
+ max_seq_length: int,
334
+ mask: Optional[torch.Tensor] = None,
335
+ input_pos: Optional[torch.Tensor] = None,
336
+ ) -> torch.Tensor:
337
+
338
+ n_1 = self.norm_1(x)
339
+ h = self.attn(n_1, rope, max_seq_length, mask, input_pos)
340
+ if self.config.parallel_residual:
341
+ n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
342
+ x = x + h + self.mlp(n_2)
343
+ else:
344
+ if self.config.shared_attention_norm:
345
+ raise NotImplementedError(
346
+ "No checkpoint amongst the ones we support uses this configuration"
347
+ " (non-parallel residual and shared attention norm)."
348
+ )
349
+
350
+ x = x + h
351
+ x = x + self.mlp(self.norm_2(x))
352
+ return x
353
+
354
+
355
+ class BidirectedlSelfAttention(nn.Module):
356
+ def __init__(self, config:YingLongConfig) -> None:
357
+ super().__init__()
358
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
359
+ self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
360
+ self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
361
+ self.config = config
362
+
363
+ def forward(
364
+ self,
365
+ x: torch.Tensor,
366
+ rope: Tuple[torch.Tensor, torch.Tensor],
367
+ max_seq_length: int,
368
+ mask: Optional[torch.Tensor] = None,
369
+ input_pos: Optional[torch.Tensor] = None,
370
+ ) -> torch.Tensor:
371
+
372
+
373
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
374
+
375
+ qkv = self.attn(x)
376
+
377
+ # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
378
+ q_per_kv = self.config.n_head // self.config.n_query_groups
379
+ total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
380
+ qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) # (B, T, n_query_groups, total_qkv, hs)
381
+
382
+
383
+ # split batched computation into three
384
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2)
385
+
386
+ q = q.reshape(B, T, -1, self.config.head_size) # (B, T, nh_q, hs)
387
+ k = k.reshape(B, T, -1, self.config.head_size)
388
+ v = v.reshape(B, T, -1, self.config.head_size)
389
+
390
+ cos, sin = rope
391
+
392
+ q = apply_rotary_emb_func(q, cos, sin, False, True)
393
+ k = apply_rotary_emb_func(k, cos, sin, False, True)
394
+
395
+
396
+ y = self.scaled_dot_product_attention(q, k, v, mask=mask)
397
+
398
+ y = y.reshape(B, T, C) # re-assemble all head outputs side by side
399
+
400
+ # output projection
401
+ y = self.proj(y)
402
+
403
+ return y
404
+
405
+
406
+
407
+ def scaled_dot_product_attention(
408
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
409
+ ):
410
+ scale = 1.0 / math.sqrt(self.config.head_size)
411
+
412
+ if (
413
+ FlashAttention2Available
414
+ and mask is None
415
+ and q.device.type == "cuda"
416
+ and q.dtype in (torch.float16, torch.bfloat16)
417
+ ):
418
+ from flash_attn import flash_attn_func
419
+
420
+ return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=False)
421
+ q = q.transpose(1, 2)
422
+ k = k.transpose(1, 2)
423
+ v = v.transpose(1, 2)
424
+ if q.size() != k.size():
425
+ k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1)
426
+ v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1)
427
+ y = torch.nn.functional.scaled_dot_product_attention(
428
+ q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=False
429
+ )
430
+ return y.transpose(1, 2)
431
+
432
+
433
+
434
+
435
+
436
+
437
+ class quantitleLoss(torch.nn.Module):
438
+ def __init__(self,
439
+ qSize = 99,
440
+ patch_size = 16,
441
+ *args,**kwargs):
442
+
443
+ super().__init__()
444
+ self.qSize = qSize
445
+ self.patch_size = patch_size
446
+
447
+
448
+ q = np.array([i+1 for i in range(self.qSize)])
449
+ q = q / (self.qSize + 1)
450
+ q = q.reshape((1,1,-1))
451
+
452
+ q_variance = q*(1-q)
453
+
454
+ self.register_buffer('q', torch.tensor(q))
455
+ self.register_buffer('q_variance', torch.tensor(q_variance))
456
+
457
+
458
+ def forward(self, input: torch.Tensor, target: torch.Tensor,rel_loss = False):
459
+
460
+
461
+
462
+ target = target.unsqueeze(-1)
463
+ input = input[:,:target.shape[1],:,:]
464
+
465
+
466
+ posPart = input - target
467
+ negPart = -posPart
468
+
469
+ raw_loss = torch.maximum(self.q * negPart, (1-self.q) * posPart)
470
+
471
+ target_absmean = torch.mean(target.abs(),dim = (1,2),keepdims = True)
472
+ raw_loss = raw_loss / torch.sqrt(self.q_variance) / (target_absmean + 1e-4)
473
+
474
+ return torch.mean(raw_loss)
475
+
476
+
477
+ def haarMatrix_unnormalized(n):
478
+
479
+ n = 2**np.ceil(np.log2(n))
480
+ if n > 2:
481
+ h = haarMatrix(n / 2)
482
+ else:
483
+ return np.array([[1, 1], [1, -1]])
484
+ h_n = np.kron(h, [1, 1])
485
+ h_i = np.kron(np.eye(len(h)), [1, -1])
486
+ h = np.vstack((h_n, h_i))
487
+ return h
488
+
489
+ def haarMatrix(n,normalized = 'ortho'):
490
+ h = haarMatrix_unnormalized(n)
491
+ scaler = np.diag(1/np.sqrt(np.diag([email protected]())))
492
+ if normalized == 'ortho':
493
+ return scaler @ h
494
+ elif normalized == 'forward':
495
+ return scaler @ h/ np.sqrt(n)
496
+
497
+ else:
498
+ return scaler @ h * np.sqrt(n)
499
+
500
+
501
+
502
+ class GptNeoxMLP(nn.Module):
503
+ def __init__(self, config:YingLongConfig) -> None:
504
+ super().__init__()
505
+ self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
506
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
507
+
508
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
509
+ x = self.fc(x)
510
+ x = torch.nn.functional.gelu(x)
511
+ return self.proj(x)
512
+
513
+
514
+ class LLaMAMLP(nn.Module):
515
+ def __init__(self, config:YingLongConfig) -> None:
516
+ super().__init__()
517
+
518
+ self.swiglu = SwiGLU(config.n_embd,config.intermediate_size, bias=False, _pack_weights=False)
519
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
520
+ return self.swiglu(x)
521
+
522
+
523
+ def build_rope_cache(
524
+ seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1
525
+ ) -> Tuple[torch.Tensor,torch.Tensor]:
526
+ """Enhanced Transformer with Rotary Position Embedding.
527
+
528
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
529
+ transformers/rope/__init__.py. MIT License:
530
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
531
+ """
532
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
533
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem))
534
+
535
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
536
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
537
+
538
+ # Calculate the product of position index and $\theta_i$
539
+ idx_theta = torch.outer(seq_idx, theta)
540
+
541
+ cos, sin = torch.cos(idx_theta), torch.sin(idx_theta)
542
+
543
+ # added by peiyuan to ensure same data type with q, k, to use fused rotary embedding
544
+ if dtype == torch.bfloat16:
545
+ return cos.bfloat16(), sin.bfloat16()
546
+ # this is to mimic the behaviour of complex32, else we will get different results
547
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
548
+ return cos.half(), sin.half()
549
+ return cos, sin
550
+
551
+
552
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
553
+ head_size = x.size(-1)
554
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
555
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
556
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
557
+ roped = (x * cos) + (rotated * sin)
558
+ return roped.type_as(x)
559
+
560
+
561
+
562
+
563
+
564
+
565
+
566
+ ######################################
567
+ #layernorm
568
+ ######################################
569
+
570
+
571
+ import torch
572
+ # Copyright (c) 2022, Tri Dao.
573
+ # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py AND https://github.com/Dao-AILab/flash-attention/blob/7a983df74215e035e566e37125b0a71e3618f39d/flash_attn/ops/layer_norm.py#L16
574
+
575
+ import dropout_layer_norm
576
+ import torch
577
+ from torch.nn import init
578
+
579
+
580
+ def maybe_align(x, alignment_in_bytes=16):
581
+ """Assume that x already has last dim divisible by alignment_in_bytes"""
582
+ # TD [2023-07-04] I'm not 100% sure that clone will align the memory
583
+ # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
584
+ return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
585
+
586
+
587
+ def _dropout_add_layer_norm_forward(
588
+ x0,
589
+ residual,
590
+ gamma,
591
+ beta,
592
+ rowscale,
593
+ colscale,
594
+ dropout_p,
595
+ epsilon,
596
+ residual_in_fp32=False,
597
+ is_rms_norm=False,
598
+ ):
599
+ """Assume that arguments are contiguous and aligned to 16 bytes"""
600
+ hidden_size = gamma.numel()
601
+ x0mat = x0.view((-1, hidden_size))
602
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
603
+ rowscale = rowscale.view(-1) if rowscale is not None else None
604
+ zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
605
+ x0mat,
606
+ residualmat,
607
+ gamma,
608
+ beta,
609
+ rowscale,
610
+ colscale,
611
+ None,
612
+ None,
613
+ dropout_p,
614
+ epsilon,
615
+ 1.0,
616
+ 0,
617
+ None,
618
+ residual_in_fp32,
619
+ is_rms_norm,
620
+ )
621
+ # dmask is None if dropout_p == 0.0
622
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
623
+ return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
624
+
625
+
626
+ def _dropout_add_layer_norm_backward(
627
+ dz,
628
+ dx,
629
+ x,
630
+ x0,
631
+ dmask,
632
+ mu,
633
+ rsigma,
634
+ gamma,
635
+ rowscale,
636
+ colscale,
637
+ dropout_p,
638
+ has_residual,
639
+ is_rms_norm=False,
640
+ ):
641
+ """Assume that arguments are contiguous and aligned to 16 bytes
642
+ dx == None means that it was a post-norm architecture
643
+ (x = drop(x0) + residual was not returned in the fwd).
644
+ x0 must not be None if we have colscale.
645
+ """
646
+ hidden_size = gamma.numel()
647
+ xmat = x.view((-1, hidden_size))
648
+ dzmat = dz.view(xmat.shape)
649
+ dxmat = dx.view(xmat.shape) if dx is not None else None
650
+ x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
651
+ rowscale = rowscale.view(-1) if rowscale is not None else None
652
+ if colscale is not None:
653
+ assert x0 is not None, "x0 is required to compute the gradient of colscale"
654
+ dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
655
+ dzmat,
656
+ dxmat,
657
+ xmat,
658
+ x0mat,
659
+ dmask,
660
+ mu,
661
+ rsigma,
662
+ gamma,
663
+ rowscale,
664
+ colscale,
665
+ None,
666
+ None,
667
+ dropout_p,
668
+ 1.0,
669
+ 0,
670
+ has_residual,
671
+ is_rms_norm,
672
+ )
673
+ # dresidualmat is None if not has_residual
674
+ if colscale is None:
675
+ return dx0mat, dresidualmat, dgamma, dbeta
676
+ else:
677
+ dcolscale = rest[0]
678
+ return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
679
+
680
+
681
+ def _dropout_add_layer_norm_subset_forward(
682
+ x0,
683
+ residual,
684
+ gamma,
685
+ beta,
686
+ colscale,
687
+ x0_subset,
688
+ out_subset,
689
+ dropout_p,
690
+ epsilon,
691
+ rowscale_const,
692
+ out_numrows,
693
+ residual_in_fp32=False,
694
+ is_rms_norm=False,
695
+ ):
696
+ """Assume that arguments are contiguous and aligned to 16 bytes"""
697
+ hidden_size = gamma.numel()
698
+ x0mat = x0.view((-1, hidden_size))
699
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
700
+ x0_subset = x0_subset.view(-1) if x0_subset is not None else None
701
+ out_subset = out_subset.view(-1) if out_subset is not None else None
702
+ zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
703
+ x0mat,
704
+ residualmat,
705
+ gamma,
706
+ beta,
707
+ None,
708
+ colscale,
709
+ x0_subset,
710
+ out_subset,
711
+ dropout_p,
712
+ epsilon,
713
+ rowscale_const,
714
+ out_numrows,
715
+ None,
716
+ residual_in_fp32,
717
+ is_rms_norm,
718
+ )
719
+ # dmask is None if dropout_p == 0.0
720
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
721
+ return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
722
+
723
+
724
+ def _dropout_add_layer_norm_subset_backward(
725
+ dz,
726
+ dx,
727
+ x,
728
+ x0,
729
+ dmask,
730
+ mu,
731
+ rsigma,
732
+ gamma,
733
+ colscale,
734
+ x0_subset,
735
+ out_subset,
736
+ dropout_p,
737
+ rowscale_const,
738
+ x0_numrows,
739
+ has_residual,
740
+ is_rms_norm=False,
741
+ ):
742
+ """Assume that arguments are contiguous and aligned to 16 bytes
743
+ dx == None means that it was a post-norm architecture
744
+ (x = drop(x0) + residual was not returned in the fwd).
745
+ x0 must not be None if we have colscale.
746
+ """
747
+ hidden_size = gamma.numel()
748
+ xmat = x.view((-1, hidden_size))
749
+ dzmat = dz.view(-1, hidden_size)
750
+ dxmat = dx.view(xmat.shape) if dx is not None else None
751
+ x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
752
+ x0_subset = x0_subset.view(-1) if x0_subset is not None else None
753
+ out_subset = out_subset.view(-1) if out_subset is not None else None
754
+ if colscale is not None:
755
+ assert x0 is not None, "x0 is required to compute the gradient of colscale"
756
+ dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
757
+ dzmat,
758
+ dxmat,
759
+ xmat,
760
+ x0mat,
761
+ dmask,
762
+ mu,
763
+ rsigma,
764
+ gamma,
765
+ None,
766
+ colscale,
767
+ x0_subset,
768
+ out_subset,
769
+ dropout_p,
770
+ rowscale_const,
771
+ x0_numrows,
772
+ has_residual,
773
+ is_rms_norm,
774
+ )
775
+ # dresidualmat is None if not has_residual
776
+ if colscale is None:
777
+ return dx0mat, dresidualmat, dgamma, dbeta
778
+ else:
779
+ dcolscale = rest[0]
780
+ return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
781
+
782
+
783
+ def _dropout_add_layer_norm_parallel_residual_forward(
784
+ x0,
785
+ x1,
786
+ residual,
787
+ gamma0,
788
+ beta0,
789
+ gamma1,
790
+ beta1,
791
+ dropout_p,
792
+ epsilon,
793
+ residual_in_fp32=False,
794
+ is_rms_norm=False,
795
+ ):
796
+ """Assume that arguments are contiguous and aligned to 16 bytes"""
797
+ hidden_size = gamma0.numel()
798
+ x0mat = x0.view((-1, hidden_size))
799
+ x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
800
+ residualmat = residual.view((-1, hidden_size)) if residual is not None else None
801
+ (
802
+ z0mat,
803
+ z1mat,
804
+ xmat,
805
+ dmask0,
806
+ dmask1,
807
+ mu,
808
+ rsigma,
809
+ ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
810
+ x0mat,
811
+ x1mat,
812
+ residualmat,
813
+ gamma0,
814
+ beta0,
815
+ gamma1,
816
+ beta1,
817
+ dropout_p,
818
+ epsilon,
819
+ None,
820
+ residual_in_fp32,
821
+ is_rms_norm,
822
+ )
823
+ # dmask0 and dmask1 are None if dropout_p == 0.0
824
+ # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
825
+ return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
826
+
827
+
828
+ def _dropout_add_layer_norm_parallel_residual_backward(
829
+ dz0,
830
+ dz1,
831
+ dx,
832
+ x,
833
+ dmask0,
834
+ dmask1,
835
+ mu,
836
+ rsigma,
837
+ gamma0,
838
+ gamma1,
839
+ dropout_p,
840
+ has_x1,
841
+ has_residual,
842
+ is_rms_norm=False,
843
+ ):
844
+ """Assume that arguments are contiguous and aligned to 16 bytes
845
+ dx == None means that it was a post-norm architecture
846
+ (x = drop(x0) + residual was not returned in the fwd).
847
+ """
848
+ hidden_size = gamma0.numel()
849
+ xmat = x.view((-1, hidden_size))
850
+ dz0mat = dz0.view(xmat.shape)
851
+ dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
852
+ dxmat = dx.view(xmat.shape) if dx is not None else None
853
+ (
854
+ dx0mat,
855
+ dx1mat,
856
+ dresidualmat,
857
+ dgamma0,
858
+ dbeta0,
859
+ dgamma1,
860
+ dbeta1,
861
+ *rest,
862
+ ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
863
+ dz0mat,
864
+ dz1mat,
865
+ dxmat,
866
+ xmat,
867
+ dmask0,
868
+ dmask1,
869
+ mu,
870
+ rsigma,
871
+ gamma0,
872
+ gamma1,
873
+ dropout_p,
874
+ has_x1,
875
+ has_residual,
876
+ is_rms_norm,
877
+ )
878
+ # dresidualmat is None if not has_residual
879
+ return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
880
+
881
+
882
+ class DropoutAddLayerNormFn(torch.autograd.Function):
883
+ @staticmethod
884
+ def forward(
885
+ ctx,
886
+ x0,
887
+ residual,
888
+ gamma,
889
+ beta,
890
+ rowscale,
891
+ colscale,
892
+ dropout_p,
893
+ epsilon,
894
+ residual_in_fp32=False,
895
+ prenorm=False,
896
+ is_rms_norm=False,
897
+ return_dmask=False,
898
+ ):
899
+ x0 = maybe_align(x0.contiguous(), 16)
900
+ residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
901
+ gamma = maybe_align(gamma.contiguous(), 16)
902
+ beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
903
+ rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
904
+ colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
905
+ zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
906
+ x0,
907
+ residual,
908
+ gamma,
909
+ beta,
910
+ rowscale,
911
+ colscale,
912
+ dropout_p,
913
+ epsilon,
914
+ residual_in_fp32,
915
+ is_rms_norm,
916
+ )
917
+ # Only need to save x0 if we need to compute gradient wrt colscale
918
+ x0_saved = x0 if colscale is not None else None
919
+ ctx.save_for_backward(
920
+ xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
921
+ )
922
+ ctx.prenorm = prenorm
923
+ ctx.dropout_p = dropout_p
924
+ ctx.has_residual = residual is not None
925
+ ctx.is_rms_norm = is_rms_norm
926
+ ctx.has_beta = beta is not None
927
+ if not return_dmask:
928
+ return (
929
+ zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
930
+ )
931
+ else:
932
+ dmask = (
933
+ dmask.view(x0.shape)
934
+ if dropout_p > 0.0
935
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
936
+ )
937
+ ctx.mark_non_differentiable(dmask)
938
+ return (
939
+ (zmat.view(x0.shape), dmask)
940
+ if not prenorm
941
+ else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
942
+ )
943
+
944
+ @staticmethod
945
+ def backward(ctx, dz, *args):
946
+ # assert dz.is_contiguous()
947
+ dz = maybe_align(dz.contiguous(), 16) # this happens!
948
+ dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
949
+ x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
950
+ # x0 is None if colscale is None
951
+ dropout_p = ctx.dropout_p
952
+ has_residual = ctx.has_residual
953
+ dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
954
+ dz,
955
+ dx,
956
+ x,
957
+ x0,
958
+ dmask,
959
+ mu,
960
+ rsigma,
961
+ gamma,
962
+ rowscale,
963
+ colscale,
964
+ dropout_p,
965
+ has_residual,
966
+ ctx.is_rms_norm,
967
+ )
968
+ dx0 = dx0mat.view(x.shape)
969
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
970
+ dcolscale = rest[0] if colscale is not None else None
971
+ return (
972
+ dx0,
973
+ dresidual,
974
+ dgamma,
975
+ dbeta if ctx.has_beta else None,
976
+ None,
977
+ dcolscale,
978
+ None,
979
+ None,
980
+ None,
981
+ None,
982
+ None,
983
+ None,
984
+ )
985
+
986
+
987
+ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
988
+ @staticmethod
989
+ def forward(
990
+ ctx,
991
+ x0,
992
+ residual,
993
+ gamma,
994
+ beta,
995
+ colscale,
996
+ x0_subset,
997
+ out_subset,
998
+ dropout_p,
999
+ epsilon,
1000
+ rowscale_const,
1001
+ out_numrows,
1002
+ residual_in_fp32=False,
1003
+ prenorm=False,
1004
+ is_rms_norm=False,
1005
+ return_dmask=False,
1006
+ ):
1007
+ x0 = maybe_align(x0.contiguous(), 16)
1008
+ residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
1009
+ gamma = maybe_align(gamma.contiguous(), 16)
1010
+ beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
1011
+ colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
1012
+ zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
1013
+ x0,
1014
+ residual,
1015
+ gamma,
1016
+ beta,
1017
+ colscale,
1018
+ x0_subset,
1019
+ out_subset,
1020
+ dropout_p,
1021
+ epsilon,
1022
+ rowscale_const,
1023
+ out_numrows,
1024
+ residual_in_fp32,
1025
+ is_rms_norm,
1026
+ )
1027
+ # Only need to save x0 if we need to compute gradient wrt colscale
1028
+ x0_saved = x0 if colscale is not None else None
1029
+ x_shape = (-1, *x0.shape[1:])
1030
+ ctx.save_for_backward(
1031
+ xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
1032
+ )
1033
+ ctx.prenorm = prenorm
1034
+ ctx.dropout_p = dropout_p
1035
+ ctx.rowscale_const = rowscale_const
1036
+ ctx.x0_numrows = x0.shape[:-1].numel()
1037
+ ctx.has_residual = residual is not None
1038
+ ctx.is_rms_norm = is_rms_norm
1039
+ ctx.has_beta = beta is not None
1040
+ z_shape = (-1, *x0.shape[1:])
1041
+ if not return_dmask:
1042
+ return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
1043
+ else:
1044
+ z = zmat.view(z_shape)
1045
+ dmask = (
1046
+ dmask.view(x0.shape)
1047
+ if dropout_p > 0.0
1048
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
1049
+ )
1050
+ ctx.mark_non_differentiable(dmask)
1051
+ return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
1052
+
1053
+ @staticmethod
1054
+ def backward(ctx, dz, *args):
1055
+ # assert dz.is_contiguous()
1056
+ dz = maybe_align(dz.contiguous(), 16) # this happens!
1057
+ dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
1058
+ x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
1059
+ # x0 is None if colscale is None
1060
+ dropout_p = ctx.dropout_p
1061
+ has_residual = ctx.has_residual
1062
+ dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
1063
+ dz,
1064
+ dx,
1065
+ x,
1066
+ x0,
1067
+ dmask,
1068
+ mu,
1069
+ rsigma,
1070
+ gamma,
1071
+ colscale,
1072
+ x0_subset,
1073
+ out_subset,
1074
+ dropout_p,
1075
+ ctx.rowscale_const,
1076
+ ctx.x0_numrows,
1077
+ has_residual,
1078
+ ctx.is_rms_norm,
1079
+ )
1080
+ dx0 = dx0mat.view(-1, *x.shape[1:])
1081
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
1082
+ dcolscale = rest[0] if colscale is not None else None
1083
+ return (
1084
+ dx0,
1085
+ dresidual,
1086
+ dgamma,
1087
+ dbeta if ctx.has_beta else None,
1088
+ dcolscale,
1089
+ None,
1090
+ None,
1091
+ None,
1092
+ None,
1093
+ None,
1094
+ None,
1095
+ None,
1096
+ None,
1097
+ None,
1098
+ None,
1099
+ )
1100
+
1101
+
1102
+ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
1103
+ @staticmethod
1104
+ def forward(
1105
+ ctx,
1106
+ x0,
1107
+ x1,
1108
+ residual,
1109
+ gamma0,
1110
+ beta0,
1111
+ gamma1,
1112
+ beta1,
1113
+ dropout_p,
1114
+ epsilon,
1115
+ residual_in_fp32=False,
1116
+ prenorm=False,
1117
+ is_rms_norm=False,
1118
+ return_dmask=False,
1119
+ ):
1120
+ x0 = maybe_align(x0.contiguous(), 16)
1121
+ x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
1122
+ residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
1123
+ gamma0 = maybe_align(gamma0.contiguous(), 16)
1124
+ beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
1125
+ gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
1126
+ beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
1127
+ (
1128
+ z0mat,
1129
+ z1mat,
1130
+ xmat,
1131
+ dmask0,
1132
+ dmask1,
1133
+ mu,
1134
+ rsigma,
1135
+ ) = _dropout_add_layer_norm_parallel_residual_forward(
1136
+ x0,
1137
+ x1,
1138
+ residual,
1139
+ gamma0,
1140
+ beta0,
1141
+ gamma1,
1142
+ beta1,
1143
+ dropout_p,
1144
+ epsilon,
1145
+ residual_in_fp32,
1146
+ is_rms_norm,
1147
+ )
1148
+ ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
1149
+ ctx.prenorm = prenorm
1150
+ ctx.dropout_p = dropout_p
1151
+ ctx.has_x1 = x1 is not None
1152
+ ctx.has_residual = residual is not None
1153
+ ctx.is_rms_norm = is_rms_norm
1154
+ ctx.has_beta = beta0 is not None
1155
+ z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
1156
+ if not return_dmask:
1157
+ return z if not prenorm else (*z, xmat.view(x0.shape))
1158
+ else:
1159
+ dmask0 = (
1160
+ dmask0.view(x0.shape)
1161
+ if dropout_p > 0.0
1162
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
1163
+ )
1164
+ dmask1 = (
1165
+ dmask1.view(x0.shape)
1166
+ if dropout_p > 0.0 and x1 is not None
1167
+ else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
1168
+ )
1169
+ ctx.mark_non_differentiable(dmask0)
1170
+ ctx.mark_non_differentiable(dmask1)
1171
+ return (
1172
+ (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
1173
+ )
1174
+
1175
+ @staticmethod
1176
+ def backward(ctx, dz0, dz1, *args):
1177
+ dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
1178
+ dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
1179
+ dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
1180
+ x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
1181
+ dropout_p = ctx.dropout_p
1182
+ has_x1 = ctx.has_x1
1183
+ has_residual = ctx.has_residual
1184
+ (
1185
+ dx0mat,
1186
+ dx1mat,
1187
+ dresidualmat,
1188
+ dgamma0,
1189
+ dbeta0,
1190
+ dgamma1,
1191
+ dbeta1,
1192
+ ) = _dropout_add_layer_norm_parallel_residual_backward(
1193
+ dz0,
1194
+ dz1,
1195
+ dx,
1196
+ x,
1197
+ dmask0,
1198
+ dmask1,
1199
+ mu,
1200
+ rsigma,
1201
+ gamma0,
1202
+ gamma1,
1203
+ dropout_p,
1204
+ has_x1,
1205
+ has_residual,
1206
+ ctx.is_rms_norm,
1207
+ )
1208
+ dx0 = dx0mat.view(x.shape)
1209
+ dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
1210
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
1211
+ return (
1212
+ dx0,
1213
+ dx1,
1214
+ dresidual,
1215
+ dgamma0,
1216
+ dbeta0 if ctx.has_beta else None,
1217
+ dgamma1,
1218
+ dbeta1 if ctx.has_beta else None,
1219
+ None,
1220
+ None,
1221
+ None,
1222
+ None,
1223
+ None,
1224
+ None,
1225
+ )
1226
+
1227
+
1228
+ def layer_norm(x, weight, bias, epsilon):
1229
+ return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
1230
+
1231
+
1232
+ def dropout_add_layer_norm(
1233
+ x0,
1234
+ residual,
1235
+ weight,
1236
+ bias,
1237
+ dropout_p,
1238
+ epsilon,
1239
+ rowscale=None,
1240
+ layerscale=None,
1241
+ prenorm=False,
1242
+ residual_in_fp32=False,
1243
+ return_dropout_mask=False,
1244
+ ):
1245
+ """residual_in_fp32 only has an effect if residual is None.
1246
+ Otherwise residual dtype is residual.dtype.
1247
+ """
1248
+ return DropoutAddLayerNormFn.apply(
1249
+ x0,
1250
+ residual,
1251
+ weight,
1252
+ bias,
1253
+ rowscale,
1254
+ layerscale,
1255
+ dropout_p,
1256
+ epsilon,
1257
+ residual_in_fp32,
1258
+ prenorm,
1259
+ False,
1260
+ return_dropout_mask,
1261
+ )
1262
+
1263
+
1264
+ def dropout_add_layer_norm_subset(
1265
+ x0,
1266
+ residual,
1267
+ weight,
1268
+ bias,
1269
+ dropout_p,
1270
+ epsilon,
1271
+ layerscale=None,
1272
+ x0_subset=None,
1273
+ out_subset=None,
1274
+ rowscale_const=1.0,
1275
+ out_numrows=0,
1276
+ prenorm=False,
1277
+ residual_in_fp32=False,
1278
+ return_dropout_mask=False,
1279
+ ):
1280
+ """residual_in_fp32 only has an effect if residual is None.
1281
+ Otherwise residual dtype is residual.dtype.
1282
+ """
1283
+ return DropoutAddLayerNormSubsetFn.apply(
1284
+ x0,
1285
+ residual,
1286
+ weight,
1287
+ bias,
1288
+ layerscale,
1289
+ x0_subset,
1290
+ out_subset,
1291
+ dropout_p,
1292
+ epsilon,
1293
+ rowscale_const,
1294
+ out_numrows,
1295
+ residual_in_fp32,
1296
+ prenorm,
1297
+ False,
1298
+ return_dropout_mask,
1299
+ )
1300
+
1301
+
1302
+ def dropout_add_layer_norm_parallel_residual(
1303
+ x0,
1304
+ x1,
1305
+ residual,
1306
+ weight0,
1307
+ bias0,
1308
+ weight1,
1309
+ bias1,
1310
+ dropout_p,
1311
+ epsilon,
1312
+ prenorm=False,
1313
+ residual_in_fp32=False,
1314
+ return_dropout_mask=False,
1315
+ ):
1316
+ """residual_in_fp32 only has an effect if residual is None.
1317
+ Otherwise residual dtype is residual.dtype.
1318
+ """
1319
+ return DropoutAddLayerNormParallelResidualFn.apply(
1320
+ x0,
1321
+ x1,
1322
+ residual,
1323
+ weight0,
1324
+ bias0,
1325
+ weight1,
1326
+ bias1,
1327
+ dropout_p,
1328
+ epsilon,
1329
+ residual_in_fp32,
1330
+ prenorm,
1331
+ False,
1332
+ return_dropout_mask,
1333
+ )
1334
+
1335
+
1336
+ class DropoutAddLayerNorm(torch.nn.Module):
1337
+ def __init__(
1338
+ self,
1339
+ hidden_size,
1340
+ prenorm=False,
1341
+ p=0.0,
1342
+ eps=1e-5,
1343
+ residual_in_fp32=False,
1344
+ device=None,
1345
+ dtype=None,
1346
+ ):
1347
+ factory_kwargs = {"device": device, "dtype": dtype}
1348
+ super().__init__()
1349
+ self.prenorm = prenorm
1350
+ self.p = p
1351
+ self.eps = eps
1352
+ self.residual_in_fp32 = residual_in_fp32
1353
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1354
+ self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1355
+ self.reset_parameters()
1356
+
1357
+ def reset_parameters(self):
1358
+ init.ones_(self.weight)
1359
+ init.zeros_(self.bias)
1360
+
1361
+ def forward(self, x0, residual=None):
1362
+ return dropout_add_layer_norm(
1363
+ x0,
1364
+ residual,
1365
+ self.weight,
1366
+ self.bias,
1367
+ self.p if self.training else 0.0,
1368
+ self.eps,
1369
+ prenorm=self.prenorm,
1370
+ residual_in_fp32=self.residual_in_fp32,
1371
+ )
1372
+
1373
+ def rms_norm(x, weight, epsilon):
1374
+ return DropoutAddLayerNormFn.apply(
1375
+ x, None, weight, None, None, None, 0.0, epsilon, False, False, True
1376
+ )
1377
+ class FusedRMSNorm(torch.nn.Module):
1378
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5):
1379
+ super().__init__()
1380
+ self.eps = eps
1381
+ self.weight = torch.nn.Parameter(torch.ones(size))
1382
+ self.dim = dim
1383
+ self.reset_parameters()
1384
+
1385
+ def reset_parameters(self):
1386
+ init.ones_(self.weight)
1387
+
1388
+ def forward(self, x):
1389
+ return rms_norm(x, self.weight, self.eps)
1390
+
1391
+
1392
+ class RMSNorm(torch.nn.Module):
1393
+ """Root Mean Square Layer Normalization.
1394
+
1395
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
1396
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
1397
+ """
1398
+
1399
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
1400
+ super().__init__()
1401
+ self.weight = torch.nn.Parameter(torch.ones(size))
1402
+ self.eps = eps
1403
+ self.dim = dim
1404
+
1405
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1406
+ # NOTE: the original RMSNorm paper implementation is not equivalent
1407
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
1408
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
1409
+ return self.weight * x_normed
1410
+
1411
+ def reset_parameters(self):
1412
+ torch.nn.init.ones_(self.weight)
1413
+
1414
+
1415
+
1416
+
1417
+
1418
+
1419
+
1420
+ ######################################
1421
+ #rope_emb
1422
+ ######################################
1423
+
1424
+
1425
+
1426
+
1427
+
1428
+
1429
+
1430
+ # Copyright (c) 2023, Tri Dao.
1431
+
1432
+ import math
1433
+ from typing import Optional, Tuple
1434
+
1435
+ import rotary_emb
1436
+ import torch
1437
+ from einops import rearrange, repeat
1438
+
1439
+ class ApplyRotaryEmb(torch.autograd.Function):
1440
+ @staticmethod
1441
+ def forward(ctx, x, cos, sin, interleaved=False, inplace=False,future_token = 0):
1442
+ """
1443
+ x: (batch_size, seqlen, nheads, headdim)
1444
+ cos, sin: (seqlen, rotary_dim / 2)
1445
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
1446
+ of 1st half and 2nd half (GPT-NeoX style).
1447
+ rotary_dim must be <= headdim
1448
+ Apply rotary embedding to the first rotary_dim of x.
1449
+ """
1450
+ batch, seqlen, nheads, headdim = x.shape
1451
+ rotary_seqlen, rotary_dim = cos.shape
1452
+ rotary_dim *= 2
1453
+
1454
+
1455
+ # print('谁纸盘仲裁',x.shape,cos.shape)
1456
+ # 谁纸盘仲裁 torch.Size([224, 96, 12, 64]) torch.Size([1, 32])
1457
+ # 谁纸盘仲裁 2049 2048
1458
+ assert rotary_dim <= headdim
1459
+ # print(seqlen,rotary_seqlen)
1460
+ assert seqlen <= rotary_seqlen
1461
+ assert sin.shape == (rotary_seqlen, rotary_dim // 2)
1462
+ x_ro = x[..., :rotary_dim]
1463
+ x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
1464
+ out = torch.empty_like(x) if not inplace else x
1465
+ out_ro = out[..., :rotary_dim]
1466
+ if inplace:
1467
+ o1, o2 = x1, x2
1468
+ else:
1469
+ o1, o2 = (
1470
+ out_ro.chunk(2, dim=-1)
1471
+ if not interleaved
1472
+ else (out_ro[..., ::2], out_ro[..., 1::2])
1473
+ )
1474
+ rotary_emb.apply_rotary(
1475
+ x1,
1476
+ x2,
1477
+ rearrange(cos[:seqlen], "s d -> s 1 d"),
1478
+ rearrange(sin[:seqlen], "s d -> s 1 d"),
1479
+ o1,
1480
+ o2,
1481
+ False,
1482
+ )
1483
+ if not inplace and rotary_dim < headdim:
1484
+ out[..., rotary_dim:].copy_(x[..., rotary_dim:])
1485
+ ctx.save_for_backward(cos, sin)
1486
+ ctx.interleaved = interleaved
1487
+ ctx.inplace = inplace
1488
+ return out if not inplace else x
1489
+
1490
+ @staticmethod
1491
+ def backward(ctx, do):
1492
+ cos, sin = ctx.saved_tensors
1493
+ _, seqlen, _, headdim = do.shape
1494
+ rotary_dim = cos.shape[-1]
1495
+ rotary_dim *= 2
1496
+ inplace = ctx.inplace
1497
+ do_ro = do[..., :rotary_dim]
1498
+ do1, do2 = (
1499
+ do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
1500
+ )
1501
+ dx = torch.empty_like(do) if not inplace else do
1502
+ if inplace:
1503
+ dx1, dx2 = do1, do2
1504
+ else:
1505
+ dx_ro = dx[..., :rotary_dim]
1506
+ dx1, dx2 = (
1507
+ dx_ro.chunk(2, dim=-1)
1508
+ if not ctx.interleaved
1509
+ else (dx_ro[..., ::2], dx_ro[..., 1::2])
1510
+ )
1511
+ rotary_emb.apply_rotary(
1512
+ do1,
1513
+ do2,
1514
+ rearrange(cos[:seqlen], "s d -> s 1 d"),
1515
+ rearrange(sin[:seqlen], "s d -> s 1 d"),
1516
+ dx1,
1517
+ dx2,
1518
+ True,
1519
+ )
1520
+ if not inplace and rotary_dim < headdim:
1521
+ dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
1522
+ return dx, None, None, None, None
1523
+
1524
+
1525
+ apply_rotary_emb_func = ApplyRotaryEmb.apply
1526
+
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2df65a7dc7730709b3e1b943fec891d72ef003aa42d0761370f5cd7aa7bf440
3
+ size 14646052
model_config.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class YingLongConfig(PretrainedConfig):
6
+ model_type = "yinglong"
7
+ # keys_to_ignore_at_inference = ["past_key_values"]
8
+
9
+ def __init__(
10
+ self,
11
+ # input_token_len: int = 1,
12
+ # hidden_size: int = 1024,
13
+ # intermediate_size: int = 2048,
14
+ # output_token_lens: List[int] = [1, 8, 32, 64],
15
+ # num_hidden_layers: int = 8,
16
+ # num_attention_heads: int = 8,
17
+ # hidden_act: str = "silu",
18
+ # use_cache: bool = True,
19
+ # rope_theta: int = 10000,
20
+ # attention_dropout: float = 0.0,
21
+ # initializer_range: float = 0.02,
22
+ # max_position_embeddings: int = 10000,
23
+ #####
24
+ bias = False,
25
+ condense_ratio = 1,
26
+ haar_trans = True,
27
+ haar_trans_inv = True,
28
+ haar_trans_norm = 'backward',
29
+ half_diff = False,
30
+ intermediate_size = 1024,
31
+ n_embd = 256,
32
+ n_head = 16,
33
+ n_layer = 6,
34
+ n_query_groups = 4,
35
+ norm_eps = 1e-5,
36
+ org = 'Alibaba',
37
+ patch_size = 32,
38
+ rope_base = 10000,
39
+ rotary_percentage = 1.0,
40
+ shared_attention_norm = False,
41
+ unet = True,
42
+ _mlp_class = "LLaMAMLP",
43
+ _norm_class="FusedRMSNorm",
44
+ *args,
45
+ **kwargs,
46
+ ):
47
+
48
+ # self.input_token_len = input_token_len
49
+ # self.hidden_size = hidden_size
50
+ # self.intermediate_size = intermediate_size
51
+ # self.num_hidden_layers = num_hidden_layers
52
+ # self.num_attention_heads = num_attention_heads
53
+ # self.hidden_act = hidden_act
54
+ # self.output_token_lens = output_token_lens;
55
+ # self.use_cache = use_cache
56
+ # self.rope_theta = rope_theta
57
+ # self.attention_dropout = attention_dropout
58
+ # self.initializer_range = initializer_range
59
+ # self.max_position_embeddings = max_position_embeddings
60
+ self.org = 'Alibaba'
61
+ self.patch_size = patch_size
62
+ self.unet = unet
63
+
64
+ self.n_embd = n_embd
65
+ self.intermediate_size = intermediate_size
66
+ self.n_head = n_head
67
+ self.n_layer = n_layer
68
+ self.n_query_groups = n_query_groups
69
+ self.norm_eps = norm_eps
70
+ self.bias = bias
71
+ self.shared_attention_norm = shared_attention_norm
72
+
73
+ self.condense_ratio = condense_ratio
74
+ self.rope_base = rope_base
75
+ self.rotary_percentage = rotary_percentage
76
+
77
+ self.haar_trans = haar_trans
78
+ self.haar_trans_inv = haar_trans_inv
79
+ self.haar_trans_norm = haar_trans_norm
80
+ self.half_diff = half_diff
81
+
82
+ self._norm_class = _norm_class
83
+
84
+ self._mlp_class = _mlp_class
85
+
86
+ assert self.n_embd % self.n_head == 0
87
+ assert self.n_head % self.n_query_groups == 0
88
+
89
+ self.head_size = self.n_embd // self.n_head
90
+ self.rope_n_elem = int(self.rotary_percentage * self.head_size)
91
+ self.rope_condense_ratio = self.condense_ratio
92
+
93
+
94
+
95
+
96
+
97
+
98
+ super().__init__(
99
+ **kwargs,
100
+ )
未命名.ipynb ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 8,
6
+ "id": "94220666-124e-461d-b171-cf1f86056555",
7
+ "metadata": {
8
+ "ExecutionIndicator": {
9
+ "show": false
10
+ },
11
+ "execution": {
12
+ "iopub.execute_input": "2025-05-17T02:39:43.905633Z",
13
+ "iopub.status.busy": "2025-05-17T02:39:43.905106Z",
14
+ "iopub.status.idle": "2025-05-17T02:39:43.967707Z",
15
+ "shell.execute_reply": "2025-05-17T02:39:43.967125Z",
16
+ "shell.execute_reply.started": "2025-05-17T02:39:43.905610Z"
17
+ },
18
+ "tags": []
19
+ },
20
+ "outputs": [],
21
+ "source": [
22
+ "import torch\n",
23
+ "from transformers import AutoModelForCausalLM\n",
24
+ "model = AutoModelForCausalLM.from_pretrained('./', trust_remote_code=True,torch_dtype=torch.bfloat16).cuda()\n"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 1,
30
+ "id": "5feec891-588c-4150-a7c5-6d82e7f62c20",
31
+ "metadata": {
32
+ "ExecutionIndicator": {
33
+ "show": false
34
+ },
35
+ "execution": {
36
+ "iopub.execute_input": "2025-05-17T10:07:59.689508Z",
37
+ "iopub.status.busy": "2025-05-17T10:07:59.689021Z",
38
+ "iopub.status.idle": "2025-05-17T10:07:59.806489Z",
39
+ "shell.execute_reply": "2025-05-17T10:07:59.805876Z",
40
+ "shell.execute_reply.started": "2025-05-17T10:07:59.689488Z"
41
+ },
42
+ "tags": []
43
+ },
44
+ "outputs": [
45
+ {
46
+ "name": "stdout",
47
+ "output_type": "stream",
48
+ "text": [
49
+ "/usr/bin/sh: 1: fabric: not found\n"
50
+ ]
51
+ }
52
+ ],
53
+ "source": [
54
+ "batch_size, lookback_length = 1, 2880\n",
55
+ "seqs = torch.randn(batch_size, lookback_length).bfloat16().cuda()\n",
56
+ "prediction_length = 96\n",
57
+ "\n"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 45,
63
+ "id": "f2b2daff-5bdb-4d07-a9f6-e0ebf538dd58",
64
+ "metadata": {
65
+ "ExecutionIndicator": {
66
+ "show": false
67
+ },
68
+ "execution": {
69
+ "iopub.execute_input": "2025-05-16T21:34:39.066356Z",
70
+ "iopub.status.busy": "2025-05-16T21:34:39.065849Z",
71
+ "iopub.status.idle": "2025-05-16T21:34:39.068752Z",
72
+ "shell.execute_reply": "2025-05-16T21:34:39.068269Z",
73
+ "shell.execute_reply.started": "2025-05-16T21:34:39.066334Z"
74
+ },
75
+ "tags": []
76
+ },
77
+ "outputs": [],
78
+ "source": [
79
+ "output = model.generate(seqs, future_token=prediction_length)\n",
80
+ "output.shape"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 9,
86
+ "id": "d02706bd-9ec2-4a61-bbaa-b6c478025e42",
87
+ "metadata": {
88
+ "ExecutionIndicator": {
89
+ "show": false
90
+ },
91
+ "execution": {
92
+ "iopub.execute_input": "2025-05-17T02:39:54.043448Z",
93
+ "iopub.status.busy": "2025-05-17T02:39:54.043193Z",
94
+ "iopub.status.idle": "2025-05-17T02:39:54.046722Z",
95
+ "shell.execute_reply": "2025-05-17T02:39:54.046261Z",
96
+ "shell.execute_reply.started": "2025-05-17T02:39:54.043431Z"
97
+ },
98
+ "tags": []
99
+ },
100
+ "outputs": [],
101
+ "source": []
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": 10,
106
+ "id": "48f905ca-37ae-4382-9da5-8fa4cff50531",
107
+ "metadata": {
108
+ "ExecutionIndicator": {
109
+ "show": false
110
+ },
111
+ "execution": {
112
+ "iopub.execute_input": "2025-05-17T02:39:54.717760Z",
113
+ "iopub.status.busy": "2025-05-17T02:39:54.717228Z",
114
+ "iopub.status.idle": "2025-05-17T02:39:56.175917Z",
115
+ "shell.execute_reply": "2025-05-17T02:39:56.175406Z",
116
+ "shell.execute_reply.started": "2025-05-17T02:39:54.717720Z"
117
+ },
118
+ "tags": []
119
+ },
120
+ "outputs": [],
121
+ "source": []
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 11,
126
+ "id": "b8eca6b8-3ed1-4d87-b8c7-d84e6994acce",
127
+ "metadata": {
128
+ "ExecutionIndicator": {
129
+ "show": false
130
+ },
131
+ "execution": {
132
+ "iopub.execute_input": "2025-05-17T02:39:57.844387Z",
133
+ "iopub.status.busy": "2025-05-17T02:39:57.843876Z",
134
+ "iopub.status.idle": "2025-05-17T02:39:57.849493Z",
135
+ "shell.execute_reply": "2025-05-17T02:39:57.848990Z",
136
+ "shell.execute_reply.started": "2025-05-17T02:39:57.844364Z"
137
+ },
138
+ "tags": []
139
+ },
140
+ "outputs": [
141
+ {
142
+ "data": {
143
+ "text/plain": [
144
+ "torch.Size([1, 96, 99])"
145
+ ]
146
+ },
147
+ "execution_count": 11,
148
+ "metadata": {},
149
+ "output_type": "execute_result"
150
+ }
151
+ ],
152
+ "source": []
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "id": "66cf66aa-a991-4a6e-891d-fac9227c383a",
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": []
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "id": "3a908d32-4abc-49c6-9ff7-9dfc01379eba",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "# !fabric run \\\n",
170
+ "# --accelerator=cuda \\\n",
171
+ "# --devices=4 \\\n",
172
+ "# --num-nodes=1 \\\n",
173
+ "# --main-port=1145 \\\n",
174
+ "# Long_Term_Forecasting/main.py \\\n",
175
+ "# --batch_size 32 \\\n",
176
+ "# --seq_len 4096 \\\n",
177
+ "# --future_token 4096 \\\n",
178
+ "# --model_name ./\\\n",
179
+ "# --num_gpus 4 \\\n",
180
+ "# -t ETTh1 \\\n",
181
+ "# -t ETTh2 \\\n",
182
+ "# -t ETTm1 \\\n",
183
+ "# -t ETTm2 \\\n",
184
+ "# -t Weather "
185
+ ]
186
+ }
187
+ ],
188
+ "metadata": {
189
+ "kernelspec": {
190
+ "display_name": "Python 3 (ipykernel)",
191
+ "language": "python",
192
+ "name": "python3"
193
+ },
194
+ "language_info": {
195
+ "codemirror_mode": {
196
+ "name": "ipython",
197
+ "version": 3
198
+ },
199
+ "file_extension": ".py",
200
+ "mimetype": "text/x-python",
201
+ "name": "python",
202
+ "nbconvert_exporter": "python",
203
+ "pygments_lexer": "ipython3",
204
+ "version": "3.11.11"
205
+ }
206
+ },
207
+ "nbformat": 4,
208
+ "nbformat_minor": 5
209
+ }