|
#include "float8.metal" |
|
#include "utils.metal" |
|
#include <metal_stdlib> |
|
|
|
using namespace metal; |
|
|
|
// Convert between different precision formats for cache tensors |
|
// This kernel handles conversions like float->fp8, fp8->float, etc. |
|
|
|
template <typename SRC_T, typename DST_T> |
|
[[kernel]] void convert_fp8_kernel( |
|
const device SRC_T *__restrict__ src [[buffer(0)]], |
|
device DST_T *__restrict__ dst [[buffer(1)]], |
|
const device float &scale [[buffer(2)]], |
|
const device uint32_t &num_elements [[buffer(3)]], |
|
uint gid [[thread_position_in_grid]]) { |
|
|
|
if (gid >= num_elements) { |
|
return; |
|
} |
|
|
|
// Load source value |
|
SRC_T src_val = src[gid]; |
|
|
|
// Convert based on source and destination types |
|
if constexpr (is_same_v<SRC_T, uchar> && !is_same_v<DST_T, uchar>) { |
|
// FP8 -> higher precision (dequantization) |
|
float fp32_val = fp8_e4m3_to_float(src_val) * scale; |
|
dst[gid] = static_cast<DST_T>(fp32_val); |
|
} else if constexpr (!is_same_v<SRC_T, uchar> && is_same_v<DST_T, uchar>) { |
|
// Higher precision -> FP8 (quantization) |
|
float fp32_val = static_cast<float>(src_val) / scale; |
|
dst[gid] = float_to_fp8_e4m3(fp32_val); |
|
} else if constexpr (is_same_v<SRC_T, uchar> && is_same_v<DST_T, uchar>) { |
|
// FP8 -> FP8 (with rescaling) |
|
float fp32_val = fp8_e4m3_to_float(src_val) * scale; |
|
dst[gid] = float_to_fp8_e4m3(fp32_val); |
|
} else { |
|
// Regular precision -> regular precision (with scaling) |
|
float fp32_val = static_cast<float>(src_val) * scale; |
|
dst[gid] = static_cast<DST_T>(fp32_val); |
|
} |
|
} |
|
|
|
// Instantiate all required combinations |
|
#define INSTANTIATE_CONVERT_FP8(src_type, dst_type) \ |
|
template [[host_name("convert_fp8_" #src_type "_to_" #dst_type)]] \ |
|
[[kernel]] void convert_fp8_kernel<src_type, dst_type>( \ |
|
const device src_type *__restrict__ src [[buffer(0)]], \ |
|
device dst_type *__restrict__ dst [[buffer(1)]], \ |
|
const device float &scale [[buffer(2)]], \ |
|
const device uint32_t &num_elements [[buffer(3)]], \ |
|
uint gid [[thread_position_in_grid]]); |
|
|
|
// FP8 to other formats (dequantization) |
|
INSTANTIATE_CONVERT_FP8(uchar, float); |
|
INSTANTIATE_CONVERT_FP8(uchar, half); |
|
INSTANTIATE_CONVERT_FP8(uchar, bfloat16_t); |
|
|
|
// Other formats to FP8 (quantization) |
|
INSTANTIATE_CONVERT_FP8(float, uchar); |
|
INSTANTIATE_CONVERT_FP8(half, uchar); |
|
INSTANTIATE_CONVERT_FP8(bfloat16_t, uchar); |
|
|
|
// FP8 to FP8 (rescaling) |
|
INSTANTIATE_CONVERT_FP8(uchar, uchar); |
|
|
|
// Regular precision conversions with scaling |
|
INSTANTIATE_CONVERT_FP8(float, float); |
|
INSTANTIATE_CONVERT_FP8(float, half); |
|
INSTANTIATE_CONVERT_FP8(float, bfloat16_t); |
|
INSTANTIATE_CONVERT_FP8(half, float); |
|
INSTANTIATE_CONVERT_FP8(half, half); |
|
INSTANTIATE_CONVERT_FP8(half, bfloat16_t); |
|
INSTANTIATE_CONVERT_FP8(bfloat16_t, float); |
|
INSTANTIATE_CONVERT_FP8(bfloat16_t, half); |
|
INSTANTIATE_CONVERT_FP8(bfloat16_t, bfloat16_t); |