diff --git a/torch/headeronly/util/BFloat16.h b/torch/headeronly/util/BFloat16.h index 2c1f805ac7b73..d71bb7b454fd0 100644 --- a/torch/headeronly/util/BFloat16.h +++ b/torch/headeronly/util/BFloat16.h @@ -12,7 +12,7 @@ #include #include -#if defined(__CUDACC__) && !defined(USE_ROCM) +#if defined(__CUDACC__) #include #endif @@ -25,12 +25,16 @@ namespace c10 { struct alignas(2) BFloat16 { - uint16_t x; - +#if defined(__HIPCC__) + static_assert(sizeof(__hip_bfloat16) == sizeof(uint16_t)); + union { + uint16_t x; + __hip_bfloat16 x__hip_bfloat16; + }; // HIP wants __host__ __device__ tag, CUDA does not -#if defined(USE_ROCM) && defined(__HIPCC__) C10_HOST_DEVICE BFloat16() = default; #else + uint16_t x; BFloat16() = default; #endif @@ -41,10 +45,10 @@ struct alignas(2) BFloat16 { constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) : x(bits) {} - /* implicit */ inline C10_HOST_DEVICE BFloat16(float value); + inline C10_HOST_DEVICE BFloat16(float value); inline C10_HOST_DEVICE operator float() const; -#if defined(__CUDACC__) && !defined(USE_ROCM) +#if defined(__CUDACC__) inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; #endif @@ -66,7 +70,7 @@ inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { uint32_t tmp = src; tmp <<= 16; -#if defined(USE_ROCM) && defined(__HIPCC__) +#if defined(__HIPCC__) float* tempRes; // We should be using memcpy in order to respect the strict aliasing rule @@ -83,7 +87,7 @@ inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { uint32_t res = 0; -#if defined(USE_ROCM) && defined(__HIPCC__) +#if defined(__HIPCC__) // We should be using memcpy in order to respect the strict aliasing rule // but it fails in the HIP environment. uint32_t* tempRes = reinterpret_cast(&src); @@ -96,7 +100,7 @@ inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { } inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { -#if defined(USE_ROCM) && defined(__HIPCC__) +#if defined(__HIPCC__) if (src != src) { #elif defined(_MSC_VER) if (isnan(src)) { @@ -122,8 +126,9 @@ C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") /// Constructors inline C10_HOST_DEVICE BFloat16::BFloat16(float value) : -#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \ - __CUDA_ARCH__ >= 800 +#if defined(__HIPCC__) + x__hip_bfloat16(value) +#elif defined(__CUDACC__) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 x(__bfloat16_as_ushort(__float2bfloat16(value))) #elif defined(__SYCL_DEVICE_ONLY__) && \ defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) @@ -137,7 +142,9 @@ inline C10_HOST_DEVICE BFloat16::BFloat16(float value) /// Implicit conversions inline C10_HOST_DEVICE BFloat16::operator float() const { -#if defined(__CUDACC__) && !defined(USE_ROCM) +#if defined(__HIPCC__) + return float(x__hip_bfloat16); +#elif defined(__CUDACC__) return __bfloat162float(*reinterpret_cast(&x)); #elif defined(__SYCL_DEVICE_ONLY__) && \ defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) @@ -147,7 +154,13 @@ inline C10_HOST_DEVICE BFloat16::operator float() const { #endif } -#if defined(__CUDACC__) && !defined(USE_ROCM) +#if defined(__HIPCC__) +inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) + : x__hip_bfloat16(value) {} +inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const { + return x__hip_bfloat16; +} +#elif defined(__CUDACC__) inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) { x = *reinterpret_cast(&value); } @@ -170,7 +183,7 @@ inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const { #if defined(__CUDACC__) || defined(__HIPCC__) inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) { -#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#if !defined(__HIPCC__) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __ldg(reinterpret_cast(ptr)); #else return *ptr;