#ifndef _data_types_cuh #define _data_types_cuh #include "marlin.cuh" #include #include namespace marlin { template class ScalarType {}; template <> class ScalarType { public: using scalar_t = half; using scalar_t2 = half2; // Matrix fragments for tensor core instructions; their precise layout is // documented here: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; using FragZP = Vec; static __device__ float inline num2float(const half x) { return __half2float(x); } static __device__ half2 inline num2num2(const half x) { return __half2half2(x); } static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); } static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } }; template <> class ScalarType { public: using scalar_t = nv_bfloat16; using scalar_t2 = nv_bfloat162; using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; using FragZP = Vec; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); } static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); } static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } #endif }; } // namespace marlin #endif