EricB's picture
EricB HF Staff
Add metal paged attention
ed30f9d
#include "float8.metal"
#include "utils.metal"
#include <metal_stdlib>
using namespace metal;
// Convert between different precision formats for cache tensors
// This kernel handles conversions like float->fp8, fp8->float, etc.
template <typename SRC_T, typename DST_T>
[[kernel]] void convert_fp8_kernel(
const device SRC_T *__restrict__ src [[buffer(0)]],
device DST_T *__restrict__ dst [[buffer(1)]],
const device float &scale [[buffer(2)]],
const device uint32_t &num_elements [[buffer(3)]],
uint gid [[thread_position_in_grid]]) {
if (gid >= num_elements) {
return;
}
// Load source value
SRC_T src_val = src[gid];
// Convert based on source and destination types
if constexpr (is_same_v<SRC_T, uchar> && !is_same_v<DST_T, uchar>) {
// FP8 -> higher precision (dequantization)
float fp32_val = fp8_e4m3_to_float(src_val) * scale;
dst[gid] = static_cast<DST_T>(fp32_val);
} else if constexpr (!is_same_v<SRC_T, uchar> && is_same_v<DST_T, uchar>) {
// Higher precision -> FP8 (quantization)
float fp32_val = static_cast<float>(src_val) / scale;
dst[gid] = float_to_fp8_e4m3(fp32_val);
} else if constexpr (is_same_v<SRC_T, uchar> && is_same_v<DST_T, uchar>) {
// FP8 -> FP8 (with rescaling)
float fp32_val = fp8_e4m3_to_float(src_val) * scale;
dst[gid] = float_to_fp8_e4m3(fp32_val);
} else {
// Regular precision -> regular precision (with scaling)
float fp32_val = static_cast<float>(src_val) * scale;
dst[gid] = static_cast<DST_T>(fp32_val);
}
}
// Instantiate all required combinations
#define INSTANTIATE_CONVERT_FP8(src_type, dst_type) \
template [[host_name("convert_fp8_" #src_type "_to_" #dst_type)]] \
[[kernel]] void convert_fp8_kernel<src_type, dst_type>( \
const device src_type *__restrict__ src [[buffer(0)]], \
device dst_type *__restrict__ dst [[buffer(1)]], \
const device float &scale [[buffer(2)]], \
const device uint32_t &num_elements [[buffer(3)]], \
uint gid [[thread_position_in_grid]]);
// FP8 to other formats (dequantization)
INSTANTIATE_CONVERT_FP8(uchar, float);
INSTANTIATE_CONVERT_FP8(uchar, half);
INSTANTIATE_CONVERT_FP8(uchar, bfloat16_t);
// Other formats to FP8 (quantization)
INSTANTIATE_CONVERT_FP8(float, uchar);
INSTANTIATE_CONVERT_FP8(half, uchar);
INSTANTIATE_CONVERT_FP8(bfloat16_t, uchar);
// FP8 to FP8 (rescaling)
INSTANTIATE_CONVERT_FP8(uchar, uchar);
// Regular precision conversions with scaling
INSTANTIATE_CONVERT_FP8(float, float);
INSTANTIATE_CONVERT_FP8(float, half);
INSTANTIATE_CONVERT_FP8(float, bfloat16_t);
INSTANTIATE_CONVERT_FP8(half, float);
INSTANTIATE_CONVERT_FP8(half, half);
INSTANTIATE_CONVERT_FP8(half, bfloat16_t);
INSTANTIATE_CONVERT_FP8(bfloat16_t, float);
INSTANTIATE_CONVERT_FP8(bfloat16_t, half);
INSTANTIATE_CONVERT_FP8(bfloat16_t, bfloat16_t);