Asheng98 commited on
Commit
40f71f0
·
verified ·
1 Parent(s): 1eb1b53

Upload DeepTime.py

Browse files
Files changed (1) hide show
  1. DeepTime.py +147 -0
DeepTime.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer
6
+ from layers.SelfAttention_Family import FullAttention, AttentionLayer
7
+ from layers.Embed import PatchEmbedding
8
+ from collections import Counter
9
+ from layers.SharedWavMoE import WavMoE
10
+ from layers.RevIN import RevIN
11
+ import torch.fft
12
+ from layers.Embed import DataEmbedding
13
+
14
+ class FlattenHead(nn.Module):
15
+ def __init__(self, n_vars, nf, target_window, head_dropout=0):
16
+ super().__init__()
17
+ self.n_vars = n_vars
18
+ # self.flatten = nn.Flatten(start_dim=-2)
19
+ self.linear = nn.Linear(nf, target_window)
20
+ self.dropout = nn.Dropout(head_dropout)
21
+
22
+ def forward(self, x): # x: [bs x nvars x d_model x patch_num]
23
+ # x = self.flatten(x)
24
+ # print(self.linear,x.shape)
25
+ x = self.linear(x)
26
+ x = self.dropout(x)
27
+ return x
28
+
29
+
30
+
31
+ class Model(nn.Module):
32
+ """
33
+ """
34
+
35
+ def __init__(self, configs):
36
+ super(Model, self).__init__()
37
+ self.task_name = configs.task_name
38
+ self.seq_len = configs.seq_len
39
+ self.patch_len = configs.input_token_len
40
+ self.stride = self.patch_len
41
+ self.pred_len = configs.test_pred_len
42
+ self.test_seq_len = configs.test_seq_len
43
+ # embedding configs
44
+ self.output_attention = configs.output_attention
45
+ self.padding = configs.padding
46
+ # MoE设置
47
+ self.hidden_size = configs.hidden_size
48
+ self.intermediate_size = configs.intermediate_size
49
+ self.top_k = configs.top_k
50
+ self.shared_experts = configs.shared_experts
51
+ self.wavelet = configs.wavelet
52
+ self.level = configs.shared_experts
53
+ self.proj_wight = configs.proj_wight
54
+ # Embedding
55
+ self.patch_embedding = PatchEmbedding(
56
+ configs.d_model, self.patch_len, self.stride, self.padding, configs.dropout)
57
+
58
+ self.data_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
59
+ configs.dropout)
60
+
61
+ self.revin_layer = RevIN(configs.enc_in)
62
+ self.encoder_patch = Encoder(
63
+ [
64
+ EncoderLayer(
65
+ AttentionLayer(
66
+ FullAttention(False, configs.factor,
67
+ attention_dropout=configs.dropout,
68
+ output_attention=configs.output_attention),
69
+ configs.d_model, configs.n_heads),
70
+ configs.d_model,
71
+ configs.d_ff,
72
+ dropout=configs.dropout,
73
+ activation=configs.activation
74
+ ) for l in range(configs.e_layers)
75
+ ],
76
+ norm_layer=torch.nn.LayerNorm(configs.d_model)
77
+ )
78
+ self.encoder_time = Encoder(
79
+ [
80
+ EncoderLayer(
81
+ AttentionLayer(
82
+ FullAttention(False, configs.factor,
83
+ attention_dropout=configs.dropout,
84
+ output_attention=configs.output_attention),
85
+ configs.d_model, configs.n_heads),
86
+ configs.d_model,
87
+ configs.d_ff,
88
+ dropout=configs.dropout,
89
+ activation=configs.activation
90
+ ) for l in range(configs.e_layers)
91
+ ],
92
+ norm_layer=torch.nn.LayerNorm(configs.d_model)
93
+ )
94
+ self.head_nf = configs.d_model * \
95
+ int((configs.seq_len - self.patch_len) / self.stride + 1)
96
+ self.projection = nn.Linear(self.head_nf, int(configs.seq_len*self.proj_wight), bias=True)
97
+
98
+ self.data_projection = nn.Linear(configs.d_model, configs.enc_in, bias=True)
99
+ self.wavmoe = WavMoE(configs)
100
+ self.head = FlattenHead(configs.enc_in, nf= int(configs.seq_len*self.proj_wight), target_window= self.seq_len,
101
+ head_dropout=configs.dropout)
102
+ self.gelu = nn.GELU()
103
+
104
+
105
+
106
+ def main(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
107
+ # 归一化并且嵌入
108
+ x_revin = self.revin_layer(x_enc, 'norm').permute(0, 2, 1)
109
+ # print("x_revin.shape:",x_revin.shape)
110
+ B, D, S = x_revin.shape
111
+
112
+ # 进入注意力机制
113
+ x_inver=self.data_embedding(x_revin.permute(0, 2, 1), x_mark_enc)
114
+ nav_out, attn_w = self.encoder_time(x_inver, attn_mask=None)
115
+ #print("nav_out.shape:", nav_out.shape,self.data_projection)
116
+ nav_out = self.data_projection(nav_out)
117
+ #print("nav_out.shape:", nav_out.shape)
118
+
119
+ #patch embedding进入多头FullAttention
120
+
121
+ # u: [bs * nvars x patch_num x d_model]
122
+ x_pe, n_vars = self.patch_embedding(x_revin+nav_out.permute(0, 2, 1))
123
+ #print("x_pe.shape:",x_pe.shape, n_vars)
124
+ enc_out, attn = self.encoder_patch(x_pe)
125
+ dec_out = enc_out.reshape(B, D, -1)
126
+ #print("dec_out.shape:",dec_out.shape, self.head_nf)
127
+ act_val = self.projection(dec_out)
128
+ #print("act_val:", act_val.shape)
129
+
130
+ # 专家系统
131
+ moe_out, router_logits = self.wavmoe(act_val + nav_out.permute(0, 2, 1))
132
+ #print("moe_out", moe_out.shape)
133
+ head_out = self.head(moe_out)
134
+
135
+ # 逆归一化输出
136
+ x_out = self.revin_layer(head_out.permute(0, 2, 1), 'denorm')
137
+ #print(x_out.shape)
138
+ return x_out
139
+ def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
140
+ if self.task_name == 'long_term_forecast' or self.task_name == 'forecast':
141
+ dec_out = self.main(x_enc, x_mark_enc, x_dec, x_mark_dec)
142
+ return dec_out[:, -self.test_seq_len :, :] # [B, L, D]
143
+ if self.task_name == 'anomaly_detection':
144
+ dec_out = self.main(x_enc, x_mark_enc, x_dec, x_mark_dec)
145
+ return dec_out # [B, L, D]
146
+ return None
147
+