From d79c7cebb01313352b1acf9e91ef6d2c886950bb Mon Sep 17 00:00:00 2001 From: Prachi Gupta Date: Mon, 16 Mar 2026 17:09:08 -0500 Subject: [PATCH] [release/2.10] ProcessGroupGloo: fix CUDA tensor stream handling with futures (#170812) (#3073) Fixes #155714 There's a very subtle bug in Gloo where CUDA future streams aren't preserved correctly leading to silent corruption when using Gloo with a CUDA model using the DDP reducer. Test plan: ```python import os RANK = int(os.environ["RANK"]) WORLD_SIZE = int(os.environ["WORLD_SIZE"]) os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["LOCAL_RANK"] import torch import torch.distributed as dist torch.manual_seed(0) dist.init_process_group("gloo") N = 10 expected = torch.sum(torch.arange(0, WORLD_SIZE, dtype=torch.float)).item() t = torch.full((1000000,), RANK, device="cuda", dtype=torch.float) tensors = [ t.clone() for _ in range(N) ] futs = [] for tensor in tensors: work = dist.all_reduce(tensor, async_op=True) futs.append(work.get_future()) # create high priority stream to do the CPU copy and preempt the default stream stream = torch.cuda.Stream(priority=-1) for fut, tensor in zip(futs, tensors): with torch.cuda.stream(stream): fut.wait() val = tensor[-1].item() assert val == expected, f"Expected {expected}, got {val}" ``` ``` torchrun --nnodes 1 --nproc_per_node=gpu ~/scripts/gloo_future_stream.py ``` ``` BACKEND=gloo WORLD_SIZE=4 TEMP_DIR=/tmp/foo pytest test/distributed/test_distributed_spawn.py -v -s -x -k 'test_ddp_apply_optim_in_backward' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/170812 Approved by: https://github.com/fduwjj, https://github.com/jeffdaily (cherry picked from commit 398d338e32b80b9da04c9ac0edc89d6aa90f4e90) ## Motivation ## Technical Details ## Test Plan ## Test Result ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Co-authored-by: Tristan Rice --- torch/csrc/distributed/c10d/ProcessGroupGloo.cpp | 16 ++++++++++++++-- .../_internal/distributed/distributed_test.py | 9 +++++++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 087c2831b4edb..3958c5d7eb3cd 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -241,12 +241,24 @@ std::chrono::milliseconds ProcessGroupGloo::AsyncWork::getTimeout() const { namespace { c10::intrusive_ptr createFutureAsOutput( const std::vector>& outputTensors) { + // We need to set device in future construction otherwise CUDA streams in + // futures are ignored. + std::vector devices{}; + for (const auto& outputTensor : outputTensors) { + for (const auto& tensor : outputTensor) { + auto device = tensor.device(); + if (!device.is_cpu()) { + devices.push_back(device); + } + } + } if (outputTensors.size() > 1) { return c10::make_intrusive( - c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + c10::ListType::create(c10::ListType::create(c10::TensorType::get())), + devices); } return c10::make_intrusive( - c10::ListType::create(c10::TensorType::get())); + c10::ListType::create(c10::TensorType::get()), devices); } void returnFutureWithOutput( diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 28b761a37d58c..bc1ffcf167e68 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -4791,7 +4791,12 @@ def _test_ddp_apply_optim_in_backward( # Test a simple linear as well as a ResNet model. models_to_test = [ - nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)).cuda() + nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)).cuda(), + # run model of at least 1M parameters to hit potential race conditions in + # stream semantics + nn.Sequential( + nn.Linear(3, 1024), nn.Linear(1024, 1024), nn.Linear(1024, 3) + ).cuda(), ] if HAS_TORCHVISION: models_to_test.append(torchvision.models.resnet50().cuda()) @@ -4831,7 +4836,7 @@ def _test_ddp_apply_optim_in_backward( for i in range(8): inp = ( torch.randn(1, 3, 1000, 1000, device="cuda") - if j == 1 + if j == 2 else torch.randn(10, 3, device="cuda") ) model(inp).sum().backward()