diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_16bit_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_16bit_float.h index 425fb6aa79e13..b151fe01359d1 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_16bit_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_16bit_float.h @@ -198,7 +198,7 @@ class Vectorized16 { Vectorized16() {} Vectorized16(__m256i v) : values(v) {} Vectorized16(T val) { - value_type uw = val.x; + value_type uw = val.bits(); values = _mm256_set1_epi16(uw); } Vectorized16( @@ -219,22 +219,22 @@ class Vectorized16 { T val15, T val16) { values = _mm256_setr_epi16( - val1.x, - val2.x, - val3.x, - val4.x, - val5.x, - val6.x, - val7.x, - val8.x, - val9.x, - val10.x, - val11.x, - val12.x, - val13.x, - val14.x, - val15.x, - val16.x); + val1.bits(), + val2.bits(), + val3.bits(), + val4.bits(), + val5.bits(), + val6.bits(), + val7.bits(), + val8.bits(), + val9.bits(), + val10.bits(), + val11.bits(), + val12.bits(), + val13.bits(), + val14.bits(), + val15.bits(), + val16.bits()); } operator __m256i() const { return values; diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h index 844b3b1fcc1e8..29f6ab1d55d9b 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h @@ -197,7 +197,7 @@ class Vectorized16 { } Vectorized16(__m512i v) : values(v) {} Vectorized16(T val) { - value_type uw = val.x; + value_type uw = val.bits(); values = _mm512_set1_epi16(uw); } Vectorized16( @@ -234,38 +234,38 @@ class Vectorized16 { T val31, T val32) { values = _mm512_set_epi16( - val32.x, - val31.x, - val30.x, - val29.x, - val28.x, - val27.x, - val26.x, - val25.x, - val24.x, - val23.x, - val22.x, - val21.x, - val20.x, - val19.x, - val18.x, - val17.x, - val16.x, - val15.x, - val14.x, - val13.x, - val12.x, - val11.x, - val10.x, - val9.x, - val8.x, - val7.x, - val6.x, - val5.x, - val4.x, - val3.x, - val2.x, - val1.x); + val32.bits(), + val31.bits(), + val30.bits(), + val29.bits(), + val28.bits(), + val27.bits(), + val26.bits(), + val25.bits(), + val24.bits(), + val23.bits(), + val22.bits(), + val21.bits(), + val20.bits(), + val19.bits(), + val18.bits(), + val17.bits(), + val16.bits(), + val15.bits(), + val14.bits(), + val13.bits(), + val12.bits(), + val11.bits(), + val10.bits(), + val9.bits(), + val8.bits(), + val7.bits(), + val6.bits(), + val5.bits(), + val4.bits(), + val3.bits(), + val2.bits(), + val1.bits()); } operator __m512i() const { return values; diff --git a/aten/src/ATen/cuda/Atomic.cuh b/aten/src/ATen/cuda/Atomic.cuh index f16be30f8b71b..aa71061584df2 100644 --- a/aten/src/ATen/cuda/Atomic.cuh +++ b/aten/src/ATen/cuda/Atomic.cuh @@ -47,13 +47,13 @@ struct AtomicFPOp { at::BFloat16 bsum; do { assumed = old; - bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + bsum.bits() = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); bsum = func(bsum, val); - old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x; + old = (size_t)address & 2 ? (old & 0xffff) | (bsum.bits() << 16) : (old & 0xffff0000) | bsum.bits(); old = atomicCAS(address_as_ui, assumed, old); } while (assumed != old); - bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - return bsum.x; + bsum.bits() = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + return bsum.bits(); } }; diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index d941c230630c4..90eccd59931d3 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -121,7 +121,7 @@ inline bool IsZero(T v) { template <> inline bool IsZero(BFloat16 v) { - return v.x == 0; + return v.bits() == 0; } template <> diff --git a/aten/src/ATen/native/cpu/FillKernel.cpp b/aten/src/ATen/native/cpu/FillKernel.cpp index e22df01635fe0..0ef9300f3d869 100644 --- a/aten/src/ATen/native/cpu/FillKernel.cpp +++ b/aten/src/ATen/native/cpu/FillKernel.cpp @@ -15,7 +15,7 @@ namespace { template void fill_non_native_type(TensorIterator& iter, const Scalar& value_scalar) { - auto value = value_scalar.to().x; + auto value = value_scalar.to().bits(); using H = typename std::make_signed_t; // Signed type has more acceleration // Reserve the representation of value. static_cast(value) is implementation defined. H val = *reinterpret_cast(std::addressof(value)); diff --git a/aten/src/ATen/native/cuda/SortingRadixSelect.cuh b/aten/src/ATen/native/cuda/SortingRadixSelect.cuh index 71f9a3aa13c8e..31338135ec28b 100644 --- a/aten/src/ATen/native/cuda/SortingRadixSelect.cuh +++ b/aten/src/ATen/native/cuda/SortingRadixSelect.cuh @@ -152,7 +152,7 @@ struct TopKTypeConfig { typedef uint32_t RadixType; static inline __device__ RadixType convert(at::BFloat16 v) { - RadixType x = v.x; + RadixType x = v.bits(); RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; return (v == v) ? (x ^ mask) : 0xffff; } @@ -160,7 +160,7 @@ struct TopKTypeConfig { static inline __device__ at::BFloat16 deconvert(RadixType v) { RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; at::BFloat16 r; - r.x = (v ^ mask); + r.bits() = (v ^ mask); return r; } }; diff --git a/c10/test/util/bfloat16_test.cpp b/c10/test/util/bfloat16_test.cpp index 062d339cdd8e1..577eb365c46c4 100644 --- a/c10/test/util/bfloat16_test.cpp +++ b/c10/test/util/bfloat16_test.cpp @@ -33,8 +33,8 @@ TEST(BFloat16Conversion, FloatToBFloat16AndBack) { float out[100]; for (const auto i : c10::irange(100)) { - bfloats[i].x = c10::detail::bits_from_f32(in[i]); - out[i] = c10::detail::f32_from_bits(bfloats[i].x); + bfloats[i].bits() = c10::detail::bits_from_f32(in[i]); + out[i] = c10::detail::f32_from_bits(bfloats[i].bits()); // The relative error should be less than 1/(2^7) since BFloat16 // has 7 bits mantissa. @@ -56,8 +56,8 @@ TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) { float out[100]; for (const auto i : c10::irange(100)) { - bfloats[i].x = c10::detail::round_to_nearest_even(in[i]); - out[i] = c10::detail::f32_from_bits(bfloats[i].x); + bfloats[i].bits() = c10::detail::round_to_nearest_even(in[i]); + out[i] = c10::detail::f32_from_bits(bfloats[i].bits()); // The relative error should be less than 1/(2^7) since BFloat16 // has 7 bits mantissa. @@ -70,7 +70,7 @@ TEST(BFloat16Conversion, NaN) { EXPECT_TRUE(std::isnan(inNaN)); c10::BFloat16 a = c10::BFloat16(inNaN); - float out = c10::detail::f32_from_bits(a.x); + float out = c10::detail::f32_from_bits(a.bits()); EXPECT_TRUE(std::isnan(out)); } @@ -80,7 +80,7 @@ TEST(BFloat16Conversion, Inf) { EXPECT_TRUE(std::isinf(inInf)); c10::BFloat16 a = c10::BFloat16(inInf); - float out = c10::detail::f32_from_bits(a.x); + float out = c10::detail::f32_from_bits(a.bits()); EXPECT_TRUE(std::isinf(out)); } @@ -89,7 +89,7 @@ TEST(BFloat16Conversion, SmallestDenormal) { float in = std::numeric_limits::denorm_min(); // The smallest non-zero // subnormal number c10::BFloat16 a = c10::BFloat16(in); - float out = c10::detail::f32_from_bits(a.x); + float out = c10::detail::f32_from_bits(a.bits()); EXPECT_FLOAT_EQ(in, out); } @@ -109,10 +109,10 @@ TEST(BFloat16Math, Addition) { float expected = float_from_bytes(0, 0, 0x40c80000); c10::BFloat16 b{}; - b.x = c10::detail::bits_from_f32(input); + b.bits() = c10::detail::bits_from_f32(input); b = b + b; - float res = c10::detail::f32_from_bits(b.x); + float res = c10::detail::f32_from_bits(b.bits()); EXPECT_EQ(res, expected); } @@ -131,10 +131,10 @@ TEST(BFloat16Math, Subtraction) { float expected = float_from_bytes(0, 0, 0x40280000); c10::BFloat16 b{}; - b.x = c10::detail::bits_from_f32(input); + b.bits() = c10::detail::bits_from_f32(input); b = b - 5; - float res = c10::detail::f32_from_bits(b.x); + float res = c10::detail::f32_from_bits(b.bits()); EXPECT_EQ(res, expected); } @@ -145,7 +145,7 @@ TEST(BFloat16Math, NextAfterZero) { [](c10::BFloat16 from, c10::BFloat16 to, c10::BFloat16 expected) { c10::BFloat16 actual = std::nextafter(from, to); // Check for bitwise equality! - ASSERT_EQ(actual.x ^ expected.x, uint16_t{0}); + ASSERT_EQ(actual.bits() ^ expected.bits(), uint16_t{0}); }; check_nextafter(zero, zero, /*expected=*/zero); check_nextafter(zero, -zero, /*expected=*/-zero); diff --git a/c10/util/copysign.h b/c10/util/copysign.h index e77b7e61b0aef..f1131f83df754 100644 --- a/c10/util/copysign.h +++ b/c10/util/copysign.h @@ -19,9 +19,4 @@ inline c10::Half copysign(c10::Half a, c10::Half b) { return c10::Half((a.x & 0x7fff) | (b.x & 0x8000), c10::Half::from_bits()); } -inline c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) { - return c10::BFloat16( - (a.x & 0x7fff) | (b.x & 0x8000), c10::BFloat16::from_bits()); -} - } // namespace c10 diff --git a/torch/headeronly/util/BFloat16.h b/torch/headeronly/util/BFloat16.h index 2c1f805ac7b73..5bb6c5a373d85 100644 --- a/torch/headeronly/util/BFloat16.h +++ b/torch/headeronly/util/BFloat16.h @@ -12,7 +12,7 @@ #include #include -#if defined(__CUDACC__) && !defined(USE_ROCM) +#if defined(__CUDACC__) #include #endif @@ -25,13 +25,17 @@ namespace c10 { struct alignas(2) BFloat16 { +#if defined(__HIPCC__) +protected: + __hip_bfloat16 x; +public: +// HIP wants __host__ __device__ tag, CUDA does not + C10_HOST_DEVICE BFloat16() = default; + inline unsigned short& bits() { return x.__x; } +#else uint16_t x; - - // HIP wants __host__ __device__ tag, CUDA does not -#if defined(USE_ROCM) && defined(__HIPCC__) - C10_HOST_DEVICE BFloat16() = default; -#else BFloat16() = default; + uint16_t& bits() { return x; } #endif struct from_bits_t {}; @@ -41,10 +45,10 @@ struct alignas(2) BFloat16 { constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) : x(bits) {} - /* implicit */ inline C10_HOST_DEVICE BFloat16(float value); + inline C10_HOST_DEVICE BFloat16(float value); inline C10_HOST_DEVICE operator float() const; -#if defined(__CUDACC__) && !defined(USE_ROCM) +#if defined(__CUDACC__) inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; #endif @@ -53,6 +57,11 @@ struct alignas(2) BFloat16 { inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value); explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const; #endif + +friend inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b); +friend inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b); +friend inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b); +friend inline C10_HOST_DEVICE BFloat16 copysign(BFloat16 a, BFloat16 b); }; inline std::ostream& operator<<(std::ostream& out, const BFloat16& value) { @@ -66,9 +75,8 @@ inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { uint32_t tmp = src; tmp <<= 16; -#if defined(USE_ROCM) && defined(__HIPCC__) +#if defined(__HIPCC__) float* tempRes; - // We should be using memcpy in order to respect the strict aliasing rule // but it fails in the HIP environment. tempRes = reinterpret_cast(&tmp); @@ -83,7 +91,7 @@ inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { uint32_t res = 0; -#if defined(USE_ROCM) && defined(__HIPCC__) +#if defined(__HIPCC__) // We should be using memcpy in order to respect the strict aliasing rule // but it fails in the HIP environment. uint32_t* tempRes = reinterpret_cast(&src); @@ -96,7 +104,7 @@ inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { } inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { -#if defined(USE_ROCM) && defined(__HIPCC__) +#if defined(__HIPCC__) if (src != src) { #elif defined(_MSC_VER) if (isnan(src)) { @@ -113,7 +121,6 @@ inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { } // namespace detail -//-------- the following is copied from c10/util/BFloat16-inl.h ---------// C10_CLANG_DIAGNOSTIC_PUSH() #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") @@ -122,8 +129,9 @@ C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") /// Constructors inline C10_HOST_DEVICE BFloat16::BFloat16(float value) : -#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \ - __CUDA_ARCH__ >= 800 +#if defined(__HIPCC__) + x(value) +#elif defined(__CUDACC__) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 x(__bfloat16_as_ushort(__float2bfloat16(value))) #elif defined(__SYCL_DEVICE_ONLY__) && \ defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) @@ -137,7 +145,9 @@ inline C10_HOST_DEVICE BFloat16::BFloat16(float value) /// Implicit conversions inline C10_HOST_DEVICE BFloat16::operator float() const { -#if defined(__CUDACC__) && !defined(USE_ROCM) +#if defined(__HIPCC__) + return x; +#elif defined(__CUDACC__) return __bfloat162float(*reinterpret_cast(&x)); #elif defined(__SYCL_DEVICE_ONLY__) && \ defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) @@ -147,7 +157,11 @@ inline C10_HOST_DEVICE BFloat16::operator float() const { #endif } -#if defined(__CUDACC__) && !defined(USE_ROCM) +#if defined(__HIPCC__) +inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) : x(value) {} +inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const { + return x; } +#elif defined(__CUDACC__) inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) { x = *reinterpret_cast(&value); } @@ -170,7 +184,7 @@ inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const { #if defined(__CUDACC__) || defined(__HIPCC__) inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) { -#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#if !defined(__HIPCC__) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return __ldg(reinterpret_cast(ptr)); #else return *ptr; @@ -225,20 +239,25 @@ inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) { } inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) { - a.x = a.x | b.x; + a.bits() = a.bits() | b.bits(); return a; } inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) { - a.x = a.x ^ b.x; + a.bits() = a.bits() ^ b.bits(); return a; } inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) { - a.x = a.x & b.x; + a.bits() = a.bits() & b.bits(); return a; } +inline C10_HOST_DEVICE BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) { + return BFloat16( + (a.bits() & 0x7fff) | (b.bits() & 0x8000), c10::BFloat16::from_bits()); +} + /// Arithmetic with floats inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) { diff --git a/torch/headeronly/util/Float8_e4m3fn.h b/torch/headeronly/util/Float8_e4m3fn.h index d54a8f40a6c16..969cebf7d1e97 100644 --- a/torch/headeronly/util/Float8_e4m3fn.h +++ b/torch/headeronly/util/Float8_e4m3fn.h @@ -46,6 +46,9 @@ struct alignas(1) Float8_e4m3fn { constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t) : x(bits) {} + + C10_HOST_DEVICE uint8_t& bits() { return x; } + inline C10_HOST_DEVICE Float8_e4m3fn(float value); inline C10_HOST_DEVICE operator float() const; inline C10_HOST_DEVICE bool isnan() const; diff --git a/torch/headeronly/util/Float8_e4m3fnuz.h b/torch/headeronly/util/Float8_e4m3fnuz.h index 772ffd9e96c6a..0de741329711a 100644 --- a/torch/headeronly/util/Float8_e4m3fnuz.h +++ b/torch/headeronly/util/Float8_e4m3fnuz.h @@ -47,6 +47,9 @@ struct alignas(1) Float8_e4m3fnuz { constexpr C10_HOST_DEVICE Float8_e4m3fnuz(uint8_t bits, from_bits_t) : x(bits) {} + + C10_HOST_DEVICE uint8_t& bits() { return x; } + inline C10_HOST_DEVICE Float8_e4m3fnuz(float value); inline C10_HOST_DEVICE operator float() const; inline C10_HOST_DEVICE bool isnan() const; diff --git a/torch/headeronly/util/Float8_e5m2.h b/torch/headeronly/util/Float8_e5m2.h index aeee40d8e5b80..fc044999fc671 100644 --- a/torch/headeronly/util/Float8_e5m2.h +++ b/torch/headeronly/util/Float8_e5m2.h @@ -31,6 +31,9 @@ struct alignas(1) Float8_e5m2 { Float8_e5m2() = default; constexpr C10_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) : x(bits) {} + + C10_HOST_DEVICE uint8_t& bits() { return x; } + inline C10_HOST_DEVICE Float8_e5m2(float value); inline C10_HOST_DEVICE operator float() const; inline C10_HOST_DEVICE bool isnan() const; diff --git a/torch/headeronly/util/Float8_e5m2fnuz.h b/torch/headeronly/util/Float8_e5m2fnuz.h index 8bcb2ac07f76d..de44c3286e4d3 100644 --- a/torch/headeronly/util/Float8_e5m2fnuz.h +++ b/torch/headeronly/util/Float8_e5m2fnuz.h @@ -46,6 +46,9 @@ struct alignas(1) Float8_e5m2fnuz { constexpr C10_HOST_DEVICE Float8_e5m2fnuz(uint8_t bits, from_bits_t) : x(bits) {} + + C10_HOST_DEVICE uint8_t& bits() { return x; } + inline C10_HOST_DEVICE Float8_e5m2fnuz(float value); inline C10_HOST_DEVICE operator float() const; inline C10_HOST_DEVICE bool isnan() const; diff --git a/torch/headeronly/util/Float8_e8m0fnu.h b/torch/headeronly/util/Float8_e8m0fnu.h index c5a70525f2f2f..54701d241c8e0 100644 --- a/torch/headeronly/util/Float8_e8m0fnu.h +++ b/torch/headeronly/util/Float8_e8m0fnu.h @@ -41,6 +41,9 @@ struct alignas(1) Float8_e8m0fnu { constexpr C10_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t) : x(bits) {} + + C10_HOST_DEVICE uint8_t& bits() { return x; } + inline C10_HOST_DEVICE Float8_e8m0fnu(float value); inline C10_HOST_DEVICE operator float() const; inline C10_HOST_DEVICE bool isnan() const; diff --git a/torch/headeronly/util/Half.h b/torch/headeronly/util/Half.h index 59a86f07e333e..b41a3cf51dd1e 100644 --- a/torch/headeronly/util/Half.h +++ b/torch/headeronly/util/Half.h @@ -68,6 +68,8 @@ namespace c10 { struct alignas(2) Half { unsigned short x; + C10_HOST_DEVICE uint16_t& bits() { return x; } + struct from_bits_t {}; C10_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t();