Support torch_dtype and CLS pooling

#6
configuration_xlm_roberta.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers import PretrainedConfig
 
2
 
3
  class XLMRobertaFlashConfig(PretrainedConfig):
4
  def __init__(
@@ -22,6 +23,8 @@ class XLMRobertaFlashConfig(PretrainedConfig):
22
  use_cache=True,
23
  classifier_dropout=None,
24
  use_flash_attn=True,
 
 
25
  **kwargs,
26
  ):
27
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@@ -42,4 +45,9 @@ class XLMRobertaFlashConfig(PretrainedConfig):
42
  self.use_cache = use_cache
43
  self.classifier_dropout = classifier_dropout
44
  self.use_flash_attn = use_flash_attn
45
-
 
 
 
 
 
 
1
  from transformers import PretrainedConfig
2
+ import torch
3
 
4
  class XLMRobertaFlashConfig(PretrainedConfig):
5
  def __init__(
 
23
  use_cache=True,
24
  classifier_dropout=None,
25
  use_flash_attn=True,
26
+ torch_dtype=None,
27
+ emb_pooler=None,
28
  **kwargs,
29
  ):
30
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
45
  self.use_cache = use_cache
46
  self.classifier_dropout = classifier_dropout
47
  self.use_flash_attn = use_flash_attn
48
+ self.emb_pooler = emb_pooler
49
+ if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
50
+ self.torch_dtype = getattr(torch, torch_dtype)
51
+ else:
52
+ self.torch_dtype = torch_dtype
53
+
modeling_xlm_roberta.py CHANGED
@@ -395,6 +395,17 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
395
  if isinstance(module, XLMRobertaEncoder):
396
  module.gradient_checkpointing = value
397
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
400
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
@@ -545,9 +556,14 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
545
  elif output_value is None:
546
  raise NotImplementedError
547
  else:
548
- embeddings = self.mean_pooling(
549
- token_embs, encoded_input['attention_mask']
550
- )
 
 
 
 
 
551
 
552
  if normalize_embeddings:
553
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
@@ -580,6 +596,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
580
  )
581
 
582
 
 
 
 
 
 
 
583
  def forward(
584
  self,
585
  input_ids,
 
395
  if isinstance(module, XLMRobertaEncoder):
396
  module.gradient_checkpointing = value
397
 
398
+ @classmethod
399
+ def from_pretrained(
400
+ cls,
401
+ *args,
402
+ **kwargs,
403
+ ):
404
+ if not 'torch_dtype' in kwargs:
405
+ kwargs['torch_dtype'] = 'auto'
406
+ return super().from_pretrained(*args, **kwargs)
407
+
408
+
409
 
410
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
411
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
 
556
  elif output_value is None:
557
  raise NotImplementedError
558
  else:
559
+ if self.config.emb_pooler == 'cls':
560
+ embeddings = self.cls_pooling(
561
+ token_embs, encoded_input['attention_mask']
562
+ )
563
+ else:
564
+ embeddings = self.mean_pooling(
565
+ token_embs, encoded_input['attention_mask']
566
+ )
567
 
568
  if normalize_embeddings:
569
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
 
596
  )
597
 
598
 
599
+ def cls_pooling(
600
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
601
+ ):
602
+ return token_embeddings[:,0]
603
+
604
+
605
  def forward(
606
  self,
607
  input_ids,