File size: 17,400 Bytes
cad4384 aa1f56f cad4384 cf48cb9 cad4384 f93d8f2 cad4384 72ed94e cad4384 aa1f56f 72ed94e aa1f56f 72ed94e aa1f56f 72ed94e aa1f56f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 |
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
import torch.utils.checkpoint
from torch import nn, Tensor
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import Optional, Tuple, Union, Any
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def padding_to_multiple_of(n, mult):
remainder = n % mult
if remainder == 0:
return 0
return mult - remainder
# scalenorm
class ScaleNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
return x / norm.clamp(min = self.eps) * self.g
# absolute positional encodings
class ScaledSinuEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = nn.Parameter(torch.ones(1,))
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x):
n, device = x.shape[1], x.device
t = torch.arange(n, device = device).type_as(self.inv_freq)
sinu = einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
return emb * self.scale
# T5 relative positional bias
class T5RelativePositionBias(nn.Module):
def __init__(
self,
scale,
causal = False,
num_buckets = 32,
max_distance = 128
):
super().__init__()
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, 1)
@staticmethod
def _relative_position_bucket(
relative_position,
causal = True,
num_buckets = 32,
max_distance = 128
):
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, x):
i, j, device = *x.shape[-2:], x.device
q_pos = torch.arange(i, dtype = torch.long, device = device)
k_pos = torch.arange(j, dtype = torch.long, device = device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j 1 -> i j')
return bias * self.scale
# class
class OffsetScale(nn.Module):
def __init__(self, dim, heads = 1):
super().__init__()
self.weight = nn.Parameter(torch.ones(heads, dim))
self.bias = nn.Parameter(torch.zeros(heads, dim))
nn.init.normal_(self.weight, std = 0.02)
def forward(self, x):
out = einsum('... d, h d -> ... h d', x, self.weight) + self.bias
return out.unbind(dim = -2)
# activation functions
class ReLUSquared(nn.Module):
def forward(self, x):
return F.relu(x) ** 2
class LaplacianAttnFn(nn.Module):
""" https://arxiv.org/abs/2209.10655 claims this is more stable than Relu squared """
def forward(self, x):
mu = math.sqrt(0.5)
std = math.sqrt((4 * math.pi) ** -1)
return (1 + torch.special.erf((x - mu) / (std * math.sqrt(2)))) * 0.5
class FLASH(nn.Module):
def __init__(
self,
*,
dim,
group_size = 256,
query_key_dim = 128,
expansion_factor = 2.,
causal = False,
dropout = 0.,
rotary_pos_emb = None,
norm_klass = nn.LayerNorm,
shift_tokens = False,
laplace_attn_fn = False,
reduce_group_non_causal_attn = True
):
super().__init__()
hidden_dim = int(dim * expansion_factor)
self.group_size = group_size
self.causal = causal
self.shift_tokens = shift_tokens
self.attn_fn = ReLUSquared() if not laplace_attn_fn else LaplacianAttnFn()
# positional embeddings
self.rotary_pos_emb = rotary_pos_emb
self.rel_pos_bias = T5RelativePositionBias(query_key_dim ** 0.5, causal = causal)
# norm
self.norm = norm_klass(dim)
self.dropout = nn.Dropout(dropout)
# whether to reduce groups in non causal linear attention
self.reduce_group_non_causal_attn = reduce_group_non_causal_attn
# projections
self.to_hidden = nn.Sequential(
nn.Linear(dim, hidden_dim * 2),
nn.SiLU()
)
self.to_qk = nn.Sequential(
nn.Linear(dim, query_key_dim),
nn.SiLU()
)
self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
self.to_out = nn.Linear(hidden_dim, dim)
def forward(
self,
x,
*,
mask = None
):
"""
b - batch
n - sequence length (within groups)
g - group dimension
d - feature dimension (keys)
e - feature dimension (values)
i - sequence dimension (source)
j - sequence dimension (target)
"""
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
# prenorm
normed_x = self.norm(x)
# do token shift - a great, costless trick from an independent AI researcher in Shenzhen
if self.shift_tokens:
x_shift, x_pass = normed_x.chunk(2, dim = -1)
x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
normed_x = torch.cat((x_shift, x_pass), dim = -1)
# initial projections
v, gate = self.to_hidden(normed_x).chunk(2, dim = -1)
qk = self.to_qk(normed_x)
# offset and scale
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
# mask out linear attention keys
if exists(mask):
lin_mask = rearrange(mask, '... -> ... 1')
lin_k = lin_k.masked_fill(~lin_mask.bool(), 0.)
# rotate queries and keys
if exists(self.rotary_pos_emb):
quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
# padding for groups
padding = padding_to_multiple_of(n, g)
if padding > 0:
quad_q, quad_k, lin_q, lin_k, v = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v))
mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
mask = F.pad(mask, (0, padding), value = False)
# group along sequence
quad_q, quad_k, lin_q, lin_k, v = map(lambda t: rearrange(t, 'b (n g) d -> b n g d', g = self.group_size), (quad_q, quad_k, lin_q, lin_k, v))
if exists(mask):
mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
# calculate quadratic attention output
sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
sim = sim + self.rel_pos_bias(sim)
attn = self.attn_fn(sim)
attn = self.dropout(attn)
if exists(mask):
attn = attn.masked_fill(~mask.bool(), 0.)
if self.causal:
causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
attn = attn.masked_fill(causal_mask.bool(), 0.)
quad_out = einsum('... i j, ... j d -> ... i d', attn, v)
# calculate linear attention output
if self.causal:
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
# exclusive cumulative sum along group dimension
lin_kv = lin_kv.cumsum(dim = 1)
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
lin_out = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
else:
context_einsum_eq = 'b d e' if self.reduce_group_non_causal_attn else 'b g d e'
lin_kv = einsum(f'b g n d, b g n e -> {context_einsum_eq}', lin_k, v) / n
lin_out = einsum(f'b g n d, {context_einsum_eq} -> b g n e', lin_q, lin_kv)
# fold back groups into full sequence, and excise out padding
quad_attn_out, lin_attn_out = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out, lin_out))
# gate
out = gate * (quad_attn_out + lin_attn_out)
# projection out and residual
return self.to_out(out) + x
# FLASH Transformer
class FLASHTransformer(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
depth,
group_size = 256,
query_key_dim = 128,
expansion_factor = 2.,
causal = False,
attn_dropout = 0.,
norm_type = 'scalenorm',
shift_tokens = True,
laplace_attn_fn = False,
reduce_group_non_causal_attn = True
):
super().__init__()
assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
if norm_type == 'scalenorm':
norm_klass = ScaleNorm
elif norm_type == 'layernorm':
norm_klass = nn.LayerNorm
self.token_emb = nn.Embedding(num_tokens, dim)
self.abs_pos_emb = ScaledSinuEmbedding(dim)
self.group_size = group_size
rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
# max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
self.layers = nn.ModuleList([FLASH(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens, reduce_group_non_causal_attn = reduce_group_non_causal_attn, laplace_attn_fn = laplace_attn_fn) for _ in range(depth)])
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_tokens)
)
def forward(
self,
x,
*,
mask = None
):
x = self.token_emb(x)
x = self.abs_pos_emb(x) + x
for flash in self.layers:
x = flash(x, mask = mask)
x_norm = self.to_logits[0](x)
logits = self.to_logits[1](x_norm)
return logits, x_norm
class FLASHTransformerConfig(PretrainedConfig):
model_type = "flash_transformer"
def __init__(
self,
hidden_size=512,
vocab_size=4096,
num_layers=12,
group_size=256,
query_key_dim=128,
expansion_factor=2.0,
causal=False,
attn_dropout=0.1,
norm_type="scalenorm",
shift_tokens=True,
laplace_attn_fn=False,
reduce_group_non_causal_attn=True,
**kwargs
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.num_layers = num_layers
self.group_size = group_size
self.query_key_dim = query_key_dim
self.expansion_factor = expansion_factor
self.causal = causal
self.attn_dropout = attn_dropout
self.norm_type = norm_type
self.shift_tokens = shift_tokens
self.laplace_attn_fn = laplace_attn_fn
self.reduce_group_non_causal_attn = reduce_group_non_causal_attn
class FLASHTransformerForPretrained(PreTrainedModel):
config_class = FLASHTransformerConfig
base_model_prefix = "flash_transformer"
def __init__(self, config):
super().__init__(config)
self.model = FLASHTransformer(
dim=config.hidden_size,
num_tokens=config.vocab_size,
depth=config.num_layers,
group_size=config.group_size,
query_key_dim=config.query_key_dim,
expansion_factor=config.expansion_factor,
causal=config.causal,
attn_dropout=config.attn_dropout,
norm_type=config.norm_type,
shift_tokens=config.shift_tokens,
laplace_attn_fn=config.laplace_attn_fn,
reduce_group_non_causal_attn=config.reduce_group_non_causal_attn
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
)->Union[Tuple, MaskedLMOutput]:
logits, x = self.model(input_ids, mask=attention_mask)
return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None)
class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
if getattr(config, "use_mlp_classifier", False):
self.score = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(config.hidden_size, self.num_labels, bias=False),
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
# 获取基模型输出
outputs = super().forward(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs["hidden_states"]
input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配
mean_pooled = torch.sum(token_embeddings * input_mask_expanded, dim=1) / input_mask_expanded.sum(dim=1) # 计算加权平均
logits = self.score(mean_pooled)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(loss=loss, logits=logits) |