HGB commited on
Commit
9c8bb9e
1 Parent(s): e9fd7b3

add bert padding + modify internVit classes

Browse files
modeling_intern_vit.py CHANGED
@@ -20,14 +20,9 @@ from transformers.utils import logging
20
  from .configuration_intern_vit import InternVisionConfig
21
 
22
  try:
23
- try: # v1
24
- from flash_attn.flash_attn_interface import \
25
- flash_attn_unpadded_qkvpacked_func
26
- except: # v2
27
- from flash_attn.flash_attn_interface import \
28
- flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
29
 
30
- from flash_attn.bert_padding import pad_input, unpad_input
31
 
32
  has_flash_attn = True
33
  except:
@@ -74,28 +69,31 @@ class FlashAttention(nn.Module):
74
  max_s = seqlen
75
  cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
76
  device=qkv.device)
77
- output = flash_attn_unpadded_qkvpacked_func(
78
  qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
79
- softmax_scale=self.softmax_scale, causal=causal
80
  )
81
- output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
 
82
  else:
83
  nheads = qkv.shape[-2]
84
  x = rearrange(qkv, 'b s three h d -> b s (three h d)')
85
- x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
86
- x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
87
- output_unpad = flash_attn_unpadded_qkvpacked_func(
 
 
88
  x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
89
- softmax_scale=self.softmax_scale, causal=causal
90
  )
91
  output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
92
  indices, batch_size, seqlen),
93
  'b s (h d) -> b s h d', h=nheads)
94
  else:
95
  assert max_s is not None
96
- output = flash_attn_unpadded_qkvpacked_func(
97
  qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
98
- softmax_scale=self.softmax_scale, causal=causal
99
  )
100
 
101
  return output, None
@@ -111,7 +109,8 @@ class InternRMSNorm(nn.Module):
111
  input_dtype = hidden_states.dtype
112
  hidden_states = hidden_states.to(torch.float32)
113
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
114
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 
115
  return self.weight * hidden_states.to(input_dtype)
116
 
117
 
@@ -120,12 +119,14 @@ try:
120
 
121
  InternRMSNorm = FusedRMSNorm # noqa
122
 
123
- logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
 
124
  except ImportError:
125
  # using the normal InternRMSNorm
126
  pass
127
  except Exception:
128
- logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
 
129
  pass
130
 
131
 
@@ -154,7 +155,8 @@ class InternVisionEmbeddings(nn.Module):
154
  self.num_patches = (self.image_size // self.patch_size) ** 2
155
  self.num_positions = self.num_patches + 1
156
 
157
- self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
 
158
 
159
  def _get_pos_embed(self, pos_embed, H, W):
160
  target_dtype = pos_embed.dtype
@@ -166,14 +168,17 @@ class InternVisionEmbeddings(nn.Module):
166
 
167
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
168
  target_dtype = self.patch_embedding.weight.dtype
169
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
 
170
  batch_size, _, height, width = patch_embeds.shape
171
  patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
172
- class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
 
173
  embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
174
  position_embedding = torch.cat([
175
  self.position_embedding[:, :1, :],
176
- self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
 
177
  ], dim=1)
178
  embeddings = embeddings + position_embedding.to(target_dtype)
179
  return embeddings
@@ -189,38 +194,48 @@ class InternAttention(nn.Module):
189
  self.num_heads = config.num_attention_heads
190
  self.use_flash_attn = config.use_flash_attn and has_flash_attn
191
  if config.use_flash_attn and not has_flash_attn:
192
- print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
 
193
  self.head_dim = self.embed_dim // self.num_heads
194
  if self.head_dim * self.num_heads != self.embed_dim:
195
  raise ValueError(
196
- f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
 
197
  f' {self.num_heads}).'
198
  )
199
 
200
  self.scale = self.head_dim ** -0.5
201
- self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
 
202
  self.attn_drop = nn.Dropout(config.attention_dropout)
203
  self.proj_drop = nn.Dropout(config.dropout)
204
 
205
  self.qk_normalization = config.qk_normalization
206
 
207
  if self.qk_normalization:
208
- self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
209
- self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
 
 
210
 
211
  if self.use_flash_attn:
212
- self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
 
213
  self.proj = nn.Linear(self.embed_dim, self.embed_dim)
214
 
215
  def _naive_attn(self, x):
216
  B, N, C = x.shape
217
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
218
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
 
 
219
 
220
  if self.qk_normalization:
221
  B_, H_, N_, D_ = q.shape
222
- q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
223
- k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
 
 
224
 
225
  attn = ((q * self.scale) @ k.transpose(-2, -1))
226
  attn = attn.softmax(dim=-1)
@@ -233,7 +248,8 @@ class InternAttention(nn.Module):
233
 
234
  def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
235
  qkv = self.qkv(x)
236
- qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
 
237
 
238
  if self.qk_normalization:
239
  q, k, v = qkv.unbind(2)
@@ -249,7 +265,8 @@ class InternAttention(nn.Module):
249
  return outs
250
 
251
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
252
- x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
 
253
  return x
254
 
255
 
@@ -277,13 +294,19 @@ class InternVisionEncoderLayer(nn.Module):
277
 
278
  self.attn = InternAttention(config)
279
  self.mlp = InternMLP(config)
280
- self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
281
- self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
282
-
283
- self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
284
- self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
285
- self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
286
- self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
 
 
 
 
 
 
287
 
288
  def forward(
289
  self,
@@ -293,9 +316,11 @@ class InternVisionEncoderLayer(nn.Module):
293
  Args:
294
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
295
  """
296
- hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
 
297
 
298
- hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
 
299
 
300
  return hidden_states
301
 
@@ -314,7 +339,8 @@ class InternVisionEncoder(nn.Module):
314
  super().__init__()
315
  self.config = config
316
  # stochastic depth decay rule
317
- dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
 
318
  self.layers = nn.ModuleList([
319
  InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
320
  self.gradient_checkpointing = True
@@ -382,13 +408,17 @@ class InternVisionModel(PreTrainedModel):
382
  pos_emb = self.embeddings.position_embedding
383
  _, num_positions, embed_dim = pos_emb.shape
384
  cls_emb = pos_emb[:, :1, :]
385
- pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
386
- pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
387
- pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
 
 
 
388
  pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
389
  self.embeddings.position_embedding = nn.Parameter(pos_emb)
390
  self.embeddings.image_size = new_size
391
- logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
 
392
 
393
  def get_input_embeddings(self):
394
  return self.embeddings
@@ -406,7 +436,8 @@ class InternVisionModel(PreTrainedModel):
406
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
407
 
408
  if pixel_values is None and pixel_embeds is None:
409
- raise ValueError('You have to specify pixel_values or pixel_embeds')
 
410
 
411
  if pixel_embeds is not None:
412
  hidden_states = pixel_embeds
@@ -414,7 +445,8 @@ class InternVisionModel(PreTrainedModel):
414
  if len(pixel_values.shape) == 4:
415
  hidden_states = self.embeddings(pixel_values)
416
  else:
417
- raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
 
418
  encoder_outputs = self.encoder(
419
  inputs_embeds=hidden_states,
420
  output_hidden_states=output_hidden_states,
 
20
  from .configuration_intern_vit import InternVisionConfig
21
 
22
  try:
23
+ from triton_flash_atn import _attention
 
 
 
 
 
24
 
25
+ from triton_bert_pading import pad_input, unpad_input
26
 
27
  has_flash_attn = True
28
  except:
 
69
  max_s = seqlen
70
  cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
71
  device=qkv.device)
72
+ output = _attention.apply(
73
  qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
74
+ sm_scale=self.softmax_scale, causal=causal
75
  )
76
+ output = rearrange(
77
+ output, '(b s) ... -> b s ...', b=batch_size)
78
  else:
79
  nheads = qkv.shape[-2]
80
  x = rearrange(qkv, 'b s three h d -> b s (three h d)')
81
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(
82
+ x, key_padding_mask)
83
+ x_unpad = rearrange(
84
+ x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
85
+ output_unpad = _attention.apply(
86
  x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
87
+ sm_scale=self.softmax_scale, causal=causal
88
  )
89
  output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
90
  indices, batch_size, seqlen),
91
  'b s (h d) -> b s h d', h=nheads)
92
  else:
93
  assert max_s is not None
94
+ output = _attention.apply(
95
  qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
96
+ sm_scale=self.softmax_scale, causal=causal
97
  )
98
 
99
  return output, None
 
109
  input_dtype = hidden_states.dtype
110
  hidden_states = hidden_states.to(torch.float32)
111
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
112
+ hidden_states = hidden_states * \
113
+ torch.rsqrt(variance + self.variance_epsilon)
114
  return self.weight * hidden_states.to(input_dtype)
115
 
116
 
 
119
 
120
  InternRMSNorm = FusedRMSNorm # noqa
121
 
122
+ logger.info(
123
+ 'Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
124
  except ImportError:
125
  # using the normal InternRMSNorm
126
  pass
127
  except Exception:
128
+ logger.warning(
129
+ 'discovered apex but it failed to load, falling back to InternRMSNorm')
130
  pass
131
 
132
 
 
155
  self.num_patches = (self.image_size // self.patch_size) ** 2
156
  self.num_positions = self.num_patches + 1
157
 
158
+ self.position_embedding = nn.Parameter(
159
+ torch.randn(1, self.num_positions, self.embed_dim))
160
 
161
  def _get_pos_embed(self, pos_embed, H, W):
162
  target_dtype = pos_embed.dtype
 
168
 
169
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
170
  target_dtype = self.patch_embedding.weight.dtype
171
+ # shape = [*, channel, width, height]
172
+ patch_embeds = self.patch_embedding(pixel_values)
173
  batch_size, _, height, width = patch_embeds.shape
174
  patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
175
+ class_embeds = self.class_embedding.expand(
176
+ batch_size, 1, -1).to(target_dtype)
177
  embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
178
  position_embedding = torch.cat([
179
  self.position_embedding[:, :1, :],
180
+ self._get_pos_embed(
181
+ self.position_embedding[:, 1:, :], height, width)
182
  ], dim=1)
183
  embeddings = embeddings + position_embedding.to(target_dtype)
184
  return embeddings
 
194
  self.num_heads = config.num_attention_heads
195
  self.use_flash_attn = config.use_flash_attn and has_flash_attn
196
  if config.use_flash_attn and not has_flash_attn:
197
+ print(
198
+ 'Warning: Flash Attention is not available, use_flash_attn is set to False.')
199
  self.head_dim = self.embed_dim // self.num_heads
200
  if self.head_dim * self.num_heads != self.embed_dim:
201
  raise ValueError(
202
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {
203
+ self.embed_dim} and `num_heads`:'
204
  f' {self.num_heads}).'
205
  )
206
 
207
  self.scale = self.head_dim ** -0.5
208
+ self.qkv = nn.Linear(self.embed_dim, 3 *
209
+ self.embed_dim, bias=config.qkv_bias)
210
  self.attn_drop = nn.Dropout(config.attention_dropout)
211
  self.proj_drop = nn.Dropout(config.dropout)
212
 
213
  self.qk_normalization = config.qk_normalization
214
 
215
  if self.qk_normalization:
216
+ self.q_norm = InternRMSNorm(
217
+ self.embed_dim, eps=config.layer_norm_eps)
218
+ self.k_norm = InternRMSNorm(
219
+ self.embed_dim, eps=config.layer_norm_eps)
220
 
221
  if self.use_flash_attn:
222
+ self.inner_attn = FlashAttention(
223
+ attention_dropout=config.attention_dropout)
224
  self.proj = nn.Linear(self.embed_dim, self.embed_dim)
225
 
226
  def _naive_attn(self, x):
227
  B, N, C = x.shape
228
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
229
+ self.num_heads).permute(2, 0, 3, 1, 4)
230
+ # make torchscript happy (cannot use tensor as tuple)
231
+ q, k, v = qkv.unbind(0)
232
 
233
  if self.qk_normalization:
234
  B_, H_, N_, D_ = q.shape
235
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)
236
+ ).view(B_, N_, H_, D_).transpose(1, 2)
237
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)
238
+ ).view(B_, N_, H_, D_).transpose(1, 2)
239
 
240
  attn = ((q * self.scale) @ k.transpose(-2, -1))
241
  attn = attn.softmax(dim=-1)
 
248
 
249
  def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
250
  qkv = self.qkv(x)
251
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d',
252
+ three=3, h=self.num_heads)
253
 
254
  if self.qk_normalization:
255
  q, k, v = qkv.unbind(2)
 
265
  return outs
266
 
267
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
268
+ x = self._naive_attn(
269
+ hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
270
  return x
271
 
272
 
 
294
 
295
  self.attn = InternAttention(config)
296
  self.mlp = InternMLP(config)
297
+ self.norm1 = NORM2FN[self.norm_type](
298
+ self.embed_dim, eps=config.layer_norm_eps)
299
+ self.norm2 = NORM2FN[self.norm_type](
300
+ self.embed_dim, eps=config.layer_norm_eps)
301
+
302
+ self.ls1 = nn.Parameter(
303
+ config.initializer_factor * torch.ones(self.embed_dim))
304
+ self.ls2 = nn.Parameter(
305
+ config.initializer_factor * torch.ones(self.embed_dim))
306
+ self.drop_path1 = DropPath(
307
+ drop_path_rate) if drop_path_rate > 0. else nn.Identity()
308
+ self.drop_path2 = DropPath(
309
+ drop_path_rate) if drop_path_rate > 0. else nn.Identity()
310
 
311
  def forward(
312
  self,
 
316
  Args:
317
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
318
  """
319
+ hidden_states = hidden_states + \
320
+ self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
321
 
322
+ hidden_states = hidden_states + \
323
+ self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
324
 
325
  return hidden_states
326
 
 
339
  super().__init__()
340
  self.config = config
341
  # stochastic depth decay rule
342
+ dpr = [x.item() for x in torch.linspace(
343
+ 0, config.drop_path_rate, config.num_hidden_layers)]
344
  self.layers = nn.ModuleList([
345
  InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
346
  self.gradient_checkpointing = True
 
408
  pos_emb = self.embeddings.position_embedding
409
  _, num_positions, embed_dim = pos_emb.shape
410
  cls_emb = pos_emb[:, :1, :]
411
+ pos_emb = pos_emb[:, 1:, :].reshape(
412
+ 1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
413
+ pos_emb = F.interpolate(pos_emb.float(
414
+ ), size=new_size // patch_size, mode='bicubic', align_corners=False)
415
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(
416
+ 1, embed_dim, -1).permute(0, 2, 1)
417
  pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
418
  self.embeddings.position_embedding = nn.Parameter(pos_emb)
419
  self.embeddings.image_size = new_size
420
+ logger.info('Resized position embeddings from {} to {}'.format(
421
+ old_size, new_size))
422
 
423
  def get_input_embeddings(self):
424
  return self.embeddings
 
436
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
437
 
438
  if pixel_values is None and pixel_embeds is None:
439
+ raise ValueError(
440
+ 'You have to specify pixel_values or pixel_embeds')
441
 
442
  if pixel_embeds is not None:
443
  hidden_states = pixel_embeds
 
445
  if len(pixel_values.shape) == 4:
446
  hidden_states = self.embeddings(pixel_values)
447
  else:
448
+ raise ValueError(f'wrong pixel_values size: {
449
+ pixel_values.shape}')
450
  encoder_outputs = self.encoder(
451
  inputs_embeds=hidden_states,
452
  output_hidden_states=output_hidden_states,
triton.py → triton-test.py RENAMED
File without changes
triton_bert_pading.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+
7
+
8
+ class IndexFirstAxis(torch.autograd.Function):
9
+ @staticmethod
10
+ def forward(ctx, input, indices):
11
+ ctx.save_for_backward(indices)
12
+ assert input.ndim >= 2
13
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
14
+ second_dim = other_shape.numel()
15
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
16
+ # return input[indices]
17
+ return torch.gather(
18
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices,
19
+ "z -> z d", d=second_dim)
20
+ ).reshape(-1, *other_shape)
21
+
22
+ @staticmethod
23
+ def backward(ctx, grad_output):
24
+ (indices,) = ctx.saved_tensors
25
+ assert grad_output.ndim >= 2
26
+ other_shape = grad_output.shape[1:]
27
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
28
+ grad_input = torch.zeros(
29
+ [ctx.first_axis_dim, grad_output.shape[1]],
30
+ device=grad_output.device,
31
+ dtype=grad_output.dtype,
32
+ )
33
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
34
+ # grad_input[indices] = grad_output
35
+ grad_input.scatter_(0, repeat(indices, "z -> z d",
36
+ d=grad_output.shape[1]), grad_output)
37
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
38
+
39
+
40
+ index_first_axis = IndexFirstAxis.apply
41
+
42
+
43
+ class IndexPutFirstAxis(torch.autograd.Function):
44
+ @staticmethod
45
+ def forward(ctx, values, indices, first_axis_dim):
46
+ ctx.save_for_backward(indices)
47
+ assert indices.ndim == 1
48
+ assert values.ndim >= 2
49
+ output = torch.zeros(
50
+ first_axis_dim, *
51
+ values.shape[1:], device=values.device, dtype=values.dtype
52
+ )
53
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
54
+ output[indices] = values
55
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output):
60
+ (indices,) = ctx.saved_tensors
61
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
62
+ grad_values = grad_output[indices]
63
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
64
+ return grad_values, None, None
65
+
66
+
67
+ index_put_first_axis = IndexPutFirstAxis.apply
68
+
69
+
70
+ class IndexFirstAxisResidual(torch.autograd.Function):
71
+ @staticmethod
72
+ def forward(ctx, input, indices):
73
+ ctx.save_for_backward(indices)
74
+ assert input.ndim >= 2
75
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
76
+ second_dim = other_shape.numel()
77
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
78
+ output = input[indices]
79
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
80
+ # memory format to channel_first. In other words, input might not be contiguous.
81
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
82
+ return output, input.detach()
83
+
84
+ @staticmethod
85
+ def backward(ctx, grad_output, grad_residual):
86
+ (indices,) = ctx.saved_tensors
87
+ assert grad_output.ndim >= 2
88
+ other_shape = grad_output.shape[1:]
89
+ assert grad_residual.shape[1:] == other_shape
90
+ grad_input = grad_residual
91
+ # grad_input[indices] += grad_output
92
+ indices = indices.reshape(
93
+ indices.shape[0], *((1,) * (grad_output.ndim - 1)))
94
+ indices = indices.expand_as(grad_output)
95
+ grad_input.scatter_add_(0, indices, grad_output)
96
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
97
+
98
+
99
+ index_first_axis_residual = IndexFirstAxisResidual.apply
100
+
101
+
102
+ def unpad_input(hidden_states, attention_mask):
103
+ """
104
+ Arguments:
105
+ hidden_states: (batch, seqlen, ...)
106
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
107
+ Return:
108
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
109
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
110
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
111
+ max_seqlen_in_batch: int
112
+ """
113
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
114
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
115
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
116
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0,
117
+ dtype=torch.torch.int32), (1, 0))
118
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
119
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
120
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
121
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
122
+ # so we write custom forward and backward to make it a bit faster.
123
+ return (
124
+ index_first_axis(
125
+ rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
126
+ indices,
127
+ cu_seqlens,
128
+ max_seqlen_in_batch,
129
+ )
130
+
131
+
132
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
133
+ """
134
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
135
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
136
+
137
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
138
+ ```
139
+ [
140
+ [2, 3, 0, 0, 0, 0],
141
+ [3, 2, 0, 0, 0, 0],
142
+ [6, 0, 0, 0, 0, 0]
143
+ ]
144
+ ```
145
+ , which refers to the 3D-attention mask:
146
+ ```
147
+ [
148
+ [
149
+ [1, 0, 0, 0, 0, 0],
150
+ [1, 1, 0, 0, 0, 0],
151
+ [0, 0, 1, 0, 0, 0],
152
+ [0, 0, 1, 1, 0, 0],
153
+ [0, 0, 1, 1, 1, 0],
154
+ [0, 0, 0, 0, 0, 1]
155
+ ],
156
+ [
157
+ [1, 0, 0, 0, 0, 0],
158
+ [1, 1, 0, 0, 0, 0],
159
+ [1, 1, 1, 0, 0, 0],
160
+ [0, 0, 0, 1, 0, 0],
161
+ [0, 0, 0, 1, 1, 0],
162
+ [0, 0, 0, 0, 0, 1]
163
+ ],
164
+ [
165
+ [1, 0, 0, 0, 0, 0],
166
+ [1, 1, 0, 0, 0, 0],
167
+ [1, 1, 1, 0, 0, 0],
168
+ [1, 1, 1, 1, 0, 0],
169
+ [1, 1, 1, 1, 1, 0],
170
+ [1, 1, 1, 1, 1, 1]
171
+ ]
172
+ ]
173
+ ```.
174
+
175
+ Arguments:
176
+ hidden_states: (batch, seqlen, ...)
177
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
178
+ Return:
179
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
180
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
181
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
182
+ max_seqlen_in_batch: int
183
+ """
184
+ length = attention_mask_in_length.sum(dim=-1)
185
+ seqlen = attention_mask_in_length.size(-1)
186
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(
187
+ len(length), seqlen) < length.unsqueeze(1)
188
+ real_indices_idx = torch.nonzero(
189
+ attention_mask_in_length.flatten(), as_tuple=False).flatten()
190
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
191
+ indices = torch.nonzero(attention_mask_2d.flatten(),
192
+ as_tuple=False).flatten()
193
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
194
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0,
195
+ dtype=torch.torch.int32), (1, 0))
196
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
197
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
198
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
199
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
200
+ # so we write custom forward and backward to make it a bit faster.
201
+ return (
202
+ index_first_axis(
203
+ rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
204
+ indices,
205
+ cu_seqlens,
206
+ max_seqlen_in_batch,
207
+ )
208
+
209
+
210
+ def pad_input(hidden_states, indices, batch, seqlen):
211
+ """
212
+ Arguments:
213
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
214
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
215
+ batch: int, batch size for the padded sequence.
216
+ seqlen: int, maximum sequence length for the padded sequence.
217
+ Return:
218
+ hidden_states: (batch, seqlen, ...)
219
+ """
220
+ dim = hidden_states.shape[-1]
221
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
222
+ # output[indices] = hidden_states
223
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
224
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
triton_flash_atn.py CHANGED
@@ -11,62 +11,66 @@ Extra Credits:
11
 
12
  """
13
 
 
14
  import torch
15
 
16
  import triton
17
  import triton.language as tl
18
 
 
19
 
20
- def is_hip():
21
- return triton.runtime.driver.HIP
 
 
 
 
 
22
 
23
 
24
  @triton.jit
25
- def _attn_fwd_inner(acc, l_i, m_i, q, #
26
- K_block_ptr, V_block_ptr, #
27
- start_m, qk_scale, #
28
- BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
29
- STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
30
- N_CTX: tl.constexpr, fp8_v: tl.constexpr):
 
31
  # range of values handled by this stage
32
  if STAGE == 1:
33
  lo, hi = 0, start_m * BLOCK_M
34
  elif STAGE == 2:
35
  lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
36
  lo = tl.multiple_of(lo, BLOCK_M)
 
 
37
  # causal = False
38
  else:
39
  lo, hi = 0, N_CTX
40
- K_block_ptr = tl.advance(K_block_ptr, (0, lo))
41
- V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
42
  # loop over k, v and update accumulator
43
  for start_n in range(lo, hi, BLOCK_N):
44
  start_n = tl.multiple_of(start_n, BLOCK_N)
45
  # -- compute qk ----
46
  k = tl.load(K_block_ptr)
47
- qk = tl.dot(q, k)
 
 
48
  if STAGE == 2:
49
  mask = offs_m[:, None] >= (start_n + offs_n[None, :])
50
- qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
51
- m_ij = tl.maximum(m_i, tl.max(qk, 1))
52
- qk -= m_ij[:, None]
53
- else:
54
- m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
55
- qk = qk * qk_scale - m_ij[:, None]
56
  p = tl.math.exp2(qk)
57
- l_ij = tl.sum(p, 1)
58
- # -- update m_i and l_i
59
- alpha = tl.math.exp2(m_i - m_ij)
60
- l_i = l_i * alpha + l_ij
61
  # -- update output accumulator --
 
62
  acc = acc * alpha[:, None]
63
- # update acc
64
- v = tl.load(V_block_ptr)
65
- if fp8_v:
66
- p = p.to(tl.float8e5)
67
- else:
68
- p = p.to(tl.float16)
69
- acc = tl.dot(p, v, acc)
70
  # update m_i and l_i
71
  m_i = m_ij
72
  V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
@@ -74,80 +78,77 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
74
  return acc, l_i, m_i
75
 
76
 
77
- # We don't run auto-tuning every time to keep the tutorial fast. Keeping
78
  # the code below and commenting out the equivalent parameters is convenient for
79
  # re-tuning.
80
- configs = [
81
- triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w)
82
- for BM in [64, 128]
83
- for BN in [32, 64]
84
- for s in ([1] if is_hip() else [3, 4, 7])
85
- for w in [4, 8]
86
- ]
87
-
88
-
89
- def keep(conf):
90
- BLOCK_M = conf.kwargs["BLOCK_M"]
91
- BLOCK_N = conf.kwargs["BLOCK_N"]
92
- if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8:
93
- return False
94
- return True
95
-
96
-
97
- @triton.autotune(list(filter(keep, configs)), key=["N_CTX"])
 
98
  @triton.jit
99
- def _attn_fwd(Q, K, V, sm_scale, M, Out, #
100
- stride_qz, stride_qh, stride_qm, stride_qk, #
101
- stride_kz, stride_kh, stride_kn, stride_kk, #
102
- stride_vz, stride_vh, stride_vk, stride_vn, #
103
- stride_oz, stride_oh, stride_om, stride_on, #
104
- Z, H, N_CTX, #
105
- BLOCK_M: tl.constexpr, #
106
- BLOCK_N: tl.constexpr, #
107
- HEAD_DIM: tl.constexpr, #
108
- STAGE: tl.constexpr #
 
 
109
  ):
110
- tl.static_assert(BLOCK_N <= HEAD_DIM)
111
  start_m = tl.program_id(0)
112
  off_hz = tl.program_id(1)
113
- off_z = off_hz // H
114
- off_h = off_hz % H
115
- qvk_offset = off_z.to(tl.int64) * stride_qz + \
116
- off_h.to(tl.int64) * stride_qh
117
 
118
  # block pointers
119
  Q_block_ptr = tl.make_block_ptr(
120
  base=Q + qvk_offset,
121
- shape=(N_CTX, HEAD_DIM),
122
  strides=(stride_qm, stride_qk),
123
  offsets=(start_m * BLOCK_M, 0),
124
- block_shape=(BLOCK_M, HEAD_DIM),
125
  order=(1, 0),
126
  )
127
- v_order: tl.constexpr = (
128
- 0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
129
  V_block_ptr = tl.make_block_ptr(
130
  base=V + qvk_offset,
131
- shape=(N_CTX, HEAD_DIM),
132
  strides=(stride_vk, stride_vn),
133
  offsets=(0, 0),
134
- block_shape=(BLOCK_N, HEAD_DIM),
135
- order=v_order,
136
  )
137
  K_block_ptr = tl.make_block_ptr(
138
  base=K + qvk_offset,
139
- shape=(HEAD_DIM, N_CTX),
140
  strides=(stride_kk, stride_kn),
141
  offsets=(0, 0),
142
- block_shape=(HEAD_DIM, BLOCK_N),
143
  order=(0, 1),
144
  )
145
  O_block_ptr = tl.make_block_ptr(
146
  base=Out + qvk_offset,
147
- shape=(N_CTX, HEAD_DIM),
148
  strides=(stride_om, stride_on),
149
  offsets=(start_m * BLOCK_M, 0),
150
- block_shape=(BLOCK_M, HEAD_DIM),
151
  order=(1, 0),
152
  )
153
  # initialize offsets
@@ -156,82 +157,99 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
156
  # initialize pointer to m and l
157
  m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
158
  l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
159
- acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
160
- # load scales
161
- qk_scale = sm_scale
162
- qk_scale *= 1.44269504 # 1/log(2)
163
- # load q: it will stay in SRAM throughout
 
164
  q = tl.load(Q_block_ptr)
 
165
  # stage 1: off-band
166
  # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
167
  # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
168
  if STAGE & 1:
169
- acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
170
- start_m, qk_scale, #
171
- BLOCK_M, HEAD_DIM, BLOCK_N, #
172
- 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
 
173
  )
174
  # stage 2: on-band
175
  if STAGE & 2:
176
  # barrier makes it easier for compielr to schedule the
177
  # two loops independently
178
- acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
179
- start_m, qk_scale, #
180
- BLOCK_M, HEAD_DIM, BLOCK_N, #
181
- 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
 
 
182
  )
183
  # epilogue
184
- m_i += tl.math.log2(l_i)
185
  acc = acc / l_i[:, None]
186
  m_ptrs = M + off_hz * N_CTX + offs_m
187
- tl.store(m_ptrs, m_i)
188
  tl.store(O_block_ptr, acc.to(Out.type.element_ty))
189
 
190
 
191
  @triton.jit
192
- def _attn_bwd_preprocess(O, DO, #
193
- Delta, #
194
- Z, H, N_CTX, #
195
- BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr #
196
  ):
197
  off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
198
  off_hz = tl.program_id(1)
199
- off_n = tl.arange(0, HEAD_DIM)
200
- # load
201
- o = tl.load(O + off_hz * HEAD_DIM * N_CTX +
202
- off_m[:, None] * HEAD_DIM + off_n[None, :])
203
- do = tl.load(DO + off_hz * HEAD_DIM * N_CTX +
204
- off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
205
  delta = tl.sum(o * do, axis=1)
206
- # write-back
207
  tl.store(Delta + off_hz * N_CTX + off_m, delta)
208
 
209
 
210
  # The main inner-loop logic for computing dK and dV.
211
  @triton.jit
212
- def _attn_bwd_dkdv(dk, dv, #
213
- Q, k, v, sm_scale, #
214
- DO, #
215
- M, D, #
216
  # shared by Q/K/V/DO.
217
- stride_tok, stride_d, #
218
- H, N_CTX, BLOCK_M1: tl.constexpr, #
219
- BLOCK_N1: tl.constexpr, #
220
- HEAD_DIM: tl.constexpr, #
221
  # Filled in by the wrapper.
222
- start_n, start_m, num_steps, #
223
  MASK: tl.constexpr):
224
  offs_m = start_m + tl.arange(0, BLOCK_M1)
225
  offs_n = start_n + tl.arange(0, BLOCK_N1)
226
- offs_k = tl.arange(0, HEAD_DIM)
227
- qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
228
- do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
230
  tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
231
  curr_m = start_m
232
  step_m = BLOCK_M1
233
  for blk_idx in range(num_steps):
234
- qT = tl.load(qT_ptrs)
235
  # Load m before computing qk to reduce pipeline stall.
236
  offs_m = curr_m + tl.arange(0, BLOCK_M1)
237
  m = tl.load(M + offs_m)
@@ -241,7 +259,7 @@ def _attn_bwd_dkdv(dk, dv, #
241
  if MASK:
242
  mask = (offs_m[None, :] >= offs_n[:, None])
243
  pT = tl.where(mask, pT, 0.0)
244
- do = tl.load(do_ptrs)
245
  # Compute dV.
246
  ppT = pT
247
  ppT = ppT.to(tl.float16)
@@ -249,35 +267,49 @@ def _attn_bwd_dkdv(dk, dv, #
249
  # D (= delta) is pre-divided by ds_scale.
250
  Di = tl.load(D + offs_m)
251
  # Compute dP and dS.
252
- dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
253
  dsT = pT * (dpT - Di[None, :])
254
  dsT = dsT.to(tl.float16)
255
  dk += tl.dot(dsT, tl.trans(qT))
256
  # Increment pointers.
257
  curr_m += step_m
258
- qT_ptrs += step_m * stride_tok
259
- do_ptrs += step_m * stride_tok
260
  return dk, dv
261
 
262
 
263
  # the main inner-loop logic for computing dQ
264
  @triton.jit
265
- def _attn_bwd_dq(dq, q, K, V, #
266
  do, m, D,
267
  # shared by Q/K/V/DO.
268
- stride_tok, stride_d, #
269
- H, N_CTX, #
270
- BLOCK_M2: tl.constexpr, #
271
- BLOCK_N2: tl.constexpr, #
272
- HEAD_DIM: tl.constexpr,
273
  # Filled in by the wrapper.
274
- start_m, start_n, num_steps, #
275
  MASK: tl.constexpr):
276
  offs_m = start_m + tl.arange(0, BLOCK_M2)
277
  offs_n = start_n + tl.arange(0, BLOCK_N2)
278
- offs_k = tl.arange(0, HEAD_DIM)
279
- kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
280
- vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  # D (= delta) is pre-divided by ds_scale.
282
  Di = tl.load(D + offs_m)
283
  # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
@@ -285,8 +317,7 @@ def _attn_bwd_dq(dq, q, K, V, #
285
  curr_n = start_n
286
  step_n = BLOCK_N2
287
  for blk_idx in range(num_steps):
288
- kT = tl.load(kT_ptrs)
289
- vT = tl.load(vT_ptrs)
290
  qk = tl.dot(q, kT)
291
  p = tl.math.exp2(qk - m)
292
  # Autoregressive masking.
@@ -295,6 +326,7 @@ def _attn_bwd_dq(dq, q, K, V, #
295
  mask = (offs_m[:, None] >= offs_n[None, :])
296
  p = tl.where(mask, p, 0.0)
297
  # Compute dP and dS.
 
298
  dp = tl.dot(do, vT).to(tl.float32)
299
  ds = p * (dp - Di[:, None])
300
  ds = ds.to(tl.float16)
@@ -303,25 +335,49 @@ def _attn_bwd_dq(dq, q, K, V, #
303
  dq += tl.dot(ds, tl.trans(kT))
304
  # Increment pointers.
305
  curr_n += step_n
306
- kT_ptrs += step_n * stride_tok
307
- vT_ptrs += step_n * stride_tok
308
  return dq
309
 
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  @triton.jit
312
- def _attn_bwd(Q, K, V, sm_scale, #
313
- DO, #
314
- DQ, DK, DV, #
315
  M, D,
316
  # shared by Q/K/V/DO.
317
- stride_z, stride_h, stride_tok, stride_d, #
318
- H, N_CTX, #
319
- BLOCK_M1: tl.constexpr, #
320
- BLOCK_N1: tl.constexpr, #
321
- BLOCK_M2: tl.constexpr, #
322
- BLOCK_N2: tl.constexpr, #
323
- BLK_SLICE_FACTOR: tl.constexpr, #
324
- HEAD_DIM: tl.constexpr):
 
325
  LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
326
 
327
  bhid = tl.program_id(2)
@@ -340,58 +396,91 @@ def _attn_bwd(Q, K, V, sm_scale, #
340
  M += off_chz
341
  D += off_chz
342
 
343
- # load scales
344
- offs_k = tl.arange(0, HEAD_DIM)
345
 
346
  start_n = pid * BLOCK_N1
 
 
 
347
  start_m = start_n
348
 
349
  MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
350
  offs_n = start_n + tl.arange(0, BLOCK_N1)
351
 
352
- dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
353
- dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
354
 
355
- # load K and V: they stay in SRAM throughout the inner loop.
356
- k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
357
- v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
  num_steps = BLOCK_N1 // MASK_BLOCK_M1
360
 
361
- dk, dv = _attn_bwd_dkdv(dk, dv, #
362
- Q, k, v, sm_scale, #
363
- DO, #
364
- M, D, #
365
- stride_tok, stride_d, #
366
- H, N_CTX, #
367
- MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, #
368
- start_n, start_m, num_steps, #
369
- MASK=True #
370
  )
371
 
372
  start_m += num_steps * MASK_BLOCK_M1
373
  num_steps = (N_CTX - start_m) // BLOCK_M1
374
 
375
  # Compute dK and dV for non-masked blocks.
376
- dk, dv = _attn_bwd_dkdv( #
377
- dk, dv, #
378
- Q, k, v, sm_scale, #
379
- DO, #
380
- M, D, #
381
- stride_tok, stride_d, #
382
- H, N_CTX, #
383
- BLOCK_M1, BLOCK_N1, HEAD_DIM, #
384
- start_n, start_m, num_steps, #
385
- MASK=False #
386
  )
387
 
388
- dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
389
- tl.store(dv_ptrs, dv)
 
 
 
 
 
 
 
390
 
391
  # Write back dK.
392
  dk *= sm_scale
393
- dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
394
- tl.store(dk_ptrs, dk)
 
 
 
 
 
 
 
395
 
396
  # THIS BLOCK DOES DQ:
397
  start_m = pid * BLOCK_M2
@@ -400,10 +489,26 @@ def _attn_bwd(Q, K, V, sm_scale, #
400
  MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
401
  offs_m = start_m + tl.arange(0, BLOCK_M2)
402
 
403
- q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
404
- dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
405
- do = tl.load(DO + offs_m[:, None] * stride_tok +
406
- offs_k[None, :] * stride_d)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
  m = tl.load(M + offs_m)
409
  m = m[:, None]
@@ -414,29 +519,39 @@ def _attn_bwd(Q, K, V, sm_scale, #
414
  # not due to anything important. I just wanted to reuse the loop
415
  # structure for dK & dV above as much as possible.
416
  num_steps = BLOCK_M2 // MASK_BLOCK_N2
417
- dq = _attn_bwd_dq(dq, q, K, V, #
418
- do, m, D, #
419
- stride_tok, stride_d, #
420
- H, N_CTX, #
421
- BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, #
422
- start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #
423
- MASK=True #
424
  )
425
  end_n -= num_steps * MASK_BLOCK_N2
426
  # stage 2
427
  num_steps = end_n // BLOCK_N2
428
- dq = _attn_bwd_dq(dq, q, K, V, #
429
- do, m, D, #
430
- stride_tok, stride_d, #
431
- H, N_CTX, #
432
- BLOCK_M2, BLOCK_N2, HEAD_DIM, #
433
- start_m, end_n - num_steps * BLOCK_N2, num_steps, #
434
- MASK=False #
435
  )
436
  # Write back dQ.
437
- dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
 
 
 
 
 
 
 
438
  dq *= LN2
439
- tl.store(dq_ptrs, dq)
 
 
 
440
 
441
 
442
  class _attention(torch.autograd.Function):
@@ -444,45 +559,58 @@ class _attention(torch.autograd.Function):
444
  @staticmethod
445
  def forward(ctx, q, k, v, causal, sm_scale):
446
  # shape constraints
447
- HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
448
- # when v is in float8_e5m2 it is transposed.
449
- HEAD_DIM_V = v.shape[-2] if v.dtype == torch.float8_e5m2 else v.shape[-1]
450
- assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
451
- assert HEAD_DIM_K in {16, 32, 64, 128, 256}
452
- o = torch.empty_like(q)
 
 
 
 
 
 
 
453
  stage = 3 if causal else 1
454
- extra_kern_args = {}
455
- # Tuning for AMD target
456
- if is_hip():
457
- waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
458
- extra_kern_args = {"waves_per_eu": waves_per_eu,
459
- "allow_flush_denorm": True}
460
-
461
- def grid(args): return (triton.cdiv(
462
- q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
463
- M = torch.empty((q.shape[0], q.shape[1], q.shape[2]),
464
  device=q.device, dtype=torch.float32)
465
  _attn_fwd[grid](
466
- q, k, v, sm_scale, M, o, #
467
- q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
468
- k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
469
- v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
470
- o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
471
- q.shape[0], q.shape[1], #
472
- N_CTX=q.shape[2], #
473
- HEAD_DIM=HEAD_DIM_K, #
474
- STAGE=stage, #
475
- **extra_kern_args)
 
 
 
 
 
476
 
477
  ctx.save_for_backward(q, k, v, o, M)
478
  ctx.grid = grid
479
  ctx.sm_scale = sm_scale
480
- ctx.HEAD_DIM = HEAD_DIM_K
481
  ctx.causal = causal
482
  return o
483
 
484
  @staticmethod
485
  def backward(ctx, do):
 
 
 
 
486
  q, k, v, o, M = ctx.saved_tensors
487
  assert do.is_contiguous()
488
  assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
@@ -491,34 +619,33 @@ class _attention(torch.autograd.Function):
491
  dv = torch.empty_like(v)
492
  BATCH, N_HEAD, N_CTX = q.shape[:3]
493
  PRE_BLOCK = 128
494
- NUM_WARPS, NUM_STAGES = 4, 5
495
- BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
496
  BLK_SLICE_FACTOR = 2
497
  RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
498
  arg_k = k
499
  arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
500
- PRE_BLOCK = 128
501
  assert N_CTX % PRE_BLOCK == 0
502
  pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
503
  delta = torch.empty_like(M)
504
  _attn_bwd_preprocess[pre_grid](
505
- o, do, #
506
- delta, #
507
- BATCH, N_HEAD, N_CTX, #
508
- BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM #
 
 
 
 
 
 
509
  )
510
- grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
511
  _attn_bwd[grid](
512
- q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #
513
- M, delta, #
514
- q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
515
- N_HEAD, N_CTX, #
516
- BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
517
- BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
518
- BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
519
- HEAD_DIM=ctx.HEAD_DIM, #
520
- num_warps=NUM_WARPS, #
521
- num_stages=NUM_STAGES #
522
  )
523
 
524
  return dq, dk, dv, None, None
 
11
 
12
  """
13
 
14
+ import pytest
15
  import torch
16
 
17
  import triton
18
  import triton.language as tl
19
 
20
+ # Pick the fp8 data type
21
 
22
+ # AMD E4M3B8
23
+ # Note: When picking this f8 data type, scaling is required when using f8
24
+ # for the second gemm
25
+ # TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
26
+
27
+ # AMD E5M2B16
28
+ TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz')
29
 
30
 
31
  @triton.jit
32
+ def _attn_fwd_inner(acc, l_i, m_i, q,
33
+ K_block_ptr, V_block_ptr,
34
+ start_m,
35
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
36
+ STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
37
+ N_CTX,
38
+ pre_load_v: tl.constexpr):
39
  # range of values handled by this stage
40
  if STAGE == 1:
41
  lo, hi = 0, start_m * BLOCK_M
42
  elif STAGE == 2:
43
  lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
44
  lo = tl.multiple_of(lo, BLOCK_M)
45
+ K_block_ptr = tl.advance(K_block_ptr, (0, lo))
46
+ V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
47
  # causal = False
48
  else:
49
  lo, hi = 0, N_CTX
 
 
50
  # loop over k, v and update accumulator
51
  for start_n in range(lo, hi, BLOCK_N):
52
  start_n = tl.multiple_of(start_n, BLOCK_N)
53
  # -- compute qk ----
54
  k = tl.load(K_block_ptr)
55
+ if pre_load_v:
56
+ v = tl.load(V_block_ptr)
57
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
58
  if STAGE == 2:
59
  mask = offs_m[:, None] >= (start_n + offs_n[None, :])
60
+ qk = tl.where(mask, qk, float("-inf"))
61
+ qk += tl.dot(q, k)
62
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
63
+ qk = qk - m_ij[:, None]
 
 
64
  p = tl.math.exp2(qk)
 
 
 
 
65
  # -- update output accumulator --
66
+ alpha = tl.math.exp2(m_i - m_ij)
67
  acc = acc * alpha[:, None]
68
+ if not pre_load_v:
69
+ v = tl.load(V_block_ptr)
70
+ acc += tl.dot(p.to(v.dtype), v)
71
+ # -- update m_i and l_i
72
+ l_ij = tl.sum(p, 1)
73
+ l_i = l_i * alpha + l_ij
 
74
  # update m_i and l_i
75
  m_i = m_ij
76
  V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
 
78
  return acc, l_i, m_i
79
 
80
 
81
+ # We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting
82
  # the code below and commenting out the equivalent parameters is convenient for
83
  # re-tuning.
84
+ @triton.autotune(
85
+ configs=[
86
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 2,
87
+ 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=2),
88
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 2,
89
+ 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=2),
90
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2,
91
+ 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=1),
92
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2,
93
+ 'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=1),
94
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 2,
95
+ 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=2),
96
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 3,
97
+ 'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=1),
98
+ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 3,
99
+ 'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=1),
100
+ ],
101
+ key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
102
+ )
103
  @triton.jit
104
+ def _attn_fwd(Q, K, V, sm_scale, M, Out,
105
+ stride_qz, stride_qh, stride_qm, stride_qk,
106
+ stride_kz, stride_kh, stride_kn, stride_kk,
107
+ stride_vz, stride_vh, stride_vk, stride_vn,
108
+ stride_oz, stride_oh, stride_om, stride_on,
109
+ Z, H,
110
+ N_CTX,
111
+ BLOCK_DMODEL: tl.constexpr,
112
+ STAGE: tl.constexpr,
113
+ BLOCK_M: tl.constexpr,
114
+ BLOCK_N: tl.constexpr,
115
+ pre_load_v: tl.constexpr,
116
  ):
 
117
  start_m = tl.program_id(0)
118
  off_hz = tl.program_id(1)
119
+ qvk_offset = off_hz * stride_qh
 
 
 
120
 
121
  # block pointers
122
  Q_block_ptr = tl.make_block_ptr(
123
  base=Q + qvk_offset,
124
+ shape=(N_CTX, BLOCK_DMODEL),
125
  strides=(stride_qm, stride_qk),
126
  offsets=(start_m * BLOCK_M, 0),
127
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
128
  order=(1, 0),
129
  )
 
 
130
  V_block_ptr = tl.make_block_ptr(
131
  base=V + qvk_offset,
132
+ shape=(N_CTX, BLOCK_DMODEL),
133
  strides=(stride_vk, stride_vn),
134
  offsets=(0, 0),
135
+ block_shape=(BLOCK_N, BLOCK_DMODEL),
136
+ order=(1, 0),
137
  )
138
  K_block_ptr = tl.make_block_ptr(
139
  base=K + qvk_offset,
140
+ shape=(BLOCK_DMODEL, N_CTX),
141
  strides=(stride_kk, stride_kn),
142
  offsets=(0, 0),
143
+ block_shape=(BLOCK_DMODEL, BLOCK_N),
144
  order=(0, 1),
145
  )
146
  O_block_ptr = tl.make_block_ptr(
147
  base=Out + qvk_offset,
148
+ shape=(N_CTX, BLOCK_DMODEL),
149
  strides=(stride_om, stride_on),
150
  offsets=(start_m * BLOCK_M, 0),
151
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
152
  order=(1, 0),
153
  )
154
  # initialize offsets
 
157
  # initialize pointer to m and l
158
  m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
159
  l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
160
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
161
+ # scale sm_scale by log_2(e) and use
162
+ # 2^x instead of exp in the loop because CSE and LICM
163
+ # don't work as expected with `exp` in the loop
164
+ qk_scale = sm_scale * 1.44269504
165
+ # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
166
  q = tl.load(Q_block_ptr)
167
+ q = (q * qk_scale).to(q.dtype)
168
  # stage 1: off-band
169
  # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
170
  # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
171
  if STAGE & 1:
172
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
173
+ start_m,
174
+ BLOCK_M, BLOCK_DMODEL, BLOCK_N,
175
+ 4 - STAGE, offs_m, offs_n, N_CTX,
176
+ pre_load_v,
177
  )
178
  # stage 2: on-band
179
  if STAGE & 2:
180
  # barrier makes it easier for compielr to schedule the
181
  # two loops independently
182
+ tl.debug_barrier()
183
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
184
+ start_m,
185
+ BLOCK_M, BLOCK_DMODEL, BLOCK_N,
186
+ 2, offs_m, offs_n, N_CTX,
187
+ pre_load_v,
188
  )
189
  # epilogue
190
+ # write back m
191
  acc = acc / l_i[:, None]
192
  m_ptrs = M + off_hz * N_CTX + offs_m
193
+ tl.store(m_ptrs, m_i + tl.math.log2(l_i))
194
  tl.store(O_block_ptr, acc.to(Out.type.element_ty))
195
 
196
 
197
  @triton.jit
198
+ def _attn_bwd_preprocess(O, DO,
199
+ Delta,
200
+ Z, H, N_CTX,
201
+ BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
202
  ):
203
  off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
204
  off_hz = tl.program_id(1)
205
+ off_n = tl.arange(0, D_HEAD)
206
+ o = tl.load(O + off_hz * D_HEAD * N_CTX +
207
+ off_m[:, None] * D_HEAD + off_n[None, :])
208
+ do = tl.load(DO + off_hz * D_HEAD * N_CTX +
209
+ off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
 
210
  delta = tl.sum(o * do, axis=1)
 
211
  tl.store(Delta + off_hz * N_CTX + off_m, delta)
212
 
213
 
214
  # The main inner-loop logic for computing dK and dV.
215
  @triton.jit
216
+ def _attn_bwd_dkdv(dk, dv,
217
+ Q, k, v, sm_scale,
218
+ DO,
219
+ M, D,
220
  # shared by Q/K/V/DO.
221
+ stride_tok, stride_d,
222
+ H, N_CTX, BLOCK_M1: tl.constexpr,
223
+ BLOCK_N1: tl.constexpr,
224
+ BLOCK_DMODEL: tl.constexpr,
225
  # Filled in by the wrapper.
226
+ start_n, start_m, num_steps,
227
  MASK: tl.constexpr):
228
  offs_m = start_m + tl.arange(0, BLOCK_M1)
229
  offs_n = start_n + tl.arange(0, BLOCK_N1)
230
+ offs_k = tl.arange(0, BLOCK_DMODEL)
231
+ QT_block_ptr = tl.make_block_ptr(
232
+ base=Q,
233
+ shape=(BLOCK_DMODEL, N_CTX),
234
+ strides=(stride_d, stride_tok),
235
+ offsets=(0, start_m),
236
+ block_shape=(BLOCK_DMODEL, BLOCK_M1),
237
+ order=(0, 1)
238
+ )
239
+ DO_block_ptr = tl.make_block_ptr(
240
+ base=DO,
241
+ shape=(N_CTX, BLOCK_DMODEL),
242
+ strides=(stride_tok, stride_d),
243
+ offsets=(start_m, 0),
244
+ block_shape=(BLOCK_M1, BLOCK_DMODEL),
245
+ order=(1, 0)
246
+ )
247
  # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
248
  tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
249
  curr_m = start_m
250
  step_m = BLOCK_M1
251
  for blk_idx in range(num_steps):
252
+ qT = tl.load(QT_block_ptr)
253
  # Load m before computing qk to reduce pipeline stall.
254
  offs_m = curr_m + tl.arange(0, BLOCK_M1)
255
  m = tl.load(M + offs_m)
 
259
  if MASK:
260
  mask = (offs_m[None, :] >= offs_n[:, None])
261
  pT = tl.where(mask, pT, 0.0)
262
+ do = tl.load(DO_block_ptr)
263
  # Compute dV.
264
  ppT = pT
265
  ppT = ppT.to(tl.float16)
 
267
  # D (= delta) is pre-divided by ds_scale.
268
  Di = tl.load(D + offs_m)
269
  # Compute dP and dS.
270
+ dpT = tl.dot(v, tl.trans(do))
271
  dsT = pT * (dpT - Di[None, :])
272
  dsT = dsT.to(tl.float16)
273
  dk += tl.dot(dsT, tl.trans(qT))
274
  # Increment pointers.
275
  curr_m += step_m
276
+ QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m))
277
+ DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0))
278
  return dk, dv
279
 
280
 
281
  # the main inner-loop logic for computing dQ
282
  @triton.jit
283
+ def _attn_bwd_dq(dq, q, K, V,
284
  do, m, D,
285
  # shared by Q/K/V/DO.
286
+ stride_tok, stride_d,
287
+ H, N_CTX,
288
+ BLOCK_M2: tl.constexpr,
289
+ BLOCK_N2: tl.constexpr,
290
+ BLOCK_DMODEL: tl.constexpr,
291
  # Filled in by the wrapper.
292
+ start_m, start_n, num_steps,
293
  MASK: tl.constexpr):
294
  offs_m = start_m + tl.arange(0, BLOCK_M2)
295
  offs_n = start_n + tl.arange(0, BLOCK_N2)
296
+ offs_k = tl.arange(0, BLOCK_DMODEL)
297
+ KT_block_ptr = tl.make_block_ptr(
298
+ base=K,
299
+ shape=(BLOCK_DMODEL, N_CTX),
300
+ strides=(stride_d, stride_tok),
301
+ offsets=(0, start_n),
302
+ block_shape=(BLOCK_DMODEL, BLOCK_N2),
303
+ order=(0, 1)
304
+ )
305
+ VT_block_ptr = tl.make_block_ptr(
306
+ base=V,
307
+ shape=(BLOCK_DMODEL, N_CTX),
308
+ strides=(stride_d, stride_tok),
309
+ offsets=(0, start_n),
310
+ block_shape=(BLOCK_DMODEL, BLOCK_N2),
311
+ order=(0, 1)
312
+ )
313
  # D (= delta) is pre-divided by ds_scale.
314
  Di = tl.load(D + offs_m)
315
  # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
 
317
  curr_n = start_n
318
  step_n = BLOCK_N2
319
  for blk_idx in range(num_steps):
320
+ kT = tl.load(KT_block_ptr)
 
321
  qk = tl.dot(q, kT)
322
  p = tl.math.exp2(qk - m)
323
  # Autoregressive masking.
 
326
  mask = (offs_m[:, None] >= offs_n[None, :])
327
  p = tl.where(mask, p, 0.0)
328
  # Compute dP and dS.
329
+ vT = tl.load(VT_block_ptr)
330
  dp = tl.dot(do, vT).to(tl.float32)
331
  ds = p * (dp - Di[:, None])
332
  ds = ds.to(tl.float16)
 
335
  dq += tl.dot(ds, tl.trans(kT))
336
  # Increment pointers.
337
  curr_n += step_n
338
+ KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n))
339
+ VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))
340
  return dq
341
 
342
 
343
+ @triton.autotune(
344
+ configs=[
345
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
346
+ num_stages=1, num_warps=4),
347
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
348
+ num_stages=1, num_warps=4),
349
+ triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
350
+ num_stages=1, num_warps=4),
351
+ triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
352
+ num_stages=1, num_warps=4),
353
+ triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
354
+ num_stages=1, num_warps=4),
355
+ triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
356
+ num_stages=1, num_warps=4),
357
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
358
+ num_stages=1, num_warps=4),
359
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
360
+ num_stages=1, num_warps=4),
361
+ triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
362
+ num_stages=1, num_warps=8),
363
+ ],
364
+ key=['H', 'N_CTX', 'BLOCK_DMODEL'],
365
+ )
366
  @triton.jit
367
+ def _attn_bwd(Q, K, V, sm_scale,
368
+ DO,
369
+ DQ, DK, DV,
370
  M, D,
371
  # shared by Q/K/V/DO.
372
+ stride_z, stride_h, stride_tok, stride_d,
373
+ # H = 16, N_CTX = 1024
374
+ H, N_CTX,
375
+ BLOCK_DMODEL: tl.constexpr,
376
+ BLOCK_M1: tl.constexpr,
377
+ BLOCK_N1: tl.constexpr,
378
+ BLOCK_M2: tl.constexpr,
379
+ BLOCK_N2: tl.constexpr,
380
+ BLK_SLICE_FACTOR: tl.constexpr):
381
  LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
382
 
383
  bhid = tl.program_id(2)
 
396
  M += off_chz
397
  D += off_chz
398
 
399
+ offs_k = tl.arange(0, BLOCK_DMODEL)
 
400
 
401
  start_n = pid * BLOCK_N1
402
+ # This assignment is important. It is what allows us to pick the diagonal
403
+ # blocks. Later, when we want to do the lower triangular, we update start_m
404
+ # after the first dkdv call.
405
  start_m = start_n
406
 
407
  MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
408
  offs_n = start_n + tl.arange(0, BLOCK_N1)
409
 
410
+ dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
411
+ dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
412
 
413
+ K_block_ptr = tl.make_block_ptr(
414
+ base=K,
415
+ shape=(N_CTX, BLOCK_DMODEL),
416
+ strides=(stride_tok, stride_d),
417
+ offsets=(start_n, 0),
418
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
419
+ order=(1, 0),
420
+ )
421
+ V_block_ptr = tl.make_block_ptr(
422
+ base=V,
423
+ shape=(N_CTX, BLOCK_DMODEL),
424
+ strides=(stride_tok, stride_d),
425
+ offsets=(start_n, 0),
426
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
427
+ order=(1, 0),
428
+ )
429
+
430
+ # load K and V: they stay in SRAM throughout the inner loop for dkdv.
431
+ k = tl.load(K_block_ptr)
432
+ v = tl.load(V_block_ptr)
433
 
434
  num_steps = BLOCK_N1 // MASK_BLOCK_M1
435
 
436
+ dk, dv = _attn_bwd_dkdv(dk, dv,
437
+ Q, k, v, sm_scale,
438
+ DO,
439
+ M, D,
440
+ stride_tok, stride_d,
441
+ H, N_CTX,
442
+ MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
443
+ start_n, start_m, num_steps,
444
+ MASK=True
445
  )
446
 
447
  start_m += num_steps * MASK_BLOCK_M1
448
  num_steps = (N_CTX - start_m) // BLOCK_M1
449
 
450
  # Compute dK and dV for non-masked blocks.
451
+ dk, dv = _attn_bwd_dkdv(
452
+ dk, dv,
453
+ Q, k, v, sm_scale,
454
+ DO,
455
+ M, D,
456
+ stride_tok, stride_d,
457
+ H, N_CTX,
458
+ BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
459
+ start_n, start_m, num_steps,
460
+ MASK=False
461
  )
462
 
463
+ DV_block_ptrs = tl.make_block_ptr(
464
+ base=DV,
465
+ shape=(N_CTX, BLOCK_DMODEL),
466
+ strides=(stride_tok, stride_d),
467
+ offsets=(start_n, 0),
468
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
469
+ order=(1, 0)
470
+ )
471
+ tl.store(DV_block_ptrs, dv.to(tl.float16))
472
 
473
  # Write back dK.
474
  dk *= sm_scale
475
+ DK_block_ptrs = tl.make_block_ptr(
476
+ base=DK,
477
+ shape=(N_CTX, BLOCK_DMODEL),
478
+ strides=(stride_tok, stride_d),
479
+ offsets=(start_n, 0),
480
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
481
+ order=(1, 0)
482
+ )
483
+ tl.store(DK_block_ptrs, dk.to(tl.float16))
484
 
485
  # THIS BLOCK DOES DQ:
486
  start_m = pid * BLOCK_M2
 
489
  MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
490
  offs_m = start_m + tl.arange(0, BLOCK_M2)
491
 
492
+ Q_block_ptr = tl.make_block_ptr(
493
+ base=Q,
494
+ shape=(N_CTX, BLOCK_DMODEL),
495
+ strides=(stride_tok, stride_d),
496
+ offsets=(start_m, 0),
497
+ block_shape=(BLOCK_M2, BLOCK_DMODEL),
498
+ order=(1, 0)
499
+ )
500
+
501
+ DO_block_ptr = tl.make_block_ptr(
502
+ base=DO,
503
+ shape=(N_CTX, BLOCK_DMODEL),
504
+ strides=(stride_tok, stride_d),
505
+ offsets=(start_m, 0),
506
+ block_shape=(BLOCK_M2, BLOCK_DMODEL),
507
+ order=(1, 0)
508
+ )
509
+ q = tl.load(Q_block_ptr)
510
+ do = tl.load(DO_block_ptr)
511
+ dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
512
 
513
  m = tl.load(M + offs_m)
514
  m = m[:, None]
 
519
  # not due to anything important. I just wanted to reuse the loop
520
  # structure for dK & dV above as much as possible.
521
  num_steps = BLOCK_M2 // MASK_BLOCK_N2
522
+ dq = _attn_bwd_dq(dq, q, K, V,
523
+ do, m, D,
524
+ stride_tok, stride_d,
525
+ H, N_CTX,
526
+ BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
527
+ start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
528
+ MASK=True
529
  )
530
  end_n -= num_steps * MASK_BLOCK_N2
531
  # stage 2
532
  num_steps = end_n // BLOCK_N2
533
+ dq = _attn_bwd_dq(dq, q, K, V,
534
+ do, m, D,
535
+ stride_tok, stride_d,
536
+ H, N_CTX,
537
+ BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
538
+ start_m, end_n - num_steps * BLOCK_N2, num_steps,
539
+ MASK=False
540
  )
541
  # Write back dQ.
542
+ DQ_block_ptr = tl.make_block_ptr(
543
+ base=DQ,
544
+ shape=(N_CTX, BLOCK_DMODEL),
545
+ strides=(stride_tok, stride_d),
546
+ offsets=(start_m, 0),
547
+ block_shape=(BLOCK_M2, BLOCK_DMODEL),
548
+ order=(1, 0)
549
+ )
550
  dq *= LN2
551
+ tl.store(DQ_block_ptr, dq.to(tl.float16))
552
+
553
+
554
+ empty = torch.empty(128, device="cuda")
555
 
556
 
557
  class _attention(torch.autograd.Function):
 
559
  @staticmethod
560
  def forward(ctx, q, k, v, causal, sm_scale):
561
  # shape constraints
562
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
563
+ assert Lq == Lk and Lk == Lv
564
+ assert Lk in {16, 32, 64, 128}
565
+ o = torch.empty_like(q, dtype=v.dtype)
566
+ if torch.version.hip is None:
567
+ BLOCK_M = 128
568
+ BLOCK_N = 64 if Lk <= 64 else 32
569
+ num_stages = 4 if Lk <= 64 else 3
570
+ num_warps = 4 if Lk <= 64 else 8
571
+ # Tuning for H100
572
+ if torch.cuda.get_device_capability()[0] == 9:
573
+ num_warps = 8
574
+ num_stages = 7 if Lk >= 64 else 3
575
  stage = 3 if causal else 1
576
+
577
+ def grid(META): return (
578
+ triton.cdiv(q.shape[2], META['BLOCK_M']),
579
+ q.shape[0] * q.shape[1],
580
+ 1
581
+ )
582
+ M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]),
 
 
 
583
  device=q.device, dtype=torch.float32)
584
  _attn_fwd[grid](
585
+ q, k, v, sm_scale, M, o,
586
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
587
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
588
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
589
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
590
+ q.shape[0], q.shape[1],
591
+ N_CTX=q.shape[2],
592
+ BLOCK_DMODEL=Lk,
593
+ STAGE=stage,
594
+ )
595
+
596
+ # restore the grid for bwd kernel
597
+ best_config = _attn_fwd.get_best_config()
598
+ block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
599
+ grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)
600
 
601
  ctx.save_for_backward(q, k, v, o, M)
602
  ctx.grid = grid
603
  ctx.sm_scale = sm_scale
604
+ ctx.BLOCK_DMODEL = Lk
605
  ctx.causal = causal
606
  return o
607
 
608
  @staticmethod
609
  def backward(ctx, do):
610
+ if torch.version.hip is not None:
611
+ BLOCK = 64
612
+ else:
613
+ BLOCK = 128
614
  q, k, v, o, M = ctx.saved_tensors
615
  assert do.is_contiguous()
616
  assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
 
619
  dv = torch.empty_like(v)
620
  BATCH, N_HEAD, N_CTX = q.shape[:3]
621
  PRE_BLOCK = 128
622
+ NUM_WARPS, NUM_STAGES = 4, 1
623
+ BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
624
  BLK_SLICE_FACTOR = 2
625
  RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
626
  arg_k = k
627
  arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
 
628
  assert N_CTX % PRE_BLOCK == 0
629
  pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
630
  delta = torch.empty_like(M)
631
  _attn_bwd_preprocess[pre_grid](
632
+ o, do,
633
+ delta,
634
+ BATCH, N_HEAD, N_CTX,
635
+ BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL
636
+ )
637
+
638
+ def grid(META): return (
639
+ triton.cdiv(N_CTX, META['BLOCK_N1']),
640
+ 1,
641
+ BATCH * N_HEAD
642
  )
 
643
  _attn_bwd[grid](
644
+ q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,
645
+ M, delta,
646
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
647
+ N_HEAD, N_CTX,
648
+ BLOCK_DMODEL=ctx.BLOCK_DMODEL
 
 
 
 
 
649
  )
650
 
651
  return dq, dk, dv, None, None