Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions aten/src/ATen/cpu/vec/vec256/vec256_16bit_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class Vectorized16 {
Vectorized16() {}
Vectorized16(__m256i v) : values(v) {}
Vectorized16(T val) {
value_type uw = val.x;
value_type uw = val.bits();
values = _mm256_set1_epi16(uw);
}
Vectorized16(
Expand All @@ -219,22 +219,22 @@ class Vectorized16 {
T val15,
T val16) {
values = _mm256_setr_epi16(
val1.x,
val2.x,
val3.x,
val4.x,
val5.x,
val6.x,
val7.x,
val8.x,
val9.x,
val10.x,
val11.x,
val12.x,
val13.x,
val14.x,
val15.x,
val16.x);
val1.bits(),
val2.bits(),
val3.bits(),
val4.bits(),
val5.bits(),
val6.bits(),
val7.bits(),
val8.bits(),
val9.bits(),
val10.bits(),
val11.bits(),
val12.bits(),
val13.bits(),
val14.bits(),
val15.bits(),
val16.bits());
}
operator __m256i() const {
return values;
Expand Down
66 changes: 33 additions & 33 deletions aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class Vectorized16 {
}
Vectorized16(__m512i v) : values(v) {}
Vectorized16(T val) {
value_type uw = val.x;
value_type uw = val.bits();
values = _mm512_set1_epi16(uw);
}
Vectorized16(
Expand Down Expand Up @@ -234,38 +234,38 @@ class Vectorized16 {
T val31,
T val32) {
values = _mm512_set_epi16(
val32.x,
val31.x,
val30.x,
val29.x,
val28.x,
val27.x,
val26.x,
val25.x,
val24.x,
val23.x,
val22.x,
val21.x,
val20.x,
val19.x,
val18.x,
val17.x,
val16.x,
val15.x,
val14.x,
val13.x,
val12.x,
val11.x,
val10.x,
val9.x,
val8.x,
val7.x,
val6.x,
val5.x,
val4.x,
val3.x,
val2.x,
val1.x);
val32.bits(),
val31.bits(),
val30.bits(),
val29.bits(),
val28.bits(),
val27.bits(),
val26.bits(),
val25.bits(),
val24.bits(),
val23.bits(),
val22.bits(),
val21.bits(),
val20.bits(),
val19.bits(),
val18.bits(),
val17.bits(),
val16.bits(),
val15.bits(),
val14.bits(),
val13.bits(),
val12.bits(),
val11.bits(),
val10.bits(),
val9.bits(),
val8.bits(),
val7.bits(),
val6.bits(),
val5.bits(),
val4.bits(),
val3.bits(),
val2.bits(),
val1.bits());
}
operator __m512i() const {
return values;
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/cuda/Atomic.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ struct AtomicFPOp<at::BFloat16> {
at::BFloat16 bsum;
do {
assumed = old;
bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
bsum.bits() = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
bsum = func(bsum, val);
old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x;
old = (size_t)address & 2 ? (old & 0xffff) | (bsum.bits() << 16) : (old & 0xffff0000) | bsum.bits();
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
return bsum.x;
bsum.bits() = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
return bsum.bits();
}
};

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/tunable/TunableGemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ inline bool IsZero(T v) {

template <>
inline bool IsZero(BFloat16 v) {
return v.x == 0;
return v.bits() == 0;
}

template <>
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/FillKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace {

template <typename scalar_t>
void fill_non_native_type(TensorIterator& iter, const Scalar& value_scalar) {
auto value = value_scalar.to<scalar_t>().x;
auto value = value_scalar.to<scalar_t>().bits();
using H = typename std::make_signed_t<decltype(value)>; // Signed type has more acceleration
// Reserve the representation of value. static_cast<H>(value) is implementation defined.
H val = *reinterpret_cast<H*>(std::addressof(value));
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/SortingRadixSelect.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ struct TopKTypeConfig<at::BFloat16> {
typedef uint32_t RadixType;

static inline __device__ RadixType convert(at::BFloat16 v) {
RadixType x = v.x;
RadixType x = v.bits();
RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000;
return (v == v) ? (x ^ mask) : 0xffff;
}

static inline __device__ at::BFloat16 deconvert(RadixType v) {
RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
at::BFloat16 r;
r.x = (v ^ mask);
r.bits() = (v ^ mask);
return r;
}
};
Expand Down
24 changes: 12 additions & 12 deletions c10/test/util/bfloat16_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ TEST(BFloat16Conversion, FloatToBFloat16AndBack) {
float out[100];

for (const auto i : c10::irange(100)) {
bfloats[i].x = c10::detail::bits_from_f32(in[i]);
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
bfloats[i].bits() = c10::detail::bits_from_f32(in[i]);
out[i] = c10::detail::f32_from_bits(bfloats[i].bits());

// The relative error should be less than 1/(2^7) since BFloat16
// has 7 bits mantissa.
Expand All @@ -56,8 +56,8 @@ TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) {
float out[100];

for (const auto i : c10::irange(100)) {
bfloats[i].x = c10::detail::round_to_nearest_even(in[i]);
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
bfloats[i].bits() = c10::detail::round_to_nearest_even(in[i]);
out[i] = c10::detail::f32_from_bits(bfloats[i].bits());

// The relative error should be less than 1/(2^7) since BFloat16
// has 7 bits mantissa.
Expand All @@ -70,7 +70,7 @@ TEST(BFloat16Conversion, NaN) {
EXPECT_TRUE(std::isnan(inNaN));

c10::BFloat16 a = c10::BFloat16(inNaN);
float out = c10::detail::f32_from_bits(a.x);
float out = c10::detail::f32_from_bits(a.bits());

EXPECT_TRUE(std::isnan(out));
}
Expand All @@ -80,7 +80,7 @@ TEST(BFloat16Conversion, Inf) {
EXPECT_TRUE(std::isinf(inInf));

c10::BFloat16 a = c10::BFloat16(inInf);
float out = c10::detail::f32_from_bits(a.x);
float out = c10::detail::f32_from_bits(a.bits());

EXPECT_TRUE(std::isinf(out));
}
Expand All @@ -89,7 +89,7 @@ TEST(BFloat16Conversion, SmallestDenormal) {
float in = std::numeric_limits<float>::denorm_min(); // The smallest non-zero
// subnormal number
c10::BFloat16 a = c10::BFloat16(in);
float out = c10::detail::f32_from_bits(a.x);
float out = c10::detail::f32_from_bits(a.bits());

EXPECT_FLOAT_EQ(in, out);
}
Expand All @@ -109,10 +109,10 @@ TEST(BFloat16Math, Addition) {
float expected = float_from_bytes(0, 0, 0x40c80000);

c10::BFloat16 b{};
b.x = c10::detail::bits_from_f32(input);
b.bits() = c10::detail::bits_from_f32(input);
b = b + b;

float res = c10::detail::f32_from_bits(b.x);
float res = c10::detail::f32_from_bits(b.bits());
EXPECT_EQ(res, expected);
}

Expand All @@ -131,10 +131,10 @@ TEST(BFloat16Math, Subtraction) {
float expected = float_from_bytes(0, 0, 0x40280000);

c10::BFloat16 b{};
b.x = c10::detail::bits_from_f32(input);
b.bits() = c10::detail::bits_from_f32(input);
b = b - 5;

float res = c10::detail::f32_from_bits(b.x);
float res = c10::detail::f32_from_bits(b.bits());
EXPECT_EQ(res, expected);
}

Expand All @@ -145,7 +145,7 @@ TEST(BFloat16Math, NextAfterZero) {
[](c10::BFloat16 from, c10::BFloat16 to, c10::BFloat16 expected) {
c10::BFloat16 actual = std::nextafter(from, to);
// Check for bitwise equality!
ASSERT_EQ(actual.x ^ expected.x, uint16_t{0});
ASSERT_EQ(actual.bits() ^ expected.bits(), uint16_t{0});
};
check_nextafter(zero, zero, /*expected=*/zero);
check_nextafter(zero, -zero, /*expected=*/-zero);
Expand Down
5 changes: 0 additions & 5 deletions c10/util/copysign.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,4 @@ inline c10::Half copysign(c10::Half a, c10::Half b) {
return c10::Half((a.x & 0x7fff) | (b.x & 0x8000), c10::Half::from_bits());
}

inline c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) {
return c10::BFloat16(
(a.x & 0x7fff) | (b.x & 0x8000), c10::BFloat16::from_bits());
}

} // namespace c10
Loading