refine-codebase (#33)
Browse files- style: removing unused files, black, isort (6e55444329aa0c7f03cad24df7181b1599f8c0fc)
- docs: add comments (1752c7c313018d447f01539883f352af472d8001)
- configuration (3830381a980542ad592bcf3cc6c8d8cda25947fb)
- chore: remove model weights (f9962ee8bc4141783cf69a239d29866bc93c5343)
- docs: readme (1babf317f017ea44880d1e39a8e86ed86c087399)
- chore: remove tokenizer file (37bae6d5bd9f07db0adb1e5c058490dc5c985b79)
- chore: remove config (447c0db77b026b8f668214d36436c66a1cc154ce)
- chore: no need for tokenizer config as well (3b9e730c66bf50cd728107c28adf0d8c2be0c207)
- readme (e3fca02ce1e5e0e99c6d2cbb310a4409811de116)
- README.md +11 -4
- block.py +5 -4
- config.json +0 -31
- configuration_xlm_roberta.py +85 -36
- embedding.py +27 -13
- mha.py +101 -42
- mlp.py +33 -15
- modeling_lora.py +49 -22
- modeling_xlm_roberta.py +116 -194
- modeling_xlm_roberta_for_glue.py +0 -109
- pytorch_model.bin +0 -3
- rotary.py +43 -16
- stochastic_depth.py +1 -1
- tokenizer.json +0 -0
- tokenizer_config.json +0 -4
- xlm_padding.py +24 -10
README.md
CHANGED
|
@@ -1,5 +1,12 @@
|
|
| 1 |
-
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Core implementation of Jina XLM-RoBERTa
|
| 2 |
|
| 3 |
+
This implementation is adapted from [XLM-Roberta](https://huggingface.co/docs/transformers/en/model_doc/xlm-roberta). In contrast to the original implementation, this model uses Rotary positional encodings and supports flash-attention 2.
|
| 4 |
+
|
| 5 |
+
### Models that use this implementation
|
| 6 |
+
|
| 7 |
+
to be added soon
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
### Converting weights
|
| 11 |
+
|
| 12 |
+
Weights from an [original XLMRoberta model](https://huggingface.co/FacebookAI/xlm-roberta-large) can be converted using the `convert_roberta_weights_to_flash.py` script in the model repository.
|
block.py
CHANGED
|
@@ -8,15 +8,14 @@ from typing import Optional
|
|
| 8 |
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
-
import torch.nn.functional as F
|
| 12 |
from torch import Tensor
|
| 13 |
|
| 14 |
-
from .stochastic_depth import StochasticDepth
|
| 15 |
from .mha import MHA
|
| 16 |
from .mlp import Mlp
|
|
|
|
| 17 |
|
| 18 |
try:
|
| 19 |
-
from flash_attn.ops.triton.layer_norm import
|
| 20 |
except ImportError:
|
| 21 |
layer_norm_fn, RMSNorm = None, None
|
| 22 |
|
|
@@ -233,7 +232,9 @@ class Block(nn.Module):
|
|
| 233 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 234 |
)
|
| 235 |
if not isinstance(self.mlp, nn.Identity):
|
| 236 |
-
mlp_out = self.mlp(
|
|
|
|
|
|
|
| 237 |
if self.return_residual: # mlp out is actually a pair here
|
| 238 |
mlp_out, hidden_states = mlp_out
|
| 239 |
if not self.fused_dropout_add_ln:
|
|
|
|
| 8 |
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
|
|
|
| 11 |
from torch import Tensor
|
| 12 |
|
|
|
|
| 13 |
from .mha import MHA
|
| 14 |
from .mlp import Mlp
|
| 15 |
+
from .stochastic_depth import StochasticDepth
|
| 16 |
|
| 17 |
try:
|
| 18 |
+
from flash_attn.ops.triton.layer_norm import RMSNorm, layer_norm_fn
|
| 19 |
except ImportError:
|
| 20 |
layer_norm_fn, RMSNorm = None, None
|
| 21 |
|
|
|
|
| 232 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 233 |
)
|
| 234 |
if not isinstance(self.mlp, nn.Identity):
|
| 235 |
+
mlp_out = self.mlp(
|
| 236 |
+
hidden_states, adapter_mask=mixer_kwargs.get("adapter_mask")
|
| 237 |
+
)
|
| 238 |
if self.return_residual: # mlp out is actually a pair here
|
| 239 |
mlp_out, hidden_states = mlp_out
|
| 240 |
if not self.fused_dropout_add_ln:
|
config.json
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"auto_map": {
|
| 3 |
-
"AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
|
| 4 |
-
"AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
|
| 5 |
-
"AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
|
| 6 |
-
"AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM",
|
| 7 |
-
"AutoModelForSequenceClassification":"modeling_xlm_roberta.XLMRobertaForSequenceClassification"
|
| 8 |
-
},
|
| 9 |
-
"architectures": [
|
| 10 |
-
"XLMRobertaModel"
|
| 11 |
-
],
|
| 12 |
-
"attention_probs_dropout_prob": 0.1,
|
| 13 |
-
"bos_token_id": 0,
|
| 14 |
-
"eos_token_id": 2,
|
| 15 |
-
"hidden_act": "gelu",
|
| 16 |
-
"hidden_dropout_prob": 0.1,
|
| 17 |
-
"hidden_size": 768,
|
| 18 |
-
"initializer_range": 0.02,
|
| 19 |
-
"intermediate_size": 3072,
|
| 20 |
-
"layer_norm_eps": 1e-05,
|
| 21 |
-
"max_position_embeddings": 8194,
|
| 22 |
-
"num_attention_heads": 12,
|
| 23 |
-
"num_hidden_layers": 12,
|
| 24 |
-
"output_past": true,
|
| 25 |
-
"pad_token_id": 1,
|
| 26 |
-
"position_embedding_type": "absolute",
|
| 27 |
-
"transformers_version": "4.17.0.dev0",
|
| 28 |
-
"type_vocab_size": 1,
|
| 29 |
-
"use_cache": false,
|
| 30 |
-
"vocab_size": 250002
|
| 31 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configuration_xlm_roberta.py
CHANGED
|
@@ -1,44 +1,89 @@
|
|
| 1 |
-
from
|
|
|
|
| 2 |
import torch
|
|
|
|
|
|
|
| 3 |
|
| 4 |
class XLMRobertaFlashConfig(PretrainedConfig):
|
| 5 |
def __init__(
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
):
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
self.vocab_size = vocab_size
|
| 44 |
self.hidden_size = hidden_size
|
|
@@ -67,7 +112,11 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
| 67 |
self.emb_pooler = emb_pooler
|
| 68 |
self.matryoshka_dimensions = matryoshka_dimensions
|
| 69 |
self.truncate_dim = truncate_dim
|
| 70 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
self.torch_dtype = getattr(torch, torch_dtype)
|
| 72 |
else:
|
| 73 |
self.torch_dtype = torch_dtype
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
+
from transformers import PretrainedConfig
|
| 5 |
+
|
| 6 |
|
| 7 |
class XLMRobertaFlashConfig(PretrainedConfig):
|
| 8 |
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
vocab_size: int = 250002,
|
| 11 |
+
hidden_size: int = 1024,
|
| 12 |
+
num_hidden_layers: int = 24,
|
| 13 |
+
num_attention_heads: int = 16,
|
| 14 |
+
intermediate_size: int = 4096,
|
| 15 |
+
hidden_act: str = "gelu",
|
| 16 |
+
hidden_dropout_prob: float = 0.1,
|
| 17 |
+
attention_probs_dropout_prob: float = 0.1,
|
| 18 |
+
max_position_embeddings: int = 8194,
|
| 19 |
+
type_vocab_size: int = 1,
|
| 20 |
+
initializer_range: float = 0.02,
|
| 21 |
+
layer_norm_eps: float = 1e-05,
|
| 22 |
+
pad_token_id: int = 1,
|
| 23 |
+
bos_token_id: int = 0,
|
| 24 |
+
eos_token_id: int = 2,
|
| 25 |
+
position_embedding_type: str = "rotary",
|
| 26 |
+
rotary_emb_base: float = 10000.0,
|
| 27 |
+
use_cache: bool = True,
|
| 28 |
+
classifier_dropout: Optional[float] = None,
|
| 29 |
+
lora_adaptations: Optional[List[str]] = None,
|
| 30 |
+
lora_prompts: Optional[Dict[str, str]] = None,
|
| 31 |
+
lora_rank: int = 4,
|
| 32 |
+
lora_dropout_p: float = 0.0,
|
| 33 |
+
lora_alpha: int = 1,
|
| 34 |
+
lora_main_params_trainable: bool = False,
|
| 35 |
+
load_trained_adapters: bool = False,
|
| 36 |
+
use_flash_attn: bool = True,
|
| 37 |
+
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
| 38 |
+
emb_pooler: Optional[str] = None,
|
| 39 |
+
matryoshka_dimensions: Optional[List[int]] = None,
|
| 40 |
+
truncate_dim: Optional[int] = None,
|
| 41 |
+
**kwargs: Dict[str, Any],
|
| 42 |
):
|
| 43 |
+
"""
|
| 44 |
+
Initialize the XLMRobertaFlashConfig configuration.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
vocab_size (int): Size of the vocabulary.
|
| 48 |
+
hidden_size (int): Dimensionality of the encoder layers and the pooler layer.
|
| 49 |
+
num_hidden_layers (int): Number of hidden layers in the Transformer encoder.
|
| 50 |
+
num_attention_heads (int): Number of attention heads for each attention layer in the Transformer encoder.
|
| 51 |
+
intermediate_size (int): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer.
|
| 52 |
+
hidden_act (str): The activation function to use.
|
| 53 |
+
hidden_dropout_prob (float): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 54 |
+
attention_probs_dropout_prob (float): The dropout ratio for the attention probabilities.
|
| 55 |
+
max_position_embeddings (int): The maximum length of the position embeddings.
|
| 56 |
+
type_vocab_size (int): The vocabulary size of the token type ids.
|
| 57 |
+
initializer_range (float): The standard deviation for initializing all weight matrices.
|
| 58 |
+
layer_norm_eps (float): The epsilon used by the layer normalization layers.
|
| 59 |
+
pad_token_id (int): The ID of the padding token.
|
| 60 |
+
bos_token_id (int): The ID of the beginning-of-sequence token.
|
| 61 |
+
eos_token_id (int): The ID of the end-of-sequence token.
|
| 62 |
+
position_embedding_type (str): Type of position embeddings. Options are 'absolute', 'alibi', or 'rotary'.
|
| 63 |
+
rotary_emb_base (float): Base for rotary embeddings.
|
| 64 |
+
use_cache (bool): Whether or not the model should return the last key/values attentions (not used by all models).
|
| 65 |
+
classifier_dropout (Optional[float]): The dropout ratio for the classification head.
|
| 66 |
+
lora_adaptations (Optional[List[str]]): LoRA adaptations configuration.
|
| 67 |
+
lora_prompts (Optional[Dict[str, str]]): LoRA prompts configuration.
|
| 68 |
+
lora_rank (int): Rank for LoRA adaptations.
|
| 69 |
+
lora_dropout_p (float): Dropout probability for LoRA adaptations.
|
| 70 |
+
lora_alpha (int): Alpha parameter for LoRA.
|
| 71 |
+
lora_main_params_trainable (bool): Whether to make the main model parameters trainable when using LoRA.
|
| 72 |
+
load_trained_adapters (bool): Whether to load trained adapters.
|
| 73 |
+
use_flash_attn (bool): Whether to use FlashAttention.
|
| 74 |
+
torch_dtype (Optional[Union[str, torch.dtype]]): Data type for the tensors.
|
| 75 |
+
emb_pooler (Optional[str]): Pooling layer configuration.
|
| 76 |
+
matryoshka_dimensions (Optional[List[int]]): Configuration for matryoshka dimension reduction.
|
| 77 |
+
truncate_dim (Optional[int]): Dimension to truncate embeddings to, if any.
|
| 78 |
+
**kwargs (Dict[str, Any]): Additional keyword arguments passed to the configuration.
|
| 79 |
+
"""
|
| 80 |
|
| 81 |
+
super().__init__(
|
| 82 |
+
pad_token_id=pad_token_id,
|
| 83 |
+
bos_token_id=bos_token_id,
|
| 84 |
+
eos_token_id=eos_token_id,
|
| 85 |
+
**kwargs,
|
| 86 |
+
)
|
| 87 |
|
| 88 |
self.vocab_size = vocab_size
|
| 89 |
self.hidden_size = hidden_size
|
|
|
|
| 112 |
self.emb_pooler = emb_pooler
|
| 113 |
self.matryoshka_dimensions = matryoshka_dimensions
|
| 114 |
self.truncate_dim = truncate_dim
|
| 115 |
+
if (
|
| 116 |
+
torch_dtype
|
| 117 |
+
and hasattr(torch, torch_dtype)
|
| 118 |
+
and type(getattr(torch, torch_dtype)) is torch.dtype
|
| 119 |
+
):
|
| 120 |
self.torch_dtype = getattr(torch, torch_dtype)
|
| 121 |
else:
|
| 122 |
self.torch_dtype = torch_dtype
|
embedding.py
CHANGED
|
@@ -5,10 +5,8 @@
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
-
from
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
|
| 12 |
|
| 13 |
|
| 14 |
class XLMRobertaEmbeddings(nn.Module):
|
|
@@ -38,20 +36,29 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
| 38 |
max_position_embeddings, embed_dim, **factory_kwargs
|
| 39 |
)
|
| 40 |
if self.type_vocab_size > 0:
|
| 41 |
-
self.token_type_embeddings = nn.Embedding(
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
def forward(
|
|
|
|
|
|
|
| 44 |
"""
|
| 45 |
input_ids: (batch, seqlen)
|
| 46 |
position_ids: (batch, seqlen)
|
| 47 |
token_type_ids: (batch, seqlen)
|
|
|
|
| 48 |
"""
|
| 49 |
batch_size, seqlen = input_ids.shape
|
| 50 |
if adapter_mask is not None:
|
| 51 |
unique_tasks = torch.unique(adapter_mask)
|
| 52 |
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
| 53 |
-
embeddings = torch.empty(
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
for task_id in unique_tasks:
|
| 56 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 57 |
task_input_ids = input_ids[task_indices]
|
|
@@ -61,20 +68,27 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
| 61 |
embeddings = self.word_embeddings(input_ids)
|
| 62 |
if self.max_position_embeddings > 0:
|
| 63 |
if position_ids is None:
|
| 64 |
-
position_ids = create_position_ids_from_input_ids(
|
| 65 |
-
|
|
|
|
| 66 |
position_embeddings = self.position_embeddings(position_ids)
|
| 67 |
embeddings = embeddings + position_embeddings
|
| 68 |
if self.type_vocab_size > 0:
|
| 69 |
if token_type_ids is None:
|
| 70 |
-
token_type_ids = torch.zeros(
|
|
|
|
|
|
|
| 71 |
|
| 72 |
if adapter_mask is not None:
|
| 73 |
unique_tasks = torch.unique(adapter_mask)
|
| 74 |
for task_id in unique_tasks:
|
| 75 |
-
task_token_type_embeddings = self.token_type_embeddings(
|
|
|
|
|
|
|
| 76 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 77 |
-
embeddings[task_indices] =
|
|
|
|
|
|
|
| 78 |
else:
|
| 79 |
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 80 |
embeddings = embeddings + token_type_embeddings
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
+
from transformers.models.xlm_roberta.modeling_xlm_roberta import \
|
| 9 |
+
create_position_ids_from_input_ids
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class XLMRobertaEmbeddings(nn.Module):
|
|
|
|
| 36 |
max_position_embeddings, embed_dim, **factory_kwargs
|
| 37 |
)
|
| 38 |
if self.type_vocab_size > 0:
|
| 39 |
+
self.token_type_embeddings = nn.Embedding(
|
| 40 |
+
type_vocab_size, embed_dim, **factory_kwargs
|
| 41 |
+
)
|
| 42 |
|
| 43 |
+
def forward(
|
| 44 |
+
self, input_ids, position_ids=None, token_type_ids=None, adapter_mask=None
|
| 45 |
+
):
|
| 46 |
"""
|
| 47 |
input_ids: (batch, seqlen)
|
| 48 |
position_ids: (batch, seqlen)
|
| 49 |
token_type_ids: (batch, seqlen)
|
| 50 |
+
adapter_mask: (batch, 1)
|
| 51 |
"""
|
| 52 |
batch_size, seqlen = input_ids.shape
|
| 53 |
if adapter_mask is not None:
|
| 54 |
unique_tasks = torch.unique(adapter_mask)
|
| 55 |
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
| 56 |
+
embeddings = torch.empty(
|
| 57 |
+
*input_ids.shape,
|
| 58 |
+
self.word_embeddings.embedding_dim,
|
| 59 |
+
dtype=embedding_dtype,
|
| 60 |
+
device=input_ids.device
|
| 61 |
+
)
|
| 62 |
for task_id in unique_tasks:
|
| 63 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 64 |
task_input_ids = input_ids[task_indices]
|
|
|
|
| 68 |
embeddings = self.word_embeddings(input_ids)
|
| 69 |
if self.max_position_embeddings > 0:
|
| 70 |
if position_ids is None:
|
| 71 |
+
position_ids = create_position_ids_from_input_ids(
|
| 72 |
+
input_ids, padding_idx=self.word_embeddings.padding_idx
|
| 73 |
+
).to(input_ids.device)
|
| 74 |
position_embeddings = self.position_embeddings(position_ids)
|
| 75 |
embeddings = embeddings + position_embeddings
|
| 76 |
if self.type_vocab_size > 0:
|
| 77 |
if token_type_ids is None:
|
| 78 |
+
token_type_ids = torch.zeros(
|
| 79 |
+
seqlen, dtype=torch.long, device=input_ids.device
|
| 80 |
+
)
|
| 81 |
|
| 82 |
if adapter_mask is not None:
|
| 83 |
unique_tasks = torch.unique(adapter_mask)
|
| 84 |
for task_id in unique_tasks:
|
| 85 |
+
task_token_type_embeddings = self.token_type_embeddings(
|
| 86 |
+
token_type_ids, task_id=task_id
|
| 87 |
+
)
|
| 88 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 89 |
+
embeddings[task_indices] = (
|
| 90 |
+
embeddings[task_indices] + task_token_type_embeddings
|
| 91 |
+
)
|
| 92 |
else:
|
| 93 |
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 94 |
embeddings = embeddings + token_type_embeddings
|
mha.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Copyright (c) 2023, Tri Dao.
|
| 2 |
-
# Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
|
| 3 |
|
| 4 |
import math
|
| 5 |
from functools import partial
|
|
@@ -9,20 +12,19 @@ import torch.nn as nn
|
|
| 9 |
from einops import rearrange, repeat
|
| 10 |
|
| 11 |
try:
|
| 12 |
-
from flash_attn import (
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
flash_attn_with_kvcache,
|
| 18 |
-
)
|
| 19 |
except ImportError:
|
| 20 |
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
| 21 |
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
| 22 |
flash_attn_with_kvcache = None
|
| 23 |
|
| 24 |
try:
|
| 25 |
-
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense,
|
|
|
|
| 26 |
except ImportError:
|
| 27 |
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
| 28 |
|
|
@@ -42,7 +44,9 @@ def get_alibi_slopes(nheads):
|
|
| 42 |
closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
|
| 43 |
return (
|
| 44 |
get_slopes_power_of_2(closest_power_of_2)
|
| 45 |
-
+ get_alibi_slopes(2 * closest_power_of_2)[0::2][
|
|
|
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
|
|
@@ -67,7 +71,9 @@ class FlashSelfAttention(nn.Module):
|
|
| 67 |
deterministic=False,
|
| 68 |
):
|
| 69 |
super().__init__()
|
| 70 |
-
assert
|
|
|
|
|
|
|
| 71 |
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
| 72 |
self.causal = causal
|
| 73 |
self.softmax_scale = softmax_scale
|
|
@@ -147,7 +153,9 @@ class FlashCrossAttention(nn.Module):
|
|
| 147 |
deterministic=False,
|
| 148 |
):
|
| 149 |
super().__init__()
|
| 150 |
-
assert
|
|
|
|
|
|
|
| 151 |
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
| 152 |
self.causal = causal
|
| 153 |
self.softmax_scale = softmax_scale
|
|
@@ -313,7 +321,10 @@ class CrossAttention(nn.Module):
|
|
| 313 |
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 314 |
if key_padding_mask is not None:
|
| 315 |
padding_mask = torch.full(
|
| 316 |
-
(batch_size, seqlen_k),
|
|
|
|
|
|
|
|
|
|
| 317 |
)
|
| 318 |
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
| 319 |
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
|
@@ -425,20 +436,26 @@ class MHA(nn.Module):
|
|
| 425 |
else:
|
| 426 |
alibi_slopes = None
|
| 427 |
if window_size != (-1, -1):
|
| 428 |
-
assert
|
|
|
|
|
|
|
| 429 |
|
| 430 |
self.num_heads = num_heads
|
| 431 |
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
| 432 |
assert (
|
| 433 |
self.num_heads % self.num_heads_kv == 0
|
| 434 |
), "num_heads must be divisible by num_heads_kv"
|
| 435 |
-
assert
|
|
|
|
|
|
|
| 436 |
self.head_dim = self.embed_dim // num_heads
|
| 437 |
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
| 438 |
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
| 439 |
|
| 440 |
if self.rotary_emb_dim > 0:
|
| 441 |
-
assert
|
|
|
|
|
|
|
| 442 |
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
| 443 |
self.rotary_emb = RotaryEmbedding(
|
| 444 |
self.rotary_emb_dim,
|
|
@@ -453,23 +470,33 @@ class MHA(nn.Module):
|
|
| 453 |
|
| 454 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 455 |
linear_resid_cls = (
|
| 456 |
-
LinearResidual
|
|
|
|
|
|
|
| 457 |
)
|
| 458 |
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
| 459 |
inner_attn_cls = (
|
| 460 |
-
partial(
|
|
|
|
|
|
|
| 461 |
if use_flash_attn
|
| 462 |
else SelfAttention
|
| 463 |
)
|
| 464 |
inner_cross_attn_cls = (
|
| 465 |
-
partial(
|
|
|
|
|
|
|
| 466 |
if use_flash_attn
|
| 467 |
else CrossAttention
|
| 468 |
)
|
| 469 |
if not self.cross_attn:
|
| 470 |
-
self.Wqkv = wqkv_cls(
|
|
|
|
|
|
|
| 471 |
else:
|
| 472 |
-
self.Wq = linear_cls(
|
|
|
|
|
|
|
| 473 |
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 474 |
if self.dwconv:
|
| 475 |
if self.num_heads_kv == self.num_heads:
|
|
@@ -480,7 +507,9 @@ class MHA(nn.Module):
|
|
| 480 |
self.dwconv_q = nn.Conv1d(
|
| 481 |
embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
|
| 482 |
)
|
| 483 |
-
self.dwconv_kv = nn.Conv1d(
|
|
|
|
|
|
|
| 484 |
self.inner_attn = inner_attn_cls(
|
| 485 |
causal=causal,
|
| 486 |
softmax_scale=softmax_scale,
|
|
@@ -489,7 +518,9 @@ class MHA(nn.Module):
|
|
| 489 |
self.inner_cross_attn = inner_cross_attn_cls(
|
| 490 |
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
| 491 |
)
|
| 492 |
-
self.out_proj = linear_cls(
|
|
|
|
|
|
|
| 493 |
|
| 494 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
| 495 |
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
|
@@ -507,7 +538,9 @@ class MHA(nn.Module):
|
|
| 507 |
def _update_kv_cache(self, kv, inference_params):
|
| 508 |
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
| 509 |
assert not self.dwconv, "Generation does not support dwconv yet"
|
| 510 |
-
assert
|
|
|
|
|
|
|
| 511 |
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
| 512 |
|
| 513 |
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
|
@@ -523,7 +556,10 @@ class MHA(nn.Module):
|
|
| 523 |
self.rotary_emb._update_cos_sin_cache(
|
| 524 |
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
| 525 |
)
|
| 526 |
-
rotary_cos, rotary_sin =
|
|
|
|
|
|
|
|
|
|
| 527 |
else:
|
| 528 |
rotary_cos, rotary_sin = None, None
|
| 529 |
batch = q.shape[0]
|
|
@@ -545,7 +581,9 @@ class MHA(nn.Module):
|
|
| 545 |
cache_seqlens=cache_seqlens,
|
| 546 |
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 547 |
causal=self.inner_cross_attn.causal,
|
| 548 |
-
rotary_interleaved=
|
|
|
|
|
|
|
| 549 |
alibi_slopes=alibi_slopes,
|
| 550 |
)
|
| 551 |
return context
|
|
@@ -640,40 +678,49 @@ class MHA(nn.Module):
|
|
| 640 |
)
|
| 641 |
)
|
| 642 |
rotary_max_seqlen = (
|
| 643 |
-
inference_params.max_sequence_len
|
|
|
|
|
|
|
| 644 |
)
|
| 645 |
-
batch, seqlen = x.shape[:2]
|
| 646 |
-
lora_kwargs = {}
|
| 647 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
| 648 |
assert x_kv is None and mixer_subset is None
|
| 649 |
|
| 650 |
if adapter_mask is not None:
|
| 651 |
unique_tasks = torch.unique(adapter_mask)
|
| 652 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
| 653 |
-
qkv = torch.empty(
|
| 654 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 655 |
for task_id in unique_tasks:
|
| 656 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 657 |
task_tensor = x[task_indices]
|
| 658 |
if not self.return_residual:
|
| 659 |
task_qkv = self.Wqkv(task_tensor, task_id=task_id)
|
| 660 |
else:
|
| 661 |
-
task_qkv, _ = self.Wqkv(
|
|
|
|
|
|
|
| 662 |
qkv[task_indices] = task_qkv
|
| 663 |
else:
|
| 664 |
if not self.return_residual:
|
| 665 |
qkv = self.Wqkv(x)
|
| 666 |
else:
|
| 667 |
-
if hasattr(self.Wqkv,
|
| 668 |
qkv, x = self.Wqkv(x, residual=True)
|
| 669 |
else:
|
| 670 |
qkv, x = self.Wqkv(x)
|
| 671 |
|
| 672 |
if self.dwconv:
|
| 673 |
qkv = rearrange(
|
| 674 |
-
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2],
|
|
|
|
| 675 |
).contiguous()
|
| 676 |
-
qkv = rearrange(
|
|
|
|
|
|
|
| 677 |
if (
|
| 678 |
inference_params is None
|
| 679 |
or inference_params.seqlen_offset == 0
|
|
@@ -691,7 +738,9 @@ class MHA(nn.Module):
|
|
| 691 |
if not self.checkpointing:
|
| 692 |
context = self.inner_attn(qkv, **kwargs)
|
| 693 |
else:
|
| 694 |
-
context = torch.utils.checkpoint.checkpoint(
|
|
|
|
|
|
|
| 695 |
else:
|
| 696 |
context = self._update_kvcache_attention(
|
| 697 |
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
|
@@ -720,13 +769,17 @@ class MHA(nn.Module):
|
|
| 720 |
q = qkv[..., : self.num_heads * self.head_dim]
|
| 721 |
kv = qkv[..., self.num_heads * self.head_dim :]
|
| 722 |
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
| 723 |
-
kv = rearrange(
|
|
|
|
|
|
|
| 724 |
if self.dwconv:
|
| 725 |
q = rearrange(
|
| 726 |
-
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2],
|
|
|
|
| 727 |
).contiguous()
|
| 728 |
kv = rearrange(
|
| 729 |
-
self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2],
|
|
|
|
| 730 |
).contiguous()
|
| 731 |
if (
|
| 732 |
inference_params is None
|
|
@@ -752,14 +805,20 @@ class MHA(nn.Module):
|
|
| 752 |
else:
|
| 753 |
context = self._update_kvcache_attention(q, kv, inference_params)
|
| 754 |
else:
|
| 755 |
-
context = self._apply_rotary_update_kvcache_attention(
|
|
|
|
|
|
|
| 756 |
|
| 757 |
inp = rearrange(context, "... h d -> ... (h d)")
|
| 758 |
if adapter_mask is not None:
|
| 759 |
unique_tasks = torch.unique(adapter_mask)
|
| 760 |
out_dtype = next(self.out_proj.parameters()).dtype
|
| 761 |
-
out = torch.empty(
|
| 762 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
for task_id in unique_tasks:
|
| 764 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 765 |
task_tensor = inp[task_indices]
|
|
|
|
| 1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
|
| 2 |
+
# Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
|
| 3 |
+
# Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
|
| 4 |
+
|
| 5 |
# Copyright (c) 2023, Tri Dao.
|
|
|
|
| 6 |
|
| 7 |
import math
|
| 8 |
from functools import partial
|
|
|
|
| 12 |
from einops import rearrange, repeat
|
| 13 |
|
| 14 |
try:
|
| 15 |
+
from flash_attn import (flash_attn_kvpacked_func,
|
| 16 |
+
flash_attn_qkvpacked_func,
|
| 17 |
+
flash_attn_varlen_kvpacked_func,
|
| 18 |
+
flash_attn_varlen_qkvpacked_func,
|
| 19 |
+
flash_attn_with_kvcache)
|
|
|
|
|
|
|
| 20 |
except ImportError:
|
| 21 |
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
| 22 |
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
| 23 |
flash_attn_with_kvcache = None
|
| 24 |
|
| 25 |
try:
|
| 26 |
+
from flash_attn.ops.fused_dense import (ColumnParallelLinear, FusedDense,
|
| 27 |
+
RowParallelLinear)
|
| 28 |
except ImportError:
|
| 29 |
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
| 30 |
|
|
|
|
| 44 |
closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
|
| 45 |
return (
|
| 46 |
get_slopes_power_of_2(closest_power_of_2)
|
| 47 |
+
+ get_alibi_slopes(2 * closest_power_of_2)[0::2][
|
| 48 |
+
: nheads - closest_power_of_2
|
| 49 |
+
]
|
| 50 |
)
|
| 51 |
|
| 52 |
|
|
|
|
| 71 |
deterministic=False,
|
| 72 |
):
|
| 73 |
super().__init__()
|
| 74 |
+
assert (
|
| 75 |
+
flash_attn_varlen_qkvpacked_func is not None
|
| 76 |
+
), "FlashAttention is not installed"
|
| 77 |
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
| 78 |
self.causal = causal
|
| 79 |
self.softmax_scale = softmax_scale
|
|
|
|
| 153 |
deterministic=False,
|
| 154 |
):
|
| 155 |
super().__init__()
|
| 156 |
+
assert (
|
| 157 |
+
flash_attn_varlen_kvpacked_func is not None
|
| 158 |
+
), "FlashAttention is not installed"
|
| 159 |
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
| 160 |
self.causal = causal
|
| 161 |
self.softmax_scale = softmax_scale
|
|
|
|
| 321 |
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 322 |
if key_padding_mask is not None:
|
| 323 |
padding_mask = torch.full(
|
| 324 |
+
(batch_size, seqlen_k),
|
| 325 |
+
-10000.0,
|
| 326 |
+
dtype=scores.dtype,
|
| 327 |
+
device=scores.device,
|
| 328 |
)
|
| 329 |
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
| 330 |
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
|
|
|
| 436 |
else:
|
| 437 |
alibi_slopes = None
|
| 438 |
if window_size != (-1, -1):
|
| 439 |
+
assert (
|
| 440 |
+
use_flash_attn
|
| 441 |
+
), "Local (sliding window) attention code path requires flash_attn"
|
| 442 |
|
| 443 |
self.num_heads = num_heads
|
| 444 |
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
| 445 |
assert (
|
| 446 |
self.num_heads % self.num_heads_kv == 0
|
| 447 |
), "num_heads must be divisible by num_heads_kv"
|
| 448 |
+
assert (
|
| 449 |
+
self.embed_dim % num_heads == 0
|
| 450 |
+
), "embed_dim must be divisible by num_heads"
|
| 451 |
self.head_dim = self.embed_dim // num_heads
|
| 452 |
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
| 453 |
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
| 454 |
|
| 455 |
if self.rotary_emb_dim > 0:
|
| 456 |
+
assert (
|
| 457 |
+
not cross_attn
|
| 458 |
+
), "MHA with rotary embedding does not support cross-attention yet"
|
| 459 |
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
| 460 |
self.rotary_emb = RotaryEmbedding(
|
| 461 |
self.rotary_emb_dim,
|
|
|
|
| 470 |
|
| 471 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 472 |
linear_resid_cls = (
|
| 473 |
+
LinearResidual
|
| 474 |
+
if not fused_bias_fc
|
| 475 |
+
else partial(FusedDense, return_residual=True)
|
| 476 |
)
|
| 477 |
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
| 478 |
inner_attn_cls = (
|
| 479 |
+
partial(
|
| 480 |
+
FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size
|
| 481 |
+
)
|
| 482 |
if use_flash_attn
|
| 483 |
else SelfAttention
|
| 484 |
)
|
| 485 |
inner_cross_attn_cls = (
|
| 486 |
+
partial(
|
| 487 |
+
FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size
|
| 488 |
+
)
|
| 489 |
if use_flash_attn
|
| 490 |
else CrossAttention
|
| 491 |
)
|
| 492 |
if not self.cross_attn:
|
| 493 |
+
self.Wqkv = wqkv_cls(
|
| 494 |
+
embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs
|
| 495 |
+
)
|
| 496 |
else:
|
| 497 |
+
self.Wq = linear_cls(
|
| 498 |
+
embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs
|
| 499 |
+
)
|
| 500 |
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 501 |
if self.dwconv:
|
| 502 |
if self.num_heads_kv == self.num_heads:
|
|
|
|
| 507 |
self.dwconv_q = nn.Conv1d(
|
| 508 |
embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
|
| 509 |
)
|
| 510 |
+
self.dwconv_kv = nn.Conv1d(
|
| 511 |
+
kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim
|
| 512 |
+
)
|
| 513 |
self.inner_attn = inner_attn_cls(
|
| 514 |
causal=causal,
|
| 515 |
softmax_scale=softmax_scale,
|
|
|
|
| 518 |
self.inner_cross_attn = inner_cross_attn_cls(
|
| 519 |
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
| 520 |
)
|
| 521 |
+
self.out_proj = linear_cls(
|
| 522 |
+
embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs
|
| 523 |
+
)
|
| 524 |
|
| 525 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
| 526 |
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
|
|
|
| 538 |
def _update_kv_cache(self, kv, inference_params):
|
| 539 |
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
| 540 |
assert not self.dwconv, "Generation does not support dwconv yet"
|
| 541 |
+
assert (
|
| 542 |
+
self.layer_idx is not None
|
| 543 |
+
), "Generation requires layer_idx in the constructor"
|
| 544 |
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
| 545 |
|
| 546 |
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
|
|
|
| 556 |
self.rotary_emb._update_cos_sin_cache(
|
| 557 |
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
| 558 |
)
|
| 559 |
+
rotary_cos, rotary_sin = (
|
| 560 |
+
self.rotary_emb._cos_cached,
|
| 561 |
+
self.rotary_emb._sin_cached,
|
| 562 |
+
)
|
| 563 |
else:
|
| 564 |
rotary_cos, rotary_sin = None, None
|
| 565 |
batch = q.shape[0]
|
|
|
|
| 581 |
cache_seqlens=cache_seqlens,
|
| 582 |
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 583 |
causal=self.inner_cross_attn.causal,
|
| 584 |
+
rotary_interleaved=(
|
| 585 |
+
self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False
|
| 586 |
+
),
|
| 587 |
alibi_slopes=alibi_slopes,
|
| 588 |
)
|
| 589 |
return context
|
|
|
|
| 678 |
)
|
| 679 |
)
|
| 680 |
rotary_max_seqlen = (
|
| 681 |
+
inference_params.max_sequence_len
|
| 682 |
+
if inference_params is not None
|
| 683 |
+
else max_seqlen
|
| 684 |
)
|
|
|
|
|
|
|
| 685 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
| 686 |
assert x_kv is None and mixer_subset is None
|
| 687 |
|
| 688 |
if adapter_mask is not None:
|
| 689 |
unique_tasks = torch.unique(adapter_mask)
|
| 690 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
| 691 |
+
qkv = torch.empty(
|
| 692 |
+
*x.shape[:-1],
|
| 693 |
+
self.Wqkv.out_features,
|
| 694 |
+
dtype=qkv_dtype,
|
| 695 |
+
device=x.device,
|
| 696 |
+
)
|
| 697 |
for task_id in unique_tasks:
|
| 698 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 699 |
task_tensor = x[task_indices]
|
| 700 |
if not self.return_residual:
|
| 701 |
task_qkv = self.Wqkv(task_tensor, task_id=task_id)
|
| 702 |
else:
|
| 703 |
+
task_qkv, _ = self.Wqkv(
|
| 704 |
+
task_tensor, task_id=task_id, residual=True
|
| 705 |
+
)
|
| 706 |
qkv[task_indices] = task_qkv
|
| 707 |
else:
|
| 708 |
if not self.return_residual:
|
| 709 |
qkv = self.Wqkv(x)
|
| 710 |
else:
|
| 711 |
+
if hasattr(self.Wqkv, "parametrizations"):
|
| 712 |
qkv, x = self.Wqkv(x, residual=True)
|
| 713 |
else:
|
| 714 |
qkv, x = self.Wqkv(x)
|
| 715 |
|
| 716 |
if self.dwconv:
|
| 717 |
qkv = rearrange(
|
| 718 |
+
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2],
|
| 719 |
+
"b d s -> b s d",
|
| 720 |
).contiguous()
|
| 721 |
+
qkv = rearrange(
|
| 722 |
+
qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
|
| 723 |
+
)
|
| 724 |
if (
|
| 725 |
inference_params is None
|
| 726 |
or inference_params.seqlen_offset == 0
|
|
|
|
| 738 |
if not self.checkpointing:
|
| 739 |
context = self.inner_attn(qkv, **kwargs)
|
| 740 |
else:
|
| 741 |
+
context = torch.utils.checkpoint.checkpoint(
|
| 742 |
+
self.inner_attn, qkv, **kwargs
|
| 743 |
+
)
|
| 744 |
else:
|
| 745 |
context = self._update_kvcache_attention(
|
| 746 |
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
|
|
|
| 769 |
q = qkv[..., : self.num_heads * self.head_dim]
|
| 770 |
kv = qkv[..., self.num_heads * self.head_dim :]
|
| 771 |
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
| 772 |
+
kv = rearrange(
|
| 773 |
+
kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim
|
| 774 |
+
)
|
| 775 |
if self.dwconv:
|
| 776 |
q = rearrange(
|
| 777 |
+
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2],
|
| 778 |
+
"b d s -> b s d",
|
| 779 |
).contiguous()
|
| 780 |
kv = rearrange(
|
| 781 |
+
self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2],
|
| 782 |
+
"b d s -> b s d",
|
| 783 |
).contiguous()
|
| 784 |
if (
|
| 785 |
inference_params is None
|
|
|
|
| 805 |
else:
|
| 806 |
context = self._update_kvcache_attention(q, kv, inference_params)
|
| 807 |
else:
|
| 808 |
+
context = self._apply_rotary_update_kvcache_attention(
|
| 809 |
+
q, kv, inference_params
|
| 810 |
+
)
|
| 811 |
|
| 812 |
inp = rearrange(context, "... h d -> ... (h d)")
|
| 813 |
if adapter_mask is not None:
|
| 814 |
unique_tasks = torch.unique(adapter_mask)
|
| 815 |
out_dtype = next(self.out_proj.parameters()).dtype
|
| 816 |
+
out = torch.empty(
|
| 817 |
+
*inp.shape[:-1],
|
| 818 |
+
self.out_proj.out_features,
|
| 819 |
+
dtype=out_dtype,
|
| 820 |
+
device=inp.device,
|
| 821 |
+
)
|
| 822 |
for task_id in unique_tasks:
|
| 823 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 824 |
task_tensor = inp[task_indices]
|
mlp.py
CHANGED
|
@@ -8,14 +8,14 @@ import torch.nn as nn
|
|
| 8 |
import torch.nn.functional as F
|
| 9 |
from torch.distributed import ProcessGroup
|
| 10 |
|
| 11 |
-
|
| 12 |
try:
|
| 13 |
from flash_attn.ops.activations import swiglu
|
| 14 |
except ImportError:
|
| 15 |
swiglu = None
|
| 16 |
|
| 17 |
try:
|
| 18 |
-
from flash_attn.ops.fused_dense import ColumnParallelLinear,
|
|
|
|
| 19 |
except ImportError:
|
| 20 |
ColumnParallelLinear, RowParallelLinear = None, None
|
| 21 |
|
|
@@ -41,18 +41,23 @@ class Mlp(nn.Module):
|
|
| 41 |
factory_kwargs = {"device": device, "dtype": dtype}
|
| 42 |
super().__init__()
|
| 43 |
out_features = out_features if out_features is not None else in_features
|
| 44 |
-
hidden_features =
|
|
|
|
|
|
|
| 45 |
self.return_residual = return_residual
|
| 46 |
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
| 47 |
self.activation = activation
|
| 48 |
-
self.fc2 = nn.Linear(
|
|
|
|
|
|
|
| 49 |
|
| 50 |
def forward(self, x, adapter_mask=None):
|
| 51 |
if adapter_mask is not None:
|
| 52 |
unique_tasks = torch.unique(adapter_mask)
|
| 53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
| 54 |
-
y = torch.empty(
|
| 55 |
-
|
|
|
|
| 56 |
for task_id in unique_tasks:
|
| 57 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 58 |
task_tensor = x[task_indices]
|
|
@@ -66,8 +71,9 @@ class Mlp(nn.Module):
|
|
| 66 |
if adapter_mask is not None:
|
| 67 |
unique_tasks = torch.unique(adapter_mask)
|
| 68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
| 69 |
-
out = torch.empty(
|
| 70 |
-
|
|
|
|
| 71 |
for task_id in unique_tasks:
|
| 72 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 73 |
task_tensor = y[task_indices]
|
|
@@ -98,7 +104,9 @@ class ParallelMLP(nn.Module):
|
|
| 98 |
assert ColumnParallelLinear is not None, "Need to install fused_dense"
|
| 99 |
assert RowParallelLinear is not None, "Need to install fused_dense"
|
| 100 |
out_features = out_features if out_features is not None else in_features
|
| 101 |
-
hidden_features =
|
|
|
|
|
|
|
| 102 |
self.fc1 = ColumnParallelLinear(
|
| 103 |
in_features,
|
| 104 |
hidden_features,
|
|
@@ -144,17 +152,25 @@ class GatedMlp(nn.Module):
|
|
| 144 |
hidden_features = (
|
| 145 |
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
| 146 |
)
|
| 147 |
-
hidden_features = (
|
|
|
|
|
|
|
| 148 |
self.return_residual = return_residual
|
| 149 |
-
self.fc1 = nn.Linear(
|
|
|
|
|
|
|
| 150 |
self.activation = activation
|
| 151 |
-
self.fc2 = nn.Linear(
|
|
|
|
|
|
|
| 152 |
|
| 153 |
def forward(self, x):
|
| 154 |
y = self.fc1(x)
|
| 155 |
if self.activation == F.sigmoid: # Special case for GLU
|
| 156 |
y = F.glu(y, dim=-1)
|
| 157 |
-
elif
|
|
|
|
|
|
|
| 158 |
y, gate = y.chunk(2, dim=-1)
|
| 159 |
y = swiglu(gate, y)
|
| 160 |
else:
|
|
@@ -187,7 +203,9 @@ class ParallelGatedMlp(nn.Module):
|
|
| 187 |
hidden_features = (
|
| 188 |
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
| 189 |
)
|
| 190 |
-
hidden_features = (
|
|
|
|
|
|
|
| 191 |
if ColumnParallelLinear is None or RowParallelLinear is None:
|
| 192 |
raise ImportError("fused_dense is not installed")
|
| 193 |
self.fc1 = ColumnParallelLinear(
|
|
@@ -216,4 +234,4 @@ class ParallelGatedMlp(nn.Module):
|
|
| 216 |
y, gate = y.chunk(2, dim=-1)
|
| 217 |
y = y * self.activation(gate)
|
| 218 |
y = self.fc2(y)
|
| 219 |
-
return y
|
|
|
|
| 8 |
import torch.nn.functional as F
|
| 9 |
from torch.distributed import ProcessGroup
|
| 10 |
|
|
|
|
| 11 |
try:
|
| 12 |
from flash_attn.ops.activations import swiglu
|
| 13 |
except ImportError:
|
| 14 |
swiglu = None
|
| 15 |
|
| 16 |
try:
|
| 17 |
+
from flash_attn.ops.fused_dense import (ColumnParallelLinear,
|
| 18 |
+
RowParallelLinear)
|
| 19 |
except ImportError:
|
| 20 |
ColumnParallelLinear, RowParallelLinear = None, None
|
| 21 |
|
|
|
|
| 41 |
factory_kwargs = {"device": device, "dtype": dtype}
|
| 42 |
super().__init__()
|
| 43 |
out_features = out_features if out_features is not None else in_features
|
| 44 |
+
hidden_features = (
|
| 45 |
+
hidden_features if hidden_features is not None else in_features * 4
|
| 46 |
+
)
|
| 47 |
self.return_residual = return_residual
|
| 48 |
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
| 49 |
self.activation = activation
|
| 50 |
+
self.fc2 = nn.Linear(
|
| 51 |
+
hidden_features, out_features, bias=bias2, **factory_kwargs
|
| 52 |
+
)
|
| 53 |
|
| 54 |
def forward(self, x, adapter_mask=None):
|
| 55 |
if adapter_mask is not None:
|
| 56 |
unique_tasks = torch.unique(adapter_mask)
|
| 57 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
| 58 |
+
y = torch.empty(
|
| 59 |
+
*x.shape[:-1], self.fc1.out_features, dtype=fc1_dtype, device=x.device
|
| 60 |
+
)
|
| 61 |
for task_id in unique_tasks:
|
| 62 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 63 |
task_tensor = x[task_indices]
|
|
|
|
| 71 |
if adapter_mask is not None:
|
| 72 |
unique_tasks = torch.unique(adapter_mask)
|
| 73 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
| 74 |
+
out = torch.empty(
|
| 75 |
+
*y.shape[:-1], self.fc2.out_features, dtype=fc2_dtype, device=y.device
|
| 76 |
+
)
|
| 77 |
for task_id in unique_tasks:
|
| 78 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 79 |
task_tensor = y[task_indices]
|
|
|
|
| 104 |
assert ColumnParallelLinear is not None, "Need to install fused_dense"
|
| 105 |
assert RowParallelLinear is not None, "Need to install fused_dense"
|
| 106 |
out_features = out_features if out_features is not None else in_features
|
| 107 |
+
hidden_features = (
|
| 108 |
+
hidden_features if hidden_features is not None else in_features * 4
|
| 109 |
+
)
|
| 110 |
self.fc1 = ColumnParallelLinear(
|
| 111 |
in_features,
|
| 112 |
hidden_features,
|
|
|
|
| 152 |
hidden_features = (
|
| 153 |
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
| 154 |
)
|
| 155 |
+
hidden_features = (
|
| 156 |
+
(hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
| 157 |
+
)
|
| 158 |
self.return_residual = return_residual
|
| 159 |
+
self.fc1 = nn.Linear(
|
| 160 |
+
in_features, 2 * hidden_features, bias=bias1, **factory_kwargs
|
| 161 |
+
)
|
| 162 |
self.activation = activation
|
| 163 |
+
self.fc2 = nn.Linear(
|
| 164 |
+
hidden_features, out_features, bias=bias2, **factory_kwargs
|
| 165 |
+
)
|
| 166 |
|
| 167 |
def forward(self, x):
|
| 168 |
y = self.fc1(x)
|
| 169 |
if self.activation == F.sigmoid: # Special case for GLU
|
| 170 |
y = F.glu(y, dim=-1)
|
| 171 |
+
elif (
|
| 172 |
+
self.activation == F.silu and swiglu is not None
|
| 173 |
+
): # Special case for SwiGLU
|
| 174 |
y, gate = y.chunk(2, dim=-1)
|
| 175 |
y = swiglu(gate, y)
|
| 176 |
else:
|
|
|
|
| 203 |
hidden_features = (
|
| 204 |
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
| 205 |
)
|
| 206 |
+
hidden_features = (
|
| 207 |
+
(hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
| 208 |
+
)
|
| 209 |
if ColumnParallelLinear is None or RowParallelLinear is None:
|
| 210 |
raise ImportError("fused_dense is not installed")
|
| 211 |
self.fc1 = ColumnParallelLinear(
|
|
|
|
| 234 |
y, gate = y.chunk(2, dim=-1)
|
| 235 |
y = y * self.activation(gate)
|
| 236 |
y = self.fc2(y)
|
| 237 |
+
return y
|
modeling_lora.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import math
|
| 2 |
import os
|
| 3 |
-
import warnings
|
| 4 |
from functools import partial
|
| 5 |
from typing import Iterator, List, Optional, Tuple, Union
|
| 6 |
|
|
@@ -12,7 +11,8 @@ from torch.nn import Parameter
|
|
| 12 |
from torch.nn import functional as F
|
| 13 |
from transformers import PretrainedConfig
|
| 14 |
|
| 15 |
-
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel,
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def initialized_weights(
|
|
@@ -162,6 +162,16 @@ class LoRAParametrization(nn.Module):
|
|
| 162 |
dropout_p: float,
|
| 163 |
alpha: float,
|
| 164 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
if isinstance(layer, nn.Linear):
|
| 166 |
parametrize.register_parametrization(
|
| 167 |
layer,
|
|
@@ -177,7 +187,9 @@ class LoRAParametrization(nn.Module):
|
|
| 177 |
|
| 178 |
def new_forward(self, input, task_id=None, residual=False):
|
| 179 |
if task_id is not None:
|
| 180 |
-
weights = self.parametrizations.weight[0].lora_forward(
|
|
|
|
|
|
|
| 181 |
else:
|
| 182 |
weights = self.weight
|
| 183 |
|
|
@@ -204,13 +216,21 @@ class LoRAParametrization(nn.Module):
|
|
| 204 |
|
| 205 |
def new_forward(self, input, task_id=None):
|
| 206 |
if task_id is not None:
|
| 207 |
-
weights = self.parametrizations.weight[0].lora_forward(
|
|
|
|
|
|
|
| 208 |
else:
|
| 209 |
weights = self.weight
|
| 210 |
|
| 211 |
out = F.embedding(
|
| 212 |
-
input,
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
return out
|
| 216 |
|
|
@@ -218,10 +238,11 @@ class LoRAParametrization(nn.Module):
|
|
| 218 |
|
| 219 |
|
| 220 |
class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
|
|
|
|
|
|
|
|
| 221 |
def __init__(
|
| 222 |
-
self,
|
| 223 |
-
config: XLMRobertaFlashConfig,
|
| 224 |
-
roberta: Optional[XLMRobertaModel] = None
|
| 225 |
):
|
| 226 |
super().__init__(config)
|
| 227 |
if roberta is None:
|
|
@@ -235,7 +256,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 235 |
or len(self._lora_adaptations) < 1
|
| 236 |
):
|
| 237 |
raise ValueError(
|
| 238 |
-
f
|
| 239 |
)
|
| 240 |
self._lora_prompts = config.lora_prompts
|
| 241 |
if (
|
|
@@ -244,9 +265,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 244 |
or not all([v in self._lora_adaptations for v in self._lora_prompts.keys()])
|
| 245 |
):
|
| 246 |
raise ValueError(
|
| 247 |
-
f
|
| 248 |
-
f
|
| 249 |
-
|
| 250 |
self._adaptation_map = {
|
| 251 |
name: idx for idx, name in enumerate(self._lora_adaptations)
|
| 252 |
}
|
|
@@ -261,7 +282,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 261 |
)
|
| 262 |
self.main_params_trainable = config.lora_main_params_trainable
|
| 263 |
|
| 264 |
-
|
| 265 |
@property
|
| 266 |
def rotary_emb_base(self):
|
| 267 |
return self.roberta.rotary_emb_base
|
|
@@ -305,13 +325,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 305 |
config = XLMRobertaFlashConfig.from_pretrained(
|
| 306 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 307 |
)
|
| 308 |
-
|
| 309 |
-
if config.load_trained_adapters:
|
| 310 |
return super().from_pretrained(
|
| 311 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 312 |
)
|
| 313 |
-
else:
|
| 314 |
-
roberta = XLMRobertaModel.from_pretrained(
|
|
|
|
|
|
|
| 315 |
return cls(config, roberta=roberta)
|
| 316 |
|
| 317 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
|
@@ -350,10 +371,12 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 350 |
**kwargs,
|
| 351 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 352 |
"""
|
| 353 |
-
Computes sentence embeddings
|
| 354 |
|
|
|
|
|
|
|
| 355 |
task_type(`str`, *optional*, defaults to `None`):
|
| 356 |
-
Specifies the task for which the encoding is intended. If `task_type` is not
|
| 357 |
all LoRA adapters are disabled, and the model reverts to its original,
|
| 358 |
general-purpose weights.
|
| 359 |
"""
|
|
@@ -367,5 +390,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
| 367 |
if task_type:
|
| 368 |
task_id = self._adaptation_map[task_type]
|
| 369 |
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
| 370 |
-
adapter_mask = torch.full(
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
import os
|
|
|
|
| 3 |
from functools import partial
|
| 4 |
from typing import Iterator, List, Optional, Tuple, Union
|
| 5 |
|
|
|
|
| 11 |
from torch.nn import functional as F
|
| 12 |
from transformers import PretrainedConfig
|
| 13 |
|
| 14 |
+
from .modeling_xlm_roberta import (XLMRobertaFlashConfig, XLMRobertaModel,
|
| 15 |
+
XLMRobertaPreTrainedModel)
|
| 16 |
|
| 17 |
|
| 18 |
def initialized_weights(
|
|
|
|
| 162 |
dropout_p: float,
|
| 163 |
alpha: float,
|
| 164 |
):
|
| 165 |
+
"""
|
| 166 |
+
Registering LoRA adapters to all embedding and linear layers.
|
| 167 |
+
|
| 168 |
+
Additionally, we implement a custom forward function for LoRA parametrization.
|
| 169 |
+
This function modifies the layer's forward pass to optionally use task-specific
|
| 170 |
+
parameters. When a `task_id` is provided, it employs a LoRA parametrization
|
| 171 |
+
to modify the original weights according to the specific task. This allows
|
| 172 |
+
the layer to adapt dynamically to different tasks at runtime. If no `task_id`
|
| 173 |
+
is specified, the layer uses its original weights.
|
| 174 |
+
"""
|
| 175 |
if isinstance(layer, nn.Linear):
|
| 176 |
parametrize.register_parametrization(
|
| 177 |
layer,
|
|
|
|
| 187 |
|
| 188 |
def new_forward(self, input, task_id=None, residual=False):
|
| 189 |
if task_id is not None:
|
| 190 |
+
weights = self.parametrizations.weight[0].lora_forward(
|
| 191 |
+
self.weight, current_task=task_id
|
| 192 |
+
)
|
| 193 |
else:
|
| 194 |
weights = self.weight
|
| 195 |
|
|
|
|
| 216 |
|
| 217 |
def new_forward(self, input, task_id=None):
|
| 218 |
if task_id is not None:
|
| 219 |
+
weights = self.parametrizations.weight[0].lora_forward(
|
| 220 |
+
self.weight, current_task=task_id
|
| 221 |
+
)
|
| 222 |
else:
|
| 223 |
weights = self.weight
|
| 224 |
|
| 225 |
out = F.embedding(
|
| 226 |
+
input,
|
| 227 |
+
weights,
|
| 228 |
+
self.padding_idx,
|
| 229 |
+
self.max_norm,
|
| 230 |
+
self.norm_type,
|
| 231 |
+
self.scale_grad_by_freq,
|
| 232 |
+
self.sparse,
|
| 233 |
+
)
|
| 234 |
|
| 235 |
return out
|
| 236 |
|
|
|
|
| 238 |
|
| 239 |
|
| 240 |
class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
| 241 |
+
"""
|
| 242 |
+
A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
|
| 243 |
+
"""
|
| 244 |
def __init__(
|
| 245 |
+
self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None
|
|
|
|
|
|
|
| 246 |
):
|
| 247 |
super().__init__(config)
|
| 248 |
if roberta is None:
|
|
|
|
| 256 |
or len(self._lora_adaptations) < 1
|
| 257 |
):
|
| 258 |
raise ValueError(
|
| 259 |
+
f"`lora_adaptations` must be a list and contain at least one element"
|
| 260 |
)
|
| 261 |
self._lora_prompts = config.lora_prompts
|
| 262 |
if (
|
|
|
|
| 265 |
or not all([v in self._lora_adaptations for v in self._lora_prompts.keys()])
|
| 266 |
):
|
| 267 |
raise ValueError(
|
| 268 |
+
f"`lora_prompts` must be a dict and contain the same number of elements "
|
| 269 |
+
f"as `lora_adaptations` with all keys in `lora_prompts` present in `lora_adaptations`."
|
| 270 |
+
)
|
| 271 |
self._adaptation_map = {
|
| 272 |
name: idx for idx, name in enumerate(self._lora_adaptations)
|
| 273 |
}
|
|
|
|
| 282 |
)
|
| 283 |
self.main_params_trainable = config.lora_main_params_trainable
|
| 284 |
|
|
|
|
| 285 |
@property
|
| 286 |
def rotary_emb_base(self):
|
| 287 |
return self.roberta.rotary_emb_base
|
|
|
|
| 325 |
config = XLMRobertaFlashConfig.from_pretrained(
|
| 326 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 327 |
)
|
| 328 |
+
if config.load_trained_adapters: # checkpoint already contains LoRA adapters
|
|
|
|
| 329 |
return super().from_pretrained(
|
| 330 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 331 |
)
|
| 332 |
+
else: # initializing new adapters
|
| 333 |
+
roberta = XLMRobertaModel.from_pretrained(
|
| 334 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
| 335 |
+
)
|
| 336 |
return cls(config, roberta=roberta)
|
| 337 |
|
| 338 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
|
|
|
| 371 |
**kwargs,
|
| 372 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 373 |
"""
|
| 374 |
+
Computes sentence embeddings.
|
| 375 |
|
| 376 |
+
sentences(`str` or `List[str]`):
|
| 377 |
+
Sentence or sentences to be encoded
|
| 378 |
task_type(`str`, *optional*, defaults to `None`):
|
| 379 |
+
Specifies the task for which the encoding is intended. If `task_type` is not provided,
|
| 380 |
all LoRA adapters are disabled, and the model reverts to its original,
|
| 381 |
general-purpose weights.
|
| 382 |
"""
|
|
|
|
| 390 |
if task_type:
|
| 391 |
task_id = self._adaptation_map[task_type]
|
| 392 |
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
| 393 |
+
adapter_mask = torch.full(
|
| 394 |
+
(num_examples,), task_id, dtype=torch.int32, device=self.device
|
| 395 |
+
)
|
| 396 |
+
return self.roberta.encode(
|
| 397 |
+
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
| 398 |
+
)
|
modeling_xlm_roberta.py
CHANGED
|
@@ -13,39 +13,29 @@ import re
|
|
| 13 |
from collections import OrderedDict
|
| 14 |
from collections.abc import Sequence
|
| 15 |
from functools import partial
|
| 16 |
-
import
|
| 17 |
|
|
|
|
| 18 |
import torch
|
| 19 |
import torch.nn as nn
|
| 20 |
import torch.nn.functional as F
|
| 21 |
import torch.utils.checkpoint
|
| 22 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 23 |
-
from
|
| 24 |
-
from transformers import
|
|
|
|
| 25 |
from transformers.modeling_utils import PreTrainedModel
|
| 26 |
-
from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
|
| 27 |
-
from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
|
| 28 |
-
|
| 29 |
from transformers.models.bert.modeling_bert import (
|
| 30 |
-
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 31 |
-
|
| 32 |
-
|
| 33 |
|
| 34 |
-
from typing import List, Optional, Tuple, Union
|
| 35 |
-
|
| 36 |
-
from .xlm_padding import (
|
| 37 |
-
index_first_axis,
|
| 38 |
-
index_first_axis_residual,
|
| 39 |
-
pad_input,
|
| 40 |
-
unpad_input,
|
| 41 |
-
)
|
| 42 |
-
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
| 43 |
from .block import Block
|
|
|
|
| 44 |
from .embedding import XLMRobertaEmbeddings
|
| 45 |
from .mha import MHA
|
| 46 |
from .mlp import FusedMLP, Mlp
|
| 47 |
-
from .
|
| 48 |
-
from .rotary import RotaryEmbedding
|
| 49 |
|
| 50 |
try:
|
| 51 |
from flash_attn.ops.fused_dense import FusedDense
|
|
@@ -79,7 +69,7 @@ def get_use_flash_attn(config: XLMRobertaFlashConfig):
|
|
| 79 |
return False
|
| 80 |
if importlib.util.find_spec("flash_attn") is None:
|
| 81 |
logger.warning(
|
| 82 |
-
|
| 83 |
)
|
| 84 |
return False
|
| 85 |
return True
|
|
@@ -109,7 +99,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
| 109 |
fused_bias_fc=fused_bias_fc,
|
| 110 |
use_flash_attn=use_flash_attn,
|
| 111 |
return_residual=return_residual,
|
| 112 |
-
use_alibi=config.position_embedding_type ==
|
| 113 |
**rotary_kwargs,
|
| 114 |
)
|
| 115 |
return mixer_cls
|
|
@@ -204,15 +194,17 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 204 |
def gradient_checkpointing(self, value):
|
| 205 |
self._grad_checkpointing = value
|
| 206 |
|
| 207 |
-
def forward(
|
|
|
|
|
|
|
| 208 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
| 209 |
This means that we only compute the last layer output for these tokens.
|
| 210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
| 211 |
"""
|
| 212 |
if key_padding_mask is None or not self.use_flash_attn:
|
| 213 |
-
mixer_kwargs = {
|
| 214 |
if key_padding_mask is not None:
|
| 215 |
-
mixer_kwargs[
|
| 216 |
for layer in self.layers:
|
| 217 |
if self._grad_checkpointing:
|
| 218 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
@@ -227,10 +219,14 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 227 |
hidden_states = hidden_states[subset_mask]
|
| 228 |
else:
|
| 229 |
batch, seqlen = hidden_states.shape[:2]
|
| 230 |
-
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask =
|
| 231 |
-
hidden_states, key_padding_mask, adapter_mask
|
| 232 |
)
|
| 233 |
-
mixer_kwargs = {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
if subset_mask is None:
|
| 236 |
for layer in self.layers:
|
|
@@ -315,12 +311,18 @@ class XLMRobertaPooler(nn.Module):
|
|
| 315 |
if adapter_mask is not None:
|
| 316 |
unique_tasks = torch.unique(adapter_mask)
|
| 317 |
pool_dtype = next(self.dense.parameters()).dtype
|
| 318 |
-
pooled_output = torch.empty(
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
for task_id in unique_tasks:
|
| 321 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 322 |
task_first_token_tensor = first_token_tensor[task_indices]
|
| 323 |
-
task_pooled_output = self.dense(
|
|
|
|
|
|
|
| 324 |
pooled_output[task_indices] = task_pooled_output
|
| 325 |
else:
|
| 326 |
pooled_output = self.dense(first_token_tensor)
|
|
@@ -413,12 +415,11 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
|
|
| 413 |
*args,
|
| 414 |
**kwargs,
|
| 415 |
):
|
| 416 |
-
if not
|
| 417 |
-
kwargs[
|
| 418 |
return super().from_pretrained(*args, **kwargs)
|
| 419 |
|
| 420 |
|
| 421 |
-
|
| 422 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
| 423 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
| 424 |
super().__init__(config)
|
|
@@ -439,7 +440,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 439 |
self.embeddings = XLMRobertaEmbeddings(
|
| 440 |
config.hidden_size,
|
| 441 |
config.vocab_size,
|
| 442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
config.type_vocab_size,
|
| 444 |
padding_idx=config.pad_token_id,
|
| 445 |
)
|
|
@@ -449,16 +454,18 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 449 |
self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
|
| 450 |
|
| 451 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 452 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
|
|
| 453 |
self._rotary_emb_base = config.rotary_emb_base
|
| 454 |
|
| 455 |
@torch.inference_mode()
|
| 456 |
def encode(
|
| 457 |
-
self:
|
| 458 |
sentences: Union[str, List[str]],
|
| 459 |
batch_size: int = 32,
|
| 460 |
show_progress_bar: Optional[bool] = None,
|
| 461 |
-
output_value: str =
|
| 462 |
convert_to_numpy: bool = True,
|
| 463 |
convert_to_tensor: bool = False,
|
| 464 |
device: Optional[torch.device] = None,
|
|
@@ -516,12 +523,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 516 |
if convert_to_tensor:
|
| 517 |
convert_to_numpy = False
|
| 518 |
|
| 519 |
-
if output_value !=
|
| 520 |
convert_to_tensor = False
|
| 521 |
convert_to_numpy = False
|
| 522 |
|
| 523 |
input_was_string = False
|
| 524 |
-
if isinstance(sentences, str) or not hasattr(sentences,
|
| 525 |
sentences = [sentences]
|
| 526 |
input_was_string = True
|
| 527 |
|
|
@@ -532,11 +539,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 532 |
inverse_permutation = np.argsort(permutation)
|
| 533 |
sentences = [sentences[idx] for idx in permutation]
|
| 534 |
|
| 535 |
-
tokenizer_kwargs[
|
| 536 |
-
tokenizer_kwargs[
|
| 537 |
-
|
| 538 |
)
|
| 539 |
-
tokenizer_kwargs[
|
| 540 |
|
| 541 |
all_embeddings = []
|
| 542 |
|
|
@@ -550,11 +557,13 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 550 |
)
|
| 551 |
else:
|
| 552 |
range_iter = range(0, len(sentences), batch_size)
|
| 553 |
-
lora_arguments =
|
|
|
|
|
|
|
| 554 |
for i in range_iter:
|
| 555 |
encoded_input = self.tokenizer(
|
| 556 |
sentences[i : i + batch_size],
|
| 557 |
-
return_tensors=
|
| 558 |
**tokenizer_kwargs,
|
| 559 |
).to(self.device)
|
| 560 |
token_embs = self.forward(**encoded_input, **lora_arguments)[0]
|
|
@@ -562,18 +571,18 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 562 |
# Accumulate in fp32 to avoid overflow
|
| 563 |
token_embs = token_embs.float()
|
| 564 |
|
| 565 |
-
if output_value ==
|
| 566 |
raise NotImplementedError
|
| 567 |
elif output_value is None:
|
| 568 |
raise NotImplementedError
|
| 569 |
else:
|
| 570 |
-
if self.config.emb_pooler ==
|
| 571 |
embeddings = self.cls_pooling(
|
| 572 |
-
token_embs, encoded_input[
|
| 573 |
)
|
| 574 |
else:
|
| 575 |
embeddings = self.mean_pooling(
|
| 576 |
-
token_embs, encoded_input[
|
| 577 |
)
|
| 578 |
|
| 579 |
if normalize_embeddings:
|
|
@@ -603,14 +612,16 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 603 |
def truncate_embeddings(self, embeddings, truncate_dim):
|
| 604 |
if not self.config.matryoshka_dimensions:
|
| 605 |
logger.warning(
|
| 606 |
-
|
| 607 |
)
|
| 608 |
return embeddings
|
| 609 |
elif truncate_dim in self.config.matryoshka_dimensions:
|
| 610 |
return [tensor[:truncate_dim] for tensor in embeddings]
|
| 611 |
else:
|
| 612 |
-
raise ValueError(
|
| 613 |
-
|
|
|
|
|
|
|
| 614 |
|
| 615 |
def mean_pooling(
|
| 616 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
|
@@ -622,10 +633,8 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 622 |
input_mask_expanded.sum(1), min=1e-9
|
| 623 |
)
|
| 624 |
|
| 625 |
-
def cls_pooling(
|
| 626 |
-
|
| 627 |
-
):
|
| 628 |
-
return token_embeddings[:,0]
|
| 629 |
|
| 630 |
@property
|
| 631 |
def rotary_emb_base(self):
|
|
@@ -635,7 +644,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 635 |
def rotary_emb_base(self, base):
|
| 636 |
if not isinstance(base, (int, float)):
|
| 637 |
raise TypeError("Base must be an integer or float")
|
| 638 |
-
logger.info(f
|
| 639 |
for layer in self.encoder.layers:
|
| 640 |
layer.mixer.rotary_emb.base = base
|
| 641 |
self._rotary_emb_base = base
|
|
@@ -655,12 +664,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 655 |
layer output for these tokens.
|
| 656 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
| 657 |
"""
|
| 658 |
-
adapter_mask = kwargs.pop(
|
| 659 |
if kwargs:
|
| 660 |
for key, value in kwargs.items():
|
| 661 |
if value is not None:
|
| 662 |
logger.warning(
|
| 663 |
-
|
| 664 |
key,
|
| 665 |
)
|
| 666 |
|
|
@@ -669,7 +678,10 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 669 |
)
|
| 670 |
|
| 671 |
hidden_states = self.embeddings(
|
| 672 |
-
input_ids,
|
|
|
|
|
|
|
|
|
|
| 673 |
)
|
| 674 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
| 675 |
# BERT puts embedding LayerNorm before embedding dropout.
|
|
@@ -693,12 +705,17 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 693 |
subset_mask = None
|
| 694 |
|
| 695 |
sequence_output = self.encoder(
|
| 696 |
-
hidden_states,
|
|
|
|
|
|
|
|
|
|
| 697 |
)
|
| 698 |
|
| 699 |
if masked_tokens_mask is None:
|
| 700 |
pooled_output = (
|
| 701 |
-
self.pooler(sequence_output, adapter_mask=adapter_mask)
|
|
|
|
|
|
|
| 702 |
)
|
| 703 |
else:
|
| 704 |
# TD [2022-03-01]: the indexing here is very tricky.
|
|
@@ -712,7 +729,9 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 712 |
pool_input = sequence_output[first_col_mask[subset_mask]]
|
| 713 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
| 714 |
pooled_output = (
|
| 715 |
-
self.pooler(pool_input, pool=False, adapter_mask=adapter_mask)
|
|
|
|
|
|
|
| 716 |
)
|
| 717 |
|
| 718 |
if not return_dict:
|
|
@@ -817,103 +836,6 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
|
|
| 817 |
)
|
| 818 |
|
| 819 |
|
| 820 |
-
# class XLMRobertaForPreTraining(XLMRobertaPreTrainedModel):
|
| 821 |
-
# def __init__(self, config: XLMRobertaFlashConfig):
|
| 822 |
-
# super().__init__(config)
|
| 823 |
-
# # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
|
| 824 |
-
# # (around 15%) to the classifier heads.
|
| 825 |
-
# self.dense_seq_output = getattr(config, "dense_seq_output", False)
|
| 826 |
-
# # If last_layer_subset, we only need the compute the last layer for a subset of tokens
|
| 827 |
-
# # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
|
| 828 |
-
# self.last_layer_subset = getattr(config, "last_layer_subset", False)
|
| 829 |
-
# if self.last_layer_subset:
|
| 830 |
-
# assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
|
| 831 |
-
# use_xentropy = getattr(config, "use_xentropy", False)
|
| 832 |
-
# if use_xentropy and CrossEntropyLoss is None:
|
| 833 |
-
# raise ImportError("xentropy_cuda is not installed")
|
| 834 |
-
# loss_cls = (
|
| 835 |
-
# nn.CrossEntropyLoss
|
| 836 |
-
# if not use_xentropy
|
| 837 |
-
# else partial(CrossEntropyLoss, inplace_backward=True)
|
| 838 |
-
# )
|
| 839 |
-
#
|
| 840 |
-
# self.xlm = XLMRobertaModel(config)
|
| 841 |
-
# self.cls = XLMRobertaPreTrainingHeads(config)
|
| 842 |
-
# self.mlm_loss = loss_cls(ignore_index=0)
|
| 843 |
-
# self.nsp_loss = loss_cls(ignore_index=-1)
|
| 844 |
-
#
|
| 845 |
-
# # Initialize weights and apply final processing
|
| 846 |
-
# self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 847 |
-
# self.tie_weights()
|
| 848 |
-
#
|
| 849 |
-
# def tie_weights(self):
|
| 850 |
-
# self.cls.predictions.decoder.weight = self.xlm.embeddings.word_embeddings.weight
|
| 851 |
-
#
|
| 852 |
-
# def forward(
|
| 853 |
-
# self,
|
| 854 |
-
# input_ids,
|
| 855 |
-
# position_ids=None,
|
| 856 |
-
# token_type_ids=None,
|
| 857 |
-
# attention_mask=None,
|
| 858 |
-
# labels=None,
|
| 859 |
-
# next_sentence_label=None,
|
| 860 |
-
# ):
|
| 861 |
-
# """
|
| 862 |
-
# If labels are provided, they must be 0 for masked out tokens (as specified in the attention
|
| 863 |
-
# mask).
|
| 864 |
-
# Outputs:
|
| 865 |
-
# if `labels` and `next_sentence_label` are not `None`:
|
| 866 |
-
# Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
| 867 |
-
# sentence classification loss.
|
| 868 |
-
# if `labels` or `next_sentence_label` is `None`:
|
| 869 |
-
# Outputs a tuple comprising
|
| 870 |
-
# - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
| 871 |
-
# - the next sentence classification logits of shape [batch_size, 2].
|
| 872 |
-
#
|
| 873 |
-
# """
|
| 874 |
-
# masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
|
| 875 |
-
# outputs = self.xlm(
|
| 876 |
-
# input_ids,
|
| 877 |
-
# position_ids=position_ids,
|
| 878 |
-
# token_type_ids=token_type_ids,
|
| 879 |
-
# attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
| 880 |
-
# masked_tokens_mask=masked_tokens_mask,
|
| 881 |
-
# )
|
| 882 |
-
# sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
|
| 883 |
-
# if self.dense_seq_output and labels is not None:
|
| 884 |
-
# masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
|
| 885 |
-
# if not self.last_layer_subset:
|
| 886 |
-
# sequence_output = index_first_axis(
|
| 887 |
-
# rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
|
| 888 |
-
# )
|
| 889 |
-
# prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
| 890 |
-
#
|
| 891 |
-
# total_loss = None
|
| 892 |
-
# if labels is not None and next_sentence_label is not None:
|
| 893 |
-
# if (
|
| 894 |
-
# self.dense_seq_output and labels is not None
|
| 895 |
-
# ): # prediction_scores are already flattened
|
| 896 |
-
# masked_lm_loss = self.mlm_loss(
|
| 897 |
-
# prediction_scores, labels.flatten()[masked_token_idx]
|
| 898 |
-
# )
|
| 899 |
-
# else:
|
| 900 |
-
# masked_lm_loss = self.mlm_loss(
|
| 901 |
-
# rearrange(prediction_scores, "... v -> (...) v"),
|
| 902 |
-
# rearrange(labels, "... -> (...)"),
|
| 903 |
-
# )
|
| 904 |
-
# next_sentence_loss = self.nsp_loss(
|
| 905 |
-
# rearrange(seq_relationship_score, "... t -> (...) t"),
|
| 906 |
-
# rearrange(next_sentence_label, "... -> (...)"),
|
| 907 |
-
# )
|
| 908 |
-
# total_loss = masked_lm_loss.float() + next_sentence_loss.float()
|
| 909 |
-
#
|
| 910 |
-
# return BertForPreTrainingOutput(
|
| 911 |
-
# loss=total_loss,
|
| 912 |
-
# prediction_logits=prediction_scores,
|
| 913 |
-
# seq_relationship_logits=seq_relationship_score,
|
| 914 |
-
# )
|
| 915 |
-
|
| 916 |
-
|
| 917 |
def remap_state_dict(state_dict, config: PretrainedConfig):
|
| 918 |
"""
|
| 919 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
|
@@ -1065,47 +987,47 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
|
|
| 1065 |
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
| 1066 |
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
| 1067 |
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
| 1068 |
-
state_dict[
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
state_dict[
|
| 1072 |
-
|
| 1073 |
-
|
| 1074 |
-
|
| 1075 |
-
|
| 1076 |
-
state_dict[
|
| 1077 |
-
|
| 1078 |
-
|
| 1079 |
-
state_dict[
|
| 1080 |
-
|
| 1081 |
-
|
| 1082 |
-
state_dict[
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
-
state_dict[
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
else:
|
| 1089 |
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
| 1090 |
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
| 1091 |
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
| 1092 |
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
| 1093 |
-
state_dict[
|
| 1094 |
-
|
| 1095 |
-
|
| 1096 |
-
state_dict[
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
state_dict[
|
| 1100 |
-
|
| 1101 |
-
|
| 1102 |
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
| 1103 |
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
| 1104 |
: Wkv_biases.shape[0] // 2
|
| 1105 |
]
|
| 1106 |
-
state_dict[
|
| 1107 |
-
|
| 1108 |
-
|
| 1109 |
|
| 1110 |
def inv_key_mapping_ln(key):
|
| 1111 |
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
|
@@ -1294,4 +1216,4 @@ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
|
|
| 1294 |
logits=logits,
|
| 1295 |
hidden_states=outputs.hidden_states,
|
| 1296 |
attentions=outputs.attentions,
|
| 1297 |
-
)
|
|
|
|
| 13 |
from collections import OrderedDict
|
| 14 |
from collections.abc import Sequence
|
| 15 |
from functools import partial
|
| 16 |
+
from typing import List, Optional, Tuple, Union
|
| 17 |
|
| 18 |
+
import numpy as np
|
| 19 |
import torch
|
| 20 |
import torch.nn as nn
|
| 21 |
import torch.nn.functional as F
|
| 22 |
import torch.utils.checkpoint
|
| 23 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 24 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
| 25 |
+
from transformers.modeling_outputs import (MaskedLMOutput,
|
| 26 |
+
SequenceClassifierOutput)
|
| 27 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
|
|
|
|
|
|
| 28 |
from transformers.models.bert.modeling_bert import (
|
| 29 |
+
BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput)
|
| 30 |
+
from transformers.models.xlm_roberta.modeling_xlm_roberta import \
|
| 31 |
+
XLMRobertaLMHead
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
from .block import Block
|
| 34 |
+
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
| 35 |
from .embedding import XLMRobertaEmbeddings
|
| 36 |
from .mha import MHA
|
| 37 |
from .mlp import FusedMLP, Mlp
|
| 38 |
+
from .xlm_padding import index_first_axis_residual, pad_input, unpad_input
|
|
|
|
| 39 |
|
| 40 |
try:
|
| 41 |
from flash_attn.ops.fused_dense import FusedDense
|
|
|
|
| 69 |
return False
|
| 70 |
if importlib.util.find_spec("flash_attn") is None:
|
| 71 |
logger.warning(
|
| 72 |
+
"flash_attn is not installed. Using PyTorch native attention implementation."
|
| 73 |
)
|
| 74 |
return False
|
| 75 |
return True
|
|
|
|
| 99 |
fused_bias_fc=fused_bias_fc,
|
| 100 |
use_flash_attn=use_flash_attn,
|
| 101 |
return_residual=return_residual,
|
| 102 |
+
use_alibi=config.position_embedding_type == "alibi",
|
| 103 |
**rotary_kwargs,
|
| 104 |
)
|
| 105 |
return mixer_cls
|
|
|
|
| 194 |
def gradient_checkpointing(self, value):
|
| 195 |
self._grad_checkpointing = value
|
| 196 |
|
| 197 |
+
def forward(
|
| 198 |
+
self, hidden_states, key_padding_mask=None, subset_mask=None, adapter_mask=None
|
| 199 |
+
):
|
| 200 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
| 201 |
This means that we only compute the last layer output for these tokens.
|
| 202 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
| 203 |
"""
|
| 204 |
if key_padding_mask is None or not self.use_flash_attn:
|
| 205 |
+
mixer_kwargs = {"adapter_mask": adapter_mask}
|
| 206 |
if key_padding_mask is not None:
|
| 207 |
+
mixer_kwargs["key_padding_mask"] = key_padding_mask.bool()
|
| 208 |
for layer in self.layers:
|
| 209 |
if self._grad_checkpointing:
|
| 210 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
|
|
| 219 |
hidden_states = hidden_states[subset_mask]
|
| 220 |
else:
|
| 221 |
batch, seqlen = hidden_states.shape[:2]
|
| 222 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = (
|
| 223 |
+
unpad_input(hidden_states, key_padding_mask, adapter_mask)
|
| 224 |
)
|
| 225 |
+
mixer_kwargs = {
|
| 226 |
+
"cu_seqlens": cu_seqlens,
|
| 227 |
+
"max_seqlen": max_seqlen_in_batch,
|
| 228 |
+
"adapter_mask": cu_adapter_mask,
|
| 229 |
+
}
|
| 230 |
|
| 231 |
if subset_mask is None:
|
| 232 |
for layer in self.layers:
|
|
|
|
| 311 |
if adapter_mask is not None:
|
| 312 |
unique_tasks = torch.unique(adapter_mask)
|
| 313 |
pool_dtype = next(self.dense.parameters()).dtype
|
| 314 |
+
pooled_output = torch.empty(
|
| 315 |
+
first_token_tensor.shape[0],
|
| 316 |
+
self.dense.out_features,
|
| 317 |
+
dtype=pool_dtype,
|
| 318 |
+
device=first_token_tensor.device,
|
| 319 |
+
)
|
| 320 |
for task_id in unique_tasks:
|
| 321 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
| 322 |
task_first_token_tensor = first_token_tensor[task_indices]
|
| 323 |
+
task_pooled_output = self.dense(
|
| 324 |
+
task_first_token_tensor, task_id=task_id
|
| 325 |
+
)
|
| 326 |
pooled_output[task_indices] = task_pooled_output
|
| 327 |
else:
|
| 328 |
pooled_output = self.dense(first_token_tensor)
|
|
|
|
| 415 |
*args,
|
| 416 |
**kwargs,
|
| 417 |
):
|
| 418 |
+
if not "torch_dtype" in kwargs:
|
| 419 |
+
kwargs["torch_dtype"] = "auto"
|
| 420 |
return super().from_pretrained(*args, **kwargs)
|
| 421 |
|
| 422 |
|
|
|
|
| 423 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
| 424 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
| 425 |
super().__init__(config)
|
|
|
|
| 440 |
self.embeddings = XLMRobertaEmbeddings(
|
| 441 |
config.hidden_size,
|
| 442 |
config.vocab_size,
|
| 443 |
+
(
|
| 444 |
+
config.max_position_embeddings
|
| 445 |
+
if config.position_embedding_type == "absolute"
|
| 446 |
+
else -1
|
| 447 |
+
),
|
| 448 |
config.type_vocab_size,
|
| 449 |
padding_idx=config.pad_token_id,
|
| 450 |
)
|
|
|
|
| 454 |
self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
|
| 455 |
|
| 456 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 457 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 458 |
+
self.name_or_path, trust_remote_code=True
|
| 459 |
+
)
|
| 460 |
self._rotary_emb_base = config.rotary_emb_base
|
| 461 |
|
| 462 |
@torch.inference_mode()
|
| 463 |
def encode(
|
| 464 |
+
self: "XLMRobertaModel",
|
| 465 |
sentences: Union[str, List[str]],
|
| 466 |
batch_size: int = 32,
|
| 467 |
show_progress_bar: Optional[bool] = None,
|
| 468 |
+
output_value: str = "sentence_embedding",
|
| 469 |
convert_to_numpy: bool = True,
|
| 470 |
convert_to_tensor: bool = False,
|
| 471 |
device: Optional[torch.device] = None,
|
|
|
|
| 523 |
if convert_to_tensor:
|
| 524 |
convert_to_numpy = False
|
| 525 |
|
| 526 |
+
if output_value != "sentence_embedding":
|
| 527 |
convert_to_tensor = False
|
| 528 |
convert_to_numpy = False
|
| 529 |
|
| 530 |
input_was_string = False
|
| 531 |
+
if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
|
| 532 |
sentences = [sentences]
|
| 533 |
input_was_string = True
|
| 534 |
|
|
|
|
| 539 |
inverse_permutation = np.argsort(permutation)
|
| 540 |
sentences = [sentences[idx] for idx in permutation]
|
| 541 |
|
| 542 |
+
tokenizer_kwargs["padding"] = tokenizer_kwargs.get("padding", True)
|
| 543 |
+
tokenizer_kwargs["max_length"] = tokenizer_kwargs.get(
|
| 544 |
+
"max_length", self.tokenizer.init_kwargs.get("model_max_length", 8192)
|
| 545 |
)
|
| 546 |
+
tokenizer_kwargs["truncation"] = tokenizer_kwargs.get("truncation", True)
|
| 547 |
|
| 548 |
all_embeddings = []
|
| 549 |
|
|
|
|
| 557 |
)
|
| 558 |
else:
|
| 559 |
range_iter = range(0, len(sentences), batch_size)
|
| 560 |
+
lora_arguments = (
|
| 561 |
+
{"adapter_mask": adapter_mask} if adapter_mask is not None else {}
|
| 562 |
+
)
|
| 563 |
for i in range_iter:
|
| 564 |
encoded_input = self.tokenizer(
|
| 565 |
sentences[i : i + batch_size],
|
| 566 |
+
return_tensors="pt",
|
| 567 |
**tokenizer_kwargs,
|
| 568 |
).to(self.device)
|
| 569 |
token_embs = self.forward(**encoded_input, **lora_arguments)[0]
|
|
|
|
| 571 |
# Accumulate in fp32 to avoid overflow
|
| 572 |
token_embs = token_embs.float()
|
| 573 |
|
| 574 |
+
if output_value == "token_embeddings":
|
| 575 |
raise NotImplementedError
|
| 576 |
elif output_value is None:
|
| 577 |
raise NotImplementedError
|
| 578 |
else:
|
| 579 |
+
if self.config.emb_pooler == "cls":
|
| 580 |
embeddings = self.cls_pooling(
|
| 581 |
+
token_embs, encoded_input["attention_mask"]
|
| 582 |
)
|
| 583 |
else:
|
| 584 |
embeddings = self.mean_pooling(
|
| 585 |
+
token_embs, encoded_input["attention_mask"]
|
| 586 |
)
|
| 587 |
|
| 588 |
if normalize_embeddings:
|
|
|
|
| 612 |
def truncate_embeddings(self, embeddings, truncate_dim):
|
| 613 |
if not self.config.matryoshka_dimensions:
|
| 614 |
logger.warning(
|
| 615 |
+
"Matryoshka embeddings are not supported, so dimension truncation will not be performed."
|
| 616 |
)
|
| 617 |
return embeddings
|
| 618 |
elif truncate_dim in self.config.matryoshka_dimensions:
|
| 619 |
return [tensor[:truncate_dim] for tensor in embeddings]
|
| 620 |
else:
|
| 621 |
+
raise ValueError(
|
| 622 |
+
f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
|
| 623 |
+
f"Supported dimensions are {self.config.matryoshka_dimensions}."
|
| 624 |
+
)
|
| 625 |
|
| 626 |
def mean_pooling(
|
| 627 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
|
|
|
| 633 |
input_mask_expanded.sum(1), min=1e-9
|
| 634 |
)
|
| 635 |
|
| 636 |
+
def cls_pooling(self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor):
|
| 637 |
+
return token_embeddings[:, 0]
|
|
|
|
|
|
|
| 638 |
|
| 639 |
@property
|
| 640 |
def rotary_emb_base(self):
|
|
|
|
| 644 |
def rotary_emb_base(self, base):
|
| 645 |
if not isinstance(base, (int, float)):
|
| 646 |
raise TypeError("Base must be an integer or float")
|
| 647 |
+
logger.info(f"Changing RoPE base value to {base}")
|
| 648 |
for layer in self.encoder.layers:
|
| 649 |
layer.mixer.rotary_emb.base = base
|
| 650 |
self._rotary_emb_base = base
|
|
|
|
| 664 |
layer output for these tokens.
|
| 665 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
| 666 |
"""
|
| 667 |
+
adapter_mask = kwargs.pop("adapter_mask", None)
|
| 668 |
if kwargs:
|
| 669 |
for key, value in kwargs.items():
|
| 670 |
if value is not None:
|
| 671 |
logger.warning(
|
| 672 |
+
"Flash attention implementation does not support kwargs: %s",
|
| 673 |
key,
|
| 674 |
)
|
| 675 |
|
|
|
|
| 678 |
)
|
| 679 |
|
| 680 |
hidden_states = self.embeddings(
|
| 681 |
+
input_ids,
|
| 682 |
+
position_ids=position_ids,
|
| 683 |
+
token_type_ids=token_type_ids,
|
| 684 |
+
adapter_mask=adapter_mask,
|
| 685 |
)
|
| 686 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
| 687 |
# BERT puts embedding LayerNorm before embedding dropout.
|
|
|
|
| 705 |
subset_mask = None
|
| 706 |
|
| 707 |
sequence_output = self.encoder(
|
| 708 |
+
hidden_states,
|
| 709 |
+
key_padding_mask=attention_mask,
|
| 710 |
+
subset_mask=subset_mask,
|
| 711 |
+
adapter_mask=adapter_mask,
|
| 712 |
)
|
| 713 |
|
| 714 |
if masked_tokens_mask is None:
|
| 715 |
pooled_output = (
|
| 716 |
+
self.pooler(sequence_output, adapter_mask=adapter_mask)
|
| 717 |
+
if self.pooler is not None
|
| 718 |
+
else None
|
| 719 |
)
|
| 720 |
else:
|
| 721 |
# TD [2022-03-01]: the indexing here is very tricky.
|
|
|
|
| 729 |
pool_input = sequence_output[first_col_mask[subset_mask]]
|
| 730 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
| 731 |
pooled_output = (
|
| 732 |
+
self.pooler(pool_input, pool=False, adapter_mask=adapter_mask)
|
| 733 |
+
if self.pooler is not None
|
| 734 |
+
else None
|
| 735 |
)
|
| 736 |
|
| 737 |
if not return_dict:
|
|
|
|
| 836 |
)
|
| 837 |
|
| 838 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 839 |
def remap_state_dict(state_dict, config: PretrainedConfig):
|
| 840 |
"""
|
| 841 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
|
|
|
| 987 |
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
| 988 |
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
| 989 |
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
| 990 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
|
| 991 |
+
Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
|
| 992 |
+
)
|
| 993 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
|
| 994 |
+
Wqkv_weights[
|
| 995 |
+
Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
|
| 996 |
+
]
|
| 997 |
+
)
|
| 998 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
|
| 999 |
+
Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
|
| 1000 |
+
)
|
| 1001 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = (
|
| 1002 |
+
Wqkv_biases[: Wqkv_biases.shape[0] // 3]
|
| 1003 |
+
)
|
| 1004 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = (
|
| 1005 |
+
Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
|
| 1006 |
+
)
|
| 1007 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
|
| 1008 |
+
Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
|
| 1009 |
+
)
|
| 1010 |
else:
|
| 1011 |
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
| 1012 |
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
| 1013 |
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
| 1014 |
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
| 1015 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
|
| 1016 |
+
Wq_weight
|
| 1017 |
+
)
|
| 1018 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
|
| 1019 |
+
Wkv_weights[: Wkv_weights.shape[0] // 2, :]
|
| 1020 |
+
)
|
| 1021 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
|
| 1022 |
+
Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
|
| 1023 |
+
)
|
| 1024 |
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
| 1025 |
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
| 1026 |
: Wkv_biases.shape[0] // 2
|
| 1027 |
]
|
| 1028 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
|
| 1029 |
+
Wkv_biases[Wkv_biases.shape[0] // 2 :]
|
| 1030 |
+
)
|
| 1031 |
|
| 1032 |
def inv_key_mapping_ln(key):
|
| 1033 |
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
|
|
|
| 1216 |
logits=logits,
|
| 1217 |
hidden_states=outputs.hidden_states,
|
| 1218 |
attentions=outputs.attentions,
|
| 1219 |
+
)
|
modeling_xlm_roberta_for_glue.py
DELETED
|
@@ -1,109 +0,0 @@
|
|
| 1 |
-
from typing import Optional, Union, Tuple
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
|
| 6 |
-
from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
|
| 7 |
-
|
| 8 |
-
from .modeling_xlm_roberta import XLMRobertaPreTrainedModel, XLMRobertaModel
|
| 9 |
-
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
|
| 13 |
-
def __init__(self, config: XLMRobertaFlashConfig):
|
| 14 |
-
super().__init__(config)
|
| 15 |
-
self.num_labels = config.num_labels
|
| 16 |
-
self.config = config
|
| 17 |
-
|
| 18 |
-
self.roberta = XLMRobertaModel(config)
|
| 19 |
-
classifier_dropout = (
|
| 20 |
-
config.classifier_dropout
|
| 21 |
-
if config.classifier_dropout is not None
|
| 22 |
-
else config.hidden_dropout_prob
|
| 23 |
-
)
|
| 24 |
-
self.dropout = nn.Dropout(classifier_dropout)
|
| 25 |
-
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 26 |
-
|
| 27 |
-
# Initialize weights and apply final processing
|
| 28 |
-
self.post_init()
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def forward(
|
| 32 |
-
self,
|
| 33 |
-
input_ids: Optional[torch.Tensor] = None,
|
| 34 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 35 |
-
token_type_ids: Optional[torch.Tensor] = None,
|
| 36 |
-
position_ids: Optional[torch.Tensor] = None,
|
| 37 |
-
head_mask: Optional[torch.Tensor] = None,
|
| 38 |
-
inputs_embeds: Optional[torch.Tensor] = None,
|
| 39 |
-
labels: Optional[torch.Tensor] = None,
|
| 40 |
-
output_attentions: Optional[bool] = None,
|
| 41 |
-
output_hidden_states: Optional[bool] = None,
|
| 42 |
-
return_dict: Optional[bool] = None,
|
| 43 |
-
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 44 |
-
r"""
|
| 45 |
-
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 46 |
-
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 47 |
-
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 48 |
-
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 49 |
-
"""
|
| 50 |
-
return_dict = (
|
| 51 |
-
return_dict if return_dict is not None else self.config.use_return_dict
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
assert head_mask is None
|
| 55 |
-
assert inputs_embeds is None
|
| 56 |
-
assert output_attentions is None
|
| 57 |
-
assert output_hidden_states is None
|
| 58 |
-
assert return_dict
|
| 59 |
-
outputs = self.roberta(
|
| 60 |
-
input_ids,
|
| 61 |
-
attention_mask=attention_mask,
|
| 62 |
-
token_type_ids=token_type_ids,
|
| 63 |
-
position_ids=position_ids,
|
| 64 |
-
head_mask=head_mask,
|
| 65 |
-
inputs_embeds=inputs_embeds,
|
| 66 |
-
output_attentions=output_attentions,
|
| 67 |
-
output_hidden_states=output_hidden_states,
|
| 68 |
-
return_dict=return_dict,
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
pooled_output = outputs[1]
|
| 72 |
-
|
| 73 |
-
pooled_output = self.dropout(pooled_output)
|
| 74 |
-
logits = self.classifier(pooled_output)
|
| 75 |
-
|
| 76 |
-
loss = None
|
| 77 |
-
if labels is not None:
|
| 78 |
-
if self.config.problem_type is None:
|
| 79 |
-
if self.num_labels == 1:
|
| 80 |
-
self.config.problem_type = "regression"
|
| 81 |
-
elif self.num_labels > 1 and (
|
| 82 |
-
labels.dtype == torch.long or labels.dtype == torch.int
|
| 83 |
-
):
|
| 84 |
-
self.config.problem_type = "single_label_classification"
|
| 85 |
-
else:
|
| 86 |
-
self.config.problem_type = "multi_label_classification"
|
| 87 |
-
|
| 88 |
-
if self.config.problem_type == "regression":
|
| 89 |
-
loss_fct = MSELoss()
|
| 90 |
-
if self.num_labels == 1:
|
| 91 |
-
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 92 |
-
else:
|
| 93 |
-
loss = loss_fct(logits, labels)
|
| 94 |
-
elif self.config.problem_type == "single_label_classification":
|
| 95 |
-
loss_fct = CrossEntropyLoss()
|
| 96 |
-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 97 |
-
elif self.config.problem_type == "multi_label_classification":
|
| 98 |
-
loss_fct = BCEWithLogitsLoss()
|
| 99 |
-
loss = loss_fct(logits, labels)
|
| 100 |
-
if not return_dict:
|
| 101 |
-
output = (logits,) + outputs[2:]
|
| 102 |
-
return ((loss,) + output) if loss is not None else output
|
| 103 |
-
|
| 104 |
-
return SequenceClassifierOutput(
|
| 105 |
-
loss=loss,
|
| 106 |
-
logits=logits,
|
| 107 |
-
hidden_states=outputs.hidden_states,
|
| 108 |
-
attentions=outputs.attentions,
|
| 109 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pytorch_model.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:cfa8fa7c7e120199548fe7149512c0adfe58f6bc13ce19f09b895aa25e8af910
|
| 3 |
-
size 1113232188
|
|
|
|
|
|
|
|
|
|
|
|
rotary.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
| 1 |
-
#
|
|
|
|
|
|
|
|
|
|
| 2 |
# Copyright (c) 2023, Tri Dao.
|
| 3 |
|
| 4 |
import math
|
|
@@ -11,8 +14,9 @@ if torch.cuda.is_available():
|
|
| 11 |
try:
|
| 12 |
from flash_attn.ops.triton.rotary import apply_rotary
|
| 13 |
except ImportError:
|
|
|
|
| 14 |
def apply_rotary(*args, **kwargs):
|
| 15 |
-
raise RuntimeError(
|
| 16 |
|
| 17 |
|
| 18 |
def rotate_half(x, interleaved=False):
|
|
@@ -21,7 +25,9 @@ def rotate_half(x, interleaved=False):
|
|
| 21 |
return torch.cat((-x2, x1), dim=-1)
|
| 22 |
else:
|
| 23 |
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 24 |
-
return rearrange(
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
|
@@ -32,13 +38,20 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
|
| 32 |
ro_dim = cos.shape[-1] * 2
|
| 33 |
assert ro_dim <= x.shape[-1]
|
| 34 |
cos, sin = (
|
| 35 |
-
cos[:x.shape[1]],
|
| 36 |
-
sin[:x.shape[1]],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
)
|
| 38 |
-
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
| 39 |
-
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
| 40 |
return torch.cat(
|
| 41 |
-
[
|
|
|
|
|
|
|
|
|
|
| 42 |
dim=-1,
|
| 43 |
)
|
| 44 |
|
|
@@ -68,7 +81,9 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
|
| 68 |
)
|
| 69 |
|
| 70 |
if isinstance(seqlen_offsets, int):
|
| 71 |
-
ctx.save_for_backward(
|
|
|
|
|
|
|
| 72 |
ctx.seqlen_offsets = seqlen_offsets
|
| 73 |
else:
|
| 74 |
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
|
@@ -336,7 +351,9 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
|
|
| 336 |
max_seqlen=max_seqlen,
|
| 337 |
)
|
| 338 |
if isinstance(seqlen_offsets, int):
|
| 339 |
-
ctx.save_for_backward(
|
|
|
|
|
|
|
| 340 |
ctx.seqlen_offsets = seqlen_offsets
|
| 341 |
else:
|
| 342 |
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
|
@@ -451,7 +468,8 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 451 |
self.interleaved = interleaved
|
| 452 |
self.scale_base = scale_base
|
| 453 |
scale = (
|
| 454 |
-
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
|
|
|
| 455 |
if scale_base is not None
|
| 456 |
else None
|
| 457 |
)
|
|
@@ -477,7 +495,10 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 477 |
def _compute_inv_freq(self, device=None):
|
| 478 |
return 1.0 / (
|
| 479 |
self.base
|
| 480 |
-
** (
|
|
|
|
|
|
|
|
|
|
| 481 |
)
|
| 482 |
|
| 483 |
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
|
@@ -516,10 +537,14 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 516 |
self._sin_cached = torch.sin(freqs).to(dtype)
|
| 517 |
else:
|
| 518 |
power = (
|
| 519 |
-
torch.arange(
|
|
|
|
|
|
|
| 520 |
- seqlen // 2
|
| 521 |
) / self.scale_base
|
| 522 |
-
scale = self.scale.to(device=power.device) ** rearrange(
|
|
|
|
|
|
|
| 523 |
# We want the multiplication by scale to happen in fp32
|
| 524 |
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| 525 |
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
|
@@ -550,7 +575,9 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 550 |
if max_seqlen is not None:
|
| 551 |
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
| 552 |
elif isinstance(seqlen_offset, int):
|
| 553 |
-
self._update_cos_sin_cache(
|
|
|
|
|
|
|
| 554 |
if kv is None:
|
| 555 |
if self.scale is None:
|
| 556 |
return apply_rotary_emb_qkv_(
|
|
@@ -606,4 +633,4 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 606 |
cu_seqlens=cu_seqlens,
|
| 607 |
max_seqlen=max_seqlen,
|
| 608 |
)
|
| 609 |
-
return q, kv
|
|
|
|
| 1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py
|
| 2 |
+
# Commit id: 3566596ad867ee415dd3c12616dd50c610176f6c
|
| 3 |
+
# Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
|
| 4 |
+
|
| 5 |
# Copyright (c) 2023, Tri Dao.
|
| 6 |
|
| 7 |
import math
|
|
|
|
| 14 |
try:
|
| 15 |
from flash_attn.ops.triton.rotary import apply_rotary
|
| 16 |
except ImportError:
|
| 17 |
+
|
| 18 |
def apply_rotary(*args, **kwargs):
|
| 19 |
+
raise RuntimeError("RoPE requires flash-attention to be installed")
|
| 20 |
|
| 21 |
|
| 22 |
def rotate_half(x, interleaved=False):
|
|
|
|
| 25 |
return torch.cat((-x2, x1), dim=-1)
|
| 26 |
else:
|
| 27 |
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 28 |
+
return rearrange(
|
| 29 |
+
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
| 30 |
+
)
|
| 31 |
|
| 32 |
|
| 33 |
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
|
|
|
| 38 |
ro_dim = cos.shape[-1] * 2
|
| 39 |
assert ro_dim <= x.shape[-1]
|
| 40 |
cos, sin = (
|
| 41 |
+
cos[: x.shape[1]],
|
| 42 |
+
sin[: x.shape[1]],
|
| 43 |
+
)
|
| 44 |
+
cos = repeat(
|
| 45 |
+
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
| 46 |
+
)
|
| 47 |
+
sin = repeat(
|
| 48 |
+
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
| 49 |
)
|
|
|
|
|
|
|
| 50 |
return torch.cat(
|
| 51 |
+
[
|
| 52 |
+
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
| 53 |
+
x[..., ro_dim:],
|
| 54 |
+
],
|
| 55 |
dim=-1,
|
| 56 |
)
|
| 57 |
|
|
|
|
| 81 |
)
|
| 82 |
|
| 83 |
if isinstance(seqlen_offsets, int):
|
| 84 |
+
ctx.save_for_backward(
|
| 85 |
+
cos, sin, cu_seqlens
|
| 86 |
+
) # Can't save int with save_for_backward
|
| 87 |
ctx.seqlen_offsets = seqlen_offsets
|
| 88 |
else:
|
| 89 |
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
|
|
|
| 351 |
max_seqlen=max_seqlen,
|
| 352 |
)
|
| 353 |
if isinstance(seqlen_offsets, int):
|
| 354 |
+
ctx.save_for_backward(
|
| 355 |
+
cos, sin, cu_seqlens
|
| 356 |
+
) # Can't save int with save_for_backward
|
| 357 |
ctx.seqlen_offsets = seqlen_offsets
|
| 358 |
else:
|
| 359 |
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
|
|
|
| 468 |
self.interleaved = interleaved
|
| 469 |
self.scale_base = scale_base
|
| 470 |
scale = (
|
| 471 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
| 472 |
+
/ (1.4 * dim)
|
| 473 |
if scale_base is not None
|
| 474 |
else None
|
| 475 |
)
|
|
|
|
| 495 |
def _compute_inv_freq(self, device=None):
|
| 496 |
return 1.0 / (
|
| 497 |
self.base
|
| 498 |
+
** (
|
| 499 |
+
torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
|
| 500 |
+
/ self.dim
|
| 501 |
+
)
|
| 502 |
)
|
| 503 |
|
| 504 |
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
|
|
|
| 537 |
self._sin_cached = torch.sin(freqs).to(dtype)
|
| 538 |
else:
|
| 539 |
power = (
|
| 540 |
+
torch.arange(
|
| 541 |
+
seqlen, dtype=self.scale.dtype, device=self.scale.device
|
| 542 |
+
)
|
| 543 |
- seqlen // 2
|
| 544 |
) / self.scale_base
|
| 545 |
+
scale = self.scale.to(device=power.device) ** rearrange(
|
| 546 |
+
power, "s -> s 1"
|
| 547 |
+
)
|
| 548 |
# We want the multiplication by scale to happen in fp32
|
| 549 |
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| 550 |
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
|
|
|
| 575 |
if max_seqlen is not None:
|
| 576 |
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
| 577 |
elif isinstance(seqlen_offset, int):
|
| 578 |
+
self._update_cos_sin_cache(
|
| 579 |
+
seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype
|
| 580 |
+
)
|
| 581 |
if kv is None:
|
| 582 |
if self.scale is None:
|
| 583 |
return apply_rotary_emb_qkv_(
|
|
|
|
| 633 |
cu_seqlens=cu_seqlens,
|
| 634 |
max_seqlen=max_seqlen,
|
| 635 |
)
|
| 636 |
+
return q, kv
|
stochastic_depth.py
CHANGED
|
@@ -34,7 +34,7 @@
|
|
| 34 |
|
| 35 |
import torch
|
| 36 |
import torch.fx
|
| 37 |
-
from torch import
|
| 38 |
|
| 39 |
|
| 40 |
def stochastic_depth(
|
|
|
|
| 34 |
|
| 35 |
import torch
|
| 36 |
import torch.fx
|
| 37 |
+
from torch import Tensor, nn
|
| 38 |
|
| 39 |
|
| 40 |
def stochastic_depth(
|
tokenizer.json
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"model_max_length": 8194,
|
| 3 |
-
"tokenizer_class": "XLMRobertaTokenizer"
|
| 4 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xlm_padding.py
CHANGED
|
@@ -18,7 +18,9 @@ class IndexFirstAxis(torch.autograd.Function):
|
|
| 18 |
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
| 19 |
# return input[indices]
|
| 20 |
return torch.gather(
|
| 21 |
-
rearrange(input, "b ... -> b (...)"),
|
|
|
|
|
|
|
| 22 |
).reshape(-1, *other_shape)
|
| 23 |
|
| 24 |
@staticmethod
|
|
@@ -34,7 +36,9 @@ class IndexFirstAxis(torch.autograd.Function):
|
|
| 34 |
)
|
| 35 |
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
| 36 |
# grad_input[indices] = grad_output
|
| 37 |
-
grad_input.scatter_(
|
|
|
|
|
|
|
| 38 |
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
| 39 |
|
| 40 |
|
|
@@ -112,9 +116,15 @@ def unpad_input(hidden_states, attention_mask, adapter_mask=None):
|
|
| 112 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 113 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 114 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 115 |
-
cu_seqlens = F.pad(
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
cu_adapter_mask =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
| 120 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
|
@@ -184,14 +194,18 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
|
|
| 184 |
"""
|
| 185 |
length = attention_mask_in_length.sum(dim=-1)
|
| 186 |
seqlen = attention_mask_in_length.size(-1)
|
| 187 |
-
attention_mask_2d = torch.arange(
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
real_indices_idx = torch.nonzero(
|
|
|
|
|
|
|
| 191 |
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
|
| 192 |
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
|
| 193 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 194 |
-
cu_seqlens = F.pad(
|
|
|
|
|
|
|
| 195 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
| 196 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
| 197 |
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
|
@@ -219,4 +233,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
|
|
| 219 |
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
| 220 |
# output[indices] = hidden_states
|
| 221 |
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
| 222 |
-
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
|
|
|
| 18 |
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
| 19 |
# return input[indices]
|
| 20 |
return torch.gather(
|
| 21 |
+
rearrange(input, "b ... -> b (...)"),
|
| 22 |
+
0,
|
| 23 |
+
repeat(indices, "z -> z d", d=second_dim),
|
| 24 |
).reshape(-1, *other_shape)
|
| 25 |
|
| 26 |
@staticmethod
|
|
|
|
| 36 |
)
|
| 37 |
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
| 38 |
# grad_input[indices] = grad_output
|
| 39 |
+
grad_input.scatter_(
|
| 40 |
+
0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output
|
| 41 |
+
)
|
| 42 |
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
| 43 |
|
| 44 |
|
|
|
|
| 116 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 117 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 118 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 119 |
+
cu_seqlens = F.pad(
|
| 120 |
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
| 121 |
+
)
|
| 122 |
|
| 123 |
+
cu_adapter_mask = (
|
| 124 |
+
torch.repeat_interleave(adapter_mask, cu_seqlens[1:] - cu_seqlens[:-1])
|
| 125 |
+
if adapter_mask is not None
|
| 126 |
+
else None
|
| 127 |
+
)
|
| 128 |
|
| 129 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
| 130 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
|
|
|
| 194 |
"""
|
| 195 |
length = attention_mask_in_length.sum(dim=-1)
|
| 196 |
seqlen = attention_mask_in_length.size(-1)
|
| 197 |
+
attention_mask_2d = torch.arange(
|
| 198 |
+
seqlen, device=length.device, dtype=length.dtype
|
| 199 |
+
).expand(len(length), seqlen) < length.unsqueeze(1)
|
| 200 |
+
real_indices_idx = torch.nonzero(
|
| 201 |
+
attention_mask_in_length.flatten(), as_tuple=False
|
| 202 |
+
).flatten()
|
| 203 |
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
|
| 204 |
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
|
| 205 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 206 |
+
cu_seqlens = F.pad(
|
| 207 |
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
| 208 |
+
)
|
| 209 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
| 210 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
| 211 |
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
|
|
|
| 233 |
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
| 234 |
# output[indices] = hidden_states
|
| 235 |
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
| 236 |
+
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|