|
#include <metal_stdlib> |
|
using namespace metal; |
|
|
|
|
|
static inline uint as_bits(float x) { return as_type<uint>(x); } |
|
static inline float from_bits(uint b) { return as_type<float>(b); } |
|
|
|
|
|
|
|
|
|
inline float fp8_e4m3_to_float(uchar v) { |
|
const uint s = v >> 7; |
|
const uint exp = (v >> 3) & 0xF; |
|
const uint man = v & 0x7; |
|
|
|
if (exp == 0) { |
|
if (man == 0) |
|
return s ? -0.f : 0.f; |
|
const float m = float(man) / 8.f; |
|
float val = ldexp(m, 1 - 7); |
|
return s ? -val : val; |
|
} |
|
|
|
if (exp == 0xF) { |
|
if (man != 0) |
|
return NAN; |
|
return s ? -INFINITY : INFINITY; |
|
} |
|
|
|
const float m = 1.f + float(man) / 8.f; |
|
float val = ldexp(m, int(exp) - 7); |
|
return s ? -val : val; |
|
} |
|
|
|
|
|
|
|
|
|
inline float fp8_e5m2_to_float(uchar v) { |
|
const uint s = v >> 7; |
|
const uint exp = (v >> 2) & 0x1F; |
|
const uint man = v & 0x3; |
|
|
|
if (exp == 0) { |
|
if (man == 0) |
|
return s ? -0.f : 0.f; |
|
const float m = float(man) / 4.f; |
|
float val = ldexp(m, 1 - 15); |
|
return s ? -val : val; |
|
} |
|
|
|
if (exp == 0x1F) { |
|
if (man != 0) |
|
return NAN; |
|
return s ? -INFINITY : INFINITY; |
|
} |
|
|
|
const float m = 1.f + float(man) / 4.f; |
|
float val = ldexp(m, int(exp) - 15); |
|
return s ? -val : val; |
|
} |
|
|
|
|
|
|
|
|
|
namespace detail { |
|
template <int EXP_BITS, int MAN_BITS, int BIAS> |
|
inline uchar fp32_to_fp8(float f) { |
|
const uint bits = as_bits(f); |
|
const uint s = bits >> 31; |
|
const uint abs = bits & 0x7FFFFFFF; |
|
|
|
|
|
if (abs >= 0x7F800000u) { |
|
return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS) | |
|
(abs != 0x7F800000u)); |
|
} |
|
|
|
int e = int((abs >> 23) & 0xFF) - 127; |
|
uint m = abs & 0x7FFFFFu; |
|
const int EXP_MAX = (1 << EXP_BITS) - 2; |
|
|
|
|
|
int e_fp8 = e + BIAS; |
|
if (e_fp8 >= 1 && e_fp8 <= EXP_MAX) { |
|
|
|
const int shift = 23 - MAN_BITS; |
|
uint mant = m >> shift; |
|
const uint lsb = mant & 1u; |
|
const uint round = (m >> (shift - 1)) & 1u; |
|
const uint sticky = (m & ((1u << (shift - 1)) - 1u)) != 0u; |
|
mant += (round & (sticky | lsb)); |
|
if (mant >> MAN_BITS) { |
|
mant = 0; |
|
++e_fp8; |
|
if (e_fp8 > EXP_MAX) |
|
return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS)); |
|
} |
|
return uchar((s << 7) | (uint(e_fp8) << MAN_BITS) | |
|
(mant & ((1u << MAN_BITS) - 1u))); |
|
} |
|
|
|
|
|
if (e_fp8 < 1 - MAN_BITS) |
|
return uchar(s << 7); |
|
|
|
|
|
int rshift = (1 - e_fp8) + (23 - MAN_BITS); |
|
uint mant = (0x800000u | m); |
|
uint rounded = (mant + (1u << (rshift - 1))) >> rshift; |
|
if (rounded == 0) |
|
return uchar(s << 7); |
|
|
|
return uchar((s << 7) | (rounded & ((1u << MAN_BITS) - 1u))); |
|
} |
|
} |
|
|
|
inline uchar float_to_fp8_e4m3(float f) { |
|
return detail::fp32_to_fp8<4, 3, 7>(f); |
|
} |
|
inline uchar float_to_fp8_e5m2(float f) { |
|
return detail::fp32_to_fp8<5, 2, 15>(f); |
|
} |