#pragma once // Defines the bloat16 type (brain floating-point). This representation uses // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. #include #include #include namespace c10 { namespace detail { inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { float res = 0; uint32_t tmp = src; tmp <<= 16; #ifdef __HIP_PLATFORM_HCC__ float* tempRes; // We should be using memcpy in order to respect the strict aliasing rule // but it fails in the HIP environment. tempRes = reinterpret_cast(&tmp); res = *tempRes; #else std::memcpy(&res, &tmp, sizeof(tmp)); #endif return res; } inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { uint32_t res = 0; #ifdef __HIP_PLATFORM_HCC__ // We should be using memcpy in order to respect the strict aliasing rule // but it fails in the HIP environment. uint32_t* tempRes = reinterpret_cast(&src); res = *tempRes; #else std::memcpy(&res, &src, sizeof(res)); #endif return res >> 16; } inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { if (std::isnan(src)) { return 0x7FC0; } else { union { uint32_t U32; float F32; }; F32 = src; uint32_t rounding_bias = ((U32 >> 16) & 1) + 0x7FFF; return static_cast((U32 + rounding_bias) >> 16); } } } // namespace detail struct alignas(2) BFloat16 { uint16_t x; // HIP wants __host__ __device__ tag, CUDA does not #ifdef __HIP_PLATFORM_HCC__ C10_HOST_DEVICE BFloat16() = default; #else BFloat16() = default; #endif struct from_bits_t {}; static constexpr from_bits_t from_bits() { return from_bits_t(); } constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) : x(bits){}; inline C10_HOST_DEVICE BFloat16(float value); inline C10_HOST_DEVICE operator float() const; }; } // namespace c10 #include