shmelev commited on
Commit
47f03d3
·
1 Parent(s): be4ca9c

Added BetForEnformer class

Browse files
Files changed (1) hide show
  1. modeling_bert.py +84 -0
modeling_bert.py CHANGED
@@ -2206,3 +2206,87 @@ def rotate_half(x):
2206
  def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
2207
  cos, sin = cos[:, :, offset: q.shape[2] + offset, :], sin[:, :, offset: q.shape[2] + offset, :]
2208
  return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2206
  def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
2207
  cos, sin = cos[:, :, offset: q.shape[2] + offset, :], sin[:, :, offset: q.shape[2] + offset, :]
2208
  return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
2209
+
2210
+
2211
+ from torch import nn
2212
+ from transformers.modeling_outputs import TokenClassifierOutput
2213
+
2214
+ class BertForEnformer(BertPreTrainedModel):
2215
+
2216
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
2217
+
2218
+ def __init__(self, config):
2219
+ super().__init__(config)
2220
+ self.num_labels = config.num_labels
2221
+ self.config = config
2222
+
2223
+ self.bert = BertModel(config, add_pooling_layer=False)
2224
+ classifier_dropout = (
2225
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
2226
+ )
2227
+ self.dropout = nn.Dropout(classifier_dropout)
2228
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
2229
+ self.activation = nn.Softplus()
2230
+
2231
+ # Initialize weights and apply final processing
2232
+ self.post_init()
2233
+
2234
+ def forward(
2235
+ self,
2236
+ input_ids=None,
2237
+ attention_mask=None,
2238
+ token_type_ids=None,
2239
+ bins_mask=None,
2240
+ position_ids=None,
2241
+ head_mask=None,
2242
+ inputs_embeds=None,
2243
+ labels=None,
2244
+ labels_mask=None,
2245
+ output_attentions=None,
2246
+ output_hidden_states=None,
2247
+ return_dict=None,
2248
+ ):
2249
+ r"""
2250
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
2251
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
2252
+ """
2253
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2254
+
2255
+ outputs = self.bert(
2256
+ input_ids,
2257
+ attention_mask=attention_mask,
2258
+ token_type_ids=token_type_ids,
2259
+ position_ids=position_ids,
2260
+ head_mask=head_mask,
2261
+ inputs_embeds=inputs_embeds,
2262
+ output_attentions=output_attentions,
2263
+ output_hidden_states=output_hidden_states,
2264
+ return_dict=return_dict,
2265
+ )
2266
+
2267
+ sequence_output = outputs[0]
2268
+
2269
+ # select SEP tokens that represent target bins
2270
+ bins_output = sequence_output[bins_mask]
2271
+
2272
+ bins_output = self.dropout(bins_output)
2273
+ logits = self.classifier(bins_output)
2274
+ pred = self.activation(logits)
2275
+
2276
+ loss = None
2277
+ if labels is not None:
2278
+ loss_fct = nn.PoissonNLLLoss(log_input=False, reduction='mean')
2279
+ labels = labels[labels_mask]
2280
+ loss = loss_fct(pred, labels)
2281
+
2282
+ if not return_dict:
2283
+ output = (logits,) + outputs[2:]
2284
+ return ((loss,) + output) if loss is not None else output
2285
+
2286
+ return TokenClassifierOutput(
2287
+ loss=loss,
2288
+ logits=logits,
2289
+ hidden_states=outputs.hidden_states,
2290
+ attentions=outputs.attentions,
2291
+ )
2292
+