diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 087c2831b4edb..3958c5d7eb3cd 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -241,12 +241,24 @@ std::chrono::milliseconds ProcessGroupGloo::AsyncWork::getTimeout() const { namespace { c10::intrusive_ptr createFutureAsOutput( const std::vector>& outputTensors) { + // We need to set device in future construction otherwise CUDA streams in + // futures are ignored. + std::vector devices{}; + for (const auto& outputTensor : outputTensors) { + for (const auto& tensor : outputTensor) { + auto device = tensor.device(); + if (!device.is_cpu()) { + devices.push_back(device); + } + } + } if (outputTensors.size() > 1) { return c10::make_intrusive( - c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); + c10::ListType::create(c10::ListType::create(c10::TensorType::get())), + devices); } return c10::make_intrusive( - c10::ListType::create(c10::TensorType::get())); + c10::ListType::create(c10::TensorType::get()), devices); } void returnFutureWithOutput( diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 28b761a37d58c..bc1ffcf167e68 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -4791,7 +4791,12 @@ def _test_ddp_apply_optim_in_backward( # Test a simple linear as well as a ResNet model. models_to_test = [ - nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)).cuda() + nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)).cuda(), + # run model of at least 1M parameters to hit potential race conditions in + # stream semantics + nn.Sequential( + nn.Linear(3, 1024), nn.Linear(1024, 1024), nn.Linear(1024, 3) + ).cuda(), ] if HAS_TORCHVISION: models_to_test.append(torchvision.models.resnet50().cuda()) @@ -4831,7 +4836,7 @@ def _test_ddp_apply_optim_in_backward( for i in range(8): inp = ( torch.randn(1, 3, 1000, 1000, device="cuda") - if j == 1 + if j == 2 else torch.randn(10, 3, device="cuda") ) model(inp).sum().backward()