Update inference/kernel.py
Browse files- inference/kernel.py +0 -78
inference/kernel.py
CHANGED
|
@@ -194,81 +194,3 @@ def fp8_gemm(
|
|
| 194 |
kernel = fp8_gemm_kernel(N, K)
|
| 195 |
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
|
| 196 |
return c
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
|
| 200 |
-
def fp8_index_kernel(h: int, d: int):
|
| 201 |
-
b = T.symbolic("b")
|
| 202 |
-
m = T.symbolic("m")
|
| 203 |
-
n = T.symbolic("n")
|
| 204 |
-
|
| 205 |
-
blk_n1 = 512
|
| 206 |
-
blk_n2 = 128
|
| 207 |
-
|
| 208 |
-
@T.prim_func
|
| 209 |
-
def fp8_index_kernel_(
|
| 210 |
-
q: T.Tensor[(b, m, h, d), FP8],
|
| 211 |
-
q_s: T.Tensor[(b, m, h), FP32],
|
| 212 |
-
k: T.Tensor[(b, n, d), FP8],
|
| 213 |
-
k_s: T.Tensor[(b, n), FP32],
|
| 214 |
-
o: T.Tensor[(b, m, n), FP32],
|
| 215 |
-
) -> None:
|
| 216 |
-
with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
|
| 217 |
-
q_smem = T.alloc_shared((h, d), FP8)
|
| 218 |
-
T.copy(q[i_b, i_m, 0, 0], q_smem)
|
| 219 |
-
|
| 220 |
-
q_s_frag = T.alloc_fragment(h, FP32)
|
| 221 |
-
T.copy(q_s[i_b, i_m, 0], q_s_frag)
|
| 222 |
-
|
| 223 |
-
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
|
| 224 |
-
k_smem = T.alloc_shared((blk_n2, d), FP8)
|
| 225 |
-
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
|
| 226 |
-
|
| 227 |
-
k_s_frag = T.alloc_fragment(blk_n2, FP32)
|
| 228 |
-
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
|
| 229 |
-
|
| 230 |
-
logits = T.alloc_fragment((blk_n2, h), FP32)
|
| 231 |
-
T.gemm(
|
| 232 |
-
k_smem,
|
| 233 |
-
q_smem,
|
| 234 |
-
logits,
|
| 235 |
-
transpose_A=False,
|
| 236 |
-
transpose_B=True,
|
| 237 |
-
clear_accum=True,
|
| 238 |
-
)
|
| 239 |
-
|
| 240 |
-
for i_h, i3_n in T.Parallel(h, blk_n2):
|
| 241 |
-
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
|
| 242 |
-
|
| 243 |
-
logits_sum = T.alloc_fragment(blk_n2, FP32)
|
| 244 |
-
T.reduce_sum(logits, logits_sum, dim=1)
|
| 245 |
-
|
| 246 |
-
for i3_n in T.Parallel(blk_n2):
|
| 247 |
-
logits_sum[i3_n] *= k_s_frag[i3_n]
|
| 248 |
-
|
| 249 |
-
T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
|
| 250 |
-
|
| 251 |
-
return fp8_index_kernel_
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
def fp8_index(
|
| 255 |
-
q: torch.Tensor,
|
| 256 |
-
q_s: torch.Tensor,
|
| 257 |
-
k: torch.Tensor,
|
| 258 |
-
k_s: torch.Tensor,
|
| 259 |
-
) -> torch.Tensor:
|
| 260 |
-
"""
|
| 261 |
-
Perform index score using FP8 precision.
|
| 262 |
-
|
| 263 |
-
Args:
|
| 264 |
-
q (torch.Tensor): The Q tensor, must be contiguous.
|
| 265 |
-
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
|
| 266 |
-
k (torch.Tensor): The K tensor, must be contiguous.
|
| 267 |
-
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
|
| 268 |
-
|
| 269 |
-
fp8 q @ fp8 k -> fp32 logits
|
| 270 |
-
relu(fp32 logits) * q_s (weights) -> fp32 logits
|
| 271 |
-
fp32 logits -> fp32 logits_sum
|
| 272 |
-
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
|
| 273 |
-
"""
|
| 274 |
-
return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
|
|
|
|
| 194 |
kernel = fp8_gemm_kernel(N, K)
|
| 195 |
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
|
| 196 |
return c
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|