from __future__ import annotations import torch import torch.nn as nn from torch.nn import functional as F from torch import _softmax_backward_data as _softmax_backward_data from functools import partial from .configuration_gptbert import GptBertConfig from transformers.modeling_utils import PreTrainedModel from transformers.activations import gelu_new from transformers.modeling_outputs import ( MaskedLMOutput, MultipleChoiceModelOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, BaseModelOutput, CausalLMOutput ) import math from typing import TYPE_CHECKING, Optional, Union, Tuple, List try: from torch.nn.attention.flex_attention import flex_attention, create_block_mask except ImportError: pass class ModelOutput: def __init__( self, logits: torch.Tensor | None = None, loss: torch.Tensor | float | None = None, perplexity: torch.Tensor | float | None = None, accuracy: float | None = None, z_loss: torch.Tensor | float | None = None, **kwargs ): self.logits: torch.Tensor | None self.loss: torch.Tensor | float | None self.perplexity: torch.Tensor | float | None self.accuracy: float | None self.z_loss: torch.Tensor | float | None self.logits = logits self.loss = loss self.perplexity = perplexity self.accuracy = accuracy self.z_loss = z_loss for attr, value in kwargs.items(): setattr(self, attr, value) class CastedLinear(nn.Linear): def __init__(self, in_features, out_features, bias): super().__init__(in_features, out_features, bias=bias) def reset_parameters(self) -> None: std: float = math.sqrt(2.0 / (self.in_features + self.out_features)) nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std) def forward(self, x): return F.linear(x, self.weight.type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None) class CastedLinearIn(nn.Linear): def __init__(self, in_features, out_features, bias): super().__init__(in_features, out_features, bias=bias) self.scale = nn.Parameter(torch.ones(in_features)) def reset_parameters(self) -> None: std: float = math.sqrt(2.0 / (self.in_features + self.out_features)) nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std) def forward(self, x): return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None) class CastedLinearOut(nn.Linear): def __init__(self, in_features, out_features, bias): super().__init__(in_features, out_features, bias=bias) self.scale = nn.Parameter(torch.ones(out_features)) def reset_parameters(self) -> None: std: float = math.sqrt(2.0 / (self.in_features + self.out_features)) nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std) def forward(self, x): return F.linear(x, (self.scale.unsqueeze(1) * self.weight).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None) class MultiCastedLinearOrtho(nn.Module): def __init__(self, in_features, out_features, bias): super().__init__() self.in_features = in_features self.out_features = out_features self.weights = nn.ParameterList() for out_feature in out_features: self.weights.append(nn.Parameter(torch.empty((out_feature, in_features)))) if bias: self.bias = nn.Parameter(torch.zeros(sum(out_features))) else: self.bias = self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: for i, weight in enumerate(self.weights): std: float = math.sqrt(2.0 / (self.in_features + self.out_features[i])) nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std) def forward(self, x): return F.linear(x, torch.cat([weight for weight in self.weights], dim=0).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None) class MultiCastedLinearOrthoIn(nn.Module): def __init__(self, in_features, out_features, bias): super().__init__() self.in_features = in_features self.out_features = out_features self.weights = nn.ParameterList() for out_feature in out_features: self.weights.append(nn.Parameter(torch.empty((out_feature, in_features)))) if bias: self.bias = nn.Parameter(torch.zeros(sum(out_features))) else: self.bias = self.register_parameter("bias", None) self.scale = nn.Parameter(torch.ones(in_features)) self.reset_parameters() def reset_parameters(self) -> None: for weight in self.weights: std = 0.5 * (self.in_features ** -0.5) bound = (3 ** 0.5) * std with torch.no_grad(): weight.uniform_(-bound, bound) def forward(self, x): return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None) class MultiCastedLinearOrthoOut(nn.Module): def __init__(self, in_features, out_features, bias): super().__init__() self.in_features = in_features self.out_features = out_features self.weights = nn.ParameterList() for out_feature in out_features: self.weights.append(nn.Parameter(torch.empty((out_feature, in_features)))) if bias: self.bias = nn.Parameter(torch.zeros(sum(out_features))) else: self.bias = self.register_parameter("bias", None) self.scale = nn.Parameter(torch.ones(sum(out_features))) self.reset_parameters() def reset_parameters(self) -> None: for weight in self.weights: std = 0.5 * (self.in_features ** -0.5) bound = (3 ** 0.5) * std with torch.no_grad(): weight.uniform_(-bound, bound) def forward(self, x): return F.linear(x, (self.scale.unsqueeze(1) * torch.cat([weight for weight in self.weights], dim=0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None) class GeGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1) x = x * gelu_new(gate) return x class MaskedSoftmax(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int) -> torch.Tensor: ctx.dim: int ctx.dim = dim x.masked_fill_(mask, float('-inf')) x = torch.softmax(x, ctx.dim) x.masked_fill_(mask, 0.0) ctx.save_for_backward(x) return x @staticmethod def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]: output: torch.Tensor output, = ctx.saved_tensors inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype) return inputGrad, None, None class Encoder(nn.Module): def __init__(self, config) -> None: super().__init__() self.layers: nn.ModuleList[Layer] self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)]) for i, layer in enumerate(self.layers): for weight in layer.mlp.up_proj.weights: weight.data *= math.sqrt(1.0 / (2.0 * (i + 1))) layer.mlp.down_proj.weight.data *= math.sqrt(1.0 / (2.0 * (i + 1))) self.short_long_ratio = config.short_long_ratio def set_window_length(self, config) -> None: for i, layer in enumerate(self.layers): if (i+1) % self.short_long_ratio == 0: layer.set_window_length(config.window_length, config.not_flex) else: layer.set_window_length(256, config.not_flex) def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: hidden_layer: List[torch.Tensor] attention_probs: List[torch.Tensor] hidden_states = [] attention_probs = [] v1 = None for layer in self.layers: hidden_layer, v1, attention_p = layer(hidden_layer, embeddings, v1, mask) hidden_states.append(hidden_layer) attention_probs.append(attention_p) return hidden_states, attention_probs class Layer(nn.Module): def __init__(self, config, layer_idx: int) -> None: super().__init__() self.attention: SelfAttention self.mlp: FeedForward self.attention = SelfAttention(config, layer_idx) self.mlp = FeedForward(config) self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.])) def set_window_length(self, window_length: int, not_flex: bool) -> None: self.attention.set_window_length(window_length, not_flex) def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, mask: torch.Tensor | None = None) -> Tuple[torch.Tensor, torch.Tensor]: output: torch.Tensor attention_p: torch.Tensor attention_output = (1 - self.lambdas[0]) * hidden_layer + self.lambdas[0] * embeddings qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings) attention_output, v1, attention_p = self.attention(attention_output, qk_layer, v1, mask) mlp_layer = mlp_layer + attention_output hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings) output = hidden_layer + attention_output + self.mlp(mlp_layer) return output, v1, attention_p class Embedding(nn.Module): def __init__(self, config) -> None: super().__init__() assert hasattr(config, "vocab_size"), "The config must have a vocab_size attribute!" assert hasattr(config, "hidden_size"), "The config must have a hidden_size attribute!" assert hasattr(config, "embedding_dropout_p"), "The model must have a embedding_dropout_p attribute!" self.word_embedding: nn.Embedding self.word_norm: nn.LayerNorm self.dropout: nn.Dropout self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.word_norm_eps, elementwise_affine=False, bias=False) self.word_scale = nn.Parameter(torch.zeros(config.hidden_size)) self.dropout = nn.Dropout(config.embedding_dropout_p) self.initialize(config.hidden_size, config.vocab_size) @torch.no_grad() def initialize(self, hidden_size: int, vocab_size: int) -> None: std: float std = math.sqrt(2.0 / (hidden_size + vocab_size)) nn.init.trunc_normal_(self.word_embedding.weight, mean=0.0, std=std, a=-2*std, b=2*std) def forward(self, input_ids: torch.Tensor) -> torch.Tensor: word_embedding: torch.Tensor word_embedding = self.word_embedding(input_ids) word_embedding = self.word_norm(word_embedding) word_embedding = (word_embedding * (self.word_scale + 1.0).unsqueeze(0).unsqueeze(0)) return self.dropout(word_embedding) class MaskClassifier(nn.Module): def __init__(self, config, embedding_weights: nn.Parameter) -> None: super().__init__() self.projection: CastedLinear self.emb2vocab: CastedLinear self.pre_norm: nn.LayerNorm self.post_norm: nn.LayerNorm self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine) self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False) self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine) self.emb2vocab = CastedLinearIn(config.hidden_size, config.vocab_size, bias=True) self.initialize(config.hidden_size, config.vocab_size, embedding_weights) @torch.no_grad() def initialize(self, hidden_size: int, vocab_size: int, embedding_weights: nn.Parameter) -> None: proj_std: float = math.sqrt(2.0 / (hidden_size + 4*hidden_size)) nn.init.trunc_normal_(self.projection.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std) self.emb2vocab.weight = embedding_weights self.emb2vocab.bias.zero_() def project(self, hidden_layer: torch.Tensor) -> torch.Tensor: projection: torch.Tensor projection = self.projection(hidden_layer) projection = gelu_new(projection) projection = self.post_norm(projection) return projection def calculate_output(self, hidden_layer: torch.Tensor) -> torch.Tensor: return self.emb2vocab(hidden_layer) def forward(self, hidden_layer: torch.Tensor, labels: torch.Tensor | None = None) -> torch.Tensor: output: torch.Tensor if labels is not None: hidden_layer = torch.index_select(hidden_layer.flatten(0, 1), 0, torch.nonzero(labels.flatten() != -100).squeeze()) hidden_layer = self.pre_norm(hidden_layer) hidden_layer = self.project(hidden_layer) output = self.calculate_output(hidden_layer) return output class SelfAttention(nn.Module): def __init__(self, config, layer_idx) -> None: super().__init__() self.d_qk = config.d_qk self.d_v = config.d_v self.num_attention_heads = config.num_attention_heads self.num_kv_heads = config.num_kv_heads self.hidden_size = config.hidden_size self.q_out_dim = self.d_qk * self.num_attention_heads self.k_out_dim = self.d_qk * self.num_kv_heads self.v_out_dim = self.d_v * self.num_kv_heads self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False) self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False) self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False) self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.attention_pre_norm_eps, elementwise_affine=config.attention_pre_norm_affine) self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.attention_pre_norm_eps, elementwise_affine=config.attention_pre_norm_affine) self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.attention_inter_norm_eps, elementwise_affine=config.attention_inter_norm_affine) self.q_norm = nn.LayerNorm(config.d_qk, eps=config.attention_pre_norm_eps, elementwise_affine=False, bias=False) self.k_norm = nn.LayerNorm(config.d_qk, eps=config.attention_pre_norm_eps, elementwise_affine=False, bias=False) self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, config.d_qk)) self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, config.d_qk)) self.dropout = nn.Dropout(config.attention_output_dropout_p) theta = 160_000 if (layer_idx + 1) % config.short_long_ratio == 0 else 10_000 self.rope_embedding = RotaryPositionalEmbeddings(config, theta) self.scale: float = 1.0 / math.sqrt(self.d_qk) self.dropout = nn.Dropout(config.attention_dropout if hasattr(config, "attention_dropout") else 0.0) self.lambdas = nn.Parameter(torch.tensor([0.5])) self.initialize() self.sequence_length = config.max_sequence_length self.is_causal = config.is_decoder self.not_flex = config.not_flex @torch.no_grad() def initialize(self) -> None: std: float = math.sqrt(2.0 / (self.hidden_size + 4*self.hidden_size)) for weight in self.qk_proj.weights: nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std) nn.init.trunc_normal_(self.v_proj.weight, mean=0.0, std=std, a=2*std, b=2*std) self.out_proj.weight.data.zero_() def set_window_length(self, window_length: int, not_flex: bool) -> None: self.window_length: int = window_length if not not_flex: self.block_mask = self.create_block_mask(window_length) def causal_mask_mode(self, window_length, b, _, q_idx, kv_idx): return (q_idx >= kv_idx) & ((q_idx - kv_idx) < window_length) def bidirectional_mask_mode(self, window_length, b, _, q_idx, kv_idx): return ((q_idx - kv_idx) < window_length) & ((kv_idx - q_idx) < window_length) def create_block_mask(self, window_length: int) -> torch.Tensor: if self.is_causal: return create_block_mask( partial(self.causal_mask_mode, self.window_length), 1, 1, self.sequence_length, self.sequence_length, device=self.k_scale.device ) else: return create_block_mask( partial(self.bidirectional_mask_mode, self.window_length), 1, 1, self.sequence_length, self.sequence_length, device=self.k_scale.device ) def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: attention_scores: torch.Tensor attention_probabilities: torch.Tensor batch_size: int query_length: int key_length: int batch_size, _, query_length, _ = query.size() _, _, key_length, _ = key.size() if self.is_causal: window_mask = ~torch.ones(query_length, key_length, dtype=torch.bool, device=self.k_scale.device).tril().triu(diagonal=-self.window_length).view(1, 1, query_length, key_length) else: window_mask = ~torch.ones(query_length, key_length, dtype=torch.bool, device=self.k_scale.device).tril(diagonal=self.window_length).triu(diagonal=-self.window_length).view(1, 1, query_length, key_length) if padding_mask is not None: attention_mask = padding_mask | window_mask else: attention_mask = window_mask attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, T, T] attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length) attention_probabilities = MaskedSoftmax.apply(attention_scores, attention_mask, -1) attention_probabilities = self.dropout(attention_probabilities) value = torch.bmm(attention_probabilities.flatten(0, 1), value.flatten(0, 1)) value = value.view(batch_size, self.num_attention_heads, query_length, self.d_v) return value, attention_probabilities.detach() def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, mask: torch.Tensor | None = None, doc_ids: torch.Tensor | None = None) -> Tuple[torch.Tensor, torch.Tensor]: hidden_layer = self.pre_v_norm(hidden_layer) qk_layer = self.pre_qk_norm(qk_layer) query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1) value = self.v_proj(hidden_layer) query_length: int = hidden_layer.size(0) key_length: int = hidden_layer.size(0) batch_size: int = hidden_layer.size(1) query = query.reshape(query_length, batch_size, self.num_attention_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D] key = key.reshape(key_length, batch_size, self.num_kv_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D] value = value.reshape(key_length, batch_size, self.num_kv_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D] query, key = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query), ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key) if v1 is None: v1 = value value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1 query = self.rope_embedding(query) key = self.rope_embedding(key) if self.not_flex: output, attention_probabilities = self.attention_operation(query, key, value, mask) else: def document_score_mod(score, b, _, q_idx, kv_idx): return torch.where(doc_ids[q_idx] == doc_ids[kv_idx], score, -float("inf")) if self.is_causal: block_mask = create_block_mask( partial(self.causal_mask_mode, self.window_length), 1, 1, query_length, key_length, device=self.k_scale.device ) else: block_mask = create_block_mask( partial(self.bidirectional_mask_mode, self.window_length), 1, 1, query_length, key_length, device=self.k_scale.device ) output = flex_attention( query, key, value, block_mask=block_mask, enable_gqa=True ) attention_probabilities = None output = output.permute(2, 0, 1, 3).flatten(2, 3) # shape: [T, B, H*D] output = self.inter_norm(output) output = self.out_proj(output) return self.dropout(output), v1, attention_probabilities class FeedForward(nn.Module): def __init__(self, config) -> None: super().__init__() self.up_proj: CastedLinear self.down_proj: CastedLinear self.pre_norm: nn.LayerNorm self.inter_norm: nn.LayerNorm self.activation: GeGLU self.dropout: nn.Dropout self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.feed_forward_pre_norm_eps, elementwise_affine=config.feed_forward_pre_norm_affine) self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False) self.activation = GeGLU() self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.feed_forward_inter_norm_eps, elementwise_affine=config.feed_forward_inter_norm_affine) self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False) self.dropout = nn.Dropout(config.feed_forward_dropout_p) self.initialize(config.hidden_size) @torch.no_grad() def initialize(self, hidden_size: int) -> None: std: float = math.sqrt(2.0 / (5*hidden_size)) for weight in self.up_proj.weights: nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std) self.down_proj.weight.data.zero_() def up_project(self, hidden_layer: torch.Tensor) -> torch.Tensor: hidden_layer = self.pre_norm(hidden_layer) return self.up_proj(hidden_layer) def activate(self, projection: torch.Tensor) -> torch.Tensor: activated_projection: torch.Tensor activated_projection = self.activation(projection) activated_projection = self.inter_norm(activated_projection.float()).type_as(projection) return activated_projection def down_project(self, activated_projection: torch.Tensor) -> torch.Tensor: output: torch.Tensor output = self.down_proj(activated_projection) return self.dropout(output) def forward(self, hidden_layer: torch.Tensor) -> torch.Tensor: output: torch.Tensor output = self.up_project(hidden_layer) output = self.activate(output) output = self.down_project(output) return output class RotaryPositionalEmbeddings(nn.Module): def __init__(self, config, theta: int) -> None: super().__init__() assert hasattr(config, "d_qk"), "The config must have a d_qk attribute!" assert hasattr(config, "max_sequence_length"), "The config must have a max_sequence_length attribute!" self.inv_freq: torch.Tensor self.cos_matrix: torch.Tensor self.sin_matrix: torch.Tensor head_size: int max_seq_len: int inv_freq: torch.Tensor pos: torch.Tensor embedding: torch.Tensor head_size = config.d_qk assert head_size % 2 == 0 max_seq_len = config.max_sequence_length inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size)) pos = torch.arange(max_seq_len, dtype=torch.float32) embedding = torch.einsum('n, d -> nd', pos, inv_freq) embedding = torch.cat([embedding, embedding], dim=-1).unsqueeze(0) self.register_buffer("cos_matrix", embedding.cos(), persistent=False) self.register_buffer("sin_matrix", embedding.sin(), persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: seq_len: int cos_matrix: torch.Tensor sin_matrix: torch.Tensor x_rotate_half: torch.Tensor out: torch.Tensor hidden_layer = x.float() seq_len = x.shape[2] cos_matrix = self.cos_matrix[:, None, :seq_len, :] sin_matrix = self.sin_matrix[:, None, :seq_len, :] x_rotate_half = torch.cat( [ -hidden_layer[:, :, :, x.size(-1) // 2:], hidden_layer[:, :, :, :x.size(-1) // 2] ], dim=-1 ) out = hidden_layer * cos_matrix + x_rotate_half * sin_matrix return out.type_as(x) # # HuggingFace wrappers # class GptBertPreTrainedModel(PreTrainedModel): config_class = GptBertConfig supports_gradient_checkpointing = False def _set_gradient_checkpointing(self, module, value=False): raise NotImplementedError("Gradient checkpointing is not supported by this model") def _init_weights(self, module): std = math.sqrt(2.0 / (5.0 * self.hidden_size)) if isinstance(module, nn.Linear): nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class GptBertModel(GptBertPreTrainedModel): def __init__(self, config, add_mlm_layer=False, **kwargs): super().__init__(config, **kwargs) self.config = config self.hidden_size = config.hidden_size self.embedding = Embedding(config) self.encoder = Encoder(config) self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None self.set_window_length(config) def set_window_length(self, config) -> None: self.encoder.set_window_length(config) def get_input_embeddings(self): return self.embedding.word_embedding def set_input_embeddings(self, value): self.embedding.word_embedding = value def get_contextualized_embeddings( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None ) -> List[torch.Tensor]: if input_ids is not None: input_shape = input_ids.size() else: raise ValueError("You have to specify input_ids") batch_size, seq_length = input_shape device = input_ids.device # if attention_mask is None: # attention_mask = torch.zeros(batch_size, seq_length, dtype=torch.bool, device=device) if attention_mask is not None: attention_mask = ~attention_mask.bool() if len(attention_mask.size()) == 2: attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) elif len(attention_mask.size()) == 3: attention_mask = attention_mask.unsqueeze(1) if self.config.is_decoder: attention_mask = attention_mask | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(0) static_embeddings = self.embedding(input_ids.t()) contextualized_embeddings, attention_probs = self.encoder(static_embeddings, static_embeddings, attention_mask) contextualized_embeddings = [e.transpose(0, 1) for e in contextualized_embeddings] last_layer = contextualized_embeddings[-1] contextualized_embeddings = [contextualized_embeddings[0]] + [ contextualized_embeddings[i] - contextualized_embeddings[i - 1] for i in range(1, len(contextualized_embeddings)) ] return last_layer, contextualized_embeddings, attention_probs def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple[torch.Tensor], BaseModelOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask) if not return_dict: return ( sequence_output, *([contextualized_embeddings] if output_hidden_states else []), *([attention_probs] if output_attentions else []) ) return BaseModelOutput( last_hidden_state=sequence_output, hidden_states=contextualized_embeddings if output_hidden_states else None, attentions=attention_probs if output_attentions else None ) class GptBertForMaskedLM(GptBertModel): _keys_to_ignore_on_load_unexpected = ["head"] def __init__(self, config, **kwargs): super().__init__(config, add_mlm_layer=True, **kwargs) def get_output_embeddings(self): return self.classifier.emb2vocab.weight def set_output_embeddings(self, new_embeddings): self.classifier.emb2vocab.weight = new_embeddings def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, **kwargs ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask) subword_prediction = self.classifier(sequence_output) subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5) masked_lm_loss = None if labels is not None: labels_flatten = labels[:, 1:].flatten() subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1) masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten) if not return_dict: output = ( subword_prediction, *([contextualized_embeddings] if output_hidden_states else []), *([attention_probs] if output_attentions else []) ) return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return MaskedLMOutput( loss=masked_lm_loss, logits=subword_prediction, hidden_states=contextualized_embeddings if output_hidden_states else None, attentions=attention_probs if output_attentions else None ) class Classifier(nn.Module): def __init__(self, config, num_labels: int): super().__init__() drop_out = getattr(config, "cls_dropout", None) drop_out = config.hidden_dropout_prob if drop_out is None else drop_out self.projection: CastedLinear self.emb2vocab: CastedLinear self.pre_norm: nn.LayerNorm self.post_norm: nn.LayerNorm self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine) self.projection = CastedLinear(config.hidden_size, config.hidden_size, bias=False) self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine) self.emb2vocab = CastedLinear(config.hidden_size, num_labels, bias=True) self.dropout = nn.Dropout(drop_out) self.initialize(config.hidden_size, config.intermediate_size, num_labels) @torch.no_grad() def initialize(self, hidden_size: int, intermediate_size: int, vocab_size: int) -> None: proj_std: float = math.sqrt(2.0 / (hidden_size + intermediate_size)) nn.init.trunc_normal_(self.projection.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std) nn.init.trunc_normal_(self.emb2vocab.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std) self.emb2vocab.bias.zero_() def project(self, hidden_layer: torch.Tensor) -> torch.Tensor: projection: torch.Tensor projection = self.pre_norm(hidden_layer) projection = self.dropout(projection) projection = self.projection(hidden_layer) projection = gelu_new(projection) projection = self.post_norm(projection) return projection def calculate_output(self, hidden_layer: torch.Tensor) -> torch.Tensor: return self.emb2vocab(hidden_layer) def forward(self, hidden_layer: torch.Tensor) -> torch.Tensor: output: torch.Tensor projection: torch.Tensor projection = self.project(hidden_layer) output = self.calculate_output(projection) return output class GptBertForCausalLM(GptBertModel): _keys_to_ignore_on_load_unexpected = ["head"] def __init__(self, config, **kwargs): config.is_decoder = True super().__init__(config, add_mlm_layer=True, **kwargs) def get_output_embeddings(self): return self.classifier.emb2vocab.weight def set_output_embeddings(self, new_embeddings): self.classifier.emb2vocab.weight = new_embeddings def get_input_embeddings(self): return self.embedding.word_embedding def set_input_embeddings(self, value): self.embedding.word_embedding = value def set_decoder(self, decoder): self.encoder = decoder def get_decoder(self): return self.encoder def can_generate(self): return True def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None ) -> Union[Tuple, CausalLMOutput]: assert inputs_embeds is None, "inputs_embeds is not supported for now" assert past_key_values is None, "past_key_values is not supported for now" assert not use_cache, "use_cache is not supported for now" sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask) subword_prediction = self.classifier(sequence_output) subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5) masked_lm_loss = None if labels is not None: labels_flatten = labels[:, 1:].flatten() subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1) masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten) if not return_dict: output = ( subword_prediction, *([contextualized_embeddings] if output_hidden_states else []), *([attention_probs] if output_attentions else []) ) return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return CausalLMOutput( loss=masked_lm_loss, logits=subword_prediction, hidden_states=contextualized_embeddings if output_hidden_states else None, attentions=attention_probs if output_attentions else None ) def prepare_inputs_for_generation( self, input_ids: torch.Tensor, past_key_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, use_cache: bool = True, num_logits_to_keep: Optional[int] = None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases if num_logits_to_keep is not None: model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, } ) return model_inputs class GptBertForSequenceClassification(GptBertModel): _keys_to_ignore_on_load_unexpected = ["classifier"] _keys_to_ignore_on_load_missing = ["head"] def __init__(self, config, **kwargs): super().__init__(config, add_mlm_layer=False, **kwargs) self.num_labels = config.num_labels self.head = Classifier(config, self.num_labels) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, **kwargs ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask) logits = self.head(sequence_output[:, 0, :]) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = nn.MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = ( logits, *([contextualized_embeddings] if output_hidden_states else []), *([attention_probs] if output_attentions else []) ) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=contextualized_embeddings if output_hidden_states else None, attentions=attention_probs if output_attentions else None ) class GptBertForTokenClassification(GptBertModel): _keys_to_ignore_on_load_unexpected = ["classifier"] _keys_to_ignore_on_load_missing = ["head"] def __init__(self, config, **kwargs): super().__init__(config, add_mlm_layer=False, **kwargs) self.num_labels = config.num_labels self.head = Classifier(config, self.num_labels) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, **kwargs ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask) logits = self.head(sequence_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = ( logits, *([contextualized_embeddings] if output_hidden_states else []), *([attention_probs] if output_attentions else []) ) return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=contextualized_embeddings if output_hidden_states else None, attentions=attention_probs if output_attentions else None ) class GptBertForQuestionAnswering(GptBertModel): _keys_to_ignore_on_load_unexpected = ["classifier"] _keys_to_ignore_on_load_missing = ["head"] def __init__(self, config, **kwargs): super().__init__(config, add_mlm_layer=False, **kwargs) self.num_labels = config.num_labels self.head = Classifier(config, self.num_labels) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, start_positions: Optional[torch.Tensor] = None, end_positions: Optional[torch.Tensor] = None, **kwargs ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask) logits = self.head(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = ( start_logits, end_logits, *([contextualized_embeddings] if output_hidden_states else []), *([attention_probs] if output_attentions else []) ) return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=contextualized_embeddings if output_hidden_states else None, attentions=attention_probs if output_attentions else None ) class GptBertForMultipleChoice(GptBertModel): _keys_to_ignore_on_load_unexpected = ["classifier"] _keys_to_ignore_on_load_missing = ["head"] def __init__(self, config, **kwargs): super().__init__(config, add_mlm_layer=False, **kwargs) self.num_labels = getattr(config, "num_labels", 2) self.head = Classifier(config, self.num_labels) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict num_choices = input_ids.shape[1] flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask) logits = self.head(sequence_output) reshaped_logits = logits.view(-1, num_choices) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) if not return_dict: output = ( reshaped_logits, *([contextualized_embeddings] if output_hidden_states else []), *([attention_probs] if output_attentions else []) ) return ((loss,) + output) if loss is not None else output return MultipleChoiceModelOutput( loss=loss, logits=reshaped_logits, hidden_states=contextualized_embeddings if output_hidden_states else None, attentions=attention_probs if output_attentions else None )