#include "float8.metal" #include "utils.metal" #include using namespace metal; // Convert between different precision formats for cache tensors // This kernel handles conversions like float->fp8, fp8->float, etc. template [[kernel]] void convert_fp8_kernel( const device SRC_T *__restrict__ src [[buffer(0)]], device DST_T *__restrict__ dst [[buffer(1)]], const device float &scale [[buffer(2)]], const device uint32_t &num_elements [[buffer(3)]], uint gid [[thread_position_in_grid]]) { if (gid >= num_elements) { return; } // Load source value SRC_T src_val = src[gid]; // Convert based on source and destination types if constexpr (is_same_v && !is_same_v) { // FP8 -> higher precision (dequantization) float fp32_val = fp8_e4m3_to_float(src_val) * scale; dst[gid] = static_cast(fp32_val); } else if constexpr (!is_same_v && is_same_v) { // Higher precision -> FP8 (quantization) float fp32_val = static_cast(src_val) / scale; dst[gid] = float_to_fp8_e4m3(fp32_val); } else if constexpr (is_same_v && is_same_v) { // FP8 -> FP8 (with rescaling) float fp32_val = fp8_e4m3_to_float(src_val) * scale; dst[gid] = float_to_fp8_e4m3(fp32_val); } else { // Regular precision -> regular precision (with scaling) float fp32_val = static_cast(src_val) * scale; dst[gid] = static_cast(fp32_val); } } // Instantiate all required combinations #define INSTANTIATE_CONVERT_FP8(src_type, dst_type) \ template [[host_name("convert_fp8_" #src_type "_to_" #dst_type)]] \ [[kernel]] void convert_fp8_kernel( \ const device src_type *__restrict__ src [[buffer(0)]], \ device dst_type *__restrict__ dst [[buffer(1)]], \ const device float &scale [[buffer(2)]], \ const device uint32_t &num_elements [[buffer(3)]], \ uint gid [[thread_position_in_grid]]); // FP8 to other formats (dequantization) INSTANTIATE_CONVERT_FP8(uchar, float); INSTANTIATE_CONVERT_FP8(uchar, half); INSTANTIATE_CONVERT_FP8(uchar, bfloat16_t); // Other formats to FP8 (quantization) INSTANTIATE_CONVERT_FP8(float, uchar); INSTANTIATE_CONVERT_FP8(half, uchar); INSTANTIATE_CONVERT_FP8(bfloat16_t, uchar); // FP8 to FP8 (rescaling) INSTANTIATE_CONVERT_FP8(uchar, uchar); // Regular precision conversions with scaling INSTANTIATE_CONVERT_FP8(float, float); INSTANTIATE_CONVERT_FP8(float, half); INSTANTIATE_CONVERT_FP8(float, bfloat16_t); INSTANTIATE_CONVERT_FP8(half, float); INSTANTIATE_CONVERT_FP8(half, half); INSTANTIATE_CONVERT_FP8(half, bfloat16_t); INSTANTIATE_CONVERT_FP8(bfloat16_t, float); INSTANTIATE_CONVERT_FP8(bfloat16_t, half); INSTANTIATE_CONVERT_FP8(bfloat16_t, bfloat16_t);