Added BetForEnformer class
Browse files- modeling_bert.py +84 -0
modeling_bert.py
CHANGED
@@ -2206,3 +2206,87 @@ def rotate_half(x):
|
|
2206 |
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
|
2207 |
cos, sin = cos[:, :, offset: q.shape[2] + offset, :], sin[:, :, offset: q.shape[2] + offset, :]
|
2208 |
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2206 |
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
|
2207 |
cos, sin = cos[:, :, offset: q.shape[2] + offset, :], sin[:, :, offset: q.shape[2] + offset, :]
|
2208 |
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
2209 |
+
|
2210 |
+
|
2211 |
+
from torch import nn
|
2212 |
+
from transformers.modeling_outputs import TokenClassifierOutput
|
2213 |
+
|
2214 |
+
class BertForEnformer(BertPreTrainedModel):
|
2215 |
+
|
2216 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
2217 |
+
|
2218 |
+
def __init__(self, config):
|
2219 |
+
super().__init__(config)
|
2220 |
+
self.num_labels = config.num_labels
|
2221 |
+
self.config = config
|
2222 |
+
|
2223 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
2224 |
+
classifier_dropout = (
|
2225 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
2226 |
+
)
|
2227 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
2228 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
2229 |
+
self.activation = nn.Softplus()
|
2230 |
+
|
2231 |
+
# Initialize weights and apply final processing
|
2232 |
+
self.post_init()
|
2233 |
+
|
2234 |
+
def forward(
|
2235 |
+
self,
|
2236 |
+
input_ids=None,
|
2237 |
+
attention_mask=None,
|
2238 |
+
token_type_ids=None,
|
2239 |
+
bins_mask=None,
|
2240 |
+
position_ids=None,
|
2241 |
+
head_mask=None,
|
2242 |
+
inputs_embeds=None,
|
2243 |
+
labels=None,
|
2244 |
+
labels_mask=None,
|
2245 |
+
output_attentions=None,
|
2246 |
+
output_hidden_states=None,
|
2247 |
+
return_dict=None,
|
2248 |
+
):
|
2249 |
+
r"""
|
2250 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
2251 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
2252 |
+
"""
|
2253 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2254 |
+
|
2255 |
+
outputs = self.bert(
|
2256 |
+
input_ids,
|
2257 |
+
attention_mask=attention_mask,
|
2258 |
+
token_type_ids=token_type_ids,
|
2259 |
+
position_ids=position_ids,
|
2260 |
+
head_mask=head_mask,
|
2261 |
+
inputs_embeds=inputs_embeds,
|
2262 |
+
output_attentions=output_attentions,
|
2263 |
+
output_hidden_states=output_hidden_states,
|
2264 |
+
return_dict=return_dict,
|
2265 |
+
)
|
2266 |
+
|
2267 |
+
sequence_output = outputs[0]
|
2268 |
+
|
2269 |
+
# select SEP tokens that represent target bins
|
2270 |
+
bins_output = sequence_output[bins_mask]
|
2271 |
+
|
2272 |
+
bins_output = self.dropout(bins_output)
|
2273 |
+
logits = self.classifier(bins_output)
|
2274 |
+
pred = self.activation(logits)
|
2275 |
+
|
2276 |
+
loss = None
|
2277 |
+
if labels is not None:
|
2278 |
+
loss_fct = nn.PoissonNLLLoss(log_input=False, reduction='mean')
|
2279 |
+
labels = labels[labels_mask]
|
2280 |
+
loss = loss_fct(pred, labels)
|
2281 |
+
|
2282 |
+
if not return_dict:
|
2283 |
+
output = (logits,) + outputs[2:]
|
2284 |
+
return ((loss,) + output) if loss is not None else output
|
2285 |
+
|
2286 |
+
return TokenClassifierOutput(
|
2287 |
+
loss=loss,
|
2288 |
+
logits=logits,
|
2289 |
+
hidden_states=outputs.hidden_states,
|
2290 |
+
attentions=outputs.attentions,
|
2291 |
+
)
|
2292 |
+
|