Sentence Similarity
Transformers
Safetensors
multilingual
nllb-llm2vec
feature-extraction
text-embedding
embeddings
information-retrieval
beir
text-classification
language-model
text-clustering
text-semantic-similarity
text-evaluation
text-reranking
Sentence Similarity
natural_questions
ms_marco
fever
hotpot_qa
mteb
custom_code
| import math | |
| import warnings | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import transformers | |
| from packaging import version | |
| from torch.utils.data.dataloader import DataLoader | |
| from tqdm import tqdm | |
| from transformers.cache_utils import Cache | |
| from transformers.modeling_outputs import ( | |
| BaseModelOutputWithPooling, | |
| ModelOutput, | |
| SequenceClassifierOutputWithPast, | |
| TokenClassifierOutput, | |
| ) | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.models.auto import AutoModel, AutoModelForSequenceClassification, AutoModelForTokenClassification | |
| from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder | |
| from transformers.tokenization_utils import BatchEncoding | |
| from .configuration_nllbllm2vec import NLLBLLM2VecConfig | |
| from .modeling_llama_encoder import LlamaEncoderModel | |
| DEFAULT_TOKENIZE_KWARGS = { | |
| "padding": True, | |
| "truncation": True, | |
| "max_length": 512, | |
| "return_tensors": "pt", | |
| } | |
| DEFAULT_DATALOADER_KWARGS = { | |
| "shuffle": False, | |
| "batch_size": 32, | |
| "pin_memory": True, | |
| } | |
| def default_collate_fn_closure(tokenizer, tokenize_kwargs) -> Callable: | |
| def collate_fn(batch: list[str]) -> BatchEncoding: | |
| return tokenizer(batch, **tokenize_kwargs) | |
| return collate_fn | |
| def defaulter(kwd_dict: Optional[Dict], default_dict: Dict) -> Dict: | |
| return default_dict if kwd_dict is None else {**default_dict, **kwd_dict} | |
| class SequenceClassifierOutputWithPastAndPooler(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| logits: torch.FloatTensor = None | |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor, ...]] = None | |
| pooler_output: torch.FloatTensor = None | |
| class NLLBLLM2Vec(PreTrainedModel): | |
| config_class = NLLBLLM2VecConfig | |
| model_type = "nllb-llm2vec" | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = True | |
| """ | |
| NLLBLLM2Vec model combining NLLB and LLama encoders. | |
| Args: | |
| config (Optional[NLLBLLM2VecConfig]): Configuration object. | |
| nllb_encoder (Optional[M2M100Encoder]): Pre-initialized NLLB encoder. | |
| llm2vec (Optional[LlamaEncoderModel]): Pre-initialized LLama encoder. | |
| *inputs: Additional positional arguments. | |
| **kwargs: Additional keyword arguments. | |
| """ | |
| def __init__( | |
| self, | |
| config: Optional[NLLBLLM2VecConfig] = None, | |
| nllb_encoder: Optional[M2M100Encoder] = None, | |
| llm2vec: Optional[LlamaEncoderModel] = None, | |
| *inputs, | |
| **kwargs, | |
| ): | |
| # Ensure that either config is not None or both encoders are provided | |
| if config is None and (nllb_encoder is None or llm2vec is None): | |
| raise ValueError( | |
| "Either `config` must be provided, or both `nllb_encoder` and `llm2vec` must be specified." | |
| ) | |
| if config is not None: | |
| super().__init__(config, *inputs, **kwargs) | |
| # from_pretrained overwrites this after config instantiation, so we make sure it's correctly set | |
| config.nllb_config._attn_implementation = config._attn_implementation | |
| config.llm2vec_config._attn_implementation = config._attn_implementation | |
| self.nllb_encoder = nllb_encoder or M2M100Encoder(config.nllb_config) | |
| self.llm2vec = llm2vec or LlamaEncoderModel(config.llm2vec_config) | |
| self.config = config | |
| else: | |
| # Both encoders are provided | |
| self.nllb_encoder = cast(M2M100Encoder, nllb_encoder) | |
| self.llm2vec = cast(LlamaEncoderModel, llm2vec) | |
| self.config = NLLBLLM2VecConfig( | |
| nllb_config=self.nllb_encoder.config, # type: ignore | |
| llm2vec_config=self.llm2vec.config, # type: ignore | |
| ) | |
| super().__init__(self.config, *inputs, **kwargs) | |
| self.up_proj = nn.Linear( | |
| self.nllb_encoder.config.d_model, | |
| self.llm2vec.config.hidden_size, | |
| bias=False, | |
| ) | |
| # TODO: update this once commit is included | |
| min_version = "4.46.0" | |
| if self.config.nllb_config._attn_implementation == "flash_attention_2": | |
| if version.parse(transformers.__version__) < version.parse(min_version): | |
| warnings.warn( | |
| f"Installed transformers version ({transformers.__version__}) never sets NLLB-encoder dropout to `False` with FlashAttention2. See https://github.com/huggingface/transformers/pull/33844 for more info. Consider upgrading to latest to {min_version} or master.", | |
| UserWarning, | |
| ) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| indices: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| *args, | |
| **kwargs, | |
| ) -> BaseModelOutputWithPooling: | |
| """ | |
| Forward pass of the model. | |
| Args: | |
| input_ids (torch.Tensor): Input token IDs. | |
| attention_mask (torch.Tensor): Attention mask. | |
| indices (Optional[Tuple[torch.Tensor, torch.Tensor]]): Precomputed input indices and offsets. | |
| Returns: | |
| BaseModelOutputWithPooling: Model outputs with last hidden state and pooled output. | |
| """ | |
| # Compute input indices and offsets if not provided | |
| if indices is None: | |
| seq_indices, seq_offsets = self._get_input_offsets(attention_mask) | |
| else: | |
| seq_indices, seq_offsets = indices | |
| nllb_outputs = self.nllb_encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| ) | |
| nllb_last_hidden_state = nllb_outputs.last_hidden_state | |
| nllb_last_hidden_state = self.up_proj(nllb_last_hidden_state) | |
| outputs = self.llm2vec( | |
| inputs_embeds=nllb_last_hidden_state, | |
| attention_mask=attention_mask, | |
| ) | |
| pooler_output = self._mean_embedding( | |
| hidden_states=outputs.last_hidden_state, | |
| input_indices=seq_indices, | |
| offsets=seq_offsets, | |
| ) | |
| return BaseModelOutputWithPooling( | |
| last_hidden_state=outputs.last_hidden_state, | |
| pooler_output=pooler_output, | |
| ) | |
| def tokenizer(self): | |
| """ | |
| Get the tokenizer associated with the model. | |
| Returns: | |
| PreTrainedTokenizer: The tokenizer instance. | |
| """ | |
| if not hasattr(self, "_tokenizer"): | |
| from transformers import AutoTokenizer | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| "facebook/nllb-200-distilled-600M", padding_side="right" | |
| ) | |
| return self._tokenizer | |
| def encode( | |
| self, | |
| inputs: List[str], | |
| src_lang: str = "eng_Latn", | |
| dataloader_kwargs: Optional[Dict[str, Any]] = None, | |
| tokenize_kwargs: Optional[Dict[str, Any]] = None, | |
| collate_fn_closure: Optional[Callable] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Encode input texts into embeddings. | |
| Args: | |
| inputs (List[str]): List of input texts. | |
| src_lang (str): Source language code for the tokenizer (default: `"eng_Latn"`). | |
| dataloader_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the dataloader excl. `collate_fn`. | |
| Defaults to: | |
| >> dataloader_kwargs = { | |
| >> "shuffle": False, | |
| >> "pin_memory": True, | |
| >> } | |
| tokenize_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer. | |
| Defaults to: | |
| >> tokenize_kwargs = { | |
| >> "padding": True, | |
| >> "truncation": True, | |
| >> "max_length": 512, | |
| >> "return_tensors": "pt", | |
| >> } | |
| collate_fn_closure (Optional[Callable]): Closure that should return a `collate_fn`. | |
| Defaults to: | |
| >> def default_collate_fn_closure(tokenizer, tokenize_kwargs) -> Callable: | |
| >> def collate_fn(batch: list[str]) -> BatchEncoding: | |
| >> return tokenizer(batch, **tokenize_kwargs) | |
| >> return collate_fn | |
| Returns: | |
| torch.Tensor: Mean-pooled sequence embeddings of the inputs. | |
| """ | |
| # merge user kwargs with defaults, giving priority to user kwargs | |
| tokenize_kwargs = defaulter(tokenize_kwargs, DEFAULT_TOKENIZE_KWARGS) | |
| dataloader_kwargs = defaulter(dataloader_kwargs, DEFAULT_DATALOADER_KWARGS) | |
| tokenizer = self.tokenizer | |
| tokenizer.src_lang = src_lang | |
| device = next(self.parameters()).device | |
| if collate_fn_closure is None: | |
| collate_fn = default_collate_fn_closure(tokenizer, tokenize_kwargs) | |
| else: | |
| collate_fn = collate_fn_closure(tokenizer, tokenize_kwargs) | |
| assert ( | |
| "collate_fn" not in dataloader_kwargs | |
| ), "`collate_fn` should be created via `collate_fn_closure`" | |
| self.eval() | |
| if len(inputs) > dataloader_kwargs.get("batch_size", 1): | |
| dataloader = DataLoader(inputs, collate_fn=collate_fn, **dataloader_kwargs) # type: ignore | |
| all_embeddings = [] | |
| # Iterate through the dataloader with a progress bar and autocast | |
| with torch.autocast(device_type=device.type, dtype=torch.bfloat16): | |
| for batch in tqdm(dataloader, desc="Encoding"): | |
| # Move batch to device | |
| batch = {k: v.to(device) for k, v in batch.items()} | |
| # Forward pass through the model (assumes model returns embeddings) | |
| with torch.inference_mode(): | |
| pooled_embeddings = cast( | |
| SequenceClassifierOutputWithPastAndPooler, self(**batch) | |
| ).pooler_output # Assuming model returns sequence embeddings | |
| all_embeddings.append(pooled_embeddings) | |
| # Concatenate all pooled embeddings along the batch dimension | |
| all_embeddings = torch.cat(all_embeddings, dim=0) | |
| else: | |
| batch = {k: v.to(device) for k, v in collate_fn(inputs).items()} | |
| with torch.inference_mode(): | |
| all_embeddings = cast( | |
| SequenceClassifierOutputWithPastAndPooler, self(**batch) | |
| ).pooler_output # Assuming model returns sequence embeddings | |
| return all_embeddings | |
| def _get_input_offsets( | |
| attention_mask: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Compute indices and offsets for mean pooling using EmbeddingBag. | |
| Args: | |
| attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len). | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: A tuple containing: | |
| - input_indices: Indices of non-padded tokens in the flattened input. | |
| - offsets: Offsets indicating the start index of each sequence in the flattened input. | |
| """ | |
| # Find the indices of non-padded tokens in flattened hidden_states | |
| input_indices = attention_mask.view(-1).nonzero(as_tuple=False).squeeze() | |
| # Compute the offsets: for each sequence, where it starts in the flattened input | |
| non_padded_lengths = attention_mask.sum( | |
| dim=1 | |
| ) # Count non-padded tokens per sequence | |
| offsets = non_padded_lengths.cumsum(dim=0).roll(shifts=1) | |
| offsets[0] = 0 | |
| return input_indices, offsets | |
| def _mean_embedding( | |
| hidden_states: torch.Tensor, | |
| input_indices: torch.Tensor, | |
| offsets: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Compute the mean of non-padded embeddings using `embedding_bag`, | |
| properly handling padding with offsets. | |
| Args: | |
| hidden_states (torch.Tensor): Hidden states of shape (batch_size, seq_len, embed_dim). | |
| input_indices (torch.Tensor): Indices of non-padded tokens in flattened form. | |
| offsets (torch.Tensor): Offsets specifying the start of each sequence. | |
| Returns: | |
| torch.Tensor: Pooled mean embeddings of shape (batch_size, embed_dim). | |
| """ | |
| # Flatten hidden_states to 2D: shape (batch_size * seq_len, embedding_dim) | |
| batch_size, seq_len, embed_dim = hidden_states.shape | |
| token_embeds = hidden_states.view(-1, embed_dim) | |
| # Use embedding_bag with mode 'mean' and appropriate indices | |
| return F.embedding_bag( | |
| input=input_indices, # Indices of non-padded tokens in flattened form | |
| weight=token_embeds, # The flattened hidden states as embedding matrix | |
| offsets=offsets, # Offsets specifying start of each sequence | |
| mode="mean", # Aggregation mode | |
| ) | |
| class NLLBLLM2VecForSequenceClassification(PreTrainedModel): | |
| config_class = NLLBLLM2VecConfig | |
| model_type = "nllb-llm2vec" | |
| base_model_prefix = "model" | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = True | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.model = NLLBLLM2Vec(config) | |
| self.score = nn.Linear( | |
| config.llm2vec_config.hidden_size, self.num_labels, bias=False | |
| ) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def _init_weights(self, module): | |
| if module is self.score: | |
| # INFO: | |
| # - critical that clf head is in float32 (NusaX perf. drops funky otherwise) | |
| # - Initialization needs to be redone, otherwise borked | |
| # - Use kaiming uniform, b/c Llama init (cf. `nn.Linear` below) performs worse | |
| self.score = self.score.to(torch.float32) | |
| torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) | |
| elif isinstance(module, nn.Linear): | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| def get_input_embeddings(self): | |
| return self.model.nllb.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.model.nllb.embed_tokens = value | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, SequenceClassifierOutputWithPast]: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
| """ | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| transformer_outputs = self.model( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = transformer_outputs.pooler_output | |
| pooled_logits = self.score(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| if self.config.problem_type is None: | |
| if self.num_labels == 1: | |
| self.config.problem_type = "regression" | |
| elif self.num_labels > 1 and ( | |
| labels.dtype == torch.long or labels.dtype == torch.int | |
| ): | |
| self.config.problem_type = "single_label_classification" | |
| else: | |
| self.config.problem_type = "multi_label_classification" | |
| if self.config.problem_type == "regression": | |
| if self.num_labels == 1: | |
| loss = F.mse_loss(pooled_logits.squeeze(), labels.squeeze()) | |
| else: | |
| loss = F.mse_loss(pooled_logits, labels) | |
| elif self.config.problem_type == "single_label_classification": | |
| loss = F.cross_entropy( | |
| pooled_logits.view(-1, self.num_labels), labels.view(-1) | |
| ) | |
| elif self.config.problem_type == "multi_label_classification": | |
| loss = F.binary_cross_entropy_with_logits(pooled_logits, labels) | |
| if not return_dict: | |
| output = (pooled_logits,) + transformer_outputs[1:] | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutputWithPastAndPooler( | |
| loss=loss, | |
| hidden_states=hidden_states, | |
| logits=pooled_logits, | |
| pooler_output=transformer_outputs.pooler_output, | |
| ) | |
| class NLLBLLM2VecForTokenClassification(PreTrainedModel): | |
| config_class = NLLBLLM2VecConfig | |
| model_type = "nllb-llm2vec" | |
| base_model_prefix = "model" | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = True | |
| def __init__(self, config: NLLBLLM2VecConfig): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.model = NLLBLLM2Vec(config) | |
| self.classifier = nn.Linear( | |
| config.llm2vec_config.hidden_size, self.num_labels, bias=False | |
| ) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def _init_weights(self, module): | |
| if module is self.classifier: | |
| # INFO: | |
| # - critical that clf head is in float32 (NusaX perf. drops funky otherwise) | |
| # - Initialization needs to be redone, otherwise borked | |
| # - Use kaiming uniform, b/c Llama init (cf. `nn.Linear` below) performs worse | |
| self.classifier = self.classifier.to(torch.float32) | |
| torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) | |
| elif isinstance(module, nn.Linear): | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| def get_input_embeddings(self): | |
| return self.model.nllb.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.model.nllb.embed_tokens = value | |
| # adapted from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification | |
| # - removed classifier dropout | |
| # - use F.cross_entropy | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| token_type_ids: Optional[torch.LongTensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.FloatTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. | |
| """ | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| outputs = self.model( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| sequence_output = outputs[0] | |
| logits = self.classifier(sequence_output) | |
| loss = None | |
| if labels is not None: | |
| # move labels to correct device to enable model parallelism | |
| labels = labels.to(logits.device) | |
| loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return TokenClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| AutoModel.register(NLLBLLM2VecConfig, NLLBLLM2Vec) | |
| AutoModelForSequenceClassification.register( | |
| NLLBLLM2VecConfig, NLLBLLM2VecForSequenceClassification | |
| ) | |
| AutoModelForTokenClassification.register( | |
| NLLBLLM2VecConfig, NLLBLLM2VecForTokenClassification | |
| ) | |