diff --git a/aten/src/ATen/cuda/Atomic.cuh b/aten/src/ATen/cuda/Atomic.cuh index f16be30f8b71b..b29e0a0dd3c47 100644 --- a/aten/src/ATen/cuda/Atomic.cuh +++ b/aten/src/ATen/cuda/Atomic.cuh @@ -18,11 +18,18 @@ struct AtomicFPOp { template inline __device__ at::Half operator() (at::Half *address, at::Half val, const func_t& func) { unsigned int * address_as_ui = - (unsigned int *) ((char *)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; + (unsigned int *) ((char *)address - ((size_t)address & 2)); + + // Read only 2 bytes to avoid out-of-bounds access on last array element + unsigned short target_val = *reinterpret_cast(address); + + unsigned int old = ((size_t)address & 2) + ? (static_cast(target_val) << 16) + : static_cast(target_val); + unsigned int assumed; at::Half hsum; + do { assumed = old; hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); @@ -30,6 +37,7 @@ struct AtomicFPOp { old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; old = atomicCAS(address_as_ui, assumed, old); } while (assumed != old); + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); return hsum; } @@ -40,11 +48,18 @@ struct AtomicFPOp { template inline __device__ at::BFloat16 operator() (at::BFloat16 *address, at::BFloat16 val, const func_t& func) { unsigned int * address_as_ui = - (unsigned int *) ((char *)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; + (unsigned int *) ((char *)address - ((size_t)address & 2)); + + // Read only 2 bytes to avoid out-of-bounds access on last array element + unsigned short target_val = *reinterpret_cast(address); + + unsigned int old = ((size_t)address & 2) + ? (static_cast(target_val) << 16) + : static_cast(target_val); + unsigned int assumed; at::BFloat16 bsum; + do { assumed = old; bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); @@ -52,8 +67,9 @@ struct AtomicFPOp { old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x; old = atomicCAS(address_as_ui, assumed, old); } while (assumed != old); + bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - return bsum.x; + return bsum; } }; diff --git a/test/test_torch.py b/test/test_torch.py index 675b190a4e0a9..9e1df3d7e2d5f 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6307,6 +6307,20 @@ def test_index_add_bfloat16(self, device): self.assertEqual(out_cpu, out_gpu, atol=1e-2, rtol=0) + @onlyCUDA + def test_index_add_half(self, device): + inp_tensor = torch.randn(5, 3, device='cpu').half() + t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.half, device='cpu') + index = torch.tensor([0, 4, 2], device='cpu') + out_cpu = inp_tensor.index_add(0, index, t) + + inp_tensor = inp_tensor.to(device=device) + t = t.to(device=device) + index = index.to(device=device) + out_gpu = inp_tensor.index_add(0, index, t) + + self.assertEqual(out_cpu, out_gpu, atol=1e-2, rtol=0) + # FIXME: move to serialization test suite def test_device_serialization(self, device): x = torch.randn(4, 4, device=device)