GeeeekExplorer commited on
Commit
74b7f11
·
verified ·
1 Parent(s): 1b79fbe

Update inference/kernel.py

Browse files
Files changed (1) hide show
  1. inference/kernel.py +0 -78
inference/kernel.py CHANGED
@@ -194,81 +194,3 @@ def fp8_gemm(
194
  kernel = fp8_gemm_kernel(N, K)
195
  kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
196
  return c
197
-
198
-
199
- @tilelang.jit(out_idx=[4], pass_configs=pass_configs)
200
- def fp8_index_kernel(h: int, d: int):
201
- b = T.symbolic("b")
202
- m = T.symbolic("m")
203
- n = T.symbolic("n")
204
-
205
- blk_n1 = 512
206
- blk_n2 = 128
207
-
208
- @T.prim_func
209
- def fp8_index_kernel_(
210
- q: T.Tensor[(b, m, h, d), FP8],
211
- q_s: T.Tensor[(b, m, h), FP32],
212
- k: T.Tensor[(b, n, d), FP8],
213
- k_s: T.Tensor[(b, n), FP32],
214
- o: T.Tensor[(b, m, n), FP32],
215
- ) -> None:
216
- with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
217
- q_smem = T.alloc_shared((h, d), FP8)
218
- T.copy(q[i_b, i_m, 0, 0], q_smem)
219
-
220
- q_s_frag = T.alloc_fragment(h, FP32)
221
- T.copy(q_s[i_b, i_m, 0], q_s_frag)
222
-
223
- for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
224
- k_smem = T.alloc_shared((blk_n2, d), FP8)
225
- T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
226
-
227
- k_s_frag = T.alloc_fragment(blk_n2, FP32)
228
- T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
229
-
230
- logits = T.alloc_fragment((blk_n2, h), FP32)
231
- T.gemm(
232
- k_smem,
233
- q_smem,
234
- logits,
235
- transpose_A=False,
236
- transpose_B=True,
237
- clear_accum=True,
238
- )
239
-
240
- for i_h, i3_n in T.Parallel(h, blk_n2):
241
- logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
242
-
243
- logits_sum = T.alloc_fragment(blk_n2, FP32)
244
- T.reduce_sum(logits, logits_sum, dim=1)
245
-
246
- for i3_n in T.Parallel(blk_n2):
247
- logits_sum[i3_n] *= k_s_frag[i3_n]
248
-
249
- T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
250
-
251
- return fp8_index_kernel_
252
-
253
-
254
- def fp8_index(
255
- q: torch.Tensor,
256
- q_s: torch.Tensor,
257
- k: torch.Tensor,
258
- k_s: torch.Tensor,
259
- ) -> torch.Tensor:
260
- """
261
- Perform index score using FP8 precision.
262
-
263
- Args:
264
- q (torch.Tensor): The Q tensor, must be contiguous.
265
- q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
266
- k (torch.Tensor): The K tensor, must be contiguous.
267
- k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
268
-
269
- fp8 q @ fp8 k -> fp32 logits
270
- relu(fp32 logits) * q_s (weights) -> fp32 logits
271
- fp32 logits -> fp32 logits_sum
272
- fp32 logits_sum * k_s (e8m0) -> fp32 index_score
273
- """
274
- return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
 
194
  kernel = fp8_gemm_kernel(N, K)
195
  kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
196
  return c