Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions aten/src/ATen/cuda/Atomic.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,26 @@ struct AtomicFPOp<at::Half> {
template <typename func_t>
inline __device__ at::Half operator() (at::Half *address, at::Half val, const func_t& func) {
unsigned int * address_as_ui =
(unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
(unsigned int *) ((char *)address - ((size_t)address & 2));

// Read only 2 bytes to avoid out-of-bounds access on last array element
unsigned short target_val = *reinterpret_cast<unsigned short*>(address);

unsigned int old = ((size_t)address & 2)
? (static_cast<unsigned int>(target_val) << 16)
: static_cast<unsigned int>(target_val);

unsigned int assumed;
at::Half hsum;

do {
assumed = old;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
hsum = func(hsum, val);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);

hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
return hsum;
}
Expand All @@ -40,20 +48,28 @@ struct AtomicFPOp<at::BFloat16> {
template <typename func_t>
inline __device__ at::BFloat16 operator() (at::BFloat16 *address, at::BFloat16 val, const func_t& func) {
unsigned int * address_as_ui =
(unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
(unsigned int *) ((char *)address - ((size_t)address & 2));

// Read only 2 bytes to avoid out-of-bounds access on last array element
unsigned short target_val = *reinterpret_cast<unsigned short*>(address);

unsigned int old = ((size_t)address & 2)
? (static_cast<unsigned int>(target_val) << 16)
: static_cast<unsigned int>(target_val);

unsigned int assumed;
at::BFloat16 bsum;

do {
assumed = old;
bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
bsum = func(bsum, val);
old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x;
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);

bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
return bsum.x;
return bsum;
}
};

Expand Down
14 changes: 14 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6307,6 +6307,20 @@ def test_index_add_bfloat16(self, device):

self.assertEqual(out_cpu, out_gpu, atol=1e-2, rtol=0)

@onlyCUDA
def test_index_add_half(self, device):
inp_tensor = torch.randn(5, 3, device='cpu').half()
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.half, device='cpu')
index = torch.tensor([0, 4, 2], device='cpu')
out_cpu = inp_tensor.index_add(0, index, t)

inp_tensor = inp_tensor.to(device=device)
t = t.to(device=device)
index = index.to(device=device)
out_gpu = inp_tensor.index_add(0, index, t)

self.assertEqual(out_cpu, out_gpu, atol=1e-2, rtol=0)

# FIXME: move to serialization test suite
def test_device_serialization(self, device):
x = torch.randn(4, 4, device=device)
Expand Down