Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,35 @@ def _no_gather_coalesced(params: Iterable[Parameter]) -> AllGatherCoalescedHandl
return NoGatherCoalescedHandle(params)


def _unsharded_single_rank_warning(dp_world_size, data_parallel_group, env=None):
"""Detect the silent single-rank fallback described in #8084.

When a multi-process launcher (``deepspeed``, ``torchrun``, accelerate, ...) sets ``WORLD_SIZE > 1`` but the
distributed process group was not initialized before ``zero.Init`` ran, the group resolved here collapses to a
single rank. ``zero.Init`` then creates every parameter whole on every rank instead of partitioning it, so each
rank allocates the full (unsharded) model and typically OOMs. The failure is otherwise silent and looks exactly
like a "model too big" OOM. Return an actionable warning message in that case, else ``None``.

Only the default (world-group) path is checked: an explicitly supplied ``data_parallel_group`` of size 1 is
treated as intentional.
"""
if dp_world_size != 1 or data_parallel_group is not None:
return None
env = os.environ if env is None else env
try:
launcher_world_size = int(env.get("WORLD_SIZE", "0") or "0")
except (TypeError, ValueError):
return None
if launcher_world_size <= 1:
return None
return (
"zero.Init resolved a process group of world_size=1, but the launcher environment reports "
f"WORLD_SIZE={launcher_world_size}. The distributed process group was likely not initialized before "
"zero.Init ran (for example, `from_pretrained` executed before `deepspeed.init_distributed()`). Parameters "
"will NOT be partitioned: every rank allocates the full model and will likely OOM. Call "
"`deepspeed.init_distributed()` before constructing the model under zero.Init.")


# Replaces all parameters in module with Scattered Parameters
class Init(InsertPostInitMethodToModuleSubClasses):
param_id = 0
Expand Down Expand Up @@ -1035,6 +1064,10 @@ def __init__(self,
self.rank = dist.get_rank(group=self.ds_process_group)
self.dp_world_size = dist.get_world_size(group=self.ds_process_group)

_unsharded_warning = _unsharded_single_rank_warning(self.dp_world_size, data_parallel_group)
if _unsharded_warning is not None:
logger.warning(_unsharded_warning)

self.zero_param_process_group = zero_param_parallel_group
if _ds_config is not None and _ds_config.zero_config.zero_hpz_partition_size > 1 and self.zero_param_process_group is None:
groups._create_zero_param_parallel_group(_ds_config.zero_config.zero_hpz_partition_size)
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/runtime/zero/test_zero_init_unsharded_warning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# Regression coverage for #8084: zero.Init silently falls back to a single-rank (unsharded) group when the
# distributed process group is not initialized before it runs (e.g. `from_pretrained` before
# `deepspeed.init_distributed()`), so every rank allocates the full model and OOMs. The detection helper must warn
# only when the launcher reports a multi-process world but the resolved group collapsed to one rank.

import pytest

from deepspeed.runtime.zero.partition_parameters import _unsharded_single_rank_warning


def test_warns_when_launcher_multiprocess_but_group_is_single_rank():
msg = _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=None, env={"WORLD_SIZE": "8"})
assert msg is not None
assert "WORLD_SIZE=8" in msg
assert "init_distributed" in msg


def test_no_warning_for_genuine_single_process():
assert _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=None, env={"WORLD_SIZE": "1"}) is None
assert _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=None, env={}) is None


def test_no_warning_when_group_actually_shards():
assert _unsharded_single_rank_warning(dp_world_size=8, data_parallel_group=None, env={"WORLD_SIZE": "8"}) is None


def test_no_warning_when_explicit_dp_group_supplied():
# An explicitly provided size-1 data_parallel_group is treated as intentional.
sentinel_group = object()
assert _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=sentinel_group, env={"WORLD_SIZE":
"8"}) is None


@pytest.mark.parametrize("bad", ["", "not-an-int", None])
def test_malformed_world_size_does_not_raise(bad):
assert _unsharded_single_rank_warning(dp_world_size=1, data_parallel_group=None, env={"WORLD_SIZE": bad}) is None
Loading