Support torch_dtype and CLS pooling
#6
by
michael-guenther
- opened
- configuration_xlm_roberta.py +9 -1
- modeling_xlm_roberta.py +25 -3
configuration_xlm_roberta.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from transformers import PretrainedConfig
|
|
|
2 |
|
3 |
class XLMRobertaFlashConfig(PretrainedConfig):
|
4 |
def __init__(
|
@@ -22,6 +23,8 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
22 |
use_cache=True,
|
23 |
classifier_dropout=None,
|
24 |
use_flash_attn=True,
|
|
|
|
|
25 |
**kwargs,
|
26 |
):
|
27 |
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
@@ -42,4 +45,9 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
42 |
self.use_cache = use_cache
|
43 |
self.classifier_dropout = classifier_dropout
|
44 |
self.use_flash_attn = use_flash_attn
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import PretrainedConfig
|
2 |
+
import torch
|
3 |
|
4 |
class XLMRobertaFlashConfig(PretrainedConfig):
|
5 |
def __init__(
|
|
|
23 |
use_cache=True,
|
24 |
classifier_dropout=None,
|
25 |
use_flash_attn=True,
|
26 |
+
torch_dtype=None,
|
27 |
+
emb_pooler=None,
|
28 |
**kwargs,
|
29 |
):
|
30 |
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
|
45 |
self.use_cache = use_cache
|
46 |
self.classifier_dropout = classifier_dropout
|
47 |
self.use_flash_attn = use_flash_attn
|
48 |
+
self.emb_pooler = emb_pooler
|
49 |
+
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
50 |
+
self.torch_dtype = getattr(torch, torch_dtype)
|
51 |
+
else:
|
52 |
+
self.torch_dtype = torch_dtype
|
53 |
+
|
modeling_xlm_roberta.py
CHANGED
@@ -395,6 +395,17 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
|
|
395 |
if isinstance(module, XLMRobertaEncoder):
|
396 |
module.gradient_checkpointing = value
|
397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
|
399 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
400 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
@@ -545,9 +556,14 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
545 |
elif output_value is None:
|
546 |
raise NotImplementedError
|
547 |
else:
|
548 |
-
|
549 |
-
|
550 |
-
|
|
|
|
|
|
|
|
|
|
|
551 |
|
552 |
if normalize_embeddings:
|
553 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
@@ -580,6 +596,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
580 |
)
|
581 |
|
582 |
|
|
|
|
|
|
|
|
|
|
|
|
|
583 |
def forward(
|
584 |
self,
|
585 |
input_ids,
|
|
|
395 |
if isinstance(module, XLMRobertaEncoder):
|
396 |
module.gradient_checkpointing = value
|
397 |
|
398 |
+
@classmethod
|
399 |
+
def from_pretrained(
|
400 |
+
cls,
|
401 |
+
*args,
|
402 |
+
**kwargs,
|
403 |
+
):
|
404 |
+
if not 'torch_dtype' in kwargs:
|
405 |
+
kwargs['torch_dtype'] = 'auto'
|
406 |
+
return super().from_pretrained(*args, **kwargs)
|
407 |
+
|
408 |
+
|
409 |
|
410 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
411 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
|
|
556 |
elif output_value is None:
|
557 |
raise NotImplementedError
|
558 |
else:
|
559 |
+
if self.config.emb_pooler == 'cls':
|
560 |
+
embeddings = self.cls_pooling(
|
561 |
+
token_embs, encoded_input['attention_mask']
|
562 |
+
)
|
563 |
+
else:
|
564 |
+
embeddings = self.mean_pooling(
|
565 |
+
token_embs, encoded_input['attention_mask']
|
566 |
+
)
|
567 |
|
568 |
if normalize_embeddings:
|
569 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
|
|
596 |
)
|
597 |
|
598 |
|
599 |
+
def cls_pooling(
|
600 |
+
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
601 |
+
):
|
602 |
+
return token_embeddings[:,0]
|
603 |
+
|
604 |
+
|
605 |
def forward(
|
606 |
self,
|
607 |
input_ids,
|