Sifal commited on
Commit
228bda9
·
verified ·
1 Parent(s): 2bfa274

Delete bert_embeddings.py

Browse files
Files changed (1) hide show
  1. bert_embeddings.py +0 -82
bert_embeddings.py DELETED
@@ -1,82 +0,0 @@
1
- import logging
2
- from typing import Optional
3
-
4
- import torch
5
- from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
6
- from transformers import BertPreTrainedModel
7
-
8
- from bert_layers_mosa import BertModel
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- class MosaicBertForEmbeddingGeneration(BertPreTrainedModel):
14
-
15
- def __init__(self, config, add_pooling_layer=False):
16
- """
17
- Initializes the BertEmbeddings class.
18
-
19
- Args:
20
- config (BertConfig): The configuration for the BERT model.
21
- add_pooling_layer (bool, optional): Whether to add a pooling layer. Defaults to False.
22
- """
23
- super().__init__(config)
24
- assert (
25
- config.num_hidden_layers >= config.num_embedding_layers
26
- ), "num_hidden_layers should be greater than or equal to num_embedding_layers"
27
- self.config = config
28
- self.strategy = config.strategy
29
- self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
30
- # this resets the weights
31
- self.post_init()
32
-
33
- @classmethod
34
- def from_pretrained(
35
- cls, pretrained_checkpoint, state_dict=None, config=None, *inputs, **kwargs
36
- ):
37
- """Load from pre-trained."""
38
- # this gets a fresh init model
39
- model = cls(config, *inputs, **kwargs)
40
-
41
- # thus we need to load the state_dict
42
- state_dict = torch.load(pretrained_checkpoint)
43
- # remove `model` prefix to avoid error
44
- consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
45
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
46
-
47
- if len(missing_keys) > 0:
48
- logger.warning(
49
- f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
50
- )
51
-
52
- logger.warning(f"the number of which is equal to {len(missing_keys)}")
53
-
54
- if len(unexpected_keys) > 0:
55
- logger.warning(
56
- f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
57
- )
58
- logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
59
-
60
- return model
61
-
62
- def forward(
63
- self,
64
- input_ids: Optional[torch.Tensor] = None,
65
- attention_mask: Optional[torch.Tensor] = None,
66
- token_type_ids: Optional[torch.Tensor] = None,
67
- position_ids: Optional[torch.Tensor] = None,
68
- subset_mask: Optional[torch.Tensor] = None,
69
- output_all_encoded_layers: Book = True,
70
- ) -> torch.Tensor:
71
-
72
- embedding_output = self.bert.embeddings(input_ids, token_type_ids, position_ids)
73
-
74
- encoder_outputs_all = self.bert.encoder(
75
- embedding_output,
76
- attention_mask,
77
- output_all_encoded_layers=output_all_encoded_layers,
78
- subset_mask=subset_mask,
79
- )
80
-
81
- # batch_size, hidden_dim
82
- return encoder_outputs_all