Optimized wrapper with correct API
Browse files- modeling_norbert.py +64 -36
modeling_norbert.py
CHANGED
|
@@ -1,12 +1,9 @@
|
|
| 1 |
-
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 2 |
-
|
| 3 |
import math
|
| 4 |
from typing import List, Optional, Tuple, Union
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
| 9 |
-
from torch import _softmax_backward_data as _softmax_backward_data
|
| 10 |
from torch.utils import checkpoint
|
| 11 |
|
| 12 |
from configuration_norbert import NorbertConfig
|
|
@@ -20,6 +17,7 @@ from transformers.modeling_outputs import (
|
|
| 20 |
TokenClassifierOutput,
|
| 21 |
BaseModelOutput
|
| 22 |
)
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class Encoder(nn.Module):
|
|
@@ -130,8 +128,8 @@ class MaskedSoftmax(torch.autograd.Function):
|
|
| 130 |
@staticmethod
|
| 131 |
def backward(self, grad_output):
|
| 132 |
output, = self.saved_tensors
|
| 133 |
-
|
| 134 |
-
return
|
| 135 |
|
| 136 |
|
| 137 |
class Attention(nn.Module):
|
|
@@ -188,31 +186,36 @@ class Attention(nn.Module):
|
|
| 188 |
if self.position_indices.size(0) < query_len:
|
| 189 |
position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
|
| 190 |
- torch.arange(query_len, dtype=torch.long).unsqueeze(0)
|
| 191 |
-
position_indices = self.make_log_bucket_position(position_indices, self.
|
| 192 |
-
position_indices = self.
|
| 193 |
-
self.
|
| 194 |
|
| 195 |
hidden_states = self.pre_layer_norm(hidden_states)
|
| 196 |
|
| 197 |
query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
|
| 198 |
value = self.in_proj_v(hidden_states) # shape: [T, B, D]
|
| 199 |
|
| 200 |
-
pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
|
| 201 |
-
pos = F.embedding(self.position_indices[:query_len, :key_len], pos) # shape: [T, T, 2D]
|
| 202 |
-
pos = pos.view(query_len, key_len, self.num_heads, 2*self.head_size)
|
| 203 |
-
query_pos, key_pos = pos.chunk(2, dim=3)
|
| 204 |
-
|
| 205 |
query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
| 206 |
key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
| 207 |
value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
| 208 |
|
| 209 |
attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
|
| 210 |
|
|
|
|
|
|
|
| 211 |
query = query.view(batch_size, self.num_heads, query_len, self.head_size)
|
| 212 |
key = key.view(batch_size, self.num_heads, query_len, self.head_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len)
|
| 214 |
-
attention_scores.add_(
|
| 215 |
-
attention_scores.add_(
|
| 216 |
|
| 217 |
return attention_scores, value
|
| 218 |
|
|
@@ -332,12 +335,16 @@ class NorbertModel(NorbertPreTrainedModel):
|
|
| 332 |
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 333 |
|
| 334 |
if not return_dict:
|
| 335 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
return BaseModelOutput(
|
| 338 |
last_hidden_state=sequence_output,
|
| 339 |
-
hidden_states=contextualized_embeddings,
|
| 340 |
-
attentions=attention_probs
|
| 341 |
)
|
| 342 |
|
| 343 |
|
|
@@ -375,14 +382,18 @@ class NorbertForMaskedLM(NorbertModel):
|
|
| 375 |
masked_lm_loss = F.cross_entropy(subword_prediction.flatten(0, 1), labels.flatten())
|
| 376 |
|
| 377 |
if not return_dict:
|
| 378 |
-
output = (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 380 |
|
| 381 |
return MaskedLMOutput(
|
| 382 |
loss=masked_lm_loss,
|
| 383 |
logits=subword_prediction,
|
| 384 |
-
hidden_states=contextualized_embeddings,
|
| 385 |
-
attentions=attention_probs
|
| 386 |
)
|
| 387 |
|
| 388 |
|
|
@@ -465,14 +476,18 @@ class NorbertForSequenceClassification(NorbertModel):
|
|
| 465 |
loss = loss_fct(logits, labels)
|
| 466 |
|
| 467 |
if not return_dict:
|
| 468 |
-
output = (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
return ((loss,) + output) if loss is not None else output
|
| 470 |
|
| 471 |
return SequenceClassifierOutput(
|
| 472 |
loss=loss,
|
| 473 |
logits=logits,
|
| 474 |
-
hidden_states=contextualized_embeddings,
|
| 475 |
-
attentions=attention_probs
|
| 476 |
)
|
| 477 |
|
| 478 |
|
|
@@ -508,14 +523,18 @@ class NorbertForTokenClassification(NorbertModel):
|
|
| 508 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 509 |
|
| 510 |
if not return_dict:
|
| 511 |
-
output = (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
return ((loss,) + output) if loss is not None else output
|
| 513 |
|
| 514 |
return TokenClassifierOutput(
|
| 515 |
loss=loss,
|
| 516 |
logits=logits,
|
| 517 |
-
hidden_states=contextualized_embeddings,
|
| 518 |
-
attentions=attention_probs
|
| 519 |
)
|
| 520 |
|
| 521 |
|
|
@@ -569,15 +588,20 @@ class NorbertForQuestionAnswering(NorbertModel):
|
|
| 569 |
total_loss = (start_loss + end_loss) / 2
|
| 570 |
|
| 571 |
if not return_dict:
|
| 572 |
-
output =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
return ((total_loss,) + output) if total_loss is not None else output
|
| 574 |
|
| 575 |
return QuestionAnsweringModelOutput(
|
| 576 |
loss=total_loss,
|
| 577 |
start_logits=start_logits,
|
| 578 |
end_logits=end_logits,
|
| 579 |
-
hidden_states=contextualized_embeddings,
|
| 580 |
-
attentions=attention_probs
|
| 581 |
)
|
| 582 |
|
| 583 |
|
|
@@ -598,9 +622,9 @@ class NorbertForMultipleChoice(NorbertModel):
|
|
| 598 |
token_type_ids: Optional[torch.Tensor] = None,
|
| 599 |
position_ids: Optional[torch.Tensor] = None,
|
| 600 |
labels: Optional[torch.Tensor] = None,
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
| 605 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 606 |
num_choices = input_ids.shape[1]
|
|
@@ -618,12 +642,16 @@ class NorbertForMultipleChoice(NorbertModel):
|
|
| 618 |
loss = loss_fct(reshaped_logits, labels)
|
| 619 |
|
| 620 |
if not return_dict:
|
| 621 |
-
output = (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
return ((loss,) + output) if loss is not None else output
|
| 623 |
|
| 624 |
return MultipleChoiceModelOutput(
|
| 625 |
loss=loss,
|
| 626 |
logits=reshaped_logits,
|
| 627 |
-
hidden_states=contextualized_embeddings,
|
| 628 |
-
attentions=attention_probs
|
| 629 |
)
|
|
|
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
from typing import List, Optional, Tuple, Union
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
|
|
|
| 7 |
from torch.utils import checkpoint
|
| 8 |
|
| 9 |
from configuration_norbert import NorbertConfig
|
|
|
|
| 17 |
TokenClassifierOutput,
|
| 18 |
BaseModelOutput
|
| 19 |
)
|
| 20 |
+
from transformers.pytorch_utils import softmax_backward_data
|
| 21 |
|
| 22 |
|
| 23 |
class Encoder(nn.Module):
|
|
|
|
| 128 |
@staticmethod
|
| 129 |
def backward(self, grad_output):
|
| 130 |
output, = self.saved_tensors
|
| 131 |
+
input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
|
| 132 |
+
return input_grad, None, None
|
| 133 |
|
| 134 |
|
| 135 |
class Attention(nn.Module):
|
|
|
|
| 186 |
if self.position_indices.size(0) < query_len:
|
| 187 |
position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
|
| 188 |
- torch.arange(query_len, dtype=torch.long).unsqueeze(0)
|
| 189 |
+
position_indices = self.make_log_bucket_position(position_indices, self.position_bucket_size, 512)
|
| 190 |
+
position_indices = self.position_bucket_size - 1 + position_indices
|
| 191 |
+
self.position_indices = position_indices.to(hidden_states.device)
|
| 192 |
|
| 193 |
hidden_states = self.pre_layer_norm(hidden_states)
|
| 194 |
|
| 195 |
query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
|
| 196 |
value = self.in_proj_v(hidden_states) # shape: [T, B, D]
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
| 199 |
key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
| 200 |
value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
| 201 |
|
| 202 |
attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
|
| 203 |
|
| 204 |
+
pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
|
| 205 |
+
query_pos, key_pos = pos.view(-1, self.num_heads, 2*self.head_size).chunk(2, dim=2)
|
| 206 |
query = query.view(batch_size, self.num_heads, query_len, self.head_size)
|
| 207 |
key = key.view(batch_size, self.num_heads, query_len, self.head_size)
|
| 208 |
+
|
| 209 |
+
attention_c_p = torch.einsum("bhqd,khd->bhqk", query, key_pos.squeeze(1) * self.scale)
|
| 210 |
+
attention_p_c = torch.einsum("bhkd,qhd->bhqk", key * self.scale, query_pos.squeeze(1))
|
| 211 |
+
|
| 212 |
+
position_indices = self.position_indices[:query_len, :key_len].expand(batch_size, self.num_heads, -1, -1)
|
| 213 |
+
attention_c_p = attention_c_p.gather(3, position_indices)
|
| 214 |
+
attention_p_c = attention_p_c.gather(2, position_indices)
|
| 215 |
+
|
| 216 |
attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len)
|
| 217 |
+
attention_scores.add_(attention_c_p)
|
| 218 |
+
attention_scores.add_(attention_p_c)
|
| 219 |
|
| 220 |
return attention_scores, value
|
| 221 |
|
|
|
|
| 335 |
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
|
| 336 |
|
| 337 |
if not return_dict:
|
| 338 |
+
return (
|
| 339 |
+
sequence_output,
|
| 340 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 341 |
+
*([attention_probs] if output_attentions else [])
|
| 342 |
+
)
|
| 343 |
|
| 344 |
return BaseModelOutput(
|
| 345 |
last_hidden_state=sequence_output,
|
| 346 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
| 347 |
+
attentions=attention_probs if output_attentions else None
|
| 348 |
)
|
| 349 |
|
| 350 |
|
|
|
|
| 382 |
masked_lm_loss = F.cross_entropy(subword_prediction.flatten(0, 1), labels.flatten())
|
| 383 |
|
| 384 |
if not return_dict:
|
| 385 |
+
output = (
|
| 386 |
+
subword_prediction,
|
| 387 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 388 |
+
*([attention_probs] if output_attentions else [])
|
| 389 |
+
)
|
| 390 |
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 391 |
|
| 392 |
return MaskedLMOutput(
|
| 393 |
loss=masked_lm_loss,
|
| 394 |
logits=subword_prediction,
|
| 395 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
| 396 |
+
attentions=attention_probs if output_attentions else None
|
| 397 |
)
|
| 398 |
|
| 399 |
|
|
|
|
| 476 |
loss = loss_fct(logits, labels)
|
| 477 |
|
| 478 |
if not return_dict:
|
| 479 |
+
output = (
|
| 480 |
+
logits,
|
| 481 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 482 |
+
*([attention_probs] if output_attentions else [])
|
| 483 |
+
)
|
| 484 |
return ((loss,) + output) if loss is not None else output
|
| 485 |
|
| 486 |
return SequenceClassifierOutput(
|
| 487 |
loss=loss,
|
| 488 |
logits=logits,
|
| 489 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
| 490 |
+
attentions=attention_probs if output_attentions else None
|
| 491 |
)
|
| 492 |
|
| 493 |
|
|
|
|
| 523 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 524 |
|
| 525 |
if not return_dict:
|
| 526 |
+
output = (
|
| 527 |
+
logits,
|
| 528 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 529 |
+
*([attention_probs] if output_attentions else [])
|
| 530 |
+
)
|
| 531 |
return ((loss,) + output) if loss is not None else output
|
| 532 |
|
| 533 |
return TokenClassifierOutput(
|
| 534 |
loss=loss,
|
| 535 |
logits=logits,
|
| 536 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
| 537 |
+
attentions=attention_probs if output_attentions else None
|
| 538 |
)
|
| 539 |
|
| 540 |
|
|
|
|
| 588 |
total_loss = (start_loss + end_loss) / 2
|
| 589 |
|
| 590 |
if not return_dict:
|
| 591 |
+
output = (
|
| 592 |
+
start_logits,
|
| 593 |
+
end_logits,
|
| 594 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 595 |
+
*([attention_probs] if output_attentions else [])
|
| 596 |
+
)
|
| 597 |
return ((total_loss,) + output) if total_loss is not None else output
|
| 598 |
|
| 599 |
return QuestionAnsweringModelOutput(
|
| 600 |
loss=total_loss,
|
| 601 |
start_logits=start_logits,
|
| 602 |
end_logits=end_logits,
|
| 603 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
| 604 |
+
attentions=attention_probs if output_attentions else None
|
| 605 |
)
|
| 606 |
|
| 607 |
|
|
|
|
| 622 |
token_type_ids: Optional[torch.Tensor] = None,
|
| 623 |
position_ids: Optional[torch.Tensor] = None,
|
| 624 |
labels: Optional[torch.Tensor] = None,
|
| 625 |
+
output_attentions: Optional[bool] = None,
|
| 626 |
+
output_hidden_states: Optional[bool] = None,
|
| 627 |
+
return_dict: Optional[bool] = None
|
| 628 |
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
| 629 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 630 |
num_choices = input_ids.shape[1]
|
|
|
|
| 642 |
loss = loss_fct(reshaped_logits, labels)
|
| 643 |
|
| 644 |
if not return_dict:
|
| 645 |
+
output = (
|
| 646 |
+
reshaped_logits,
|
| 647 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 648 |
+
*([attention_probs] if output_attentions else [])
|
| 649 |
+
)
|
| 650 |
return ((loss,) + output) if loss is not None else output
|
| 651 |
|
| 652 |
return MultipleChoiceModelOutput(
|
| 653 |
loss=loss,
|
| 654 |
logits=reshaped_logits,
|
| 655 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
| 656 |
+
attentions=attention_probs if output_attentions else None
|
| 657 |
)
|