removed SDPA
Browse files- modeling_gptbert.py +30 -14
modeling_gptbert.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
from torch.nn import functional as F
|
|
|
|
| 6 |
|
| 7 |
from functools import partial, lru_cache
|
| 8 |
|
|
@@ -37,17 +38,11 @@ try:
|
|
| 37 |
logger.warning_once(
|
| 38 |
"NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
|
| 39 |
)
|
| 40 |
-
torch.backends.cuda.enable_flash_sdp(False)
|
| 41 |
-
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
| 42 |
-
torch.backends.cuda.enable_math_sdp(True)
|
| 43 |
except ImportError:
|
| 44 |
flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
|
| 45 |
logger.warning_once(
|
| 46 |
"NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
|
| 47 |
)
|
| 48 |
-
torch.backends.cuda.enable_flash_sdp(False)
|
| 49 |
-
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
| 50 |
-
torch.backends.cuda.enable_math_sdp(True)
|
| 51 |
|
| 52 |
|
| 53 |
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
|
@@ -318,6 +313,25 @@ class RotaryPositionalEmbeddings(nn.Module):
|
|
| 318 |
return out.type_as(x)
|
| 319 |
|
| 320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
class SelfAttention(nn.Module):
|
| 322 |
def __init__(self, config: GptBertConfig, layer_idx: int):
|
| 323 |
super().__init__()
|
|
@@ -347,6 +361,7 @@ class SelfAttention(nn.Module):
|
|
| 347 |
self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, self.d_qk))
|
| 348 |
self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, self.d_qk))
|
| 349 |
|
|
|
|
| 350 |
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 351 |
|
| 352 |
theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
|
|
@@ -390,14 +405,15 @@ class SelfAttention(nn.Module):
|
|
| 390 |
else:
|
| 391 |
attention_mask = window_mask
|
| 392 |
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
)
|
|
|
|
| 401 |
return output
|
| 402 |
|
| 403 |
def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
from torch.nn import functional as F
|
| 6 |
+
from torch import _softmax_backward_data as _softmax_backward_data
|
| 7 |
|
| 8 |
from functools import partial, lru_cache
|
| 9 |
|
|
|
|
| 38 |
logger.warning_once(
|
| 39 |
"NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
|
| 40 |
)
|
|
|
|
|
|
|
|
|
|
| 41 |
except ImportError:
|
| 42 |
flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
|
| 43 |
logger.warning_once(
|
| 44 |
"NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
|
| 45 |
)
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
|
|
|
|
| 313 |
return out.type_as(x)
|
| 314 |
|
| 315 |
|
| 316 |
+
class MaskedSoftmax(torch.autograd.Function):
|
| 317 |
+
@staticmethod
|
| 318 |
+
def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int) -> torch.Tensor:
|
| 319 |
+
ctx.dim = dim
|
| 320 |
+
x.masked_fill_(mask, float('-inf'))
|
| 321 |
+
x = torch.softmax(x, ctx.dim)
|
| 322 |
+
x.masked_fill_(mask, 0.0)
|
| 323 |
+
ctx.save_for_backward(x)
|
| 324 |
+
return x
|
| 325 |
+
|
| 326 |
+
@staticmethod
|
| 327 |
+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
|
| 328 |
+
output: torch.Tensor
|
| 329 |
+
|
| 330 |
+
output, = ctx.saved_tensors
|
| 331 |
+
inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
|
| 332 |
+
return inputGrad, None, None
|
| 333 |
+
|
| 334 |
+
|
| 335 |
class SelfAttention(nn.Module):
|
| 336 |
def __init__(self, config: GptBertConfig, layer_idx: int):
|
| 337 |
super().__init__()
|
|
|
|
| 361 |
self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, self.d_qk))
|
| 362 |
self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, self.d_qk))
|
| 363 |
|
| 364 |
+
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
| 365 |
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 366 |
|
| 367 |
theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
|
|
|
|
| 405 |
else:
|
| 406 |
attention_mask = window_mask
|
| 407 |
|
| 408 |
+
attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, Q_T, K_T]
|
| 409 |
+
attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
|
| 410 |
+
|
| 411 |
+
attention_probabilities = MaskedSoftmax.apply(attention_scores, ~attention_mask, -1)
|
| 412 |
+
attention_probabilities = self.attention_dropout(attention_probabilities)
|
| 413 |
+
|
| 414 |
+
output = torch.bmm(attention_probabilities.flatten(0, 1), value.flatten(0, 1))
|
| 415 |
+
output = output.view(batch_size, self.num_attention_heads, query_length, self.d_v)
|
| 416 |
+
|
| 417 |
return output
|
| 418 |
|
| 419 |
def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
|