norbert4-base / modeling_gptbert.py
davda54's picture
Upload folder using huggingface_hub
971057e verified
raw
history blame
50.1 kB
from __future__ import annotations
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import _softmax_backward_data as _softmax_backward_data
from functools import partial
from .configuration_gptbert import GptBertConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.activations import gelu_new
from transformers.modeling_outputs import (
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
BaseModelOutput,
CausalLMOutput
)
import math
from typing import TYPE_CHECKING, Optional, Union, Tuple, List
try:
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
except ImportError:
pass
class ModelOutput:
def __init__(
self,
logits: torch.Tensor | None = None,
loss: torch.Tensor | float | None = None,
perplexity: torch.Tensor | float | None = None,
accuracy: float | None = None,
z_loss: torch.Tensor | float | None = None,
**kwargs
):
self.logits: torch.Tensor | None
self.loss: torch.Tensor | float | None
self.perplexity: torch.Tensor | float | None
self.accuracy: float | None
self.z_loss: torch.Tensor | float | None
self.logits = logits
self.loss = loss
self.perplexity = perplexity
self.accuracy = accuracy
self.z_loss = z_loss
for attr, value in kwargs.items():
setattr(self, attr, value)
class CastedLinear(nn.Linear):
def __init__(self, in_features, out_features, bias):
super().__init__(in_features, out_features, bias=bias)
def reset_parameters(self) -> None:
std: float = math.sqrt(2.0 / (self.in_features + self.out_features))
nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std)
def forward(self, x):
return F.linear(x, self.weight.type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
class CastedLinearIn(nn.Linear):
def __init__(self, in_features, out_features, bias):
super().__init__(in_features, out_features, bias=bias)
self.scale = nn.Parameter(torch.ones(in_features))
def reset_parameters(self) -> None:
std: float = math.sqrt(2.0 / (self.in_features + self.out_features))
nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std)
def forward(self, x):
return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
class CastedLinearOut(nn.Linear):
def __init__(self, in_features, out_features, bias):
super().__init__(in_features, out_features, bias=bias)
self.scale = nn.Parameter(torch.ones(out_features))
def reset_parameters(self) -> None:
std: float = math.sqrt(2.0 / (self.in_features + self.out_features))
nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std)
def forward(self, x):
return F.linear(x, (self.scale.unsqueeze(1) * self.weight).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
class MultiCastedLinearOrtho(nn.Module):
def __init__(self, in_features, out_features, bias):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weights = nn.ParameterList()
for out_feature in out_features:
self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
if bias:
self.bias = nn.Parameter(torch.zeros(sum(out_features)))
else:
self.bias = self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
for i, weight in enumerate(self.weights):
std: float = math.sqrt(2.0 / (self.in_features + self.out_features[i]))
nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std)
def forward(self, x):
return F.linear(x, torch.cat([weight for weight in self.weights], dim=0).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
class MultiCastedLinearOrthoIn(nn.Module):
def __init__(self, in_features, out_features, bias):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weights = nn.ParameterList()
for out_feature in out_features:
self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
if bias:
self.bias = nn.Parameter(torch.zeros(sum(out_features)))
else:
self.bias = self.register_parameter("bias", None)
self.scale = nn.Parameter(torch.ones(in_features))
self.reset_parameters()
def reset_parameters(self) -> None:
for weight in self.weights:
std = 0.5 * (self.in_features ** -0.5)
bound = (3 ** 0.5) * std
with torch.no_grad():
weight.uniform_(-bound, bound)
def forward(self, x):
return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
class MultiCastedLinearOrthoOut(nn.Module):
def __init__(self, in_features, out_features, bias):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weights = nn.ParameterList()
for out_feature in out_features:
self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
if bias:
self.bias = nn.Parameter(torch.zeros(sum(out_features)))
else:
self.bias = self.register_parameter("bias", None)
self.scale = nn.Parameter(torch.ones(sum(out_features)))
self.reset_parameters()
def reset_parameters(self) -> None:
for weight in self.weights:
std = 0.5 * (self.in_features ** -0.5)
bound = (3 ** 0.5) * std
with torch.no_grad():
weight.uniform_(-bound, bound)
def forward(self, x):
return F.linear(x, (self.scale.unsqueeze(1) * torch.cat([weight for weight in self.weights], dim=0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
class GeGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
x = x * gelu_new(gate)
return x
class MaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int) -> torch.Tensor:
ctx.dim: int
ctx.dim = dim
x.masked_fill_(mask, float('-inf'))
x = torch.softmax(x, ctx.dim)
x.masked_fill_(mask, 0.0)
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
output: torch.Tensor
output, = ctx.saved_tensors
inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
return inputGrad, None, None
class Encoder(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.layers: nn.ModuleList[Layer]
self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
for i, layer in enumerate(self.layers):
for weight in layer.mlp.up_proj.weights:
weight.data *= math.sqrt(1.0 / (2.0 * (i + 1)))
layer.mlp.down_proj.weight.data *= math.sqrt(1.0 / (2.0 * (i + 1)))
self.short_long_ratio = config.short_long_ratio
def set_window_length(self, config) -> None:
for i, layer in enumerate(self.layers):
if (i+1) % self.short_long_ratio == 0:
layer.set_window_length(config.window_length, config.not_flex)
else:
layer.set_window_length(256, config.not_flex)
def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
hidden_layer: List[torch.Tensor]
attention_probs: List[torch.Tensor]
hidden_states = []
attention_probs = []
v1 = None
for layer in self.layers:
hidden_layer, v1, attention_p = layer(hidden_layer, embeddings, v1, mask)
hidden_states.append(hidden_layer)
attention_probs.append(attention_p)
return hidden_states, attention_probs
class Layer(nn.Module):
def __init__(self, config, layer_idx: int) -> None:
super().__init__()
self.attention: SelfAttention
self.mlp: FeedForward
self.attention = SelfAttention(config, layer_idx)
self.mlp = FeedForward(config)
self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
def set_window_length(self, window_length: int, not_flex: bool) -> None:
self.attention.set_window_length(window_length, not_flex)
def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, mask: torch.Tensor | None = None) -> Tuple[torch.Tensor, torch.Tensor]:
output: torch.Tensor
attention_p: torch.Tensor
attention_output = (1 - self.lambdas[0]) * hidden_layer + self.lambdas[0] * embeddings
qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings
mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings)
attention_output, v1, attention_p = self.attention(attention_output, qk_layer, v1, mask)
mlp_layer = mlp_layer + attention_output
hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings)
output = hidden_layer + attention_output + self.mlp(mlp_layer)
return output, v1, attention_p
class Embedding(nn.Module):
def __init__(self, config) -> None:
super().__init__()
assert hasattr(config, "vocab_size"), "The config must have a vocab_size attribute!"
assert hasattr(config, "hidden_size"), "The config must have a hidden_size attribute!"
assert hasattr(config, "embedding_dropout_p"), "The model must have a embedding_dropout_p attribute!"
self.word_embedding: nn.Embedding
self.word_norm: nn.LayerNorm
self.dropout: nn.Dropout
self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.word_norm_eps, elementwise_affine=False, bias=False)
self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
self.dropout = nn.Dropout(config.embedding_dropout_p)
self.initialize(config.hidden_size, config.vocab_size)
@torch.no_grad()
def initialize(self, hidden_size: int, vocab_size: int) -> None:
std: float
std = math.sqrt(2.0 / (hidden_size + vocab_size))
nn.init.trunc_normal_(self.word_embedding.weight, mean=0.0, std=std, a=-2*std, b=2*std)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
word_embedding: torch.Tensor
word_embedding = self.word_embedding(input_ids)
word_embedding = self.word_norm(word_embedding)
word_embedding = (word_embedding * (self.word_scale + 1.0).unsqueeze(0).unsqueeze(0))
return self.dropout(word_embedding)
class MaskClassifier(nn.Module):
def __init__(self, config, embedding_weights: nn.Parameter) -> None:
super().__init__()
self.projection: CastedLinear
self.emb2vocab: CastedLinear
self.pre_norm: nn.LayerNorm
self.post_norm: nn.LayerNorm
self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine)
self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine)
self.emb2vocab = CastedLinearIn(config.hidden_size, config.vocab_size, bias=True)
self.initialize(config.hidden_size, config.vocab_size, embedding_weights)
@torch.no_grad()
def initialize(self, hidden_size: int, vocab_size: int, embedding_weights: nn.Parameter) -> None:
proj_std: float = math.sqrt(2.0 / (hidden_size + 4*hidden_size))
nn.init.trunc_normal_(self.projection.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
self.emb2vocab.weight = embedding_weights
self.emb2vocab.bias.zero_()
def project(self, hidden_layer: torch.Tensor) -> torch.Tensor:
projection: torch.Tensor
projection = self.projection(hidden_layer)
projection = gelu_new(projection)
projection = self.post_norm(projection)
return projection
def calculate_output(self, hidden_layer: torch.Tensor) -> torch.Tensor:
return self.emb2vocab(hidden_layer)
def forward(self, hidden_layer: torch.Tensor, labels: torch.Tensor | None = None) -> torch.Tensor:
output: torch.Tensor
if labels is not None:
hidden_layer = torch.index_select(hidden_layer.flatten(0, 1), 0, torch.nonzero(labels.flatten() != -100).squeeze())
hidden_layer = self.pre_norm(hidden_layer)
hidden_layer = self.project(hidden_layer)
output = self.calculate_output(hidden_layer)
return output
class SelfAttention(nn.Module):
def __init__(self, config, layer_idx) -> None:
super().__init__()
self.d_qk = config.d_qk
self.d_v = config.d_v
self.num_attention_heads = config.num_attention_heads
self.num_kv_heads = config.num_kv_heads
self.hidden_size = config.hidden_size
self.q_out_dim = self.d_qk * self.num_attention_heads
self.k_out_dim = self.d_qk * self.num_kv_heads
self.v_out_dim = self.d_v * self.num_kv_heads
self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.attention_pre_norm_eps, elementwise_affine=config.attention_pre_norm_affine)
self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.attention_pre_norm_eps, elementwise_affine=config.attention_pre_norm_affine)
self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.attention_inter_norm_eps, elementwise_affine=config.attention_inter_norm_affine)
self.q_norm = nn.LayerNorm(config.d_qk, eps=config.attention_pre_norm_eps, elementwise_affine=False, bias=False)
self.k_norm = nn.LayerNorm(config.d_qk, eps=config.attention_pre_norm_eps, elementwise_affine=False, bias=False)
self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, config.d_qk))
self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, config.d_qk))
self.dropout = nn.Dropout(config.attention_output_dropout_p)
theta = 160_000 if (layer_idx + 1) % config.short_long_ratio == 0 else 10_000
self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
self.scale: float = 1.0 / math.sqrt(self.d_qk)
self.dropout = nn.Dropout(config.attention_dropout if hasattr(config, "attention_dropout") else 0.0)
self.lambdas = nn.Parameter(torch.tensor([0.5]))
self.initialize()
self.sequence_length = config.max_sequence_length
self.is_causal = config.is_decoder
self.not_flex = config.not_flex
@torch.no_grad()
def initialize(self) -> None:
std: float = math.sqrt(2.0 / (self.hidden_size + 4*self.hidden_size))
for weight in self.qk_proj.weights:
nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std)
nn.init.trunc_normal_(self.v_proj.weight, mean=0.0, std=std, a=2*std, b=2*std)
self.out_proj.weight.data.zero_()
def set_window_length(self, window_length: int, not_flex: bool) -> None:
self.window_length: int = window_length
if not not_flex:
self.block_mask = self.create_block_mask(window_length)
def causal_mask_mode(self, window_length, b, _, q_idx, kv_idx):
return (q_idx >= kv_idx) & ((q_idx - kv_idx) < window_length)
def bidirectional_mask_mode(self, window_length, b, _, q_idx, kv_idx):
return ((q_idx - kv_idx) < window_length) & ((kv_idx - q_idx) < window_length)
def create_block_mask(self, window_length: int) -> torch.Tensor:
if self.is_causal:
return create_block_mask(
partial(self.causal_mask_mode, self.window_length),
1, 1, self.sequence_length, self.sequence_length, device=self.k_scale.device
)
else:
return create_block_mask(
partial(self.bidirectional_mask_mode, self.window_length),
1, 1, self.sequence_length, self.sequence_length, device=self.k_scale.device
)
def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
attention_scores: torch.Tensor
attention_probabilities: torch.Tensor
batch_size: int
query_length: int
key_length: int
batch_size, _, query_length, _ = query.size()
_, _, key_length, _ = key.size()
if self.is_causal:
window_mask = ~torch.ones(query_length, key_length, dtype=torch.bool, device=self.k_scale.device).tril().triu(diagonal=-self.window_length).view(1, 1, query_length, key_length)
else:
window_mask = ~torch.ones(query_length, key_length, dtype=torch.bool, device=self.k_scale.device).tril(diagonal=self.window_length).triu(diagonal=-self.window_length).view(1, 1, query_length, key_length)
if padding_mask is not None:
attention_mask = padding_mask | window_mask
else:
attention_mask = window_mask
attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, T, T]
attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
attention_probabilities = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
attention_probabilities = self.dropout(attention_probabilities)
value = torch.bmm(attention_probabilities.flatten(0, 1), value.flatten(0, 1))
value = value.view(batch_size, self.num_attention_heads, query_length, self.d_v)
return value, attention_probabilities.detach()
def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, mask: torch.Tensor | None = None, doc_ids: torch.Tensor | None = None) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_layer = self.pre_v_norm(hidden_layer)
qk_layer = self.pre_qk_norm(qk_layer)
query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
value = self.v_proj(hidden_layer)
query_length: int = hidden_layer.size(0)
key_length: int = hidden_layer.size(0)
batch_size: int = hidden_layer.size(1)
query = query.reshape(query_length, batch_size, self.num_attention_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D]
key = key.reshape(key_length, batch_size, self.num_kv_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D]
value = value.reshape(key_length, batch_size, self.num_kv_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D]
query, key = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query), ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
if v1 is None:
v1 = value
value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
query = self.rope_embedding(query)
key = self.rope_embedding(key)
if self.not_flex:
output, attention_probabilities = self.attention_operation(query, key, value, mask)
else:
def document_score_mod(score, b, _, q_idx, kv_idx):
return torch.where(doc_ids[q_idx] == doc_ids[kv_idx], score, -float("inf"))
if self.is_causal:
block_mask = create_block_mask(
partial(self.causal_mask_mode, self.window_length),
1, 1, query_length, key_length, device=self.k_scale.device
)
else:
block_mask = create_block_mask(
partial(self.bidirectional_mask_mode, self.window_length),
1, 1, query_length, key_length, device=self.k_scale.device
)
output = flex_attention(
query, key, value, block_mask=block_mask, enable_gqa=True
)
attention_probabilities = None
output = output.permute(2, 0, 1, 3).flatten(2, 3) # shape: [T, B, H*D]
output = self.inter_norm(output)
output = self.out_proj(output)
return self.dropout(output), v1, attention_probabilities
class FeedForward(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.up_proj: CastedLinear
self.down_proj: CastedLinear
self.pre_norm: nn.LayerNorm
self.inter_norm: nn.LayerNorm
self.activation: GeGLU
self.dropout: nn.Dropout
self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.feed_forward_pre_norm_eps, elementwise_affine=config.feed_forward_pre_norm_affine)
self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
self.activation = GeGLU()
self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.feed_forward_inter_norm_eps, elementwise_affine=config.feed_forward_inter_norm_affine)
self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
self.dropout = nn.Dropout(config.feed_forward_dropout_p)
self.initialize(config.hidden_size)
@torch.no_grad()
def initialize(self, hidden_size: int) -> None:
std: float = math.sqrt(2.0 / (5*hidden_size))
for weight in self.up_proj.weights:
nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std)
self.down_proj.weight.data.zero_()
def up_project(self, hidden_layer: torch.Tensor) -> torch.Tensor:
hidden_layer = self.pre_norm(hidden_layer)
return self.up_proj(hidden_layer)
def activate(self, projection: torch.Tensor) -> torch.Tensor:
activated_projection: torch.Tensor
activated_projection = self.activation(projection)
activated_projection = self.inter_norm(activated_projection.float()).type_as(projection)
return activated_projection
def down_project(self, activated_projection: torch.Tensor) -> torch.Tensor:
output: torch.Tensor
output = self.down_proj(activated_projection)
return self.dropout(output)
def forward(self, hidden_layer: torch.Tensor) -> torch.Tensor:
output: torch.Tensor
output = self.up_project(hidden_layer)
output = self.activate(output)
output = self.down_project(output)
return output
class RotaryPositionalEmbeddings(nn.Module):
def __init__(self, config, theta: int) -> None:
super().__init__()
assert hasattr(config, "d_qk"), "The config must have a d_qk attribute!"
assert hasattr(config, "max_sequence_length"), "The config must have a max_sequence_length attribute!"
self.inv_freq: torch.Tensor
self.cos_matrix: torch.Tensor
self.sin_matrix: torch.Tensor
head_size: int
max_seq_len: int
inv_freq: torch.Tensor
pos: torch.Tensor
embedding: torch.Tensor
head_size = config.d_qk
assert head_size % 2 == 0
max_seq_len = config.max_sequence_length
inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
pos = torch.arange(max_seq_len, dtype=torch.float32)
embedding = torch.einsum('n, d -> nd', pos, inv_freq)
embedding = torch.cat([embedding, embedding], dim=-1).unsqueeze(0)
self.register_buffer("cos_matrix", embedding.cos(), persistent=False)
self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
seq_len: int
cos_matrix: torch.Tensor
sin_matrix: torch.Tensor
x_rotate_half: torch.Tensor
out: torch.Tensor
hidden_layer = x.float()
seq_len = x.shape[2]
cos_matrix = self.cos_matrix[:, None, :seq_len, :]
sin_matrix = self.sin_matrix[:, None, :seq_len, :]
x_rotate_half = torch.cat(
[
-hidden_layer[:, :, :, x.size(-1) // 2:],
hidden_layer[:, :, :, :x.size(-1) // 2]
],
dim=-1
)
out = hidden_layer * cos_matrix + x_rotate_half * sin_matrix
return out.type_as(x)
#
# HuggingFace wrappers
#
class GptBertPreTrainedModel(PreTrainedModel):
config_class = GptBertConfig
supports_gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
raise NotImplementedError("Gradient checkpointing is not supported by this model")
def _init_weights(self, module):
std = math.sqrt(2.0 / (5.0 * self.hidden_size))
if isinstance(module, nn.Linear):
nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class GptBertModel(GptBertPreTrainedModel):
def __init__(self, config, add_mlm_layer=False, **kwargs):
super().__init__(config, **kwargs)
self.config = config
self.hidden_size = config.hidden_size
self.embedding = Embedding(config)
self.encoder = Encoder(config)
self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
self.set_window_length(config)
def set_window_length(self, config) -> None:
self.encoder.set_window_length(config)
def get_input_embeddings(self):
return self.embedding.word_embedding
def set_input_embeddings(self, value):
self.embedding.word_embedding = value
def get_contextualized_embeddings(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None
) -> List[torch.Tensor]:
if input_ids is not None:
input_shape = input_ids.size()
else:
raise ValueError("You have to specify input_ids")
batch_size, seq_length = input_shape
device = input_ids.device
# if attention_mask is None:
# attention_mask = torch.zeros(batch_size, seq_length, dtype=torch.bool, device=device)
if attention_mask is not None:
attention_mask = ~attention_mask.bool()
if len(attention_mask.size()) == 2:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
elif len(attention_mask.size()) == 3:
attention_mask = attention_mask.unsqueeze(1)
if self.config.is_decoder:
attention_mask = attention_mask | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(0)
static_embeddings = self.embedding(input_ids.t())
contextualized_embeddings, attention_probs = self.encoder(static_embeddings, static_embeddings, attention_mask)
contextualized_embeddings = [e.transpose(0, 1) for e in contextualized_embeddings]
last_layer = contextualized_embeddings[-1]
contextualized_embeddings = [contextualized_embeddings[0]] + [
contextualized_embeddings[i] - contextualized_embeddings[i - 1]
for i in range(1, len(contextualized_embeddings))
]
return last_layer, contextualized_embeddings, attention_probs
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
if not return_dict:
return (
sequence_output,
*([contextualized_embeddings] if output_hidden_states else []),
*([attention_probs] if output_attentions else [])
)
return BaseModelOutput(
last_hidden_state=sequence_output,
hidden_states=contextualized_embeddings if output_hidden_states else None,
attentions=attention_probs if output_attentions else None
)
class GptBertForMaskedLM(GptBertModel):
_keys_to_ignore_on_load_unexpected = ["head"]
def __init__(self, config, **kwargs):
super().__init__(config, add_mlm_layer=True, **kwargs)
def get_output_embeddings(self):
return self.classifier.emb2vocab.weight
def set_output_embeddings(self, new_embeddings):
self.classifier.emb2vocab.weight = new_embeddings
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
subword_prediction = self.classifier(sequence_output)
subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
masked_lm_loss = None
if labels is not None:
labels_flatten = labels[:, 1:].flatten()
subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
if not return_dict:
output = (
subword_prediction,
*([contextualized_embeddings] if output_hidden_states else []),
*([attention_probs] if output_attentions else [])
)
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=subword_prediction,
hidden_states=contextualized_embeddings if output_hidden_states else None,
attentions=attention_probs if output_attentions else None
)
class Classifier(nn.Module):
def __init__(self, config, num_labels: int):
super().__init__()
drop_out = getattr(config, "cls_dropout", None)
drop_out = config.hidden_dropout_prob if drop_out is None else drop_out
self.projection: CastedLinear
self.emb2vocab: CastedLinear
self.pre_norm: nn.LayerNorm
self.post_norm: nn.LayerNorm
self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine)
self.projection = CastedLinear(config.hidden_size, config.hidden_size, bias=False)
self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine)
self.emb2vocab = CastedLinear(config.hidden_size, num_labels, bias=True)
self.dropout = nn.Dropout(drop_out)
self.initialize(config.hidden_size, config.intermediate_size, num_labels)
@torch.no_grad()
def initialize(self, hidden_size: int, intermediate_size: int, vocab_size: int) -> None:
proj_std: float = math.sqrt(2.0 / (hidden_size + intermediate_size))
nn.init.trunc_normal_(self.projection.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
nn.init.trunc_normal_(self.emb2vocab.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
self.emb2vocab.bias.zero_()
def project(self, hidden_layer: torch.Tensor) -> torch.Tensor:
projection: torch.Tensor
projection = self.pre_norm(hidden_layer)
projection = self.dropout(projection)
projection = self.projection(hidden_layer)
projection = gelu_new(projection)
projection = self.post_norm(projection)
return projection
def calculate_output(self, hidden_layer: torch.Tensor) -> torch.Tensor:
return self.emb2vocab(hidden_layer)
def forward(self, hidden_layer: torch.Tensor) -> torch.Tensor:
output: torch.Tensor
projection: torch.Tensor
projection = self.project(hidden_layer)
output = self.calculate_output(projection)
return output
class GptBertForCausalLM(GptBertModel):
_keys_to_ignore_on_load_unexpected = ["head"]
def __init__(self, config, **kwargs):
config.is_decoder = True
super().__init__(config, add_mlm_layer=True, **kwargs)
def get_output_embeddings(self):
return self.classifier.emb2vocab.weight
def set_output_embeddings(self, new_embeddings):
self.classifier.emb2vocab.weight = new_embeddings
def get_input_embeddings(self):
return self.embedding.word_embedding
def set_input_embeddings(self, value):
self.embedding.word_embedding = value
def set_decoder(self, decoder):
self.encoder = decoder
def get_decoder(self):
return self.encoder
def can_generate(self):
return True
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
) -> Union[Tuple, CausalLMOutput]:
assert inputs_embeds is None, "inputs_embeds is not supported for now"
assert past_key_values is None, "past_key_values is not supported for now"
assert not use_cache, "use_cache is not supported for now"
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
subword_prediction = self.classifier(sequence_output)
subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
masked_lm_loss = None
if labels is not None:
labels_flatten = labels[:, 1:].flatten()
subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
if not return_dict:
output = (
subword_prediction,
*([contextualized_embeddings] if output_hidden_states else []),
*([attention_probs] if output_attentions else [])
)
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return CausalLMOutput(
loss=masked_lm_loss,
logits=subword_prediction,
hidden_states=contextualized_embeddings if output_hidden_states else None,
attentions=attention_probs if output_attentions else None
)
def prepare_inputs_for_generation(
self,
input_ids: torch.Tensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
use_cache: bool = True,
num_logits_to_keep: Optional[int] = None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
class GptBertForSequenceClassification(GptBertModel):
_keys_to_ignore_on_load_unexpected = ["classifier"]
_keys_to_ignore_on_load_missing = ["head"]
def __init__(self, config, **kwargs):
super().__init__(config, add_mlm_layer=False, **kwargs)
self.num_labels = config.num_labels
self.head = Classifier(config, self.num_labels)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
logits = self.head(sequence_output[:, 0, :])
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = nn.MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (
logits,
*([contextualized_embeddings] if output_hidden_states else []),
*([attention_probs] if output_attentions else [])
)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=contextualized_embeddings if output_hidden_states else None,
attentions=attention_probs if output_attentions else None
)
class GptBertForTokenClassification(GptBertModel):
_keys_to_ignore_on_load_unexpected = ["classifier"]
_keys_to_ignore_on_load_missing = ["head"]
def __init__(self, config, **kwargs):
super().__init__(config, add_mlm_layer=False, **kwargs)
self.num_labels = config.num_labels
self.head = Classifier(config, self.num_labels)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
logits = self.head(sequence_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (
logits,
*([contextualized_embeddings] if output_hidden_states else []),
*([attention_probs] if output_attentions else [])
)
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=contextualized_embeddings if output_hidden_states else None,
attentions=attention_probs if output_attentions else None
)
class GptBertForQuestionAnswering(GptBertModel):
_keys_to_ignore_on_load_unexpected = ["classifier"]
_keys_to_ignore_on_load_missing = ["head"]
def __init__(self, config, **kwargs):
super().__init__(config, add_mlm_layer=False, **kwargs)
self.num_labels = config.num_labels
self.head = Classifier(config, self.num_labels)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
**kwargs
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
logits = self.head(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (
start_logits,
end_logits,
*([contextualized_embeddings] if output_hidden_states else []),
*([attention_probs] if output_attentions else [])
)
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=contextualized_embeddings if output_hidden_states else None,
attentions=attention_probs if output_attentions else None
)
class GptBertForMultipleChoice(GptBertModel):
_keys_to_ignore_on_load_unexpected = ["classifier"]
_keys_to_ignore_on_load_missing = ["head"]
def __init__(self, config, **kwargs):
super().__init__(config, add_mlm_layer=False, **kwargs)
self.num_labels = getattr(config, "num_labels", 2)
self.head = Classifier(config, self.num_labels)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask)
logits = self.head(sequence_output)
reshaped_logits = logits.view(-1, num_choices)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
if not return_dict:
output = (
reshaped_logits,
*([contextualized_embeddings] if output_hidden_states else []),
*([attention_probs] if output_attentions else [])
)
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=contextualized_embeddings if output_hidden_states else None,
attentions=attention_probs if output_attentions else None
)