Sentence Similarity
Transformers
Safetensors
multilingual
nllb-llm2vec
feature-extraction
text-embedding
embeddings
information-retrieval
beir
text-classification
language-model
text-clustering
text-semantic-similarity
text-evaluation
text-reranking
Sentence Similarity
natural_questions
ms_marco
fever
hotpot_qa
mteb
custom_code
Commit
·
c90eb91
1
Parent(s):
b0221f6
feat: support AutoModelForSequenceClassification
Browse files- modeling_nllbllm2vec.py +97 -2
modeling_nllbllm2vec.py
CHANGED
|
@@ -1,12 +1,16 @@
|
|
| 1 |
-
from typing import Any, Dict, List, Optional, Tuple, cast
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
| 6 |
from transformers.models.auto import AutoModel
|
| 7 |
-
from transformers.modeling_outputs import
|
|
|
|
|
|
|
|
|
|
| 8 |
from transformers.modeling_utils import PreTrainedModel
|
| 9 |
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
|
|
|
|
| 10 |
|
| 11 |
from .configuration_nllbllm2vec import NLLBLLM2VecConfig
|
| 12 |
from .modeling_llama_encoder import LlamaEncoderModel
|
|
@@ -479,3 +483,94 @@ def repl():
|
|
| 479 |
|
| 480 |
with open("./model.safetensors.index.json", "r") as f:
|
| 481 |
print(json.load(f))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional, Tuple, cast, Union
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
| 6 |
from transformers.models.auto import AutoModel
|
| 7 |
+
from transformers.modeling_outputs import (
|
| 8 |
+
BaseModelOutputWithPooling,
|
| 9 |
+
SequenceClassifierOutputWithPast,
|
| 10 |
+
)
|
| 11 |
from transformers.modeling_utils import PreTrainedModel
|
| 12 |
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
|
| 13 |
+
from transformers.cache_utils import Cache
|
| 14 |
|
| 15 |
from .configuration_nllbllm2vec import NLLBLLM2VecConfig
|
| 16 |
from .modeling_llama_encoder import LlamaEncoderModel
|
|
|
|
| 483 |
|
| 484 |
with open("./model.safetensors.index.json", "r") as f:
|
| 485 |
print(json.load(f))
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class NLLBLLM2VecForSequenceClassification(PreTrainedModel):
|
| 489 |
+
def __init__(self, config):
|
| 490 |
+
super().__init__(config)
|
| 491 |
+
self.num_labels = config.num_labels
|
| 492 |
+
self.model = NLLBLLM2Vec(config)
|
| 493 |
+
self.score = nn.Linear(
|
| 494 |
+
config.llm2vec_config.hidden_size, self.num_labels, bias=False
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
# Initialize weights and apply final processing
|
| 498 |
+
self.post_init()
|
| 499 |
+
|
| 500 |
+
def get_input_embeddings(self):
|
| 501 |
+
return self.model.nllb.embed_tokens
|
| 502 |
+
|
| 503 |
+
def set_input_embeddings(self, value):
|
| 504 |
+
self.model.nllb.embed_tokens = value
|
| 505 |
+
|
| 506 |
+
def forward(
|
| 507 |
+
self,
|
| 508 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 509 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 510 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 511 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
| 512 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 513 |
+
labels: Optional[torch.LongTensor] = None,
|
| 514 |
+
use_cache: Optional[bool] = None,
|
| 515 |
+
output_attentions: Optional[bool] = None,
|
| 516 |
+
output_hidden_states: Optional[bool] = None,
|
| 517 |
+
return_dict: Optional[bool] = None,
|
| 518 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
| 519 |
+
r"""
|
| 520 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 521 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 522 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 523 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 524 |
+
"""
|
| 525 |
+
return_dict = (
|
| 526 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
transformer_outputs = self.model(
|
| 530 |
+
input_ids,
|
| 531 |
+
attention_mask=attention_mask,
|
| 532 |
+
position_ids=position_ids,
|
| 533 |
+
past_key_values=past_key_values,
|
| 534 |
+
inputs_embeds=inputs_embeds,
|
| 535 |
+
use_cache=use_cache,
|
| 536 |
+
output_attentions=output_attentions,
|
| 537 |
+
output_hidden_states=output_hidden_states,
|
| 538 |
+
return_dict=return_dict,
|
| 539 |
+
)
|
| 540 |
+
hidden_states = transformer_outputs.pooler_output
|
| 541 |
+
pooled_logits = self.score(hidden_states)
|
| 542 |
+
|
| 543 |
+
loss = None
|
| 544 |
+
if labels is not None:
|
| 545 |
+
if self.config.problem_type is None:
|
| 546 |
+
if self.num_labels == 1:
|
| 547 |
+
self.config.problem_type = "regression"
|
| 548 |
+
elif self.num_labels > 1 and (
|
| 549 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
| 550 |
+
):
|
| 551 |
+
self.config.problem_type = "single_label_classification"
|
| 552 |
+
else:
|
| 553 |
+
self.config.problem_type = "multi_label_classification"
|
| 554 |
+
|
| 555 |
+
if self.config.problem_type == "regression":
|
| 556 |
+
if self.num_labels == 1:
|
| 557 |
+
loss = F.mse_loss(pooled_logits.squeeze(), labels.squeeze())
|
| 558 |
+
else:
|
| 559 |
+
loss = F.mse_loss(pooled_logits, labels)
|
| 560 |
+
elif self.config.problem_type == "single_label_classification":
|
| 561 |
+
loss = F.cross_entropy(
|
| 562 |
+
pooled_logits.view(-1, self.num_labels), labels.view(-1)
|
| 563 |
+
)
|
| 564 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 565 |
+
loss = F.binary_cross_entropy_with_logits(pooled_logits, labels)
|
| 566 |
+
if not return_dict:
|
| 567 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
| 568 |
+
return ((loss,) + output) if loss is not None else output
|
| 569 |
+
|
| 570 |
+
return SequenceClassifierOutputWithPast(
|
| 571 |
+
loss=loss,
|
| 572 |
+
logits=pooled_logits,
|
| 573 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 574 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 575 |
+
attentions=transformer_outputs.attentions,
|
| 576 |
+
)
|