Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions torch/headeronly/util/BFloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <iosfwd>
#include <ostream>

#if defined(__CUDACC__) && !defined(USE_ROCM)
#if defined(__CUDACC__)
#include <cuda_bf16.h>
#endif

Expand All @@ -25,12 +25,16 @@
namespace c10 {

struct alignas(2) BFloat16 {
uint16_t x;

#if defined(__HIPCC__)
static_assert(sizeof(__hip_bfloat16) == sizeof(uint16_t));
union {
uint16_t x;
__hip_bfloat16 x__hip_bfloat16;
};
// HIP wants __host__ __device__ tag, CUDA does not
#if defined(USE_ROCM) && defined(__HIPCC__)
C10_HOST_DEVICE BFloat16() = default;
#else
uint16_t x;
BFloat16() = default;
#endif

Expand All @@ -41,10 +45,10 @@ struct alignas(2) BFloat16 {

constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
: x(bits) {}
/* implicit */ inline C10_HOST_DEVICE BFloat16(float value);
inline C10_HOST_DEVICE BFloat16(float value);
inline C10_HOST_DEVICE operator float() const;

#if defined(__CUDACC__) && !defined(USE_ROCM)
#if defined(__CUDACC__)
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
#endif
Expand All @@ -66,7 +70,7 @@ inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
uint32_t tmp = src;
tmp <<= 16;

#if defined(USE_ROCM) && defined(__HIPCC__)
#if defined(__HIPCC__)
float* tempRes;

// We should be using memcpy in order to respect the strict aliasing rule
Expand All @@ -83,7 +87,7 @@ inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
uint32_t res = 0;

#if defined(USE_ROCM) && defined(__HIPCC__)
#if defined(__HIPCC__)
// We should be using memcpy in order to respect the strict aliasing rule
// but it fails in the HIP environment.
uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
Expand All @@ -96,7 +100,7 @@ inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
}

inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
#if defined(USE_ROCM) && defined(__HIPCC__)
#if defined(__HIPCC__)
if (src != src) {
#elif defined(_MSC_VER)
if (isnan(src)) {
Expand All @@ -122,8 +126,9 @@ C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
/// Constructors
inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
:
#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
__CUDA_ARCH__ >= 800
#if defined(__HIPCC__)
x__hip_bfloat16(value)
#elif defined(__CUDACC__) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
x(__bfloat16_as_ushort(__float2bfloat16(value)))
#elif defined(__SYCL_DEVICE_ONLY__) && \
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
Expand All @@ -137,7 +142,9 @@ inline C10_HOST_DEVICE BFloat16::BFloat16(float value)

/// Implicit conversions
inline C10_HOST_DEVICE BFloat16::operator float() const {
#if defined(__CUDACC__) && !defined(USE_ROCM)
#if defined(__HIPCC__)
return float(x__hip_bfloat16);
#elif defined(__CUDACC__)
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
#elif defined(__SYCL_DEVICE_ONLY__) && \
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
Expand All @@ -147,7 +154,13 @@ inline C10_HOST_DEVICE BFloat16::operator float() const {
#endif
}

#if defined(__CUDACC__) && !defined(USE_ROCM)
#if defined(__HIPCC__)
inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value)
: x__hip_bfloat16(value) {}
inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
return x__hip_bfloat16;
}
#elif defined(__CUDACC__)
inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
x = *reinterpret_cast<const unsigned short*>(&value);
}
Expand All @@ -170,7 +183,7 @@ inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {

#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if !defined(__HIPCC__) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr));
#else
return *ptr;
Expand Down