missing parts
Browse files- automodel.py +26 -0
automodel.py
CHANGED
|
@@ -233,6 +233,32 @@ class ClinicalMosaicForSequenceClassification(BertPreTrainedModel):
|
|
| 233 |
attentions=None,
|
| 234 |
)
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
class ClinicalMosaicForForMaskedLM(BertPreTrainedModel):
|
| 238 |
config_class = BertConfig
|
|
|
|
| 233 |
attentions=None,
|
| 234 |
)
|
| 235 |
|
| 236 |
+
class BertLMPredictionHead(nn.Module):
|
| 237 |
+
|
| 238 |
+
def __init__(self, config, bert_model_embedding_weights):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 241 |
+
# The output weights are the same as the input embeddings, but there is
|
| 242 |
+
# an output-only bias for each token.
|
| 243 |
+
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
| 244 |
+
bert_model_embedding_weights.size(0))
|
| 245 |
+
self.decoder.weight = bert_model_embedding_weights
|
| 246 |
+
|
| 247 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 248 |
+
hidden_states = self.transform(hidden_states)
|
| 249 |
+
hidden_states = self.decoder(hidden_states)
|
| 250 |
+
return hidden_states
|
| 251 |
+
|
| 252 |
+
class BertOnlyMLMHead(nn.Module):
|
| 253 |
+
|
| 254 |
+
def __init__(self, config, bert_model_embedding_weights):
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.predictions = BertLMPredictionHead(config,
|
| 257 |
+
bert_model_embedding_weights)
|
| 258 |
+
|
| 259 |
+
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
| 260 |
+
prediction_scores = self.predictions(sequence_output)
|
| 261 |
+
return prediction_scores
|
| 262 |
|
| 263 |
class ClinicalMosaicForForMaskedLM(BertPreTrainedModel):
|
| 264 |
config_class = BertConfig
|