wangleiofficial commited on
Commit
cad4384
·
verified ·
1 Parent(s): 9e9c22d
Files changed (5) hide show
  1. config.json +24 -0
  2. dnaflash.py +414 -0
  3. special_tokens_map.json +37 -0
  4. tokenizer.json +0 -0
  5. tokenizer_config.json +52 -0
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "FLASHTransformerForPretrained"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "dnaflash.FLASHTransformerConfig",
7
+ "AutoModel": "dnaflash.FLASHTransformerForPretrained"
8
+ },
9
+ "attn_dropout": 0.0,
10
+ "causal": false,
11
+ "expansion_factor": 2.0,
12
+ "group_size": 256,
13
+ "hidden_size": 1024,
14
+ "laplace_attn_fn": false,
15
+ "model_type": "flash_transformer",
16
+ "norm_type": "scalenorm",
17
+ "num_layers": 36,
18
+ "query_key_dim": 128,
19
+ "reduce_group_non_causal_attn": true,
20
+ "shift_tokens": true,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.39.3",
23
+ "vocab_size": 4096
24
+ }
dnaflash.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn, einsum
5
+
6
+ from einops import rearrange
7
+ from rotary_embedding_torch import RotaryEmbedding
8
+
9
+ from transformers import PreTrainedModel, PretrainedConfig
10
+ from transformers.modeling_outputs import MaskedLMOutput
11
+
12
+ # helper functions
13
+
14
+ def exists(val):
15
+ return val is not None
16
+
17
+ def default(val, d):
18
+ return val if exists(val) else d
19
+
20
+ def padding_to_multiple_of(n, mult):
21
+ remainder = n % mult
22
+ if remainder == 0:
23
+ return 0
24
+ return mult - remainder
25
+
26
+ # scalenorm
27
+
28
+ class ScaleNorm(nn.Module):
29
+ def __init__(self, dim, eps = 1e-5):
30
+ super().__init__()
31
+ self.scale = dim ** -0.5
32
+ self.eps = eps
33
+ self.g = nn.Parameter(torch.ones(1))
34
+
35
+ def forward(self, x):
36
+ norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
37
+ return x / norm.clamp(min = self.eps) * self.g
38
+
39
+ # absolute positional encodings
40
+
41
+ class ScaledSinuEmbedding(nn.Module):
42
+ def __init__(self, dim):
43
+ super().__init__()
44
+ self.scale = nn.Parameter(torch.ones(1,))
45
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
46
+ self.register_buffer('inv_freq', inv_freq)
47
+
48
+ def forward(self, x):
49
+ n, device = x.shape[1], x.device
50
+ t = torch.arange(n, device = device).type_as(self.inv_freq)
51
+ sinu = einsum('i , j -> i j', t, self.inv_freq)
52
+ emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
53
+ return emb * self.scale
54
+
55
+ # T5 relative positional bias
56
+
57
+ class T5RelativePositionBias(nn.Module):
58
+ def __init__(
59
+ self,
60
+ scale,
61
+ causal = False,
62
+ num_buckets = 32,
63
+ max_distance = 128
64
+ ):
65
+ super().__init__()
66
+ self.scale = scale
67
+ self.causal = causal
68
+ self.num_buckets = num_buckets
69
+ self.max_distance = max_distance
70
+ self.relative_attention_bias = nn.Embedding(num_buckets, 1)
71
+
72
+ @staticmethod
73
+ def _relative_position_bucket(
74
+ relative_position,
75
+ causal = True,
76
+ num_buckets = 32,
77
+ max_distance = 128
78
+ ):
79
+ ret = 0
80
+ n = -relative_position
81
+ if not causal:
82
+ num_buckets //= 2
83
+ ret += (n < 0).long() * num_buckets
84
+ n = torch.abs(n)
85
+ else:
86
+ n = torch.max(n, torch.zeros_like(n))
87
+
88
+ max_exact = num_buckets // 2
89
+ is_small = n < max_exact
90
+
91
+ val_if_large = max_exact + (
92
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
93
+ ).long()
94
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
95
+
96
+ ret += torch.where(is_small, n, val_if_large)
97
+ return ret
98
+
99
+ def forward(self, x):
100
+ i, j, device = *x.shape[-2:], x.device
101
+ q_pos = torch.arange(i, dtype = torch.long, device = device)
102
+ k_pos = torch.arange(j, dtype = torch.long, device = device)
103
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
104
+ rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
105
+ values = self.relative_attention_bias(rp_bucket)
106
+ bias = rearrange(values, 'i j 1 -> i j')
107
+ return bias * self.scale
108
+
109
+ # class
110
+
111
+ class OffsetScale(nn.Module):
112
+ def __init__(self, dim, heads = 1):
113
+ super().__init__()
114
+ self.weight = nn.Parameter(torch.ones(heads, dim))
115
+ self.bias = nn.Parameter(torch.zeros(heads, dim))
116
+ nn.init.normal_(self.weight, std = 0.02)
117
+
118
+ def forward(self, x):
119
+ out = einsum('... d, h d -> ... h d', x, self.weight) + self.bias
120
+ return out.unbind(dim = -2)
121
+
122
+ # activation functions
123
+
124
+ class ReLUSquared(nn.Module):
125
+ def forward(self, x):
126
+ return F.relu(x) ** 2
127
+
128
+ class LaplacianAttnFn(nn.Module):
129
+ """ https://arxiv.org/abs/2209.10655 claims this is more stable than Relu squared """
130
+
131
+ def forward(self, x):
132
+ mu = math.sqrt(0.5)
133
+ std = math.sqrt((4 * math.pi) ** -1)
134
+ return (1 + torch.special.erf((x - mu) / (std * math.sqrt(2)))) * 0.5
135
+
136
+
137
+ class FLASH(nn.Module):
138
+ def __init__(
139
+ self,
140
+ *,
141
+ dim,
142
+ group_size = 256,
143
+ query_key_dim = 128,
144
+ expansion_factor = 2.,
145
+ causal = False,
146
+ dropout = 0.,
147
+ rotary_pos_emb = None,
148
+ norm_klass = nn.LayerNorm,
149
+ shift_tokens = False,
150
+ laplace_attn_fn = False,
151
+ reduce_group_non_causal_attn = True
152
+ ):
153
+ super().__init__()
154
+ hidden_dim = int(dim * expansion_factor)
155
+ self.group_size = group_size
156
+ self.causal = causal
157
+ self.shift_tokens = shift_tokens
158
+
159
+ self.attn_fn = ReLUSquared() if not laplace_attn_fn else LaplacianAttnFn()
160
+
161
+ # positional embeddings
162
+
163
+ self.rotary_pos_emb = rotary_pos_emb
164
+ self.rel_pos_bias = T5RelativePositionBias(query_key_dim ** 0.5, causal = causal)
165
+
166
+ # norm
167
+
168
+ self.norm = norm_klass(dim)
169
+ self.dropout = nn.Dropout(dropout)
170
+
171
+ # whether to reduce groups in non causal linear attention
172
+
173
+ self.reduce_group_non_causal_attn = reduce_group_non_causal_attn
174
+
175
+ # projections
176
+
177
+ self.to_hidden = nn.Sequential(
178
+ nn.Linear(dim, hidden_dim * 2),
179
+ nn.SiLU()
180
+ )
181
+
182
+ self.to_qk = nn.Sequential(
183
+ nn.Linear(dim, query_key_dim),
184
+ nn.SiLU()
185
+ )
186
+
187
+ self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
188
+ self.to_out = nn.Linear(hidden_dim, dim)
189
+
190
+ def forward(
191
+ self,
192
+ x,
193
+ *,
194
+ mask = None
195
+ ):
196
+ """
197
+ b - batch
198
+ n - sequence length (within groups)
199
+ g - group dimension
200
+ d - feature dimension (keys)
201
+ e - feature dimension (values)
202
+ i - sequence dimension (source)
203
+ j - sequence dimension (target)
204
+ """
205
+
206
+ b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
207
+
208
+ # prenorm
209
+
210
+ normed_x = self.norm(x)
211
+
212
+ # do token shift - a great, costless trick from an independent AI researcher in Shenzhen
213
+
214
+ if self.shift_tokens:
215
+ x_shift, x_pass = normed_x.chunk(2, dim = -1)
216
+ x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
217
+ normed_x = torch.cat((x_shift, x_pass), dim = -1)
218
+
219
+ # initial projections
220
+
221
+ v, gate = self.to_hidden(normed_x).chunk(2, dim = -1)
222
+ qk = self.to_qk(normed_x)
223
+
224
+ # offset and scale
225
+
226
+ quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
227
+
228
+ # mask out linear attention keys
229
+
230
+ if exists(mask):
231
+ lin_mask = rearrange(mask, '... -> ... 1')
232
+ lin_k = lin_k.masked_fill(~lin_mask.bool(), 0.)
233
+
234
+ # rotate queries and keys
235
+
236
+ if exists(self.rotary_pos_emb):
237
+ quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
238
+
239
+ # padding for groups
240
+
241
+ padding = padding_to_multiple_of(n, g)
242
+
243
+ if padding > 0:
244
+ quad_q, quad_k, lin_q, lin_k, v = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v))
245
+
246
+ mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
247
+ mask = F.pad(mask, (0, padding), value = False)
248
+
249
+ # group along sequence
250
+
251
+ quad_q, quad_k, lin_q, lin_k, v = map(lambda t: rearrange(t, 'b (n g) d -> b n g d', g = self.group_size), (quad_q, quad_k, lin_q, lin_k, v))
252
+
253
+ if exists(mask):
254
+ mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
255
+
256
+ # calculate quadratic attention output
257
+
258
+ sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
259
+
260
+ sim = sim + self.rel_pos_bias(sim)
261
+
262
+ attn = self.attn_fn(sim)
263
+ attn = self.dropout(attn)
264
+
265
+ if exists(mask):
266
+ attn = attn.masked_fill(~mask.bool(), 0.)
267
+
268
+ if self.causal:
269
+ causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
270
+ attn = attn.masked_fill(causal_mask.bool(), 0.)
271
+
272
+ quad_out = einsum('... i j, ... j d -> ... i d', attn, v)
273
+
274
+ # calculate linear attention output
275
+
276
+ if self.causal:
277
+ lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
278
+
279
+ # exclusive cumulative sum along group dimension
280
+
281
+ lin_kv = lin_kv.cumsum(dim = 1)
282
+ lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
283
+
284
+ lin_out = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
285
+ else:
286
+ context_einsum_eq = 'b d e' if self.reduce_group_non_causal_attn else 'b g d e'
287
+ lin_kv = einsum(f'b g n d, b g n e -> {context_einsum_eq}', lin_k, v) / n
288
+ lin_out = einsum(f'b g n d, {context_einsum_eq} -> b g n e', lin_q, lin_kv)
289
+
290
+ # fold back groups into full sequence, and excise out padding
291
+
292
+ quad_attn_out, lin_attn_out = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out, lin_out))
293
+
294
+ # gate
295
+
296
+ out = gate * (quad_attn_out + lin_attn_out)
297
+
298
+ # projection out and residual
299
+
300
+ return self.to_out(out) + x
301
+
302
+ # FLASH Transformer
303
+
304
+ class FLASHTransformer(nn.Module):
305
+ def __init__(
306
+ self,
307
+ *,
308
+ dim,
309
+ num_tokens,
310
+ depth,
311
+ group_size = 256,
312
+ query_key_dim = 128,
313
+ expansion_factor = 2.,
314
+ causal = False,
315
+ attn_dropout = 0.,
316
+ norm_type = 'scalenorm',
317
+ shift_tokens = True,
318
+ laplace_attn_fn = False,
319
+ reduce_group_non_causal_attn = True
320
+ ):
321
+ super().__init__()
322
+ assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
323
+
324
+ if norm_type == 'scalenorm':
325
+ norm_klass = ScaleNorm
326
+ elif norm_type == 'layernorm':
327
+ norm_klass = nn.LayerNorm
328
+
329
+ self.token_emb = nn.Embedding(num_tokens, dim)
330
+ self.abs_pos_emb = ScaledSinuEmbedding(dim)
331
+ self.group_size = group_size
332
+
333
+ rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
334
+ # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
335
+
336
+ self.layers = nn.ModuleList([FLASH(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens, reduce_group_non_causal_attn = reduce_group_non_causal_attn, laplace_attn_fn = laplace_attn_fn) for _ in range(depth)])
337
+
338
+ self.to_logits = nn.Sequential(
339
+ nn.LayerNorm(dim),
340
+ nn.Linear(dim, num_tokens)
341
+ )
342
+
343
+ def forward(
344
+ self,
345
+ x,
346
+ *,
347
+ mask = None
348
+ ):
349
+ x = self.token_emb(x)
350
+ x = self.abs_pos_emb(x) + x
351
+
352
+ for flash in self.layers:
353
+ x = flash(x, mask = mask)
354
+
355
+ return self.to_logits(x), x
356
+
357
+ class FLASHTransformerConfig(PretrainedConfig):
358
+ model_type = "flash_transformer"
359
+
360
+ def __init__(
361
+ self,
362
+ hidden_size=512,
363
+ vocab_size=4096,
364
+ num_layers=12,
365
+ group_size=256,
366
+ query_key_dim=128,
367
+ expansion_factor=2.0,
368
+ causal=False,
369
+ attn_dropout=0.1,
370
+ norm_type="scalenorm",
371
+ shift_tokens=True,
372
+ laplace_attn_fn=False,
373
+ reduce_group_non_causal_attn=True,
374
+ **kwargs
375
+ ):
376
+ super().__init__(**kwargs)
377
+ self.hidden_size = hidden_size
378
+ self.vocab_size = vocab_size
379
+ self.num_layers = num_layers
380
+ self.group_size = group_size
381
+ self.query_key_dim = query_key_dim
382
+ self.expansion_factor = expansion_factor
383
+ self.causal = causal
384
+ self.attn_dropout = attn_dropout
385
+ self.norm_type = norm_type
386
+ self.shift_tokens = shift_tokens
387
+ self.laplace_attn_fn = laplace_attn_fn
388
+ self.reduce_group_non_causal_attn = reduce_group_non_causal_attn
389
+
390
+
391
+ class FLASHTransformerForPretrained(PreTrainedModel):
392
+ config_class = FLASHTransformerConfig
393
+ base_model_prefix = "flash_transformer"
394
+ def __init__(self, config):
395
+ super().__init__(config)
396
+ self.model = FLASHTransformer(
397
+ dim=config.hidden_size,
398
+ num_tokens=config.vocab_size,
399
+ depth=config.num_layers,
400
+ group_size=config.group_size,
401
+ query_key_dim=config.query_key_dim,
402
+ expansion_factor=config.expansion_factor,
403
+ causal=config.causal,
404
+ attn_dropout=config.attn_dropout,
405
+ norm_type=config.norm_type,
406
+ shift_tokens=config.shift_tokens,
407
+ laplace_attn_fn=config.laplace_attn_fn,
408
+ reduce_group_non_causal_attn=config.reduce_group_non_causal_attn
409
+ )
410
+
411
+ def forward(self, input_ids, mask=None):
412
+ logits, x = self.model(input_ids, mask=mask)
413
+ return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None)
414
+
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[UNK]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[CLS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[SEP]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[PAD]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "mask_token": "[MASK]",
47
+ "model_max_length": 1000000000000000019884624838656,
48
+ "pad_token": "[PAD]",
49
+ "sep_token": "[SEP]",
50
+ "tokenizer_class": "PreTrainedTokenizerFast",
51
+ "unk_token": "[UNK]"
52
+ }