From 642a9d2827218bd5465d553933f9d75afcc3ce31 Mon Sep 17 00:00:00 2001 From: FNU AKSHANSH <105249360+akshansh47@users.noreply.github.com> Date: Wed, 24 Jun 2026 11:59:08 -0700 Subject: [PATCH] Warn when zero.Init silently falls back to a single rank (#8084) When a multi-process launcher sets WORLD_SIZE>1 but the distributed process group is not initialized before zero.Init runs (e.g. from_pretrained before deepspeed.init_distributed()), the resolved group collapses to a single rank. zero.Init then materializes every parameter whole on every rank instead of partitioning, so each rank loads the full model and OOMs with no diagnostic. Detect this case and emit an actionable warning pointing at the missing init_distributed() call. Co-authored-by: Cursor Signed-off-by: FNU AKSHANSH <105249360+akshansh47@users.noreply.github.com> Co-authored-by: Cursor --- .../runtime/zero/partition_parameters.py | 33 +++++++++++++++ .../zero/test_zero_init_unsharded_warning.py | 41 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 tests/unit/runtime/zero/test_zero_init_unsharded_warning.py diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 7b7c50454874..f3b07eb81eae 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -880,6 +880,35 @@ def _no_gather_coalesced(params: Iterable[Parameter]) -> AllGatherCoalescedHandl return NoGatherCoalescedHandle(params) +def _unsharded_single_rank_warning(dp_world_size, data_parallel_group, env=None): + """Detect the silent single-rank fallback described in #8084. + + When a multi-process launcher (``deepspeed``, ``torchrun``, accelerate, ...) sets ``WORLD_SIZE > 1`` but the + distributed process group was not initialized before ``zero.Init`` ran, the group resolved here collapses to a + single rank. ``zero.Init`` then creates every parameter whole on every rank instead of partitioning it, so each + rank allocates the full (unsharded) model and typically OOMs. The failure is otherwise silent and looks exactly + like a "model too big" OOM. Return an actionable warning message in that case, else ``None``. + + Only the default (world-group) path is checked: an explicitly supplied ``data_parallel_group`` of size 1 is + treated as intentional. + """ + if dp_world_size != 1 or data_parallel_group is not None: + return None + env = os.environ if env is None else env + try: + launcher_world_size = int(env.get("WORLD_SIZE", "0") or "0") + except (TypeError, ValueError): + return None + if launcher_world_size <= 1: + return None + return ( + "zero.Init resolved a process group of world_size=1, but the launcher environment reports " + f"WORLD_SIZE={launcher_world_size}. The distributed process group was likely not initialized before " + "zero.Init ran (for example, `from_pretrained` executed before `deepspeed.init_distributed()`). Parameters " + "will NOT be partitioned: every rank allocates the full model and will likely OOM. Call " + "`deepspeed.init_distributed()` before constructing the model under zero.Init.") + + # Replaces all parameters in module with Scattered Parameters class Init(InsertPostInitMethodToModuleSubClasses): param_id = 0 @@ -1035,6 +1064,10 @@ def __init__(self, self.rank = dist.get_rank(group=self.ds_process_group) self.dp_world_size = dist.get_world_size(group=self.ds_process_group) + _unsharded_warning = _unsharded_single_rank_warning(self.dp_world_size, data_parallel_group) + if _unsharded_warning is not None: + logger.warning(_unsharded_warning) + self.zero_param_process_group = zero_param_parallel_group if _ds_config is not None and _ds_config.zero_config.zero_hpz_partition_size > 1 and self.zero_param_process_group is None: groups._create_zero_param_parallel_group(_ds_config.zero_config.zero_hpz_partition_size) diff --git a/tests/unit/runtime/zero/test_zero_init_unsharded_warning.py b/tests/unit/runtime/zero/test_zero_init_unsharded_warning.py new file mode 100644 index 000000000000..11fc045b4345 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_init_unsharded_warning.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Regression coverage for #8084: zero.Init silently falls back to a single-rank (unsharded) group when the +# distributed process group is not initialized before it runs (e.g. `from_pretrained` before +# `deepspeed.init_distributed()`), so every rank allocates the full model and OOMs. The detection helper must warn +# only when the launcher reports a multi-process world but the resolved group collapsed to one rank. + +import pytest + +from deepspeed.runtime.zero.partition_parameters import _unsharded_single_rank_warning + + +def test_warns_when_launcher_multiprocess_but_group_is_single_rank(): + msg = _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=None, env={"WORLD_SIZE": "8"}) + assert msg is not None + assert "WORLD_SIZE=8" in msg + assert "init_distributed" in msg + + +def test_no_warning_for_genuine_single_process(): + assert _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=None, env={"WORLD_SIZE": "1"}) is None + assert _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=None, env={}) is None + + +def test_no_warning_when_group_actually_shards(): + assert _unsharded_single_rank_warning(dp_world_size=8, data_parallel_group=None, env={"WORLD_SIZE": "8"}) is None + + +def test_no_warning_when_explicit_dp_group_supplied(): + # An explicitly provided size-1 data_parallel_group is treated as intentional. + sentinel_group = object() + assert _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=sentinel_group, env={"WORLD_SIZE": + "8"}) is None + + +@pytest.mark.parametrize("bad", ["", "not-an-int", None]) +def test_malformed_world_size_does_not_raise(bad): + assert _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=None, env={"WORLD_SIZE": bad}) is None