From 278c9194892cfa3f3b25857cd26fe8b04289fcd9 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sat, 13 Jun 2026 10:34:02 -0700 Subject: [PATCH 01/16] Add AutoEP + AutoTP parallel folding Allow tensor parallelism (AutoTP) for the dense/attention path to coexist with expert parallelism (AutoEP) for routed experts on the same rank set, without requiring EP to be a subset of DP. - Treat dense and MoE as independent partitionings: dense view tp*dp, expert view ep*etp*edp, with dp/edp derived so tp*dp == ep*etp*edp == stage_size. expert_tensor_parallel_size is reserved (must currently be 1). - Express folding via the existing tensor_parallel/expert_parallel config sections, with divisibility, TP/sequence-parallel exclusivity, and preset_model consistency validation. - Add the route-full / partition-dispatch MoE path and AutoTP skipping of AutoEP subtrees; derive folded process groups via the generalized expert/data-parallel group creation. - Reduce TP-replicated router/gate gradients mode-aware (sum when tokens are partitioned, average when replicated); record per-parameter-family ZeRO checkpoint metadata and handle folded ZeRO-1/2 optimizer state. - Add folding unit tests (config, groups, dispatch, runtime, gradient parity, checkpoint), including multi-rank GPU-gated cases. Signed-off-by: Masahiro Tanaka --- deepspeed/checkpoint/autoep_universal.py | 106 +++++- deepspeed/checkpoint/constants.py | 19 + deepspeed/comm/torch.py | 2 +- deepspeed/module_inject/auto_ep_config.py | 40 +- deepspeed/module_inject/auto_ep_folding.py | 261 +++++++++++++ deepspeed/module_inject/auto_ep_layer.py | 134 ++++++- .../module_inject/auto_ep_presets/base.py | 1 + deepspeed/module_inject/auto_tp.py | 23 ++ deepspeed/moe/ep_tp_dispatch.py | 306 +++++++++++++++ deepspeed/runtime/engine.py | 353 ++++++++++++++++-- deepspeed/runtime/zero/stage_1_and_2.py | 28 ++ deepspeed/utils/groups.py | 12 +- tests/unit/v1/moe/autoep_test_utils.py | 116 +++++- .../v1/moe/test_autoep_autotp_checkpoint.py | 309 +++++++++++++++ .../v1/moe/test_autoep_autotp_dispatch.py | 223 +++++++++++ .../moe/test_autoep_autotp_folding_config.py | 135 +++++++ .../moe/test_autoep_autotp_folding_groups.py | 74 ++++ .../v1/moe/test_autoep_autotp_grad_parity.py | 301 +++++++++++++++ .../unit/v1/moe/test_autoep_autotp_runtime.py | 243 ++++++++++++ tests/unit/v1/moe/test_autoep_unit.py | 13 +- 20 files changed, 2637 insertions(+), 62 deletions(-) create mode 100644 deepspeed/module_inject/auto_ep_folding.py create mode 100644 deepspeed/moe/ep_tp_dispatch.py create mode 100644 tests/unit/v1/moe/test_autoep_autotp_checkpoint.py create mode 100644 tests/unit/v1/moe/test_autoep_autotp_dispatch.py create mode 100644 tests/unit/v1/moe/test_autoep_autotp_folding_config.py create mode 100644 tests/unit/v1/moe/test_autoep_autotp_folding_groups.py create mode 100644 tests/unit/v1/moe/test_autoep_autotp_grad_parity.py create mode 100644 tests/unit/v1/moe/test_autoep_autotp_runtime.py diff --git a/deepspeed/checkpoint/autoep_universal.py b/deepspeed/checkpoint/autoep_universal.py index 3c19ab0c4183..cf984ecb86d4 100644 --- a/deepspeed/checkpoint/autoep_universal.py +++ b/deepspeed/checkpoint/autoep_universal.py @@ -17,9 +17,98 @@ CAT_DIM, EP_IS_EXPERT_PARAM, EP_NUM_EXPERTS, + FOLDING_METADATA_KEY, + FOLDING_METADATA_VERSION, + FOLDING_TP_SIZE, + FOLDING_TP_RANK, + FOLDING_EP_SIZE, + FOLDING_EP_RANK, + FOLDING_ETP_SIZE, + FOLDING_ETP_RANK, + FOLDING_ZERO_PARTITION_GROUP, + FOLDING_ZERO_PARTITION_RANK, + FOLDING_ZERO_PARTITION_COUNT, + FOLDING_DISPATCH_STRATEGY, + FOLDING_SHARED_EXPERT_PLACEMENT, + FOLDING_FAMILY, + FOLDING_PARAM_FAMILIES, ) +def make_folding_metadata(*, + tp_size, + tp_rank, + ep_size, + ep_rank, + zero_partition_group, + zero_partition_rank, + zero_partition_count, + family, + param_families=None): + metadata = { + "version": FOLDING_METADATA_VERSION, + FOLDING_TP_SIZE: tp_size, + FOLDING_TP_RANK: tp_rank, + FOLDING_EP_SIZE: ep_size, + FOLDING_EP_RANK: ep_rank, + FOLDING_ETP_SIZE: 1, + FOLDING_ETP_RANK: 0, + FOLDING_ZERO_PARTITION_GROUP: zero_partition_group, + FOLDING_ZERO_PARTITION_RANK: zero_partition_rank, + FOLDING_ZERO_PARTITION_COUNT: zero_partition_count, + FOLDING_DISPATCH_STRATEGY: "route_full_partition_dispatch", + FOLDING_SHARED_EXPERT_PLACEMENT: "tp_sharded", + FOLDING_FAMILY: family, + } + if param_families is not None: + metadata[FOLDING_PARAM_FAMILIES] = dict(param_families) + return metadata + + +def validate_folding_metadata(metadata, + *, + tp_size, + ep_size, + etp_size=1, + tp_rank=None, + ep_rank=None, + etp_rank=None, + zero_partition_group=None, + zero_partition_rank=None, + zero_partition_count=None, + family=None, + param_families=None, + shared_expert_placement=None, + dispatch_strategy=None): + if not isinstance(metadata, dict) or FOLDING_METADATA_KEY not in metadata: + raise RuntimeError("Missing AutoEP+AutoTP folding metadata in folded checkpoint.") + folding = metadata[FOLDING_METADATA_KEY] + if folding.get("version") != FOLDING_METADATA_VERSION: + raise RuntimeError(f"Unsupported folding metadata version: {folding.get('version')}") + expected = { + FOLDING_TP_SIZE: tp_size, + FOLDING_EP_SIZE: ep_size, + FOLDING_ETP_SIZE: etp_size, + } + optional_expected = { + FOLDING_TP_RANK: tp_rank, + FOLDING_EP_RANK: ep_rank, + FOLDING_ETP_RANK: etp_rank, + FOLDING_ZERO_PARTITION_GROUP: zero_partition_group, + FOLDING_ZERO_PARTITION_RANK: zero_partition_rank, + FOLDING_ZERO_PARTITION_COUNT: zero_partition_count, + FOLDING_FAMILY: family, + FOLDING_PARAM_FAMILIES: param_families, + FOLDING_SHARED_EXPERT_PLACEMENT: shared_expert_placement, + FOLDING_DISPATCH_STRATEGY: dispatch_strategy, + } + expected.update({key: value for key, value in optional_expected.items() if value is not None}) + for key, value in expected.items(): + if folding.get(key) != value: + raise RuntimeError(f"Folding metadata mismatch for {key}: saved={folding.get(key)} runtime={value}") + return folding + + def _state_entry(state, param_id): """Get optimizer state entry by param id, handling int/str key variants.""" if param_id in state: @@ -102,6 +191,13 @@ def resolve_expert_ckpt_path(checkpoint_dir, moe_layer_id, global_expert_id): raise FileNotFoundError(f"Expert checkpoint file not found: layer_{moe_layer_id} " f"expert_{global_expert_id} in {checkpoint_dir}") if len(matches) > 1: + for match in matches: + state = torch.load(match, map_location='cpu', weights_only=False) + if FOLDING_METADATA_KEY in state: + raise NotImplementedError("Universal checkpoint conversion for folded AutoEP+AutoTP expert shards " + "is not supported yet. Load this checkpoint with a matching folded " + "runtime, or consolidate the tensor-parallel expert shards before " + "running ds_to_universal.") raise NotImplementedError(f"Multiple expert checkpoint files found for layer_{moe_layer_id} " f"expert_{global_expert_id}: {matches}. Multi-mp_rank expert files " f"are not yet supported.") @@ -138,9 +234,12 @@ def consolidate_autoep_expert_files(checkpoint_dir, output_dir, autoep_layers_me for wname in ('w1', 'w2', 'w3'): expert_tensors = [] + folding_metadata = None for global_eid in range(num_experts): ckpt_path = resolve_expert_ckpt_path(checkpoint_dir, moe_layer_id, global_eid) sd = torch.load(ckpt_path, map_location='cpu', weights_only=False) + if folding_metadata is None: + folding_metadata = sd.get(FOLDING_METADATA_KEY) key = f"{prefix}.{wname}.{global_eid}" if key not in sd: raise RuntimeError(f"Expected key '{key}' not found in {ckpt_path}") @@ -153,12 +252,15 @@ def consolidate_autoep_expert_files(checkpoint_dir, output_dir, autoep_layers_me param_name = f"{prefix}.{wname}" param_dir = os.path.join(output_dir, "zero", param_name) os.makedirs(param_dir, exist_ok=True) - torch.save({ + universal_state = { PARAM: full_tensor, CAT_DIM: 0, EP_IS_EXPERT_PARAM: True, EP_NUM_EXPERTS: num_experts, - }, os.path.join(param_dir, "fp32.pt")) + } + if folding_metadata is not None: + universal_state[FOLDING_METADATA_KEY] = folding_metadata + torch.save(universal_state, os.path.join(param_dir, "fp32.pt")) def consolidate_autoep_optimizer_states(checkpoint_dir, output_dir, autoep_layers_metadata, ep_size): diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index 0f83458a713d..b649332006f9 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -100,3 +100,22 @@ EP_IS_EXPERT_PARAM = 'is_expert_param' EP_NUM_EXPERTS = 'ep_num_experts' EXPERT_PARAMETER_PATTERNS = 'expert_parameter_patterns' + +######################################### +# AutoEP + AutoTP folding metadata keys +######################################### +FOLDING_METADATA_KEY = 'folding' +FOLDING_METADATA_VERSION = 1 +FOLDING_TP_SIZE = 'tp_size' +FOLDING_TP_RANK = 'tp_rank' +FOLDING_EP_SIZE = 'ep_size' +FOLDING_EP_RANK = 'ep_rank' +FOLDING_ETP_SIZE = 'etp_size' +FOLDING_ETP_RANK = 'etp_rank' +FOLDING_ZERO_PARTITION_GROUP = 'zero_partition_group' +FOLDING_ZERO_PARTITION_RANK = 'zero_partition_rank' +FOLDING_ZERO_PARTITION_COUNT = 'zero_partition_count' +FOLDING_DISPATCH_STRATEGY = 'dispatch_strategy' +FOLDING_SHARED_EXPERT_PLACEMENT = 'shared_expert_placement' +FOLDING_FAMILY = 'family' +FOLDING_PARAM_FAMILIES = 'param_families' diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 39e3f65fbe92..01765d3f34d3 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -30,7 +30,7 @@ def disable_compiler_collective(func): def build_shm_op(): builder = get_accelerator().create_op_builder("ShareMemCommBuilder") - if builder is None or not deepspeed.ops.__compatible_ops__[builder.NAME]: + if builder is None or not deepspeed.ops.__compatible_ops__.get(builder.NAME, False): return None shm_cpp_module = builder.load() print(f'DeepSpeed {builder.absolute_name()} built successfully') diff --git a/deepspeed/module_inject/auto_ep_config.py b/deepspeed/module_inject/auto_ep_config.py index a0e6e3a9b36b..067c1da4219c 100644 --- a/deepspeed/module_inject/auto_ep_config.py +++ b/deepspeed/module_inject/auto_ep_config.py @@ -18,6 +18,7 @@ available_preset_names, resolve_autoep_config_defaults, ) +from deepspeed.module_inject.auto_ep_folding import build_folding_spec, validate_folding_global from deepspeed.utils import logger __all__ = [ @@ -45,6 +46,7 @@ def parse_autoep_config(param_dict: dict) -> AutoEPConfig: config = AutoEPConfig() config.enabled = param_dict.get("enabled", False) config.autoep_size = param_dict.get("autoep_size", 1) + config.expert_tensor_parallel_size = param_dict.get("expert_tensor_parallel_size", 1) config.preset_model = param_dict.get("preset_model", None) config.moe_layer_pattern = param_dict.get("moe_layer_pattern", None) config.expert_pattern = param_dict.get("expert_pattern", None) @@ -95,6 +97,13 @@ def validate_autoep_config( pp_size: int, tp_size: int, sp_size: int, + *, + zero_stage: int = 0, + tp_preset_model: str | None = None, + use_data_before_expert_parallel: bool = False, + mpu=None, + zero_offload_optimizer: bool = False, + zero_offload_param: bool = False, ) -> None: """Validate config constraints. Raises ValueError on invalid config.""" if config.load_balance_coeff is not None: @@ -103,17 +112,26 @@ def validate_autoep_config( if not config.enabled: return - if tp_size > 1: - raise ValueError("AutoEP does not currently support AutoTP " - f"(tensor_parallel.autotp_size={tp_size}). Disable AutoTP for this run; " - "AutoEP+AutoTP support is planned as follow-up work.") - - # ep_size must divide the stage size (world_size / pp_size) - stage_size = world_size // pp_size - if stage_size % config.autoep_size != 0: - raise ValueError(f"autoep_size={config.autoep_size} must divide the stage size " - f"(world_size={world_size} / pp_size={pp_size} = {stage_size}). " - f"Valid autoep_size values: {_divisors(stage_size)}") + folding_spec = build_folding_spec( + world_size=world_size, + pp_size=pp_size, + tp_size=max(tp_size, 1), + ep_size=config.autoep_size, + etp_size=config.expert_tensor_parallel_size, + mp_mode="tp" if tp_size > 1 else "sp", + ) + validate_folding_global( + folding_spec, + zero_stage=zero_stage, + sp_size=sp_size, + use_data_before_expert_parallel=use_data_before_expert_parallel, + mpu=mpu, + autoep_enabled=config.enabled, + tp_preset=tp_preset_model, + ep_preset=config.preset_model, + zero_offload_optimizer=zero_offload_optimizer, + zero_offload_param=zero_offload_param, + ) # Validate preset_model if specified if config.preset_model is not None and config.preset_model not in PRESET_MODELS: diff --git a/deepspeed/module_inject/auto_ep_folding.py b/deepspeed/module_inject/auto_ep_folding.py new file mode 100644 index 000000000000..6e884d2bdb57 --- /dev/null +++ b/deepspeed/module_inject/auto_ep_folding.py @@ -0,0 +1,261 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""AutoEP + AutoTP folding topology helpers. + +The functions in this module are pure topology math unless a caller passes +runtime process-group handles into :class:`FoldingGroupHandles`. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable + + +@dataclass(frozen=True) +class ParallelFoldingSpec: + world_size: int + pp_size: int + stage_size: int + tp_size: int + dp_size: int + ep_size: int + etp_size: int + edp_size: int + mp_mode: str = "tp" + + +@dataclass(frozen=True) +class FoldingGroupTables: + tp_groups: tuple[tuple[int, ...], ...] + dense_dp_groups: tuple[tuple[int, ...], ...] + ep_groups: tuple[tuple[int, ...], ...] + edp_groups: tuple[tuple[int, ...], ...] + + +@dataclass(frozen=True) +class FoldingGroupHandles: + spec: ParallelFoldingSpec + tp_group: object + dense_dp_group: object + ep_group: object + edp_group: object + ep_group_name: str + tp_ranks: tuple[int, ...] + dense_dp_ranks: tuple[int, ...] + ep_ranks: tuple[int, ...] + edp_ranks: tuple[int, ...] + + +def _divisors(value: int) -> list[int]: + return [candidate for candidate in range(1, value + 1) if value % candidate == 0] + + +def _require_positive(name: str, value: int) -> None: + if not isinstance(value, int) or value < 1: + raise ValueError(f"{name} must be a positive integer, got {value!r}") + + +def build_folding_spec( + *, + world_size: int, + pp_size: int, + tp_size: int, + ep_size: int, + etp_size: int = 1, + mp_mode: str = "tp", +) -> ParallelFoldingSpec: + """Build the immutable per-stage folding spec from public config sizes.""" + for name, value in ( + ("world_size", world_size), + ("pp_size", pp_size), + ("tensor_parallel.autotp_size", tp_size), + ("expert_parallel.autoep_size", ep_size), + ("expert_parallel.expert_tensor_parallel_size", etp_size), + ): + _require_positive(name, value) + + if world_size % pp_size != 0: + raise ValueError(f"pp_size={pp_size} must divide world_size={world_size}. " + f"Valid pp_size values: {_divisors(world_size)}") + + stage_size = world_size // pp_size + if stage_size % tp_size != 0: + raise ValueError(f"tensor_parallel.autotp_size={tp_size} must divide the stage size " + f"(world_size={world_size} / pp_size={pp_size} = {stage_size}). " + f"Computed dp would be non-integral. Valid autotp_size values: {_divisors(stage_size)}") + + expert_width = ep_size * etp_size + if stage_size % expert_width != 0: + raise ValueError(f"expert_parallel.autoep_size * expert_parallel.expert_tensor_parallel_size " + f"({ep_size} * {etp_size} = {expert_width}) must divide the stage size " + f"(world_size={world_size} / pp_size={pp_size} = {stage_size}). " + f"Computed edp would be non-integral. Valid expert-width values: {_divisors(stage_size)}") + + return ParallelFoldingSpec( + world_size=world_size, + pp_size=pp_size, + stage_size=stage_size, + tp_size=tp_size, + dp_size=stage_size // tp_size, + ep_size=ep_size, + etp_size=etp_size, + edp_size=stage_size // expert_width, + mp_mode=mp_mode, + ) + + +def expected_folding_group_tables(spec: ParallelFoldingSpec) -> FoldingGroupTables: + """Derive TP, dense-DP, EP, and EDP rank tables without process groups.""" + tp_groups: list[tuple[int, ...]] = [] + dense_dp_groups: list[tuple[int, ...]] = [] + ep_groups: list[tuple[int, ...]] = [] + edp_groups: list[tuple[int, ...]] = [] + + for stage_start in range(0, spec.world_size, spec.stage_size): + stage_ranks = list(range(stage_start, stage_start + spec.stage_size)) + + for dp_idx in range(spec.dp_size): + start = dp_idx * spec.tp_size + tp_groups.append(tuple(stage_ranks[start:start + spec.tp_size])) + for tp_lane in range(spec.tp_size): + dense_dp_groups.append(tuple(stage_ranks[tp_lane::spec.tp_size])) + + if spec.mp_mode == "tp" and spec.tp_size > 1: + ordered_stage_ranks = [] + for tp_lane in range(spec.tp_size): + ordered_stage_ranks.extend(stage_ranks[tp_lane::spec.tp_size]) + else: + ordered_stage_ranks = stage_ranks + + local_ep_groups = [ + tuple(ordered_stage_ranks[start:start + spec.ep_size]) + for start in range(0, len(ordered_stage_ranks), spec.ep_size) + ] + ep_groups.extend(local_ep_groups) + for pos in range(spec.ep_size): + edp_groups.append(tuple(group[pos] for group in local_ep_groups)) + + return FoldingGroupTables( + tp_groups=tuple(tp_groups), + dense_dp_groups=tuple(dense_dp_groups), + ep_groups=tuple(ep_groups), + edp_groups=tuple(edp_groups), + ) + + +def local_folding_ranks(global_rank: int, spec: ParallelFoldingSpec) -> dict[str, tuple[int, ...]]: + tables = expected_folding_group_tables(spec) + result = {} + for name, groups in ( + ("tp", tables.tp_groups), + ("dense_dp", tables.dense_dp_groups), + ("ep", tables.ep_groups), + ("edp", tables.edp_groups), + ): + result[name] = next(group for group in groups if global_rank in group) + return result + + +def _mpu_world_size(mpu, *names: str) -> int | None: + if mpu is None: + return None + for name in names: + getter = getattr(mpu, name, None) + if getter is not None: + return getter() + return None + + +def validate_folding_global( + spec: ParallelFoldingSpec, + *, + zero_stage: int = 0, + sp_size: int = 1, + use_data_before_expert_parallel: bool = False, + mpu=None, + autoep_enabled: bool = True, + tp_preset: str | None = None, + ep_preset: str | None = None, + zero_offload_optimizer: bool = False, + zero_offload_param: bool = False, +) -> None: + """Validate global folding policy before any process group is created.""" + if not autoep_enabled: + return + + if spec.tp_size > 1 and spec.pp_size > 1: + raise ValueError("AutoEP+AutoTP folding currently supports pp_size=1 only; " + f"got pp_size={spec.pp_size}. Pipeline-parallel validation is planned separately.") + + if spec.tp_size > 1 and sp_size > 1: + raise ValueError("tensor_parallel.autotp_size and Ulysses sequence parallelism are mutually exclusive " + f"for AutoEP folding (autotp_size={spec.tp_size}, sp_size={sp_size}).") + + if spec.etp_size != 1: + raise ValueError(f"expert_parallel.expert_tensor_parallel_size={spec.etp_size} is reserved for " + "expert-internal tensor parallelism and is not supported yet. Use 1; ETP support " + "is planned as follow-up work.") + + expert_width = spec.ep_size * spec.etp_size + if spec.tp_size > 1 and expert_width > spec.dp_size: + raise ValueError("AutoEP+AutoTP folding does not yet support cross-lane expert-parallel groups where " + "expert_parallel.autoep_size * expert_parallel.expert_tensor_parallel_size exceeds " + f"the derived dense data-parallel size (ep * etp = {expert_width}, dp = {spec.dp_size}, " + f"stage_size = {spec.stage_size}). This is a temporary limitation; use a shape with " + "ep * etp <= dp or run a follow-up implementation for cross-lane EP groups.") + + if tp_preset is not None and ep_preset is not None and tp_preset != ep_preset: + raise ValueError("tensor_parallel.preset_model and expert_parallel.preset_model must match when both " + f"are set (tensor_parallel.preset_model={tp_preset!r}, " + f"expert_parallel.preset_model={ep_preset!r}).") + + if spec.tp_size > 1 and spec.ep_size == 1: + raise ValueError("AutoEP+AutoTP folding requires expert_parallel.autoep_size > 1. " + "The ep=1 local-computation path would duplicate routed-token gradients across TP lanes.") + + if spec.tp_size > 1 and use_data_before_expert_parallel: + raise ValueError("expert_parallel with use_data_before_expert_parallel_ is not supported with " + "AutoEP+AutoTP folding. Disable use_data_before_expert_parallel_.") + + if spec.tp_size > 1 and zero_stage == 3: + raise ValueError("AutoEP+AutoTP with ZeRO stage 3 is reserved for the separate ZeRO-3 composition lane. " + "Use ZeRO stage 0, 1, or 2 for this folding MVP.") + + if spec.tp_size > 1 and (zero_offload_optimizer or zero_offload_param): + raise ValueError("ZeRO optimizer/parameter offload with AutoEP+AutoTP folding is not validated yet. " + "Disable offload or run a follow-up proof for per-family replica groups.") + + mpu_tp = _mpu_world_size(mpu, "get_tensor_model_parallel_world_size", "get_model_parallel_world_size") + if mpu_tp not in (None, 1, spec.tp_size): + raise ValueError(f"mpu tensor/model parallel world size ({mpu_tp}) conflicts with " + f"tensor_parallel.autotp_size={spec.tp_size}.") + mpu_pp = _mpu_world_size(mpu, "get_pipeline_model_parallel_world_size", "get_pipeline_parallel_world_size") + if mpu_pp not in (None, spec.pp_size): + raise ValueError(f"mpu pipeline parallel world size ({mpu_pp}) conflicts with pp_size={spec.pp_size}.") + + +def _normalize_rank_groups(groups: Iterable[Iterable[int]]) -> set[tuple[int, ...]]: + return {tuple(int(rank) for rank in group) for group in groups} + + +def assert_group_matches_spec(existing_rank_lists, spec: ParallelFoldingSpec, *, group_kind: str = "ep_edp") -> None: + """Ensure cached ``ep_size_N`` rank lists match the requested folding spec.""" + tables = expected_folding_group_tables(spec) + expected_ep = _normalize_rank_groups(tables.ep_groups) + expected_edp = _normalize_rank_groups(tables.edp_groups) + + if isinstance(existing_rank_lists, dict): + observed_ep = existing_rank_lists.get("ep", []) + observed_edp = existing_rank_lists.get("edp", []) + else: + observed_ep, observed_edp = existing_rank_lists + + for group in _normalize_rank_groups(observed_ep): + if group not in expected_ep: + raise RuntimeError(f"Cached expert-parallel group {group} does not match folding spec {spec}.") + for group in _normalize_rank_groups(observed_edp): + if group not in expected_edp: + raise RuntimeError(f"Cached expert-data-parallel group {group} does not match folding spec {spec}.") diff --git a/deepspeed/module_inject/auto_ep_layer.py b/deepspeed/module_inject/auto_ep_layer.py index f4bba73d3d9e..e1f18633b85b 100644 --- a/deepspeed/module_inject/auto_ep_layer.py +++ b/deepspeed/module_inject/auto_ep_layer.py @@ -165,6 +165,35 @@ def compute_split_plan( ) +def compute_split_plan_from_expert_indices( + expert_indices: torch.Tensor, + num_experts: int, + ep_size: int, + num_local_experts: int, + ep_group: dist.ProcessGroup | None, +) -> SplitPlan: + """Compute EP AllToAllV splits for an already partitioned assignment list.""" + if ep_size == 1: + counts = count_tokens_per_expert(expert_indices, num_experts, out_dtype=torch.int32) + return SplitPlan([int(expert_indices.numel())], [int(expert_indices.numel())], counts, + counts.view(1, num_local_experts)) + + counts = count_tokens_per_expert(expert_indices, num_experts, out_dtype=torch.int32) + count_matrix = counts.view(ep_size, num_local_experts) + input_splits = count_matrix.sum(dim=1).cpu().tolist() + local_counts_tensor = count_matrix.sum(dim=1).clone() + remote_counts_tensor = torch.zeros_like(local_counts_tensor) + dist.all_to_all_single(remote_counts_tensor, local_counts_tensor, group=ep_group) + output_splits = remote_counts_tensor.cpu().tolist() + + local_expert_counts_flat = count_matrix.reshape(-1).contiguous() + received_counts_flat = torch.zeros_like(local_expert_counts_flat) + dist.all_to_all_single(received_counts_flat, local_expert_counts_flat, group=ep_group) + received_counts = received_counts_flat.view(ep_size, num_local_experts) + local_counts = received_counts.sum(dim=0) + return SplitPlan(input_splits, output_splits, local_counts, received_counts) + + class _AllToAllV(torch.autograd.Function): """Autograd-compatible all-to-all with variable split sizes.""" @@ -369,6 +398,8 @@ def __init__( self.hidden_size = spec.hidden_size self.ep_group_name = f"ep_size_{ep_size}" self.ep_group = None # Set by set_deepspeed_parallelism() + self.folding_group_handles = None + self.tp_group = None resolved_config = resolve_autoep_config_defaults(config, spec.model_family) # Router: copy gate weights from source @@ -508,11 +539,20 @@ def hook_fn(module, input, output): def set_deepspeed_parallelism( self, use_data_before_expert_parallel_: bool = False, + folding_group_handles=None, ) -> None: """Bind EP group handle to this module.""" from deepspeed.utils import groups from deepspeed.utils.bwc import bwc_pipeline_parallel_world_size + if folding_group_handles is not None: + self.folding_group_handles = folding_group_handles + self.ep_group_name = folding_group_handles.ep_group_name + self.ep_group = folding_group_handles.ep_group + self.tp_group = folding_group_handles.tp_group + self.ep_rank = dist.get_rank(group=self.ep_group) + return + if self.ep_group_name not in groups._get_expert_parallel_group_dict(): mp_size = max( getattr(groups, '_get_model_parallel_world_size', lambda: 1)(), @@ -554,10 +594,55 @@ def forward( # Reorder tokens by expert top_scores_sorted, token_indices_sorted, _ = self.reorderer(ro.top_scores, ro.selected_experts) + expert_indices_sorted = ro.selected_experts.reshape(-1).index_select(0, token_indices_sorted) + + folded_tp = self.folding_group_handles is not None and self.folding_group_handles.spec.tp_size > 1 + restore_ctx = None + if folded_tp: + from deepspeed.moe.ep_tp_dispatch import ( + RoutedAssignmentPayload, + assignment_ordinals_by_expert, + assert_tp_payload_consistent, + dispatch_counters, + partition_assignments, + restore_combined, + ) + payload = RoutedAssignmentPayload( + token_indices=(token_indices_sorted // self.top_k).to(torch.long), + expert_indices=expert_indices_sorted.to(torch.long), + assignment_indices=assignment_ordinals_by_expert(expert_indices_sorted.to(torch.long)), + capacity_slots=(token_indices_sorted % self.top_k).to(torch.long), + combine_weights=top_scores_sorted + if self.score_apply == "post" else torch.ones_like(top_scores_sorted), + drop_mask=torch.zeros_like(top_scores_sorted, dtype=torch.bool), + pad_mask=torch.zeros_like(top_scores_sorted, dtype=torch.bool), + input_splits=[0 for _ in range(self.ep_size)], + output_splits=[0 for _ in range(self.ep_size)], + extra={ + "destination_ranks": (expert_indices_sorted // self.num_local_experts).to(torch.long), + "top_scores": top_scores_sorted, + "num_tokens": torch.tensor(bsz * seqlen, device=hidden_states.device, dtype=torch.long), + }, + ) + assert_tp_payload_consistent(payload, + tp_group=self.tp_group, + tp_size=self.folding_group_handles.spec.tp_size) + tp_rank = dist.get_rank(group=self.tp_group) + local_payload, restore_ctx = partition_assignments(payload, + tp_group=self.tp_group, + tp_rank=tp_rank, + tp_size=self.folding_group_handles.spec.tp_size) + token_indices_for_compute = token_indices_sorted.index_select(0, restore_ctx.local_indices) + top_scores_for_compute = top_scores_sorted.index_select(0, restore_ctx.local_indices) + expert_indices_for_plan = local_payload.expert_indices + else: + token_indices_for_compute = token_indices_sorted + top_scores_for_compute = top_scores_sorted + expert_indices_for_plan = expert_indices_sorted - routed_input = x[token_indices_sorted // self.top_k] # [N, H] + routed_input = x[token_indices_for_compute // self.top_k] # [N, H] routed_input = apply_scores_before_experts_if_enabled(routed_input, - top_scores_sorted, + top_scores_for_compute, score_apply=self.score_apply) if self.ep_size == 1: @@ -574,13 +659,22 @@ def forward( expert_output = unpermute_by_local_expert(expert_output, perm_indices, n_tokens) else: # EP dispatch/compute/combine - plan = compute_split_plan( - selected_experts=ro.selected_experts, - num_experts=self.num_experts, - ep_size=self.ep_size, - num_local_experts=self.num_local_experts, - ep_group=self.ep_group, - ) + if folded_tp: + plan = compute_split_plan_from_expert_indices( + expert_indices=expert_indices_for_plan, + num_experts=self.num_experts, + ep_size=self.ep_size, + num_local_experts=self.num_local_experts, + ep_group=self.ep_group, + ) + else: + plan = compute_split_plan( + selected_experts=ro.selected_experts, + num_experts=self.num_experts, + ep_size=self.ep_size, + num_local_experts=self.num_local_experts, + ep_group=self.ep_group, + ) routed_input = _AllToAllV.apply(self.ep_group, routed_input, plan.input_splits, plan.output_splits) @@ -591,15 +685,19 @@ def forward( expert_output = _AllToAllV.apply(self.ep_group, expert_output, plan.output_splits, plan.input_splits) - output = combine_from_routed( - expert_output, - top_scores=ro.top_scores, - token_indices_sorted=token_indices_sorted, - top_k=self.top_k, - score_apply=self.score_apply, - combine_impl=self.combine_impl, - shape=(bsz, seqlen, hdim), - ) + if folded_tp: + output = restore_combined(expert_output, restore_ctx, tp_group=self.tp_group).reshape(bsz, seqlen, hdim) + self._last_folding_dispatch_counters = dispatch_counters(restore_ctx) + else: + output = combine_from_routed( + expert_output, + top_scores=ro.top_scores, + token_indices_sorted=token_indices_sorted, + top_k=self.top_k, + score_apply=self.score_apply, + combine_impl=self.combine_impl, + shape=(bsz, seqlen, hdim), + ) if self.moe_output_shape == "flat": output = output.reshape(-1, hdim) diff --git a/deepspeed/module_inject/auto_ep_presets/base.py b/deepspeed/module_inject/auto_ep_presets/base.py index 342e6ff1abb5..c023498109c9 100644 --- a/deepspeed/module_inject/auto_ep_presets/base.py +++ b/deepspeed/module_inject/auto_ep_presets/base.py @@ -98,6 +98,7 @@ class AutoEPConfig: enabled: bool = False autoep_size: int = 1 + expert_tensor_parallel_size: int = 1 preset_model: str | None = None moe_layer_pattern: str | None = None expert_pattern: str | None = None diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 4e47278e52c5..f26b10d8d670 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -548,9 +548,32 @@ def update_linear_policies(self): else: self.linear_policies = {nn.Linear: self._replace, nn.Embedding: self._slice_embedding} + def _replace_autoep_shared_experts(self, autoep_layer, autoep_name): + for child_name in ("shared_experts", "shared_experts_gate"): + child = getattr(autoep_layer, child_name, None) + if child is None: + continue + full_name = f"{autoep_name}.{child_name}" if autoep_name else child_name + if self.partition_config is not None and hasattr(child, "weight") and getattr( + child.weight, "dim", lambda: 0)() == 2: + new_child = self._replace_with_config(child, full_name) + if new_child is not None: + setattr(autoep_layer, child_name, new_child) + elif child.__class__ in self.linear_policies: + setattr(autoep_layer, child_name, self.linear_policies[child.__class__](child, full_name, + self.conv_linear_layer)) + elif any(isinstance(child, lp) for lp in self.linear_policies): + key = next(lp for lp in self.linear_policies if isinstance(child, lp)) + setattr(autoep_layer, child_name, self.linear_policies[key](child, full_name, self.conv_linear_layer)) + else: + self.update_mp_params(child) + self._replace_module(child, full_name, "") + def _replace_module(self, r_module, prev_name='', prev_class_name=''): for name, child in r_module.named_children(): if getattr(child, "_is_autoep_layer", False): + full_name = prev_name + '.' + name if prev_name else name + self._replace_autoep_shared_experts(child, full_name) continue if prev_class_name == "": diff --git a/deepspeed/moe/ep_tp_dispatch.py b/deepspeed/moe/ep_tp_dispatch.py new file mode 100644 index 000000000000..2b7b4df94ce1 --- /dev/null +++ b/deepspeed/moe/ep_tp_dispatch.py @@ -0,0 +1,306 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Route-full / partition-dispatch helpers for AutoEP + AutoTP folding.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import deepspeed.comm as dist + +_FOLDING_DIGEST_MOD_A = 2147483647 +_FOLDING_DIGEST_MOD_B = 2147483629 + + +@dataclass +class RoutedAssignmentPayload: + token_indices: torch.Tensor + expert_indices: torch.Tensor + assignment_indices: torch.Tensor + capacity_slots: torch.Tensor + combine_weights: torch.Tensor + drop_mask: torch.Tensor + pad_mask: torch.Tensor + input_splits: list[int] + output_splits: list[int] + extra: dict[str, torch.Tensor] + + +@dataclass +class RestoreContext: + original_payload: RoutedAssignmentPayload + local_indices: torch.Tensor + tp_rank: int + tp_size: int + num_tokens: int + counters: dict[str, int] + + +def assignment_ordinals_by_expert(expert_indices: torch.Tensor) -> torch.Tensor: + """Return stable ordinals within each contiguous expert segment.""" + if expert_indices.numel() == 0: + return expert_indices.to(torch.long) + positions = torch.arange(expert_indices.numel(), device=expert_indices.device, dtype=torch.long) + starts = torch.zeros_like(positions) + starts[0] = 0 + segment_start = torch.zeros(expert_indices.numel(), device=expert_indices.device, dtype=torch.bool) + segment_start[0] = True + segment_start[1:] = expert_indices[1:] != expert_indices[:-1] + starts = torch.where(segment_start, positions, starts) + starts = torch.cummax(starts, dim=0).values + return positions - starts + + +def _take(payload: RoutedAssignmentPayload, indices: torch.Tensor) -> RoutedAssignmentPayload: + extra = { + key: + value.index_select(0, indices) + if torch.is_tensor(value) and value.shape[:1] == payload.token_indices.shape[:1] else value + for key, value in payload.extra.items() + } + return RoutedAssignmentPayload( + token_indices=payload.token_indices.index_select(0, indices), + expert_indices=payload.expert_indices.index_select(0, indices), + assignment_indices=payload.assignment_indices.index_select(0, indices), + capacity_slots=payload.capacity_slots.index_select(0, indices), + combine_weights=payload.combine_weights.index_select(0, indices), + drop_mask=payload.drop_mask.index_select(0, indices), + pad_mask=payload.pad_mask.index_select(0, indices), + input_splits=list(payload.input_splits), + output_splits=list(payload.output_splits), + extra=extra, + ) + + +def _recompute_input_splits(payload: RoutedAssignmentPayload) -> list[int]: + destinations = payload.extra.get("destination_ranks") + if destinations is None: + return list(payload.input_splits) + if len(payload.input_splits) == 0: + return [] + counts = torch.bincount(destinations.to(torch.long), minlength=len(payload.input_splits)) + return [int(value) for value in counts[:len(payload.input_splits)].cpu().tolist()] + + +def _tensor_digest_words(tensor: torch.Tensor) -> torch.Tensor: + tensor = tensor.detach() + if tensor.is_floating_point(): + words = torch.nan_to_num(tensor.float(), nan=0.0, posinf=3.4028235e38, + neginf=-3.4028235e38).mul(1000003.0).round().to(torch.long) + else: + words = tensor.to(torch.long) + return words.reshape(-1) + + +def _digest_words(words: torch.Tensor, *, salt: int, modulus: int) -> torch.Tensor: + if words.numel() == 0: + return torch.tensor(salt, device=words.device, dtype=torch.long) + positions = torch.arange(1, words.numel() + 1, device=words.device, dtype=torch.long) + positions = positions.add_(salt).remainder_(modulus) + values = words.remainder(modulus) + return (values.mul(positions).remainder_(modulus).sum().add_(words.numel() * salt).remainder_(modulus)) + + +def _payload_digest(payload: RoutedAssignmentPayload) -> torch.Tensor: + device = payload.token_indices.device + active = (~payload.drop_mask & ~payload.pad_mask).to(torch.long) + digest = torch.tensor( + [payload.token_indices.numel(), + int(sum(payload.input_splits)), + int(sum(payload.output_splits)), 0, 0], + device=device, + dtype=torch.long) + fields = ( + payload.token_indices, + payload.expert_indices, + payload.assignment_indices, + payload.capacity_slots, + payload.combine_weights, + payload.drop_mask, + payload.pad_mask, + active, + payload.extra.get("destination_ranks", torch.empty(0, device=device, dtype=torch.long)), + ) + for index, field in enumerate(fields, start=1): + if not torch.is_tensor(field): + continue + words = _tensor_digest_words(field) + digest[3] = digest[3].add(_digest_words(words, salt=17 * index, + modulus=_FOLDING_DIGEST_MOD_A)).remainder_(_FOLDING_DIGEST_MOD_A) + digest[4] = digest[4].add(_digest_words(words, salt=31 * index, + modulus=_FOLDING_DIGEST_MOD_B)).remainder_(_FOLDING_DIGEST_MOD_B) + return digest + + +def assert_tp_payload_consistent(payload: RoutedAssignmentPayload, *, tp_group, tp_size: int) -> None: + if tp_size <= 1 or not dist.is_initialized(): + return + + digest = _payload_digest(payload) + max_digest = digest.clone() + min_digest = digest.clone() + dist.all_reduce(max_digest, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(min_digest, op=dist.ReduceOp.MIN, group=tp_group) + if not torch.equal(max_digest, min_digest): + raise RuntimeError("AutoEP+AutoTP routing decisions differ across tensor-parallel lanes. " + "Folded dispatch requires identical routed-token payloads before TP partitioning.") + + +def partition_assignments( + payload: RoutedAssignmentPayload, + *, + tp_group, + tp_rank: int, + tp_size: int, +) -> tuple[RoutedAssignmentPayload, RestoreContext]: + """Partition routed assignments across TP peers by stable per-expert ordinal.""" + active = ~payload.drop_mask & ~payload.pad_mask + if tp_size <= 1: + keep = active + else: + keep = (payload.assignment_indices.remainder(tp_size) == tp_rank) & active + local_indices = torch.nonzero(keep, as_tuple=False).flatten() + + local = _take(payload, local_indices) + local.input_splits = _recompute_input_splits(local) + local.output_splits = list(local.input_splits) + ctx = RestoreContext( + original_payload=payload, + local_indices=local_indices, + tp_rank=tp_rank, + tp_size=tp_size, + num_tokens=int(payload.extra.get("num_tokens", torch.tensor(0)).item()) if torch.is_tensor( + payload.extra.get("num_tokens")) else int(payload.extra.get("num_tokens", 0)), + counters={ + "assignments_total": int((~payload.drop_mask & ~payload.pad_mask).sum().item()), + "assignments_local": int(local_indices.numel()), + "padded": int(payload.pad_mask.sum().item()), + "dropped": int(payload.drop_mask.sum().item()), + "split_sum_in": int(sum(local.input_splits)), + "split_sum_out": int(sum(local.output_splits)), + }, + ) + return local, ctx + + +def _pad_rows(tensor: torch.Tensor, rows: int) -> torch.Tensor: + if tensor.shape[0] == rows: + return tensor + pad_shape = (rows - tensor.shape[0], *tensor.shape[1:]) + return torch.cat((tensor, tensor.new_zeros(pad_shape)), dim=0) + + +class _AllGatherVariableRows(torch.autograd.Function): + + @staticmethod + def forward(ctx, tensor, group, counts, max_rows): + ctx.group = group + ctx.counts = tuple(counts) + ctx.max_rows = max_rows + ctx.group_rank = dist.get_rank(group=group) + if max_rows == 0: + return tensor.new_empty((0, *tensor.shape[1:])) + padded = _pad_rows(tensor, max_rows) + gathered = [torch.zeros_like(padded) for _ in counts] + dist.all_gather(gathered, padded, group=group) + return torch.cat([chunk[:count] for chunk, count in zip(gathered, counts)], dim=0) + + @staticmethod + def backward(ctx, grad_output): + local_count = ctx.counts[ctx.group_rank] + if ctx.max_rows == 0: + return grad_output.new_empty((0, *grad_output.shape[1:])), None, None, None + chunks = torch.split(grad_output, ctx.counts, dim=0) + grad_padded = grad_output.new_zeros((ctx.max_rows, *grad_output.shape[1:])) + if local_count: + grad_padded[:local_count].copy_(chunks[ctx.group_rank]) + return grad_padded[:local_count].contiguous(), None, None, None + + +def _all_gather_variable_rows(tensor: torch.Tensor, + group, + tp_size: int, + *, + preserve_grad: bool = False) -> torch.Tensor: + if tp_size <= 1 or not dist.is_initialized(): + return tensor + + local_rows = torch.tensor([tensor.shape[0]], dtype=torch.long, device=tensor.device) + row_counts = [torch.zeros_like(local_rows) for _ in range(tp_size)] + dist.all_gather(row_counts, local_rows, group=group) + counts = [int(item.item()) for item in row_counts] + max_rows = max(counts) if counts else tensor.shape[0] + if preserve_grad: + return _AllGatherVariableRows.apply(tensor, group, tuple(counts), max_rows) + else: + padded = _pad_rows(tensor, max_rows) + gathered = [torch.zeros_like(padded) for _ in range(tp_size)] + dist.all_gather(gathered, padded, group=group) + return torch.cat([chunk[:count] for chunk, count in zip(gathered, counts)], dim=0) + + +def _debug_validate_restore_coverage(payload: RoutedAssignmentPayload, ctx: RestoreContext, + all_token_indices: torch.Tensor, all_expert_indices: torch.Tensor, + all_assignment_indices: torch.Tensor) -> None: + active = ~payload.drop_mask & ~payload.pad_mask + expected_rows = torch.stack(( + payload.token_indices[active].to(torch.long), + payload.expert_indices[active].to(torch.long), + payload.assignment_indices[active].to(torch.long), + ), + dim=1) + observed_rows = torch.stack(( + all_token_indices.to(torch.long), + all_expert_indices.to(torch.long), + all_assignment_indices.to(torch.long), + ), + dim=1) + if expected_rows.numel() == 0 and observed_rows.numel() == 0: + return + if observed_rows.shape[0] != expected_rows.shape[0]: + raise RuntimeError("AutoEP+AutoTP restore coverage mismatch: gathered assignment count " + f"{observed_rows.shape[0]} != expected {expected_rows.shape[0]}") + if observed_rows.shape[0] <= 4096: + expected = {tuple(row) for row in expected_rows.detach().cpu().tolist()} + observed = {tuple(row) for row in observed_rows.detach().cpu().tolist()} + if observed != expected: + missing = sorted(expected - observed)[:5] + duplicate_or_stale = sorted(observed - expected)[:5] + raise RuntimeError("AutoEP+AutoTP restore coverage mismatch: " + f"missing={missing} unexpected={duplicate_or_stale}") + + +def restore_combined(local_combined: torch.Tensor, ctx: RestoreContext, *, tp_group) -> torch.Tensor: + """Gather TP-partitioned assignment outputs and combine back by token index.""" + payload = ctx.original_payload + local_token_indices = payload.token_indices.index_select(0, ctx.local_indices) + local_expert_indices = payload.expert_indices.index_select(0, ctx.local_indices) + local_assignment_indices = payload.assignment_indices.index_select(0, ctx.local_indices) + local_weights = payload.combine_weights.index_select(0, ctx.local_indices).to(local_combined.dtype) + + all_outputs = _all_gather_variable_rows(local_combined, + tp_group, + ctx.tp_size, + preserve_grad=local_combined.requires_grad) + all_token_indices = _all_gather_variable_rows(local_token_indices, tp_group, ctx.tp_size).to(torch.long) + all_expert_indices = _all_gather_variable_rows(local_expert_indices, tp_group, ctx.tp_size).to(torch.long) + all_assignment_indices = _all_gather_variable_rows(local_assignment_indices, tp_group, ctx.tp_size).to(torch.long) + all_weights = _all_gather_variable_rows(local_weights, + tp_group, + ctx.tp_size, + preserve_grad=local_weights.requires_grad).to(local_combined.dtype) + _debug_validate_restore_coverage(payload, ctx, all_token_indices, all_expert_indices, all_assignment_indices) + + if ctx.num_tokens <= 0: + ctx.num_tokens = int(payload.token_indices.max().item()) + 1 if payload.token_indices.numel() else 0 + output = local_combined.new_zeros((ctx.num_tokens, local_combined.shape[-1])) + if all_outputs.numel() > 0: + output.index_add_(0, all_token_indices, all_outputs * all_weights.reshape(-1, 1)) + return output + + +def dispatch_counters(ctx: RestoreContext) -> dict[str, int]: + return dict(ctx.counters) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 7708999fcdf7..a0b49e01f0e9 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -24,7 +24,7 @@ import deepspeed from deepspeed import comm as dist -from deepspeed.runtime.utils import see_memory_usage, DummyOptim, register_output_backward_hooks, check_internal_apis_for_count_used_parameters +from deepspeed.runtime.utils import see_memory_usage, DummyOptim, register_output_backward_hooks, check_internal_apis_for_count_used_parameters, is_model_parallel_parameter from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer @@ -70,7 +70,7 @@ WEIGHT_QUANTIZE_ROUNDING, \ WEIGHT_QUANTIZE_VERBOSE, \ WEIGHT_QUANTIZE_KERNEL -from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS, UNIVERSAL_CHECKPOINT_INFO +from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS, UNIVERSAL_CHECKPOINT_INFO, FOLDING_METADATA_KEY from deepspeed.checkpoint.utils import clone_tensors_for_torch_save from deepspeed.checkpoint.ds_to_universal import dp_index_to_str from deepspeed.runtime.sparse_tensor import SparseTensor @@ -264,6 +264,8 @@ def __init__(self, self.scale_wrt_gas = None self.losses = None self.mesh_device = mesh_device + self._autoep_folding_spec = None + self._autoep_folding_group_handles = None # Flag to indicate that scale() was called before manual backward pass self._manual_backward_expected = False @@ -507,6 +509,7 @@ def _configure_expert_parallel(self, model): from deepspeed.module_inject.auto_ep import AutoEP from deepspeed.module_inject.auto_ep_config import validate_autoep_config, validate_autoep_post_detection + from deepspeed.module_inject.auto_ep_folding import build_folding_spec, validate_folding_global ep_size = autoep_config.autoep_size tp_size = self.autotp_size() @@ -523,6 +526,27 @@ def _configure_expert_parallel(self, model): world_size = dist.get_world_size() validate_autoep_config(autoep_config, world_size, pp_size, tp_size, sp_size) + folding_spec = build_folding_spec( + world_size=world_size, + pp_size=pp_size, + tp_size=max(tp_size, 1), + ep_size=ep_size, + etp_size=autoep_config.expert_tensor_parallel_size, + mp_mode="tp" if tp_size > 1 else "sp", + ) + validate_folding_global( + folding_spec, + zero_stage=self.zero_optimization_stage(), + sp_size=sp_size, + use_data_before_expert_parallel=self._config.use_data_before_expert_parallel_, + mpu=self.mpu, + autoep_enabled=autoep_config.enabled, + tp_preset=getattr(self._config.tensor_parallel_config, "preset_model", None), + ep_preset=autoep_config.preset_model, + zero_offload_optimizer=self.zero_offload_optimizer() is not None, + zero_offload_param=self.zero_offload_param() is not None, + ) + self._autoep_folding_spec = folding_spec # Create EP/EDP process groups mp_size = max(tp_size, sp_size, 1) @@ -533,6 +557,7 @@ def _configure_expert_parallel(self, model): pp_size=pp_size, mp_mode=mp_mode, use_data_before_expert_parallel_=self._config.use_data_before_expert_parallel_, + folding_spec=folding_spec if tp_size > 1 else None, ) # Derive EP rank @@ -1556,9 +1581,37 @@ def _configure_distributed_model(self, model): if self.mpu is not None: groups.mpu = self.mpu + folding_group_handles = None + try: + from deepspeed.module_inject.auto_ep_folding import FoldingGroupHandles, local_folding_ranks + except ImportError: + FoldingGroupHandles = None + local_folding_ranks = None + if (FoldingGroupHandles is not None and self._autoep_folding_spec is not None + and self._autoep_folding_spec.tp_size > 1): + ep_group_name = f"ep_size_{self._autoep_folding_spec.ep_size}" + rank = dist.get_rank() + local_ranks = local_folding_ranks(rank, self._autoep_folding_spec) + folding_group_handles = FoldingGroupHandles( + spec=self._autoep_folding_spec, + tp_group=groups.get_tensor_model_parallel_group(), + dense_dp_group=groups._get_data_parallel_group(), + ep_group=groups._get_expert_parallel_group(ep_group_name), + edp_group=groups._get_expert_data_parallel_group(ep_group_name), + ep_group_name=ep_group_name, + tp_ranks=local_ranks["tp"], + dense_dp_ranks=local_ranks["dense_dp"], + ep_ranks=local_ranks["ep"], + edp_ranks=local_ranks["edp"], + ) + self._autoep_folding_group_handles = folding_group_handles + # Set deepspeed parallelism spec. for the model including expert parallelism for _, module in self.module.named_modules(): - if hasattr(module, 'set_deepspeed_parallelism'): + if _AutoEPMoELayer is not None and isinstance(module, _AutoEPMoELayer): + module.set_deepspeed_parallelism(self._config.use_data_before_expert_parallel_, + folding_group_handles=folding_group_handles) + elif hasattr(module, 'set_deepspeed_parallelism'): module.set_deepspeed_parallelism(self._config.use_data_before_expert_parallel_) # Query the groups module to get information about various parallel groups @@ -1708,11 +1761,18 @@ def _configure_optimizer(self, client_optimizer, model_parameters): else: self.optimizer = basic_optimizer + self._configure_autoep_folding_optimizer_gradient_reduction() log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer.__class__.__name__), ranks=[0]) self.compression_scheduler = self._configure_compression_scheduler() self.quantizer = self._configure_quantization() + def _configure_autoep_folding_optimizer_gradient_reduction(self): + configure = getattr(self.optimizer, "configure_autoep_folding_tp_gradient_reduction", None) + if configure is None: + return + configure(getattr(self, "_autoep_folding_spec", None)) + def _configure_basic_optimizer(self, model_parameters): # Copy so the pop() calls below (torch_adam, adam_w_mode, fp32_optimizer_states) do not # mutate the shared config dict returned by optimizer_params(). @@ -2515,6 +2575,8 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): # Pass (PP) gas boundary flag to optimizer (required for zero) self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() + if self.is_gradient_accumulation_boundary(): + self._reduce_autoep_folding_tp_replicated_gradients() # ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well if self.zero_optimization_partition_gradients(): self.optimizer.overlapping_partition_gradients_reduce_epilogue() @@ -2530,6 +2592,38 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): elif self.zenflow: self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism) + def _reduce_autoep_folding_tp_replicated_gradients(self): + folding_spec = getattr(self, "_autoep_folding_spec", None) + if folding_spec is None or folding_spec.tp_size <= 1 or not dist.is_initialized(): + return + if (isinstance(self.optimizer, ZeROOptimizer) and getattr(self.optimizer, "partition_gradients", False) + and getattr(self.optimizer, "autoep_folding_tp_group", None) is not None): + return + tp_group = groups.get_tensor_model_parallel_group() + if tp_group is None: + return + + # TP and SP folding modes produce disjoint per-lane router/shared partials. + # Duplicated-token modes already hold a full replicated gradient per lane. + partitioned_grad_mode = getattr(folding_spec, "mp_mode", "tp") in ("tp", "sp") + tp_world_size = dist.get_world_size(group=tp_group) + for _, param in self.module.named_parameters(): + if not param.requires_grad or param.grad is None: + continue + if is_moe_param(param) or is_model_parallel_parameter(param): + continue + if param.grad.data.is_sparse: + continue + grad = param.grad.data + if partitioned_grad_mode and grad.dtype != torch.float32: + reduced = grad.float() + dist.all_reduce(reduced, group=tp_group) + grad.copy_(reduced.to(grad.dtype)) + continue + dist.all_reduce(grad, group=tp_group) + if not partitioned_grad_mode: + grad.div_(tp_world_size) + def _backward_prologue(self): if is_functorch_transforming(): return @@ -3438,6 +3532,112 @@ def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclud sd = remove_random_ltd_state_dict(sd) return sd + @staticmethod + def _make_autoep_folding_metadata(folding_spec, + *, + family, + ep_rank, + zero_partition_group, + zero_partition_rank, + zero_partition_count, + param_families=None): + from deepspeed.checkpoint.autoep_universal import make_folding_metadata + + return make_folding_metadata(tp_size=folding_spec.tp_size, + tp_rank=groups.get_tensor_model_parallel_rank(), + ep_size=folding_spec.ep_size, + ep_rank=ep_rank, + zero_partition_group=zero_partition_group, + zero_partition_rank=zero_partition_rank, + zero_partition_count=zero_partition_count, + family=family, + param_families=param_families) + + @staticmethod + def _autoep_non_expert_param_families(state_dict): + families = {} + for key in state_dict.keys(): + if ".router." in key: + families[key] = "router_gate_replicated" + elif ".shared_experts" in key: + families[key] = "shared_expert" + else: + families[key] = "dense" + return families + + @staticmethod + def _autoep_param_family(param_name): + if ".experts." in param_name: + return "routed_expert" + if ".router." in param_name: + return "router_gate_replicated" + if ".shared_experts" in param_name: + return "shared_expert" + return "dense" + + def _autoep_zero_optimizer_param_families(self): + optimizer = self.optimizer + real_dp_groups = getattr(optimizer, "real_dp_process_group", []) + partition_counts = getattr(optimizer, "partition_count", []) + param_families = {} + for group_idx, param_shapes in enumerate(self._get_zero_param_shapes()): + process_group = real_dp_groups[group_idx] if group_idx < len( + real_dp_groups) else optimizer.dp_process_group + partition_count = (partition_counts[group_idx] + if group_idx < len(partition_counts) else dist.get_world_size(group=process_group)) + partition_rank = dist.get_rank(group=process_group) + for param_name in param_shapes.keys(): + family = DeepSpeedEngine._autoep_param_family(param_name) + zero_partition_group = "edp" if family == "routed_expert" else "dense_dp" + param_families[param_name] = { + "family": family, + "zero_partition_group": zero_partition_group, + "zero_partition_rank": partition_rank, + "zero_partition_count": partition_count, + } + return param_families + + @staticmethod + def _validate_autoep_folding_checkpoint_metadata(state, + *, + folding_spec, + family, + zero_partition_group, + zero_partition_count, + tp_rank=None, + ep_rank=None, + zero_partition_rank=None, + param_families=None, + require_when_folded=True): + has_metadata = isinstance(state, dict) and FOLDING_METADATA_KEY in state + folded_runtime = folding_spec is not None and folding_spec.tp_size > 1 + if has_metadata and not folded_runtime: + raise RuntimeError("Folded AutoEP+AutoTP checkpoint requires a folded runtime with matching " + "tensor_parallel.autotp_size and expert_parallel.autoep_size.") + if not folded_runtime: + return + if require_when_folded and not has_metadata: + raise RuntimeError("Missing AutoEP+AutoTP folding metadata in folded checkpoint.") + if not has_metadata: + return + + from deepspeed.checkpoint.autoep_universal import validate_folding_metadata + + validate_folding_metadata(state, + tp_size=folding_spec.tp_size, + ep_size=folding_spec.ep_size, + etp_size=folding_spec.etp_size, + etp_rank=0, + tp_rank=tp_rank, + ep_rank=ep_rank, + zero_partition_group=zero_partition_group, + zero_partition_rank=zero_partition_rank, + zero_partition_count=zero_partition_count, + param_families=param_families, + family=family, + shared_expert_placement="tp_sharded", + dispatch_strategy="route_full_partition_dispatch") + @staticmethod def load_moe_state_dict(checkpoint_path, tag, @@ -3447,7 +3647,8 @@ def load_moe_state_dict(checkpoint_path, mpu=None, num_experts=1, checkpoint_engine=TorchCheckpointEngine(), - autoep_layers=None): + autoep_layers=None, + folding_spec=None): try: from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer as _AutoEPMoELayer except ImportError: @@ -3455,6 +3656,7 @@ def load_moe_state_dict(checkpoint_path, has_autoep_layers = _AutoEPMoELayer is not None and model is not None and any( isinstance(m, _AutoEPMoELayer) for _, m in model.named_modules()) + folded_autoep_tp = folding_spec is not None and folding_spec.tp_size > 1 if old_moe_load: if has_autoep_layers: @@ -3534,6 +3736,7 @@ def load_moe_state_dict(checkpoint_path, group_name = module.ep_group_name num_local_experts = module.num_local_experts expp_rank = groups._get_expert_parallel_rank(group_name) + exp_dp_rank = groups._get_expert_data_parallel_rank(group_name) module_prefix = f"{n_module}." if n_module else "" # Collect per-expert tensors to stack @@ -3547,6 +3750,14 @@ def load_moe_state_dict(checkpoint_path, raise FileNotFoundError(f"Expert checkpoint file not found: {expert_ckpt_path}. " f"Expected layer_{moe_layer_id} expert_{global_expert_id}.") expert_sd = checkpoint_engine.load(expert_ckpt_path, map_location=torch.device('cpu')) + DeepSpeedEngine._validate_autoep_folding_checkpoint_metadata( + expert_sd, + folding_spec=folding_spec, + family="routed_expert", + zero_partition_group="edp", + zero_partition_count=folding_spec.edp_size if folded_autoep_tp else None, + tp_rank=groups.get_tensor_model_parallel_rank() if folded_autoep_tp else None, + ep_rank=expp_rank if folded_autoep_tp else None) for wname in ('w1', 'w2', 'w3'): fused_key = f"{module_prefix}experts.{wname}" @@ -3788,6 +3999,17 @@ def _load_checkpoint(self, if checkpoint is None: return None, None + folding_spec = getattr(self, "_autoep_folding_spec", None) + folded_autoep_tp = folding_spec is not None and folding_spec.tp_size > 1 + ep_group_name = f"ep_size_{folding_spec.ep_size}" if folded_autoep_tp else None + DeepSpeedEngine._validate_autoep_folding_checkpoint_metadata( + checkpoint, + folding_spec=folding_spec, + family="dense", + zero_partition_group="dense_dp", + zero_partition_count=folding_spec.dp_size if folded_autoep_tp else None, + tp_rank=groups.get_tensor_model_parallel_rank() if folded_autoep_tp else None) + fetch_z3_params = False if self.zero_optimization_partition_weights() and not load_optimizer_states: checkpoint['module'] = get_fp32_state_dict_from_zero_checkpoint(load_dir) @@ -3816,7 +4038,8 @@ def _load_checkpoint(self, mpu=self.mpu, num_experts=self.num_experts, checkpoint_engine=self.checkpoint_engine, - autoep_layers=autoep_layers) + autoep_layers=autoep_layers, + folding_spec=folding_spec) if not self.load_universal_checkpoint(): self.load_module_state_dict(checkpoint=checkpoint, strict=load_module_strict, @@ -3836,8 +4059,17 @@ def _load_checkpoint(self, if self.has_moe_layers: largest_group_name = groups._get_max_expert_size_name() expp_rank = groups._get_expert_parallel_rank(largest_group_name) + exp_dp_rank = groups._get_expert_data_parallel_rank(largest_group_name) optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank) optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu')) + DeepSpeedEngine._validate_autoep_folding_checkpoint_metadata( + optim_checkpoint, + folding_spec=folding_spec, + family="routed_expert", + zero_partition_group="edp", + zero_partition_count=folding_spec.edp_size if folded_autoep_tp else None, + tp_rank=groups.get_tensor_model_parallel_rank() if folded_autoep_tp else None, + ep_rank=expp_rank if folded_autoep_tp else None) else: optim_checkpoint = checkpoint @@ -3845,9 +4077,7 @@ def _load_checkpoint(self, self.optimizer.load_state_dict(optim_checkpoint['optimizer'], load_optimizer_states=load_optimizer_states) else: - optim_checkpoint = checkpoint - - self.optimizer.load_state_dict(optim_checkpoint['optimizer']) + self.optimizer.load_state_dict(optim_checkpoint['optimizer']) if load_lr_scheduler_states and self.lr_scheduler is not None: self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) @@ -3981,6 +4211,13 @@ def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode): def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names): zero_sd_list = [] + folding_spec = getattr(self, "_autoep_folding_spec", None) + folded_autoep_tp = folding_spec is not None and folding_spec.tp_size > 1 + zero_partition_count = None + zero_param_families = None + if folded_autoep_tp: + zero_partition_count = dist.get_world_size(group=self.optimizer.dp_process_group) + zero_param_families = self._autoep_zero_optimizer_param_families() for i, ckpt_name in enumerate(zero_ckpt_names): _state = None if ckpt_name is None: @@ -3993,6 +4230,16 @@ def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names): ) else: _state = {OPTIMIZER_STATE_DICT: None} + if _state.get(OPTIMIZER_STATE_DICT) is not None or FOLDING_METADATA_KEY in _state: + DeepSpeedEngine._validate_autoep_folding_checkpoint_metadata( + _state, + folding_spec=folding_spec, + family="zero_optimizer_state", + zero_partition_group="per_family", + zero_partition_count=zero_partition_count, + tp_rank=groups.get_tensor_model_parallel_rank() if folded_autoep_tp else None, + zero_partition_rank=i if folded_autoep_tp else None, + param_families=zero_param_families) zero_sd_list.append(_state) zero_optimizer_sd = [sd[OPTIMIZER_STATE_DICT] for sd in zero_sd_list] @@ -4182,6 +4429,31 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa except ImportError: _AutoEPMoELayer = None + folding_spec = getattr(self, "_autoep_folding_spec", None) + folded_autoep_tp = folding_spec is not None and folding_spec.tp_size > 1 + + def folding_metadata(*, + family, + ep_rank, + zero_partition_group, + zero_partition_rank, + zero_partition_count, + param_families=None): + if not folded_autoep_tp: + return None + return DeepSpeedEngine._make_autoep_folding_metadata(folding_spec, + family=family, + ep_rank=ep_rank, + zero_partition_group=zero_partition_group, + zero_partition_rank=zero_partition_rank, + zero_partition_count=zero_partition_count, + param_families=param_families) + + def autoep_expert_writer() -> bool: + if folded_autoep_tp: + return groups._get_data_parallel_rank() < folding_spec.ep_size + return self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank) + # A hack to save the checkpointing directory. Pipeline parallelism overrides # module_state_dict() and uses this path to save the model. module_state_dict() # then instead just returns None. @@ -4262,8 +4534,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa f"multiple groups: {sorted(autoep_group_names)}. " f"All AutoEPMoELayer instances must use the same ep_size.") - # Gate file writes behind writer guard - if not self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank): + # Gate file writes behind writer guard. Folded AutoEP+AutoTP needs + # one expert shard per (TP rank, EP rank), not only mp_rank_00. + if not autoep_expert_writer(): moe_layer_id += 1 continue @@ -4276,6 +4549,14 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa param = getattr(module.experts, wname) expert_state_dict[f"{fused_key}.{global_expert_id}"] = ( param[local_expert_id].clone().detach()) + if folded_autoep_tp: + expert_state_dict[FOLDING_METADATA_KEY] = folding_metadata( + family="routed_expert", + ep_rank=expp_rank, + zero_partition_group="edp", + zero_partition_rank=exp_dp_rank, + zero_partition_count=folding_spec.edp_size, + ) moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu) saveable = expert_state_dict @@ -4290,23 +4571,32 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa largest_group_name = groups._get_max_expert_size_name() expp_rank = groups._get_expert_parallel_rank(largest_group_name) exp_dp_rank = groups._get_expert_data_parallel_rank(largest_group_name) + expert_checkpoint_writer = (groups._get_data_parallel_rank() < folding_spec.ep_size if folded_autoep_tp else + self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank)) # In the case of E + D parallelism, only the # first expert parallel group should save the expert weights # since each expert parallel group is a copy of the model's experts - if not self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank): + if not expert_checkpoint_writer and not folded_autoep_tp: return # Save optimizer states. They are different across each exp parallel rank. - optimizer_state = { - 'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None - } - # TODO: why use BufferedWriter not the path - file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank) - saveable_state_dict = optimizer_state - if self.checkpoint_engine.preserves_storage_sharing(): - saveable_state_dict = clone_tensors_for_torch_save(optimizer_state) - self.checkpoint_engine.save(saveable_state_dict, file_path) + if expert_checkpoint_writer: + optimizer_state = { + 'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None + } + if folded_autoep_tp: + optimizer_state[FOLDING_METADATA_KEY] = folding_metadata(family="routed_expert", + ep_rank=expp_rank, + zero_partition_group="edp", + zero_partition_rank=exp_dp_rank, + zero_partition_count=folding_spec.edp_size) + # TODO: why use BufferedWriter not the path + file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank) + saveable_state_dict = optimizer_state + if self.checkpoint_engine.preserves_storage_sharing(): + saveable_state_dict = clone_tensors_for_torch_save(optimizer_state) + self.checkpoint_engine.save(saveable_state_dict, file_path) # Load flow uses below saved file for model parameters, RNG and more if groups._get_data_parallel_rank() == 0: @@ -4345,8 +4635,17 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa 'ds_autoep_layers': autoep_layer_info if autoep_layer_info else None, } + if folded_autoep_tp: + ep_group_name = f"ep_size_{folding_spec.ep_size}" + state[FOLDING_METADATA_KEY] = folding_metadata( + family="dense", + ep_rank=groups._get_expert_parallel_rank(ep_group_name), + zero_partition_group="dense_dp", + zero_partition_rank=groups._get_data_parallel_rank(), + zero_partition_count=folding_spec.dp_size, + param_families=DeepSpeedEngine._autoep_non_expert_param_families(model_state_dict)) # Check for reserved-key collisions with client_state - reserved_keys = {'ds_autoep_layers', 'autoep_layers'} + reserved_keys = {'ds_autoep_layers', 'autoep_layers', FOLDING_METADATA_KEY} collisions = reserved_keys.intersection(client_state.keys()) if collisions: raise KeyError(f"client_state contains reserved checkpoint keys: {sorted(collisions)}. " @@ -4569,6 +4868,18 @@ def _change_recovery_script_permissions(self, dst): def _save_zero_checkpoint(self, save_path, tag): zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag) zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), ds_config=self.config, ds_version=version) + folding_spec = getattr(self, "_autoep_folding_spec", None) + if folding_spec is not None and folding_spec.tp_size > 1: + ep_group_name = f"ep_size_{folding_spec.ep_size}" + zero_sd[FOLDING_METADATA_KEY] = DeepSpeedEngine._make_autoep_folding_metadata( + folding_spec, + family="zero_optimizer_state", + ep_rank=groups._get_expert_parallel_rank(ep_group_name), + zero_partition_group="per_family", + zero_partition_rank=dist.get_rank(group=self.optimizer.dp_process_group), + zero_partition_count=dist.get_world_size(group=self.optimizer.dp_process_group), + param_families=self._autoep_zero_optimizer_param_families(), + ) self.checkpoint_engine.save(zero_sd, zero_checkpoint_name) if self.global_rank == 0: diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 89a45aa0fa41..65907121d288 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -249,6 +249,8 @@ def __init__(self, self.contiguous_gradients = contiguous_gradients or self.cpu_offload self.has_moe_layers = has_moe_layers + self.autoep_folding_tp_group = None + self.autoep_folding_partitioned_grad_mode = False if self.has_moe_layers: self._configure_moe_settings() self._global_grad_norm = 0. @@ -1018,6 +1020,31 @@ def increment_value(dictionary, key): def overlapping_partition_gradients_reduce_epilogue(self): self.independent_gradient_partition_epilogue() + def configure_autoep_folding_tp_gradient_reduction(self, folding_spec): + if folding_spec is None or folding_spec.tp_size <= 1: + self.autoep_folding_tp_group = None + self.autoep_folding_partitioned_grad_mode = False + return + self.autoep_folding_tp_group = groups.get_tensor_model_parallel_group() + self.autoep_folding_partitioned_grad_mode = getattr(folding_spec, "mp_mode", "tp") in ("tp", "sp") + + def _maybe_reduce_autoep_folding_tp_gradient(self, param, grad): + if not self.partition_gradients or self.autoep_folding_tp_group is None or grad is None: + return + if is_moe_param(param) or is_model_parallel_parameter(param): + return + if grad.data.is_sparse: + return + grad_data = grad.data + if self.autoep_folding_partitioned_grad_mode and grad_data.dtype != torch.float32: + reduced = grad_data.float() + dist.all_reduce(reduced, group=self.autoep_folding_tp_group) + grad_data.copy_(reduced.to(grad_data.dtype)) + return + dist.all_reduce(grad_data, group=self.autoep_folding_tp_group) + if not self.autoep_folding_partitioned_grad_mode: + grad_data.div_(dist.get_world_size(group=self.autoep_folding_tp_group)) + def _fill_param_grad_accum_attribute(self, param): if param.grad is not None: if param.grad_accum is None: @@ -1091,6 +1118,7 @@ def flatten_dense_tensors_aligned(self, tensor_list, alignment, use_cpu_data=Fal def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): grad_reduc = self.get_gradient_for_reduction(param) + self._maybe_reduce_autoep_folding_tp_gradient(param, grad_reduc) comm_dtype = self.get_param_comm_dtype(param) bucket = self.ipg_buckets[comm_dtype] if bucket.elements + param.numel() > self.reduce_bucket_size: diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py index d912625c544b..7a0fc7eb6fa4 100644 --- a/deepspeed/utils/groups.py +++ b/deepspeed/utils/groups.py @@ -241,7 +241,8 @@ def _create_expert_and_data_parallel(expert_parallel_size_, mp_size=None, pp_size=None, mp_mode="tp", - use_data_before_expert_parallel_=False): + use_data_before_expert_parallel_=False, + folding_spec=None): """Create expert and data parallel groups. When mp_size is None or 1: legacy consecutive ordering (backward compatible). @@ -342,6 +343,15 @@ def _create_expert_and_data_parallel(expert_parallel_size_, raise NotImplementedError("use_data_before_expert_parallel_ is not supported with mp_size > 1") if group_name in _EXPERT_PARALLEL_GROUP: + if folding_spec is not None: + from deepspeed.module_inject.auto_ep_folding import assert_group_matches_spec + assert_group_matches_spec( + { + "ep": [_EXPERT_PARALLEL_GROUP_RANKS[group_name]], + "edp": [_EXPERT_DATA_PARALLEL_GROUP_RANKS[group_name]], + }, + folding_spec, + ) return # Already created for pp_stage_start in range(0, world_size, pp_stride): diff --git a/tests/unit/v1/moe/autoep_test_utils.py b/tests/unit/v1/moe/autoep_test_utils.py index 8fa555e15c98..dfa6572c9297 100644 --- a/tests/unit/v1/moe/autoep_test_utils.py +++ b/tests/unit/v1/moe/autoep_test_utils.py @@ -3,16 +3,34 @@ """Shared fixtures and assertions for compact AutoEP tests.""" import copy +import os +import tempfile +import traceback +from queue import Empty import deepspeed +import deepspeed.comm as dist import pytest import torch import torch.nn as nn +import torch.multiprocessing as mp -from deepspeed.accelerator import get_accelerator +from deepspeed.accelerator import get_accelerator, set_accelerator +from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator +from unit.common import DEEPSPEED_TEST_TIMEOUT, get_master_port UNSET = object() UNSUPPORTED_LOAD_BALANCE_VALUES = [0, 0.0, 1e-3, 0.02, False, True, "1e-3", [1e-3], {"coeff": 1e-3}] +H100_TEST_ENV_VARS = ("DEEPSPEED_RUN_H100_TESTS", "DEVDS_RUN_H100_TESTS") + + +def h100_tests_enabled(): + return any(os.environ.get(name) for name in H100_TEST_ENV_VARS) + + +def skip_unless_h100_tests_enabled(reason): + if not h100_tests_enabled(): + pytest.skip(f"{reason}; set DEEPSPEED_RUN_H100_TESTS=1 or DEVDS_RUN_H100_TESTS=1") class MockHFConfig: @@ -114,6 +132,44 @@ def forward(self, x): return self.lm_head(x) +class MockMoEOnlyTransformer(nn.Module): + + def __init__(self, num_layers=2, num_experts=4, hidden_size=64, intermediate_size=128, moe_every_n=1): + super().__init__() + self.config = MockHFConfig() + self.config.num_local_experts = num_experts + self.config.hidden_size = hidden_size + self.config.intermediate_size = intermediate_size + self.model = nn.Module() + self.model.layers = nn.ModuleList([ + self._make_layer(layer_idx, num_experts, hidden_size, intermediate_size, moe_every_n) + for layer_idx in range(num_layers) + ]) + self.lm_head = nn.Linear(hidden_size, 100, bias=False) + + @staticmethod + def _make_layer(layer_idx, num_experts, hidden_size, intermediate_size, moe_every_n): + layer = nn.Module() + layer.dense = nn.Linear(hidden_size, hidden_size, bias=False) + if layer_idx % moe_every_n == 0: + layer.mlp = MockMoEBlock(num_experts, intermediate_size, hidden_size) + else: + layer.mlp = MockDenseBlock(hidden_size, intermediate_size) + layer.input_layernorm = nn.LayerNorm(hidden_size) + layer.post_attention_layernorm = nn.LayerNorm(hidden_size) + return layer + + def forward(self, x): + for layer_module in self.model.layers: + residual = x + x = layer_module.input_layernorm(x) + x = residual + layer_module.dense(x) + residual = x + x = layer_module.post_attention_layernorm(x) + x = residual + layer_module.mlp(x) + return self.lm_head(x) + + def assert_load_balance_coeff_rejection_message(exc: BaseException, value: object) -> None: text = str(exc) for needle in ("load_balance_coeff", "expert_bias", "not supported", "null", "omit"): @@ -152,7 +208,7 @@ def make_autoep_config(zero_stage=0, ep_size=1, load_balance_coeff=UNSET, mixed_ }, } if get_accelerator().device_name() == "cpu": - config["optimizer"]["torch_adam"] = True + config["optimizer"]["params"]["torch_adam"] = True if mixed_precision: config.update(mixed_precision_config()) if load_balance_coeff is not UNSET: @@ -291,3 +347,59 @@ def tiny_mixtral_config(transformers): tie_word_embeddings=False, use_cache=False, ) + + +def _cpu_gloo_worker_entry(rank, world_size, init_method, master_port, worker, shared_tmpdir, error_queue): + set_accelerator(CPU_Accelerator()) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = master_port + os.environ["LOCAL_RANK"] = str(rank) + os.environ["RANK"] = str(rank) + os.environ["LOCAL_SIZE"] = str(world_size) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ.pop("NCCL_DEBUG", None) + + try: + deepspeed.init_distributed(dist_backend="gloo", init_method=init_method, rank=rank, world_size=world_size) + worker(rank, world_size, shared_tmpdir) + except BaseException: + error_queue.put(traceback.format_exc()) + raise + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def run_cpu_gloo_test(worker, tmpdir, *, world_size=4, timeout=DEEPSPEED_TEST_TIMEOUT): + """Run a small CPU/Gloo distributed test without requiring visible GPU devices.""" + ctx = mp.get_context("spawn") + error_queue = ctx.Queue() + with tempfile.NamedTemporaryFile(delete=False, dir=str(tmpdir), suffix="_filestore") as fp: + init_method = f"file://{fp.name}" + master_port = get_master_port() + shared_tmpdir = str(tmpdir) + processes = [ + ctx.Process(target=_cpu_gloo_worker_entry, + args=(rank, world_size, init_method, master_port, worker, shared_tmpdir, error_queue)) + for rank in range(world_size) + ] + for process in processes: + process.start() + for process in processes: + process.join(timeout) + for process in processes: + if process.is_alive(): + process.terminate() + pytest.fail(f"CPU/Gloo worker {process.pid} timed out after {timeout}s", pytrace=False) + errors = [] + while True: + try: + errors.append(error_queue.get_nowait()) + except Empty: + break + failed = [process for process in processes if process.exitcode] + if errors: + pytest.fail("\n".join(errors), pytrace=False) + if failed: + pytest.fail("CPU/Gloo worker failures: " + ", ".join(str(process.exitcode) for process in failed), + pytrace=False) diff --git a/tests/unit/v1/moe/test_autoep_autotp_checkpoint.py b/tests/unit/v1/moe/test_autoep_autotp_checkpoint.py new file mode 100644 index 000000000000..86a7660661a1 --- /dev/null +++ b/tests/unit/v1/moe/test_autoep_autotp_checkpoint.py @@ -0,0 +1,309 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Checkpoint metadata tests for AutoEP + AutoTP folding.""" + +import glob +import os + +import pytest +import torch + +import deepspeed +import deepspeed.comm as dist +from deepspeed.checkpoint.autoep_universal import ( + consolidate_autoep_expert_files, + make_folding_metadata, + validate_folding_metadata, +) +from deepspeed.checkpoint.constants import ( + FOLDING_DISPATCH_STRATEGY, + FOLDING_EP_SIZE, + FOLDING_ETP_RANK, + FOLDING_ETP_SIZE, + FOLDING_FAMILY, + FOLDING_METADATA_KEY, + FOLDING_PARAM_FAMILIES, + FOLDING_SHARED_EXPERT_PLACEMENT, + FOLDING_TP_SIZE, + FOLDING_ZERO_PARTITION_GROUP, + FOLDING_ZERO_PARTITION_COUNT, +) +from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer +from deepspeed.runtime.engine import DeepSpeedEngine +from deepspeed.utils import groups +from unit.common import DistributedTest +from unit.v1.moe.autoep_test_utils import ( + MockMoEOnlyTransformer, + engine_input_dtype, + make_autoep_config, + run_cpu_gloo_test, + seed_everything, + skip_unless_h100_tests_enabled, +) + + +def test_folding_metadata_records_all_mvp_axes(): + metadata = make_folding_metadata(tp_size=2, + tp_rank=1, + ep_size=4, + ep_rank=3, + zero_partition_group="edp", + zero_partition_rank=1, + zero_partition_count=2, + family="routed_expert") + folding = metadata + + assert folding["version"] == 1 + assert folding[FOLDING_TP_SIZE] == 2 + assert folding[FOLDING_EP_SIZE] == 4 + assert folding[FOLDING_ETP_SIZE] == 1 + assert folding[FOLDING_ETP_RANK] == 0 + assert folding[FOLDING_ZERO_PARTITION_GROUP] == "edp" + assert folding[FOLDING_DISPATCH_STRATEGY] == "route_full_partition_dispatch" + assert folding[FOLDING_SHARED_EXPERT_PLACEMENT] == "tp_sharded" + assert folding[FOLDING_FAMILY] == "routed_expert" + + +def test_validate_folding_metadata_rejects_missing_and_mismatched_topology(): + folding = make_folding_metadata(tp_size=2, + tp_rank=0, + ep_size=4, + ep_rank=0, + zero_partition_group="dense_dp", + zero_partition_rank=0, + zero_partition_count=4, + family="dense") + wrapped = {FOLDING_METADATA_KEY: folding} + + assert validate_folding_metadata(wrapped, + tp_size=2, + ep_size=4, + zero_partition_group="dense_dp", + zero_partition_count=4, + family="dense", + shared_expert_placement="tp_sharded")[FOLDING_TP_SIZE] == 2 + with pytest.raises(RuntimeError, match="Missing AutoEP\\+AutoTP folding metadata"): + validate_folding_metadata({}, tp_size=2, ep_size=4) + with pytest.raises(RuntimeError, match="tp_size"): + validate_folding_metadata(wrapped, tp_size=1, ep_size=4) + with pytest.raises(RuntimeError, match="ep_size"): + validate_folding_metadata(wrapped, tp_size=2, ep_size=2) + wrapped[FOLDING_METADATA_KEY][FOLDING_ETP_SIZE] = 2 + with pytest.raises(RuntimeError, match="etp_size"): + validate_folding_metadata(wrapped, tp_size=2, ep_size=4) + wrapped[FOLDING_METADATA_KEY][FOLDING_ETP_SIZE] = 1 + wrapped[FOLDING_METADATA_KEY][FOLDING_ZERO_PARTITION_COUNT] = 2 + with pytest.raises(RuntimeError, match="zero_partition_count"): + validate_folding_metadata(wrapped, tp_size=2, ep_size=4, zero_partition_count=4) + wrapped[FOLDING_METADATA_KEY][FOLDING_ZERO_PARTITION_COUNT] = 4 + wrapped[FOLDING_METADATA_KEY][FOLDING_FAMILY] = "routed_expert" + with pytest.raises(RuntimeError, match="family"): + validate_folding_metadata(wrapped, tp_size=2, ep_size=4, family="dense") + wrapped[FOLDING_METADATA_KEY][FOLDING_FAMILY] = "dense" + wrapped[FOLDING_METADATA_KEY][FOLDING_SHARED_EXPERT_PLACEMENT] = "replicated" + with pytest.raises(RuntimeError, match="shared_expert_placement"): + validate_folding_metadata(wrapped, tp_size=2, ep_size=4, shared_expert_placement="tp_sharded") + wrapped[FOLDING_METADATA_KEY][FOLDING_SHARED_EXPERT_PLACEMENT] = "tp_sharded" + wrapped[FOLDING_METADATA_KEY][FOLDING_ETP_RANK] = 1 + with pytest.raises(RuntimeError, match="etp_rank"): + validate_folding_metadata(wrapped, tp_size=2, ep_size=4, etp_rank=0) + + +def _folded_checkpoint_config(*, ep_size=2, mixed_precision=True): + config = make_autoep_config(zero_stage=0, ep_size=ep_size, mixed_precision=mixed_precision) + if not mixed_precision: + config["optimizer"]["params"]["torch_adam"] = True + config["tensor_parallel"] = { + "autotp_size": 2, + "partition_config": { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + }, + } + return config + + +def test_validate_routed_expert_metadata_accepts_edp_replica_readers(): + folding = make_folding_metadata(tp_size=2, + tp_rank=0, + ep_size=2, + ep_rank=1, + zero_partition_group="edp", + zero_partition_rank=0, + zero_partition_count=4, + family="routed_expert") + wrapped = {FOLDING_METADATA_KEY: folding} + + assert validate_folding_metadata(wrapped, + tp_size=2, + ep_size=2, + tp_rank=0, + ep_rank=1, + zero_partition_group="edp", + zero_partition_count=4, + family="routed_expert")[FOLDING_EP_SIZE] == 2 + with pytest.raises(RuntimeError, match="zero_partition_rank"): + validate_folding_metadata(wrapped, + tp_size=2, + ep_size=2, + tp_rank=0, + ep_rank=1, + zero_partition_group="edp", + zero_partition_rank=1, + zero_partition_count=4, + family="routed_expert") + + +def test_folded_checkpoint_metadata_rejects_unfolded_runtime(): + state = { + FOLDING_METADATA_KEY: + make_folding_metadata(tp_size=2, + tp_rank=0, + ep_size=2, + ep_rank=0, + zero_partition_group="dense_dp", + zero_partition_rank=0, + zero_partition_count=2, + family="dense") + } + + with pytest.raises(RuntimeError, match="requires a folded runtime"): + DeepSpeedEngine._validate_autoep_folding_checkpoint_metadata(state, + folding_spec=None, + family="dense", + zero_partition_group="dense_dp", + zero_partition_count=2) + + +def test_universal_conversion_rejects_folded_tp_expert_shards(tmpdir): + checkpoint_dir = tmpdir.mkdir("folded") + output_dir = tmpdir.mkdir("universal") + for tp_rank in (0, 1): + torch.save( + { + FOLDING_METADATA_KEY: + make_folding_metadata(tp_size=2, + tp_rank=tp_rank, + ep_size=2, + ep_rank=0, + zero_partition_group="edp", + zero_partition_rank=0, + zero_partition_count=1, + family="routed_expert"), + "experts.w1.0": + torch.ones(2, 2), + }, + os.path.join(str(checkpoint_dir), f"layer_0_expert_0_mp_rank_0{tp_rank}_model_states.pt"), + ) + + with pytest.raises(NotImplementedError, match="folded AutoEP\\+AutoTP expert shards"): + consolidate_autoep_expert_files(str(checkpoint_dir), str(output_dir), [{ + "moe_layer_id": 0, + "num_experts": 1, + "expert_key_prefix": "experts", + }]) + + +def _load_torch_checkpoint(path): + return torch.load(path, map_location="cpu", weights_only=False) + + +def _drop_folding_metadata_from_model_checkpoints(checkpoint_dir): + model_paths = sorted(glob.glob(os.path.join(str(checkpoint_dir), "folded", "mp_rank_*_model_states.pt"))) + assert model_paths + if dist.get_rank() == 0: + for path in model_paths: + state = _load_torch_checkpoint(path) + state.pop(FOLDING_METADATA_KEY, None) + torch.save(state, path) + dist.barrier() + + +def _assert_saved_checkpoint_metadata(checkpoint_dir, *, ep_size=2): + model_paths = sorted(glob.glob(os.path.join(str(checkpoint_dir), "folded", "mp_rank_*_model_states.pt"))) + expert_paths = sorted( + glob.glob(os.path.join(str(checkpoint_dir), "folded", "layer_*_expert_*_mp_rank_*_model_states.pt"))) + optim_paths = sorted(glob.glob(os.path.join(str(checkpoint_dir), "folded", "*_optim_states.pt"))) + + assert model_paths + assert expert_paths + assert optim_paths + + for path in model_paths: + state = _load_torch_checkpoint(path) + folding = validate_folding_metadata(state, tp_size=2, ep_size=ep_size, etp_rank=0) + assert folding[FOLDING_FAMILY] == "dense" + assert folding[FOLDING_ZERO_PARTITION_GROUP] == "dense_dp" + param_families = folding[FOLDING_PARAM_FAMILIES] + assert param_families["model.layers.0.mlp.router.gate.weight"] == "router_gate_replicated" + assert param_families["model.layers.0.dense.weight"] == "dense" + assert all(not key.endswith("experts.w1") for key in param_families) + + for path in expert_paths: + state = _load_torch_checkpoint(path) + folding = validate_folding_metadata(state, tp_size=2, ep_size=ep_size) + assert folding[FOLDING_FAMILY] == "routed_expert" + assert folding[FOLDING_ZERO_PARTITION_GROUP] == "edp" + + assert any(FOLDING_METADATA_KEY in _load_torch_checkpoint(path) for path in optim_paths) + + +def _run_folded_checkpoint_same_topology_resume(checkpoint_dir, *, ep_size=2, mixed_precision=True): + config = _folded_checkpoint_config(ep_size=ep_size, mixed_precision=mixed_precision) + seed_everything(1234) + engine, _, _, _ = deepspeed.initialize(model=MockMoEOnlyTransformer(), config=config) + folded_layers = [module for module in engine.module.modules() if isinstance(module, AutoEPMoELayer)] + assert folded_layers + torch.manual_seed(1234) + x = torch.randn(1, 4, 64, device=engine.device, dtype=engine_input_dtype(engine)) + dist.broadcast(x, groups.get_tensor_model_parallel_src_rank(), group=groups.get_tensor_model_parallel_group()) + loss = engine(x).float().mean() + engine.backward(loss) + engine.step() + engine.save_checkpoint(str(checkpoint_dir), tag="folded") + _assert_saved_checkpoint_metadata(checkpoint_dir, ep_size=ep_size) + + seed_everything(1234) + reloaded, _, _, _ = deepspeed.initialize(model=MockMoEOnlyTransformer(), config=config) + reloaded.load_checkpoint(str(checkpoint_dir), tag="folded") + resumed_loss = reloaded(x).float().mean() + assert torch.isfinite(resumed_loss.detach()).item() + + _drop_folding_metadata_from_model_checkpoints(checkpoint_dir) + seed_everything(1234) + missing_metadata, _, _, _ = deepspeed.initialize(model=MockMoEOnlyTransformer(), config=config) + with pytest.raises(RuntimeError, match="Missing AutoEP\\+AutoTP folding metadata"): + missing_metadata.load_checkpoint(str(checkpoint_dir), tag="folded") + + +def _cpu_folded_checkpoint_worker(_rank, _world_size, shared_tmpdir): + _run_folded_checkpoint_same_topology_resume(shared_tmpdir, mixed_precision=False) + + +def test_cpu_gloo_folded_checkpoint_edp_replica_resume(tmpdir): + run_cpu_gloo_test(_cpu_folded_checkpoint_worker, tmpdir, world_size=8) + + +class TestH100FoldedCheckpoint(DistributedTest): + world_size = 4 + reuse_dist_env = False + + def test_h100_folded_checkpoint_same_topology_resume(self, tmpdir): + skip_unless_h100_tests_enabled("H100 checkpoint resume node") + + _run_folded_checkpoint_same_topology_resume(tmpdir) + + +class TestH100FoldedCheckpointTP2EP4(DistributedTest): + world_size = 8 + reuse_dist_env = False + + def test_h100_folded_tp2_ep4_checkpoint_same_topology_resume(self, tmpdir): + skip_unless_h100_tests_enabled("H100 TP2-EP4 checkpoint resume node") + + _run_folded_checkpoint_same_topology_resume(tmpdir, ep_size=4) diff --git a/tests/unit/v1/moe/test_autoep_autotp_dispatch.py b/tests/unit/v1/moe/test_autoep_autotp_dispatch.py new file mode 100644 index 000000000000..8136a6e95468 --- /dev/null +++ b/tests/unit/v1/moe/test_autoep_autotp_dispatch.py @@ -0,0 +1,223 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Standalone tests for AutoEP + AutoTP routed-assignment partitioning.""" + +import pytest +import torch + +import deepspeed.comm as dist +from deepspeed.module_inject.auto_ep_layer import combine_from_routed +from deepspeed.moe.ep_tp_dispatch import ( + RoutedAssignmentPayload, + assignment_ordinals_by_expert, + assert_tp_payload_consistent, + dispatch_counters, + partition_assignments, + restore_combined, +) +import deepspeed.moe.ep_tp_dispatch as dispatch +from unit.v1.moe.autoep_test_utils import run_cpu_gloo_test + + +def _payload(): + expert_indices = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2], dtype=torch.long) + token_indices = torch.tensor([0, 1, 2, 0, 3, 1, 2, 3, 4], dtype=torch.long) + combine = torch.tensor([0.5, 0.25, 1.0, 0.5, 1.0, 0.75, 0.5, 0.25, 1.0]) + drop_mask = torch.zeros_like(expert_indices, dtype=torch.bool) + pad_mask = torch.zeros_like(expert_indices, dtype=torch.bool) + return RoutedAssignmentPayload( + token_indices=token_indices, + expert_indices=expert_indices, + assignment_indices=assignment_ordinals_by_expert(expert_indices), + capacity_slots=torch.arange(expert_indices.numel(), dtype=torch.long), + combine_weights=combine, + drop_mask=drop_mask, + pad_mask=pad_mask, + input_splits=[3, 2, 4], + output_splits=[3, 2, 4], + extra={ + "destination_ranks": expert_indices, + "num_tokens": torch.tensor(5, dtype=torch.long), + }, + ) + + +def test_assignment_ordinals_are_stable_within_expert_segments(): + expert_indices = torch.tensor([0, 0, 2, 2, 2, 4], dtype=torch.long) + assert assignment_ordinals_by_expert(expert_indices).tolist() == [0, 1, 0, 1, 2, 0] + + +def test_partition_assignments_splits_each_expert_once_across_tp_lanes(): + payload = _payload() + local0, ctx0 = partition_assignments(payload, tp_group=None, tp_rank=0, tp_size=2) + local1, ctx1 = partition_assignments(payload, tp_group=None, tp_rank=1, tp_size=2) + + observed = set(local0.capacity_slots.tolist()) | set(local1.capacity_slots.tolist()) + assert observed == set(range(payload.token_indices.numel())) + assert set(local0.capacity_slots.tolist()).isdisjoint(set(local1.capacity_slots.tolist())) + assert dispatch_counters(ctx0)["assignments_total"] == payload.token_indices.numel() + assert dispatch_counters(ctx0)["assignments_local"] + dispatch_counters( + ctx1)["assignments_local"] == payload.token_indices.numel() + assert local0.input_splits == [2, 1, 2] + assert local1.input_splits == [1, 1, 2] + + +def test_partition_excludes_padded_and_dropped_assignments_from_stats(): + payload = _payload() + payload.drop_mask[1] = True + payload.pad_mask[6] = True + local0, ctx0 = partition_assignments(payload, tp_group=None, tp_rank=0, tp_size=1) + + assert local0.token_indices.numel() == payload.token_indices.numel() - 2 + counters = dispatch_counters(ctx0) + assert counters["assignments_total"] == payload.token_indices.numel() - 2 + assert counters["padded"] == 1 + assert counters["dropped"] == 1 + + +def test_restore_combined_sums_topk_assignments_by_original_token(): + payload = _payload() + local, ctx = partition_assignments(payload, tp_group=None, tp_rank=0, tp_size=1) + values = torch.arange(local.token_indices.numel() * 2, dtype=torch.float32).reshape(-1, 2) + restored = restore_combined(values, ctx, tp_group=None) + + expected = torch.zeros(5, 2) + for row, token, weight in zip(values, local.token_indices, local.combine_weights): + expected[token] += row * weight + assert torch.allclose(restored, expected) + assert restored.dtype == values.dtype + assert restored.device == values.device + + +def test_restore_combined_preserves_output_and_router_weight_gradients(): + payload = _payload() + payload.combine_weights = payload.combine_weights.clone().requires_grad_(True) + local, ctx = partition_assignments(payload, tp_group=None, tp_rank=0, tp_size=1) + values = torch.arange(local.token_indices.numel() * 2, dtype=torch.float32).reshape(-1, 2).requires_grad_(True) + + restored = restore_combined(values, ctx, tp_group=None) + restored.square().sum().backward() + + assert values.grad is not None + assert values.grad.abs().sum().item() > 0 + assert payload.combine_weights.grad is not None + assert payload.combine_weights.grad.abs().sum().item() > 0 + + +def _tp_payload_for_backward_parity(): + top_k = 2 + token_indices_sorted = torch.tensor([0, 3, 5, 1, 2, 7, 4, 6], dtype=torch.long) + expert_indices = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.long) + top_scores = torch.tensor([[0.65, 0.35], [0.25, 0.75], [0.55, 0.45], [0.30, 0.70]], dtype=torch.float32) + combine_weights = top_scores.reshape(-1).index_select(0, token_indices_sorted) + drop_mask = torch.zeros_like(expert_indices, dtype=torch.bool) + pad_mask = torch.zeros_like(expert_indices, dtype=torch.bool) + return ( + RoutedAssignmentPayload( + token_indices=(token_indices_sorted // top_k).to(torch.long), + expert_indices=expert_indices, + assignment_indices=assignment_ordinals_by_expert(expert_indices), + capacity_slots=(token_indices_sorted % top_k).to(torch.long), + combine_weights=combine_weights, + drop_mask=drop_mask, + pad_mask=pad_mask, + input_splits=[2, 2, 2, 2], + output_splits=[2, 2, 2, 2], + extra={ + "destination_ranks": expert_indices, + "num_tokens": torch.tensor(4, dtype=torch.long), + }, + ), + top_scores, + token_indices_sorted, + ) + + +def _restore_combined_backward_parity_worker(rank, world_size, _shared_tmpdir): + payload, top_scores, token_indices_sorted = _tp_payload_for_backward_parity() + tp_group = dist.get_world_group() + local, ctx = partition_assignments(payload, tp_group=tp_group, tp_rank=rank, tp_size=world_size) + + full_expert_output = torch.arange(payload.token_indices.numel() * 3, dtype=torch.float32).reshape( + payload.token_indices.numel(), 3) / 11.0 + expected_expert_output = full_expert_output.clone().requires_grad_(True) + expected_top_scores = top_scores.clone().requires_grad_(True) + expected = combine_from_routed(expected_expert_output, + top_scores=expected_top_scores, + token_indices_sorted=token_indices_sorted, + top_k=2, + score_apply="post", + combine_impl="weighted_sum", + shape=(1, 4, 3)) + expected.square().sum().backward() + expected_weight_grad = expected_top_scores.grad.reshape(-1).index_select(0, token_indices_sorted) + + actual_expert_output = full_expert_output.clone().requires_grad_(True) + actual_top_scores = top_scores.clone().requires_grad_(True) + payload.combine_weights = actual_top_scores.reshape(-1).index_select(0, token_indices_sorted) + local_values = actual_expert_output.index_select(0, ctx.local_indices) + restored = restore_combined(local_values, ctx, tp_group=tp_group) + + torch.testing.assert_close(restored.reshape(1, 4, 3), expected.detach(), rtol=0.0, atol=0.0) + restored.square().sum().backward() + + actual_value_grad = actual_expert_output.grad.detach().clone() + actual_top_score_grad = actual_top_scores.grad.detach().clone() + dist.all_reduce(actual_value_grad, group=tp_group) + dist.all_reduce(actual_top_score_grad, group=tp_group) + + torch.testing.assert_close(actual_value_grad, expected_expert_output.grad, rtol=1e-6, atol=1e-6) + torch.testing.assert_close(actual_top_score_grad, expected_top_scores.grad, rtol=1e-6, atol=1e-6) + torch.testing.assert_close(actual_top_score_grad.reshape(-1).index_select(0, token_indices_sorted), + expected_weight_grad, + rtol=1e-6, + atol=1e-6) + + +def test_restore_combined_tp_backward_matches_non_partitioned_combine(tmpdir): + run_cpu_gloo_test(_restore_combined_backward_parity_worker, tmpdir, world_size=2) + + +def test_restore_coverage_assertion_detects_missing_assignment(): + payload = _payload() + local, ctx = partition_assignments(payload, tp_group=None, tp_rank=0, tp_size=1) + ctx.local_indices = ctx.local_indices[:-1] + values = torch.ones((local.token_indices.numel() - 1, 2), dtype=torch.float32) + + with pytest.raises(RuntimeError, match="restore coverage mismatch"): + restore_combined(values, ctx, tp_group=None) + + +def test_tp_payload_consistency_detects_divergent_large_payload(monkeypatch): + rows = 4097 + expert_indices = torch.zeros(rows, dtype=torch.long) + payload = RoutedAssignmentPayload( + token_indices=torch.arange(rows, dtype=torch.long), + expert_indices=expert_indices, + assignment_indices=torch.arange(rows, dtype=torch.long), + capacity_slots=torch.arange(rows, dtype=torch.long), + combine_weights=torch.ones(rows), + drop_mask=torch.zeros(rows, dtype=torch.bool), + pad_mask=torch.zeros(rows, dtype=torch.bool), + input_splits=[rows], + output_splits=[rows], + extra={ + "destination_ranks": expert_indices, + "num_tokens": torch.tensor(rows, dtype=torch.long), + }, + ) + + calls = [] + + def fake_all_reduce(tensor, op=None, group=None): + if not calls: + tensor[3].add_(1) + calls.append(op) + + monkeypatch.setattr(dispatch.dist, "is_initialized", lambda: True) + monkeypatch.setattr(dispatch.dist, "all_reduce", fake_all_reduce) + + with pytest.raises(RuntimeError, match="routing decisions differ"): + assert_tp_payload_consistent(payload, tp_group=object(), tp_size=2) diff --git a/tests/unit/v1/moe/test_autoep_autotp_folding_config.py b/tests/unit/v1/moe/test_autoep_autotp_folding_config.py new file mode 100644 index 000000000000..6c972bb4efac --- /dev/null +++ b/tests/unit/v1/moe/test_autoep_autotp_folding_config.py @@ -0,0 +1,135 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""AutoEP + AutoTP folding config and topology tests.""" + +from types import SimpleNamespace + +import pytest + +from deepspeed.module_inject.auto_ep_config import AutoEPConfig, parse_autoep_config, validate_autoep_config +from deepspeed.module_inject.auto_ep_folding import ( + FoldingGroupHandles, + FoldingGroupTables, + ParallelFoldingSpec, + build_folding_spec, + expected_folding_group_tables, +) + + +def test_folding_symbols_import_without_distributed_init(): + spec = build_folding_spec(world_size=8, pp_size=1, tp_size=2, ep_size=4, etp_size=1) + assert isinstance(spec, ParallelFoldingSpec) + assert isinstance(expected_folding_group_tables(spec), FoldingGroupTables) + assert FoldingGroupHandles.__name__ == "FoldingGroupHandles" + + +@pytest.mark.parametrize( + "world_size,tp_size,ep_size,expected_dp,expected_edp", + [(8, 2, 4, 4, 2), (16, 2, 4, 8, 4), (4, 2, 2, 2, 2)], +) +def test_valid_folding_spec_derives_dense_and_expert_dp(world_size, tp_size, ep_size, expected_dp, expected_edp): + config = AutoEPConfig(enabled=True, autoep_size=ep_size, expert_tensor_parallel_size=1) + validate_autoep_config(config, world_size=world_size, pp_size=1, tp_size=tp_size, sp_size=1) + spec = build_folding_spec(world_size=world_size, pp_size=1, tp_size=tp_size, ep_size=ep_size, etp_size=1) + assert spec.dp_size == expected_dp + assert spec.edp_size == expected_edp + + +def test_backward_config_compatibility_defaults_etp_to_one(): + config = parse_autoep_config({"enabled": True, "autoep_size": 2, "preset_model": "mixtral"}) + assert config.expert_tensor_parallel_size == 1 + validate_autoep_config(config, world_size=4, pp_size=1, tp_size=1, sp_size=1) + + +def test_expected_folding_tables_match_design_examples(): + spec8 = build_folding_spec(world_size=8, pp_size=1, tp_size=2, ep_size=4, etp_size=1) + tables8 = expected_folding_group_tables(spec8) + assert tables8.tp_groups == ((0, 1), (2, 3), (4, 5), (6, 7)) + assert tables8.dense_dp_groups == ((0, 2, 4, 6), (1, 3, 5, 7)) + assert tables8.ep_groups == ((0, 2, 4, 6), (1, 3, 5, 7)) + assert tables8.edp_groups == ((0, 1), (2, 3), (4, 5), (6, 7)) + + spec16 = build_folding_spec(world_size=16, pp_size=1, tp_size=2, ep_size=4, etp_size=1) + tables16 = expected_folding_group_tables(spec16) + assert tables16.ep_groups == ((0, 2, 4, 6), (8, 10, 12, 14), (1, 3, 5, 7), (9, 11, 13, 15)) + assert tables16.edp_groups == ((0, 8, 1, 9), (2, 10, 3, 11), (4, 12, 5, 13), (6, 14, 7, 15)) + + +def _assert_rejects(match, **kwargs): + config_kwargs = { + "enabled": True, + "autoep_size": kwargs.pop("ep_size", 2), + "expert_tensor_parallel_size": kwargs.pop("etp_size", 1), + "preset_model": kwargs.pop("ep_preset", None), + } + validate_kwargs = { + "world_size": kwargs.pop("world_size", 4), + "pp_size": kwargs.pop("pp_size", 1), + "tp_size": kwargs.pop("tp_size", 2), + "sp_size": kwargs.pop("sp_size", 1), + } + validate_kwargs.update(kwargs) + with pytest.raises(ValueError, match=match): + validate_autoep_config(AutoEPConfig(**config_kwargs), **validate_kwargs) + + +def test_validation_rule_g1_pp_divisibility_and_pp_rejection(): + _assert_rejects("pp_size=2 must divide world_size=7", world_size=7, pp_size=2, tp_size=1, ep_size=1) + _assert_rejects("pp_size=1 only", world_size=8, pp_size=2, tp_size=2, ep_size=2) + + +def test_nonfolded_autoep_preserves_pipeline_parallel_compatibility(): + config = AutoEPConfig(enabled=True, autoep_size=2, expert_tensor_parallel_size=1) + validate_autoep_config(config, world_size=8, pp_size=2, tp_size=1, sp_size=1) + + +def test_validation_rule_g2_tp_divisibility_names_valid_divisors(): + _assert_rejects("autotp_size=3.*Valid autotp_size values", world_size=8, tp_size=3, ep_size=1) + + +def test_validation_rule_g3_expert_width_divisibility_names_valid_divisors(): + _assert_rejects("autoep_size \\* expert_parallel\\.expert_tensor_parallel_size.*Valid expert-width values", + world_size=8, + tp_size=2, + ep_size=3) + + +def test_validation_rule_g4_etp_reserved_message(): + _assert_rejects("expert_tensor_parallel_size=2 is reserved", world_size=8, tp_size=2, ep_size=2, etp_size=2) + + +def test_validation_rule_g5_tp_sp_exclusive(): + _assert_rejects("mutually exclusive", world_size=4, tp_size=2, ep_size=2, sp_size=2) + + +def test_validation_rule_g6_preset_consistency(): + _assert_rejects("must match", world_size=4, tp_size=2, ep_size=2, ep_preset="mixtral", tp_preset_model="qwen3_moe") + + +def test_validation_rule_g7_zero3_lane_pointer(): + _assert_rejects("ZeRO stage 3.*separate ZeRO-3 composition lane", zero_stage=3) + + +def test_validation_rule_g8_mpu_conflict(): + mpu = SimpleNamespace(get_tensor_model_parallel_world_size=lambda: 4, + get_pipeline_model_parallel_world_size=lambda: 1) + _assert_rejects("mpu tensor/model parallel world size", mpu=mpu) + + +def test_validation_rule_g9_ep_one_rejected_with_autotp(): + _assert_rejects("autoep_size > 1", world_size=4, tp_size=2, ep_size=1) + + +def test_validation_rule_g10_data_before_expert_parallel_rejected(): + _assert_rejects("use_data_before_expert_parallel_", use_data_before_expert_parallel=True) + + +@pytest.mark.parametrize("offload_key", ["zero_offload_optimizer", "zero_offload_param"]) +def test_validation_rule_g11_zero_offload_rejected(offload_key): + _assert_rejects("offload", **{offload_key: True}) + + +def test_validation_rule_g12_cross_lane_ep_groups_temporarily_rejected(): + _assert_rejects("temporary limitation", world_size=8, tp_size=4, ep_size=4) diff --git a/tests/unit/v1/moe/test_autoep_autotp_folding_groups.py b/tests/unit/v1/moe/test_autoep_autotp_folding_groups.py new file mode 100644 index 000000000000..d8f120552190 --- /dev/null +++ b/tests/unit/v1/moe/test_autoep_autotp_folding_groups.py @@ -0,0 +1,74 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""AutoEP + AutoTP folding group-table and handle contract tests.""" + +import pytest + +from deepspeed.module_inject.auto_ep_folding import ( + FoldingGroupHandles, + assert_group_matches_spec, + build_folding_spec, + expected_folding_group_tables, + local_folding_ranks, +) + + +def test_8gpu_tp2_ep4_tables_match_design(): + spec = build_folding_spec(world_size=8, pp_size=1, tp_size=2, ep_size=4, etp_size=1) + tables = expected_folding_group_tables(spec) + + assert tables.tp_groups == ((0, 1), (2, 3), (4, 5), (6, 7)) + assert tables.dense_dp_groups == ((0, 2, 4, 6), (1, 3, 5, 7)) + assert tables.ep_groups == ((0, 2, 4, 6), (1, 3, 5, 7)) + assert tables.edp_groups == ((0, 1), (2, 3), (4, 5), (6, 7)) + + +def test_16gpu_tp2_ep4_tables_match_design(): + spec = build_folding_spec(world_size=16, pp_size=1, tp_size=2, ep_size=4, etp_size=1) + tables = expected_folding_group_tables(spec) + + assert tables.ep_groups == ((0, 2, 4, 6), (8, 10, 12, 14), (1, 3, 5, 7), (9, 11, 13, 15)) + assert tables.edp_groups == ((0, 8, 1, 9), (2, 10, 3, 11), (4, 12, 5, 13), (6, 14, 7, 15)) + + +def test_local_folding_ranks_match_helper_tables(): + spec = build_folding_spec(world_size=8, pp_size=1, tp_size=2, ep_size=4, etp_size=1) + assert local_folding_ranks(5, spec) == { + "tp": (4, 5), + "dense_dp": (1, 3, 5, 7), + "ep": (1, 3, 5, 7), + "edp": (4, 5), + } + + +def test_stale_registry_rank_lists_are_rejected(): + spec = build_folding_spec(world_size=8, pp_size=1, tp_size=2, ep_size=4, etp_size=1) + stale_legacy_ep = ((0, 1, 2, 3), ) + stale_legacy_edp = ((0, 4), ) + + with pytest.raises(RuntimeError, match="does not match folding spec"): + assert_group_matches_spec((stale_legacy_ep, stale_legacy_edp), spec) + + +def test_group_handle_container_carries_explicit_groups_and_rank_tables(): + spec = build_folding_spec(world_size=4, pp_size=1, tp_size=2, ep_size=2, etp_size=1) + local = local_folding_ranks(2, spec) + handles = FoldingGroupHandles( + spec=spec, + tp_group=object(), + dense_dp_group=object(), + ep_group=object(), + edp_group=object(), + ep_group_name="ep_size_2", + tp_ranks=local["tp"], + dense_dp_ranks=local["dense_dp"], + ep_ranks=local["ep"], + edp_ranks=local["edp"], + ) + + assert handles.ep_group_name == "ep_size_2" + assert handles.tp_ranks == (2, 3) + assert handles.ep_ranks == (0, 2) + assert handles.edp_ranks == (2, 3) diff --git a/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py b/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py new file mode 100644 index 000000000000..0babb164cf26 --- /dev/null +++ b/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py @@ -0,0 +1,301 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Gradient and optimizer policy checks for AutoEP + AutoTP folding.""" + +import glob +import json +import os + +import pytest +import torch + +import deepspeed +import deepspeed.comm as dist +from deepspeed.checkpoint.autoep_universal import validate_folding_metadata +from deepspeed.checkpoint.constants import FOLDING_FAMILY, FOLDING_METADATA_KEY, FOLDING_PARAM_FAMILIES +from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer +from deepspeed.runtime.engine import DeepSpeedEngine +from deepspeed.utils import safe_get_full_grad +from deepspeed.utils import groups +from unit.common import DistributedTest +from unit.v1.moe.autoep_test_utils import ( + MockMoEOnlyTransformer, + engine_input_dtype, + make_autoep_config, + run_cpu_gloo_test, + seed_everything, + skip_unless_h100_tests_enabled, +) + +from deepspeed.module_inject.auto_ep_config import AutoEPConfig, validate_autoep_config + + +def test_zero_offload_paths_fail_fast_until_per_family_replica_groups_are_proven(): + for kwargs in ({"zero_offload_optimizer": True}, {"zero_offload_param": True}): + with pytest.raises(ValueError, match="offload"): + validate_autoep_config(AutoEPConfig(enabled=True, autoep_size=2), + world_size=4, + pp_size=1, + tp_size=2, + sp_size=1, + zero_stage=2, + **kwargs) + + +def test_zero3_composition_remains_separate_lane(): + with pytest.raises(ValueError, match="separate ZeRO-3 composition lane"): + validate_autoep_config(AutoEPConfig(enabled=True, autoep_size=2), + world_size=4, + pp_size=1, + tp_size=2, + sp_size=1, + zero_stage=3) + + +def _folded_zero2_config(*, mixed_precision=True): + config = make_autoep_config(zero_stage=2, ep_size=2, mixed_precision=mixed_precision) + config["gradient_accumulation_steps"] = 2 + config["gradient_clipping"] = 0.0 + if not mixed_precision: + config["optimizer"]["params"]["torch_adam"] = True + config["tensor_parallel"] = { + "autotp_size": 2, + "partition_config": { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + }, + } + return config + + +@pytest.mark.parametrize(("mp_mode", "expected_grad"), (("tp", 2.0), ("sp", 2.0), ("replicated", 1.0))) +def test_tp_replicated_gradient_reducer_respects_parallel_mode(monkeypatch, mp_mode, expected_grad): + param = torch.nn.Parameter(torch.ones(2)) + param.grad = torch.ones_like(param) + engine = object.__new__(DeepSpeedEngine) + engine._autoep_folding_spec = type("Spec", (), {"tp_size": 2, "mp_mode": mp_mode})() + engine.__dict__["optimizer"] = None + engine.__dict__["module"] = type("ModuleStub", (), + {"named_parameters": lambda self: iter([("dense.weight", param)])})() + monkeypatch.setattr(dist, "is_initialized", lambda: True) + monkeypatch.setattr(groups, "get_tensor_model_parallel_group", lambda: object()) + monkeypatch.setattr(dist, "get_world_size", lambda group=None: 2) + + def fake_all_reduce(tensor, group=None): + tensor.mul_(2) + + monkeypatch.setattr(dist, "all_reduce", fake_all_reduce) + + engine._reduce_autoep_folding_tp_replicated_gradients() + + assert torch.equal(param.grad, torch.full_like(param.grad, expected_grad)) + + +def _folded_zero2_tp2_ep4_config(): + config = _folded_zero2_config(mixed_precision=False) + config["expert_parallel"]["autoep_size"] = 4 + config["communication_data_type"] = "fp32" + return config + + +def _zero2_baseline_config(): + config = { + **{ + key: value + for key, value in make_autoep_config(zero_stage=2, ep_size=1, mixed_precision=False).items() if key != "expert_parallel" + }, + "gradient_accumulation_steps": 2, + "gradient_clipping": 0.0, + } + config["communication_data_type"] = "fp32" + config["optimizer"]["params"]["torch_adam"] = True + return config + + +def _router_grad_model(): + return MockMoEOnlyTransformer(num_layers=1, num_experts=4, hidden_size=64, intermediate_size=128) + + +def _make_logical_batches(engine, *, logical_dp_world_size, logical_dp_rank, grad_accum, seed): + batches = [] + for accum_idx in range(grad_accum): + batch_idx = accum_idx * logical_dp_world_size + logical_dp_rank + generator = torch.Generator().manual_seed(seed + batch_idx) + batch = torch.randn((1, 4, 64), generator=generator, dtype=engine_input_dtype(engine)) + batches.append(batch.to(engine.device)) + return batches + + +def _run_router_grad_boundary(engine, *, logical_dp_world_size, logical_dp_rank, seed): + batches = _make_logical_batches(engine, + logical_dp_world_size=logical_dp_world_size, + logical_dp_rank=logical_dp_rank, + grad_accum=2, + seed=seed) + for batch_idx, batch in enumerate(batches): + loss = engine(batch).float().mean() + engine.backward(loss) + if batch_idx + 1 < len(batches): + engine.step() + + +def _full_grad_by_suffix(engine, suffix): + for name, param in engine.module.named_parameters(): + if name.endswith(suffix): + grad = safe_get_full_grad(param) + assert grad is not None, f"Expected full grad for {name}" + return grad.detach().float().cpu().clone() + raise AssertionError(f"Missing parameter ending with {suffix}") + + +def _grad_parity_metrics(actual, expected): + diff = actual - expected + expected_norm_sq = expected.square().sum().item() + actual_norm = actual.norm().item() + expected_norm = expected.norm().item() + scale = actual.mul(expected).sum().item() / expected_norm_sq if expected_norm_sq else 0.0 + return { + "scale_vs_expected": scale, + "scale_vs_baseline": scale, + "max_abs": diff.abs().max().item(), + "rel_norm": diff.norm().item() / expected_norm, + "actual_norm": actual_norm, + "expected_norm": expected_norm, + "folded_norm": actual_norm, + "baseline_norm": expected_norm, + } + + +def _assert_zero_optimizer_folding_metadata(checkpoint_dir): + optim_paths = sorted(glob.glob(os.path.join(str(checkpoint_dir), "folded-zero2", "*_optim_states.pt"))) + assert optim_paths + saw_metadata = False + for path in optim_paths: + state = torch.load(path, map_location="cpu", weights_only=False) + if FOLDING_METADATA_KEY not in state: + continue + if state[FOLDING_METADATA_KEY][FOLDING_FAMILY] != "zero_optimizer_state": + continue + saw_metadata = True + folding = validate_folding_metadata(state, + tp_size=2, + ep_size=2, + family="zero_optimizer_state", + zero_partition_group="per_family", + zero_partition_count=2) + assert folding[FOLDING_FAMILY] == "zero_optimizer_state" + param_families = folding[FOLDING_PARAM_FAMILIES] + routed_entries = {name: meta for name, meta in param_families.items() if ".experts." in name} + assert routed_entries + assert all(meta["family"] == "routed_expert" for meta in routed_entries.values()) + assert all(meta["zero_partition_group"] == "edp" for meta in routed_entries.values()) + dense_entries = {name: meta for name, meta in param_families.items() if ".dense." in name} + assert dense_entries + assert all(meta["family"] == "dense" for meta in dense_entries.values()) + assert all(meta["zero_partition_group"] == "dense_dp" for meta in dense_entries.values()) + assert saw_metadata + + +def _cpu_folded_zero2_worker(_rank, _world_size, _shared_tmpdir): + seed_everything(1234) + engine, _, _, _ = deepspeed.initialize(model=MockMoEOnlyTransformer(), + config=_folded_zero2_config(mixed_precision=False)) + folded_layers = [module for module in engine.module.modules() if isinstance(module, AutoEPMoELayer)] + assert folded_layers + assert all(layer.folding_group_handles is not None for layer in folded_layers) + torch.manual_seed(1234) + x = torch.randn(1, 4, 64, device=engine.device, dtype=engine_input_dtype(engine)) + dist.broadcast(x, groups.get_tensor_model_parallel_src_rank(), group=groups.get_tensor_model_parallel_group()) + loss = engine(x).float().mean() + engine.backward(loss) + engine.step() + engine.save_checkpoint(str(_shared_tmpdir), tag="folded-zero2") + dist.barrier() + _assert_zero_optimizer_folding_metadata(_shared_tmpdir) + assert torch.isfinite(loss.detach()).item() + + +def test_cpu_gloo_folded_zero2_optimizer_state_smoke(tmpdir): + run_cpu_gloo_test(_cpu_folded_zero2_worker, tmpdir, world_size=4) + + +class TestH100FoldedZero12(DistributedTest): + world_size = 4 + reuse_dist_env = False + + def test_h100_zero12_per_family_optimizer_state(self): + skip_unless_h100_tests_enabled("H100 optimizer-state node") + + seed_everything(1234) + engine, _, _, _ = deepspeed.initialize(model=MockMoEOnlyTransformer(), config=_folded_zero2_config()) + folded_layers = [module for module in engine.module.modules() if isinstance(module, AutoEPMoELayer)] + assert folded_layers + assert all(layer.folding_group_handles is not None for layer in folded_layers) + torch.manual_seed(1234) + x = torch.randn(1, 4, 64, device=engine.device, dtype=engine_input_dtype(engine)) + dist.broadcast(x, groups.get_tensor_model_parallel_src_rank(), group=groups.get_tensor_model_parallel_group()) + loss = engine(x).float().mean() + engine.backward(loss) + engine.step() + assert torch.isfinite(loss.detach()).item() + + +class TestH100FoldedRouterGateGradParityTP2EP4(DistributedTest): + world_size = 8 + reuse_dist_env = False + + def test_folded_router_gate_grad_matches_nonfolded_zero2_baseline(self): + skip_unless_h100_tests_enabled("H100 folded router/gate gradient parity node") + + seed = 1234 + tp_size = 2 + logical_dp_world_size = self.world_size // tp_size + logical_dp_rank = dist.get_rank() // tp_size + + seed_everything(seed) + reference_state = _router_grad_model().state_dict() + baseline_model = _router_grad_model() + baseline_model.load_state_dict(reference_state) + baseline_engine, _, _, _ = deepspeed.initialize(model=baseline_model, config=_zero2_baseline_config()) + _run_router_grad_boundary(baseline_engine, + logical_dp_world_size=logical_dp_world_size, + logical_dp_rank=logical_dp_rank, + seed=seed) + baseline_grad = _full_grad_by_suffix(baseline_engine, "model.layers.0.mlp.gate.weight") + + folded_model = _router_grad_model() + folded_model.load_state_dict(reference_state) + folded_engine, _, _, _ = deepspeed.initialize(model=folded_model, config=_folded_zero2_tp2_ep4_config()) + _run_router_grad_boundary(folded_engine, + logical_dp_world_size=logical_dp_world_size, + logical_dp_rank=logical_dp_rank, + seed=seed) + + folded_grad = _full_grad_by_suffix(folded_engine, "model.layers.0.mlp.router.gate.weight") + metrics = { + **_grad_parity_metrics(folded_grad, baseline_grad), + "nodeid": + "tests/unit/v1/moe/test_autoep_autotp_grad_parity.py::" + "TestH100FoldedRouterGateGradParityTP2EP4::" + "test_folded_router_gate_grad_matches_nonfolded_zero2_baseline", + "rank": + dist.get_rank(), + "target_param": + "model.layers.0.mlp.gate.weight", + "folded_param": + "model.layers.0.mlp.router.gate.weight", + } + if dist.get_rank() == 0: + print("FOLDED_ROUTER_GATE_GRAD_PARITY " + json.dumps(metrics, sort_keys=True)) + + torch.testing.assert_close(folded_grad, + baseline_grad, + atol=1e-1, + rtol=5e-3, + msg=("Folded TP2-EP4 router/gate grad must match the non-folded ZeRO-2 " + f"baseline; metrics={metrics}")) diff --git a/tests/unit/v1/moe/test_autoep_autotp_runtime.py b/tests/unit/v1/moe/test_autoep_autotp_runtime.py new file mode 100644 index 000000000000..040e0329c70f --- /dev/null +++ b/tests/unit/v1/moe/test_autoep_autotp_runtime.py @@ -0,0 +1,243 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Runtime wiring checks for AutoEP + AutoTP folding.""" + +import torch +import torch.nn as nn + +import deepspeed +import deepspeed.comm as dist +from deepspeed.module_inject.auto_ep_config import AutoEPConfig, MoELayerSpec +from deepspeed.module_inject.auto_ep_folding import FoldingGroupHandles, build_folding_spec, local_folding_ranks +from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer +from deepspeed.module_inject.auto_tp import AutoTP +from deepspeed.utils import groups +from unit.common import DistributedTest +from unit.v1.moe.autoep_test_utils import ( + MockMoEOnlyTransformer, + engine_input_dtype, + make_autoep_config, + run_cpu_gloo_test, + seed_everything, + skip_unless_h100_tests_enabled, +) + + +def _make_spec(**overrides): + defaults = dict( + moe_module_name="model.layers.0.mlp", + model_family="mixtral", + router_name="gate", + experts_name="experts", + expert_storage="fused_3d", + expert_w1_name="gate_up_proj", + expert_w2_name="down_proj", + expert_w3_name=None, + num_experts=4, + top_k=2, + hidden_size=8, + ffn_hidden_size=16, + score_func="softmax", + score_apply="post", + route_norm=True, + gate_bias=False, + return_router_logits=False, + router_logits_capture_target="none", + router_logits_capture_index=None, + router_logits_capture_layer_name=None, + has_shared_experts=False, + shared_experts_name="", + shared_experts_gate_name="", + ) + defaults.update(overrides) + return MoELayerSpec(**defaults) + + +class TinySourceMoE(nn.Module): + + def __init__(self): + super().__init__() + self.gate = nn.Linear(8, 4, bias=False) + self.experts = nn.Module() + self.experts.gate_up_proj = nn.Parameter(torch.randn(4, 32, 8)) + self.experts.down_proj = nn.Parameter(torch.randn(4, 8, 16)) + + +def test_folded_layer_binds_explicit_group_handles(monkeypatch): + layer = AutoEPMoELayer(_make_spec(), TinySourceMoE(), ep_size=2, ep_rank=0, config=AutoEPConfig(enabled=True)) + spec = build_folding_spec(world_size=4, pp_size=1, tp_size=2, ep_size=2, etp_size=1) + local = local_folding_ranks(0, spec) + handles = FoldingGroupHandles( + spec=spec, + tp_group=object(), + dense_dp_group=object(), + ep_group=object(), + edp_group=object(), + ep_group_name="ep_size_2", + tp_ranks=local["tp"], + dense_dp_ranks=local["dense_dp"], + ep_ranks=local["ep"], + edp_ranks=local["edp"], + ) + monkeypatch.setattr("deepspeed.module_inject.auto_ep_layer.dist.get_rank", lambda group=None: 0) + + layer.set_deepspeed_parallelism(folding_group_handles=handles) + + assert layer.folding_group_handles is handles + assert layer.tp_group is handles.tp_group + assert layer.ep_group is handles.ep_group + assert layer.ep_group_name == "ep_size_2" + + +def test_autotp_reaches_autoep_shared_experts(monkeypatch): + + class AutoEPLike(nn.Module): + + def __init__(self): + super().__init__() + self._is_autoep_layer = True + self.shared_experts = nn.Linear(8, 8, bias=False) + self.shared_experts_gate = nn.Linear(8, 8, bias=False) + + model = nn.Module() + model.moe = AutoEPLike() + autotp = AutoTP.__new__(AutoTP) + calls = [] + monkeypatch.setattr(autotp, "_replace_autoep_shared_experts", lambda child, name: calls.append((child, name))) + + AutoTP._replace_module(autotp, model) + + assert calls == [(model.moe, "moe")] + + +def _folded_config(zero_stage=0, *, ep_size=2, mixed_precision=True): + config = make_autoep_config(zero_stage=zero_stage, ep_size=ep_size, mixed_precision=mixed_precision) + if not mixed_precision: + config["optimizer"]["params"]["torch_adam"] = True + config["tensor_parallel"] = { + "autotp_size": 2, + "partition_config": { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + }, + } + return config + + +def _tp_consistent_input(engine, *, seed=1234): + torch.manual_seed(seed) + x = torch.randn(1, 4, 64, device=engine.device, dtype=engine_input_dtype(engine)) + dist.broadcast(x, groups.get_tensor_model_parallel_src_rank(), group=groups.get_tensor_model_parallel_group()) + return x + + +def _initialize_folded_engine(*, zero_stage=0, ep_size=2, mixed_precision=True): + seed_everything(1234) + return deepspeed.initialize(model=MockMoEOnlyTransformer(), + config=_folded_config(zero_stage=zero_stage, + ep_size=ep_size, + mixed_precision=mixed_precision)) + + +def _assert_nonzero_named_grad(engine, *name_fragments): + grad_total = 0.0 + matched = False + for name, param in engine.module.named_parameters(): + if not any(fragment in name for fragment in name_fragments): + continue + if param.grad is None: + continue + matched = True + grad_total += param.grad.detach().float().abs().sum().item() + assert matched, f"no gradients found for parameters matching {name_fragments}" + assert grad_total > 0.0 + + +def _cpu_folded_runtime_worker(_rank, _world_size, _shared_tmpdir): + engine, _, _, _ = _initialize_folded_engine(zero_stage=0, mixed_precision=False) + assert engine.autotp_size() == 2 + assert groups.get_tensor_model_parallel_world_size() == 2 + folded_layers = [module for module in engine.module.modules() if isinstance(module, AutoEPMoELayer)] + assert folded_layers + assert all(layer.folding_group_handles is not None for layer in folded_layers) + + x = _tp_consistent_input(engine) + loss = engine(x).float().mean() + engine.backward(loss) + _assert_nonzero_named_grad(engine, "experts.") + _assert_nonzero_named_grad(engine, "router", "gate") + engine.step() + assert torch.isfinite(loss.detach()).item() + + +def test_cpu_gloo_folded_runtime_smoke(tmpdir): + run_cpu_gloo_test(_cpu_folded_runtime_worker, tmpdir, world_size=4) + + +class TestH100FoldedRuntime(DistributedTest): + world_size = 4 + reuse_dist_env = False + + def test_h100_folded_tp2_ep2_runtime(self): + skip_unless_h100_tests_enabled("H100 runtime node") + + engine, _, _, _ = _initialize_folded_engine(zero_stage=0) + assert engine.autotp_size() == 2 + assert groups.get_tensor_model_parallel_world_size() == 2 + folded_layers = [module for module in engine.module.modules() if isinstance(module, AutoEPMoELayer)] + assert folded_layers + assert all(layer.folding_group_handles is not None for layer in folded_layers) + + x = _tp_consistent_input(engine) + loss = engine(x).float().mean() + engine.backward(loss) + _assert_nonzero_named_grad(engine, "experts.") + _assert_nonzero_named_grad(engine, "router", "gate") + engine.step() + assert torch.isfinite(loss.detach()).item() + + +class TestH100FoldedRuntimeReference(DistributedTest): + world_size = 4 + reuse_dist_env = False + + def test_h100_folded_tp2_ep2_finite_loss_smoke(self): + skip_unless_h100_tests_enabled("H100 benchmark node") + + engine, _, _, _ = _initialize_folded_engine(zero_stage=0) + x = _tp_consistent_input(engine) + losses = [] + for _ in range(2): + loss = engine(x).float().mean() + engine.backward(loss) + engine.step() + losses.append(float(loss.detach().cpu())) + assert all(torch.isfinite(torch.tensor(value)) for value in losses) + + +class TestH100FoldedRuntimeTP2EP4(DistributedTest): + world_size = 8 + reuse_dist_env = False + + def test_h100_folded_tp2_ep4_runtime(self): + skip_unless_h100_tests_enabled("H100 TP2-EP4 runtime node") + + engine, _, _, _ = _initialize_folded_engine(zero_stage=0, ep_size=4) + assert engine.autotp_size() == 2 + assert groups.get_tensor_model_parallel_world_size() == 2 + folded_layers = [module for module in engine.module.modules() if isinstance(module, AutoEPMoELayer)] + assert folded_layers + assert all(layer.folding_group_handles is not None for layer in folded_layers) + + x = _tp_consistent_input(engine) + loss = engine(x).float().mean() + engine.backward(loss) + _assert_nonzero_named_grad(engine, "experts.") + _assert_nonzero_named_grad(engine, "router", "gate") + engine.step() + assert torch.isfinite(loss.detach()).item() diff --git a/tests/unit/v1/moe/test_autoep_unit.py b/tests/unit/v1/moe/test_autoep_unit.py index 3104f8f8a9ed..528531a57a2f 100644 --- a/tests/unit/v1/moe/test_autoep_unit.py +++ b/tests/unit/v1/moe/test_autoep_unit.py @@ -187,12 +187,11 @@ def test_load_balance_coeff_rejected_by_validate(self, enabled, value): assert_load_balance_coeff_rejection_message(exc_info.value, value) def test_ep_size_validation_rejects_invalid_topology(self): - with pytest.raises(ValueError, match="AutoTP"): - validate_autoep_config(AutoEPConfig(enabled=True, autoep_size=2), - world_size=8, - pp_size=1, - tp_size=2, - sp_size=1) + validate_autoep_config(AutoEPConfig(enabled=True, autoep_size=2), + world_size=8, + pp_size=1, + tp_size=2, + sp_size=1) with pytest.raises(ValueError, match="must divide the stage size"): validate_autoep_config(AutoEPConfig(enabled=True, autoep_size=3), world_size=8, @@ -248,6 +247,8 @@ def record_create(**kwargs): expert_parallel_config=AutoEPConfig(enabled=True, autoep_size=2), tensor_parallel_config=SimpleNamespace(autotp_size=1), use_data_before_expert_parallel_=False, + zero_config=SimpleNamespace(offload_optimizer=None, offload_param=None), + zero_optimization_stage=0, ) engine._configure_expert_parallel(model=nn.Module()) From 0d44a1d24d1c0d790302edc471f0d11a0de7cf5e Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sat, 13 Jun 2026 11:44:14 -0700 Subject: [PATCH 02/16] Fix folded TP gradient reductions Signed-off-by: Masahiro Tanaka --- deepspeed/moe/ep_tp_dispatch.py | 12 ++++++---- deepspeed/runtime/zero/stage_1_and_2.py | 4 +++- .../v1/moe/test_autoep_autotp_dispatch.py | 5 ++-- .../v1/moe/test_autoep_autotp_grad_parity.py | 23 +++++++++++++++++++ 4 files changed, 37 insertions(+), 7 deletions(-) diff --git a/deepspeed/moe/ep_tp_dispatch.py b/deepspeed/moe/ep_tp_dispatch.py index 2b7b4df94ce1..abde9ae821d9 100644 --- a/deepspeed/moe/ep_tp_dispatch.py +++ b/deepspeed/moe/ep_tp_dispatch.py @@ -213,10 +213,14 @@ def backward(ctx, grad_output): local_count = ctx.counts[ctx.group_rank] if ctx.max_rows == 0: return grad_output.new_empty((0, *grad_output.shape[1:])), None, None, None - chunks = torch.split(grad_output, ctx.counts, dim=0) - grad_padded = grad_output.new_zeros((ctx.max_rows, *grad_output.shape[1:])) - if local_count: - grad_padded[:local_count].copy_(chunks[ctx.group_rank]) + reduced_chunks = [] + for chunk, count in zip(torch.split(grad_output, ctx.counts, dim=0), ctx.counts): + grad_padded = grad_output.new_zeros((ctx.max_rows, *grad_output.shape[1:])) + if count: + grad_padded[:count].copy_(chunk) + dist.all_reduce(grad_padded, group=ctx.group) + reduced_chunks.append(grad_padded) + grad_padded = reduced_chunks[ctx.group_rank] return grad_padded[:local_count].contiguous(), None, None, None diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 65907121d288..8cb60323103d 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1031,6 +1031,8 @@ def configure_autoep_folding_tp_gradient_reduction(self, folding_spec): def _maybe_reduce_autoep_folding_tp_gradient(self, param, grad): if not self.partition_gradients or self.autoep_folding_tp_group is None or grad is None: return + if not getattr(param, "ds_grad_is_ready", True): + return if is_moe_param(param) or is_model_parallel_parameter(param): return if grad.data.is_sparse: @@ -1118,7 +1120,6 @@ def flatten_dense_tensors_aligned(self, tensor_list, alignment, use_cpu_data=Fal def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): grad_reduc = self.get_gradient_for_reduction(param) - self._maybe_reduce_autoep_folding_tp_gradient(param, grad_reduc) comm_dtype = self.get_param_comm_dtype(param) bucket = self.ipg_buckets[comm_dtype] if bucket.elements + param.numel() > self.reduce_bucket_size: @@ -1133,6 +1134,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): if not getattr(param, "ds_grad_is_ready", True): return + self._maybe_reduce_autoep_folding_tp_gradient(param, grad_reduc) param_id = self.get_param_id(param) assert self.params_already_reduced[param_id] == False, \ f"The parameter {debug_param2name(param)} has already been reduced. \ diff --git a/tests/unit/v1/moe/test_autoep_autotp_dispatch.py b/tests/unit/v1/moe/test_autoep_autotp_dispatch.py index 8136a6e95468..5a9d9c970498 100644 --- a/tests/unit/v1/moe/test_autoep_autotp_dispatch.py +++ b/tests/unit/v1/moe/test_autoep_autotp_dispatch.py @@ -151,7 +151,8 @@ def _restore_combined_backward_parity_worker(rank, world_size, _shared_tmpdir): score_apply="post", combine_impl="weighted_sum", shape=(1, 4, 3)) - expected.square().sum().backward() + expected_loss = sum((expected * float(peer_rank + 1)).square().sum() for peer_rank in range(world_size)) + expected_loss.backward() expected_weight_grad = expected_top_scores.grad.reshape(-1).index_select(0, token_indices_sorted) actual_expert_output = full_expert_output.clone().requires_grad_(True) @@ -161,7 +162,7 @@ def _restore_combined_backward_parity_worker(rank, world_size, _shared_tmpdir): restored = restore_combined(local_values, ctx, tp_group=tp_group) torch.testing.assert_close(restored.reshape(1, 4, 3), expected.detach(), rtol=0.0, atol=0.0) - restored.square().sum().backward() + (restored * float(rank + 1)).square().sum().backward() actual_value_grad = actual_expert_output.grad.detach().clone() actual_top_score_grad = actual_top_scores.grad.detach().clone() diff --git a/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py b/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py index 0babb164cf26..009fc88a3f68 100644 --- a/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py +++ b/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py @@ -17,6 +17,7 @@ from deepspeed.checkpoint.constants import FOLDING_FAMILY, FOLDING_METADATA_KEY, FOLDING_PARAM_FAMILIES from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer from deepspeed.runtime.engine import DeepSpeedEngine +from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer from deepspeed.utils import safe_get_full_grad from deepspeed.utils import groups from unit.common import DistributedTest @@ -96,6 +97,28 @@ def fake_all_reduce(tensor, group=None): assert torch.equal(param.grad, torch.full_like(param.grad, expected_grad)) +def test_zero2_tp_gradient_reducer_skips_incomplete_ds_grad(monkeypatch): + param = torch.nn.Parameter(torch.ones(2)) + param.grad = torch.ones_like(param) + param.ds_grad_is_ready = False + optimizer = object.__new__(DeepSpeedZeroOptimizer) + optimizer.partition_gradients = True + optimizer.autoep_folding_tp_group = object() + optimizer.autoep_folding_partitioned_grad_mode = True + calls = [] + + def fake_all_reduce(tensor, group=None): + calls.append(tensor.clone()) + tensor.mul_(2) + + monkeypatch.setattr(dist, "all_reduce", fake_all_reduce) + + optimizer._maybe_reduce_autoep_folding_tp_gradient(param, param.grad) + + assert calls == [] + torch.testing.assert_close(param.grad, torch.ones_like(param.grad)) + + def _folded_zero2_tp2_ep4_config(): config = _folded_zero2_config(mixed_precision=False) config["expert_parallel"]["autoep_size"] = 4 From 8b1c0428364ce4c769f9edc3e674fcd9ecc72f58 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sat, 13 Jun 2026 11:57:36 -0700 Subject: [PATCH 03/16] Normalize folded TP ZeRO gradients Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/zero/stage_1_and_2.py | 6 +++++- .../v1/moe/test_autoep_autotp_grad_parity.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 8cb60323103d..3aefa2566bed 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1038,13 +1038,17 @@ def _maybe_reduce_autoep_folding_tp_gradient(self, param, grad): if grad.data.is_sparse: return grad_data = grad.data + tp_world_size = dist.get_world_size(group=self.autoep_folding_tp_group) if self.autoep_folding_partitioned_grad_mode and grad_data.dtype != torch.float32: reduced = grad_data.float() dist.all_reduce(reduced, group=self.autoep_folding_tp_group) + reduced.div_(tp_world_size) grad_data.copy_(reduced.to(grad_data.dtype)) return dist.all_reduce(grad_data, group=self.autoep_folding_tp_group) - if not self.autoep_folding_partitioned_grad_mode: + if self.autoep_folding_partitioned_grad_mode: + grad_data.div_(tp_world_size) + else: grad_data.div_(dist.get_world_size(group=self.autoep_folding_tp_group)) def _fill_param_grad_accum_attribute(self, param): diff --git a/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py b/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py index 009fc88a3f68..3fd7ef008c88 100644 --- a/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py +++ b/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py @@ -119,6 +119,25 @@ def fake_all_reduce(tensor, group=None): torch.testing.assert_close(param.grad, torch.ones_like(param.grad)) +def test_zero2_tp_gradient_reducer_normalizes_partitioned_mode(monkeypatch): + param = torch.nn.Parameter(torch.ones(2)) + param.grad = torch.ones_like(param) + optimizer = object.__new__(DeepSpeedZeroOptimizer) + optimizer.partition_gradients = True + optimizer.autoep_folding_tp_group = object() + optimizer.autoep_folding_partitioned_grad_mode = True + + def fake_all_reduce(tensor, group=None): + tensor.mul_(2) + + monkeypatch.setattr(dist, "all_reduce", fake_all_reduce) + monkeypatch.setattr(dist, "get_world_size", lambda group=None: 2) + + optimizer._maybe_reduce_autoep_folding_tp_gradient(param, param.grad) + + torch.testing.assert_close(param.grad, torch.ones_like(param.grad)) + + def _folded_zero2_tp2_ep4_config(): config = _folded_zero2_config(mixed_precision=False) config["expert_parallel"]["autoep_size"] = 4 From 7246b146081f1a017202ca065c50c3ca9cdb2d3c Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 23 Jun 2026 01:03:49 -0700 Subject: [PATCH 04/16] Fix AutoEP folded gradient strategy Signed-off-by: Masahiro Tanaka --- deepspeed/module_inject/auto_ep_folding.py | 126 +++++++++ deepspeed/module_inject/auto_ep_layer.py | 2 + deepspeed/runtime/engine.py | 23 +- deepspeed/runtime/zero/stage_1_and_2.py | 25 +- .../moe/test_autoep_autotp_folding_config.py | 4 + .../v1/moe/test_autoep_autotp_grad_parity.py | 259 ++++++++++++++++-- 6 files changed, 375 insertions(+), 64 deletions(-) diff --git a/deepspeed/module_inject/auto_ep_folding.py b/deepspeed/module_inject/auto_ep_folding.py index 6e884d2bdb57..467dc9f8d5cc 100644 --- a/deepspeed/module_inject/auto_ep_folding.py +++ b/deepspeed/module_inject/auto_ep_folding.py @@ -13,6 +13,16 @@ from dataclasses import dataclass from typing import Iterable +import torch + +AUTOEP_FOLDING_PARAM_FAMILY_ATTR = "ds_autoep_folding_param_family" +AUTOEP_FOLDING_ROUTER_GATE_REPLICATED_PARAM = "router_gate_replicated" +AUTOEP_FOLDING_ROUTER_GATE_PARTIAL_PARAM = "router_gate_partial" +AUTOEP_FOLDING_SP_SHARDED_LAYERNORM_PARAM = "sp_sharded_layernorm" +AUTOEP_FOLDING_GRAD_REDUCE_SKIP = "skip" +AUTOEP_FOLDING_GRAD_REDUCE_SUM = "sum" +AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE = "average" + @dataclass(frozen=True) class ParallelFoldingSpec: @@ -207,6 +217,14 @@ def validate_folding_global( f"stage_size = {spec.stage_size}). This is a temporary limitation; use a shape with " "ep * etp <= dp or run a follow-up implementation for cross-lane EP groups.") + if spec.tp_size > 1 and spec.dp_size % expert_width != 0: + raise ValueError("AutoEP+AutoTP folding requires the derived dense data-parallel size to be divisible by " + "expert_parallel.autoep_size * expert_parallel.expert_tensor_parallel_size so expert " + "groups stay within dense data-parallel lanes " + f"(dp = {spec.dp_size}, ep * etp = {expert_width}, stage_size = {spec.stage_size}). " + "Use a shape where dp % (ep * etp) == 0 or run a follow-up implementation for " + "cross-lane EP groups.") + if tp_preset is not None and ep_preset is not None and tp_preset != ep_preset: raise ValueError("tensor_parallel.preset_model and expert_parallel.preset_model must match when both " f"are set (tensor_parallel.preset_model={tp_preset!r}, " @@ -237,6 +255,114 @@ def validate_folding_global( raise ValueError(f"mpu pipeline parallel world size ({mpu_pp}) conflicts with pp_size={spec.pp_size}.") +def mark_autoep_folding_router_parameter(param) -> None: + setattr(param, AUTOEP_FOLDING_PARAM_FAMILY_ATTR, AUTOEP_FOLDING_ROUTER_GATE_REPLICATED_PARAM) + + +def mark_autoep_folding_partial_router_parameter(param) -> None: + setattr(param, AUTOEP_FOLDING_PARAM_FAMILY_ATTR, AUTOEP_FOLDING_ROUTER_GATE_PARTIAL_PARAM) + + +def mark_autoep_folding_sp_sharded_layernorm_parameter(param) -> None: + setattr(param, AUTOEP_FOLDING_PARAM_FAMILY_ATTR, AUTOEP_FOLDING_SP_SHARDED_LAYERNORM_PARAM) + + +def _is_moe_param_marker(param) -> bool: + return hasattr(param, "allreduce") and not param.allreduce + + +def _is_model_parallel_param_marker(param) -> bool: + return bool(getattr(param, "model_parallel", False) or getattr(param, "tensor_model_parallel", False)) + + +def _autoep_folding_param_family(param, *, param_name: str | None = None) -> str | None: + family = getattr(param, AUTOEP_FOLDING_PARAM_FAMILY_ATTR, None) + if family is not None: + return family + if param_name is not None and ".router." in param_name: + return AUTOEP_FOLDING_ROUTER_GATE_REPLICATED_PARAM + return None + + +def autoep_folding_gradient_reduction_strategy( + folding_spec: ParallelFoldingSpec | None, + param, + *, + param_name: str | None = None, +) -> str: + """Classify one folded TP/SP gradient as ``sum``, ``average``, or ``skip``. + + TP means Tensor Parallel and SP means Sequence Parallel. The parallel mode + alone is not a safe SUM-vs-AVG selector because different parameter + families see different backward semantics: + + - Router/gate parameters that are explicitly marked as routed-token + partials in TP/SP token-partitioned modes receive one partial gradient per + lane, so their TP/SP reduction is a SUM. The current AutoEP folded router + gate is marked ``router_gate_replicated`` because the full-flow backward + reaches this reducer as a lane-replicated gradient; that family uses the + same AVERAGE normalization as other replicated parameters. + - Dense and LayerNorm parameters that are merely replicated by TP folding + are not routed-token partials; blindly SUMing them scales gradients by + the TP size, so their extra TP reduction is an AVERAGE. + - A true SP-sharded LayerNorm would be a partial-gradient parameter and + should SUM. The current AutoEP folding path does not mark runtime + LayerNorm parameters that way; the marker and strategy boundary exist so + future SP support has an explicit contract instead of reusing the dense + replicated default by accident. + - Expert parameters and model-parallel parameters are SKIP because their + EP/TP-specific paths own their reductions. + + Both the DeepSpeedEngine path and the ZeRO-2 path call this helper so the + policy cannot silently drift between optimizers. + """ + if folding_spec is None or getattr(folding_spec, "tp_size", 1) <= 1: + return AUTOEP_FOLDING_GRAD_REDUCE_SKIP + if _is_moe_param_marker(param) or _is_model_parallel_param_marker(param): + return AUTOEP_FOLDING_GRAD_REDUCE_SKIP + + family = _autoep_folding_param_family(param, param_name=param_name) + mp_mode = getattr(folding_spec, "mp_mode", "tp") + token_partitioned_mode = mp_mode in ("tp", "sp") + if family == AUTOEP_FOLDING_ROUTER_GATE_PARTIAL_PARAM: + return AUTOEP_FOLDING_GRAD_REDUCE_SUM if token_partitioned_mode else AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE + if family == AUTOEP_FOLDING_ROUTER_GATE_REPLICATED_PARAM: + return AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE + if family == AUTOEP_FOLDING_SP_SHARDED_LAYERNORM_PARAM and mp_mode == "sp": + return AUTOEP_FOLDING_GRAD_REDUCE_SUM + return AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE + + +def reduce_autoep_folding_gradient( + folding_spec: ParallelFoldingSpec | None, + param, + grad, + *, + tp_group, + param_name: str | None = None, +) -> str: + strategy = autoep_folding_gradient_reduction_strategy(folding_spec, param, param_name=param_name) + if strategy == AUTOEP_FOLDING_GRAD_REDUCE_SKIP or grad is None or grad.data.is_sparse: + return strategy + + from deepspeed import comm as dist + + grad_data = grad.data + tp_world_size = dist.get_world_size(group=tp_group) + if grad_data.dtype != torch.float32: + reduced = grad_data.float() + dist.all_reduce(reduced, group=tp_group) + if strategy == AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE: + reduced.div_(tp_world_size) + grad_data.copy_(reduced.to(grad_data.dtype)) + return strategy + + dist.all_reduce(grad_data, group=tp_group) + if strategy == AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE: + grad_data.div_(tp_world_size) + return strategy + + def _normalize_rank_groups(groups: Iterable[Iterable[int]]) -> set[tuple[int, ...]]: return {tuple(int(rank) for rank in group) for group in groups} diff --git a/deepspeed/module_inject/auto_ep_layer.py b/deepspeed/module_inject/auto_ep_layer.py index e1f18633b85b..2b493264408c 100644 --- a/deepspeed/module_inject/auto_ep_layer.py +++ b/deepspeed/module_inject/auto_ep_layer.py @@ -20,6 +20,7 @@ import torch.nn as nn import deepspeed.comm as dist from deepspeed.module_inject.auto_ep_config import AutoEPConfig, MoELayerSpec, resolve_autoep_config_defaults +from deepspeed.module_inject.auto_ep_folding import mark_autoep_folding_router_parameter from deepspeed.utils import logger from deepspeed.moe.ep_router import TokenChoiceTopKRouter from deepspeed.moe.ep_count import count_tokens_per_expert @@ -491,6 +492,7 @@ def __init__( # Mark shared expert and router params for global DP reduction for param in self.router.parameters(): param.allreduce = True + mark_autoep_folding_router_parameter(param) if self.shared_experts is not None: for param in self.shared_experts.parameters(): param.allreduce = True diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index a0b49e01f0e9..fb3a2639ed0c 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -24,7 +24,7 @@ import deepspeed from deepspeed import comm as dist -from deepspeed.runtime.utils import see_memory_usage, DummyOptim, register_output_backward_hooks, check_internal_apis_for_count_used_parameters, is_model_parallel_parameter +from deepspeed.runtime.utils import see_memory_usage, DummyOptim, register_output_backward_hooks, check_internal_apis_for_count_used_parameters from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer @@ -43,6 +43,7 @@ from deepspeed.linear.optimized_linear import LoRAOptimizedLinear from deepspeed.module_inject.layers import GatherReplacedLayerParams, configure_tensor_parallel_runtime, collect_autotp_universal_checkpoint_info +from deepspeed.module_inject.auto_ep_folding import reduce_autoep_folding_gradient from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \ @@ -2603,26 +2604,10 @@ def _reduce_autoep_folding_tp_replicated_gradients(self): if tp_group is None: return - # TP and SP folding modes produce disjoint per-lane router/shared partials. - # Duplicated-token modes already hold a full replicated gradient per lane. - partitioned_grad_mode = getattr(folding_spec, "mp_mode", "tp") in ("tp", "sp") - tp_world_size = dist.get_world_size(group=tp_group) - for _, param in self.module.named_parameters(): + for param_name, param in self.module.named_parameters(): if not param.requires_grad or param.grad is None: continue - if is_moe_param(param) or is_model_parallel_parameter(param): - continue - if param.grad.data.is_sparse: - continue - grad = param.grad.data - if partitioned_grad_mode and grad.dtype != torch.float32: - reduced = grad.float() - dist.all_reduce(reduced, group=tp_group) - grad.copy_(reduced.to(grad.dtype)) - continue - dist.all_reduce(grad, group=tp_group) - if not partitioned_grad_mode: - grad.div_(tp_world_size) + reduce_autoep_folding_gradient(folding_spec, param, param.grad, tp_group=tp_group, param_name=param_name) def _backward_prologue(self): if is_functorch_transforming(): diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 3aefa2566bed..734d84025294 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -35,6 +35,7 @@ from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.accelerator import get_accelerator from deepspeed.runtime.zero.muon.original_muon import muon_update +from deepspeed.module_inject.auto_ep_folding import reduce_autoep_folding_gradient from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS) @@ -250,7 +251,7 @@ def __init__(self, self.has_moe_layers = has_moe_layers self.autoep_folding_tp_group = None - self.autoep_folding_partitioned_grad_mode = False + self.autoep_folding_spec = None if self.has_moe_layers: self._configure_moe_settings() self._global_grad_norm = 0. @@ -1023,33 +1024,17 @@ def overlapping_partition_gradients_reduce_epilogue(self): def configure_autoep_folding_tp_gradient_reduction(self, folding_spec): if folding_spec is None or folding_spec.tp_size <= 1: self.autoep_folding_tp_group = None - self.autoep_folding_partitioned_grad_mode = False + self.autoep_folding_spec = None return self.autoep_folding_tp_group = groups.get_tensor_model_parallel_group() - self.autoep_folding_partitioned_grad_mode = getattr(folding_spec, "mp_mode", "tp") in ("tp", "sp") + self.autoep_folding_spec = folding_spec def _maybe_reduce_autoep_folding_tp_gradient(self, param, grad): if not self.partition_gradients or self.autoep_folding_tp_group is None or grad is None: return if not getattr(param, "ds_grad_is_ready", True): return - if is_moe_param(param) or is_model_parallel_parameter(param): - return - if grad.data.is_sparse: - return - grad_data = grad.data - tp_world_size = dist.get_world_size(group=self.autoep_folding_tp_group) - if self.autoep_folding_partitioned_grad_mode and grad_data.dtype != torch.float32: - reduced = grad_data.float() - dist.all_reduce(reduced, group=self.autoep_folding_tp_group) - reduced.div_(tp_world_size) - grad_data.copy_(reduced.to(grad_data.dtype)) - return - dist.all_reduce(grad_data, group=self.autoep_folding_tp_group) - if self.autoep_folding_partitioned_grad_mode: - grad_data.div_(tp_world_size) - else: - grad_data.div_(dist.get_world_size(group=self.autoep_folding_tp_group)) + reduce_autoep_folding_gradient(self.autoep_folding_spec, param, grad, tp_group=self.autoep_folding_tp_group) def _fill_param_grad_accum_attribute(self, param): if param.grad is not None: diff --git a/tests/unit/v1/moe/test_autoep_autotp_folding_config.py b/tests/unit/v1/moe/test_autoep_autotp_folding_config.py index 6c972bb4efac..6dee57889862 100644 --- a/tests/unit/v1/moe/test_autoep_autotp_folding_config.py +++ b/tests/unit/v1/moe/test_autoep_autotp_folding_config.py @@ -133,3 +133,7 @@ def test_validation_rule_g11_zero_offload_rejected(offload_key): def test_validation_rule_g12_cross_lane_ep_groups_temporarily_rejected(): _assert_rejects("temporary limitation", world_size=8, tp_size=4, ep_size=4) + + +def test_validation_rule_g13_expert_width_must_tile_dense_dp_lane(): + _assert_rejects("dp % \\(ep \\* etp\\) == 0", world_size=12, tp_size=3, ep_size=3) diff --git a/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py b/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py index 3fd7ef008c88..34f8b1450eca 100644 --- a/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py +++ b/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py @@ -15,6 +15,15 @@ import deepspeed.comm as dist from deepspeed.checkpoint.autoep_universal import validate_folding_metadata from deepspeed.checkpoint.constants import FOLDING_FAMILY, FOLDING_METADATA_KEY, FOLDING_PARAM_FAMILIES +from deepspeed.module_inject.auto_ep_folding import ( + AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE, + AUTOEP_FOLDING_GRAD_REDUCE_SKIP, + AUTOEP_FOLDING_GRAD_REDUCE_SUM, + autoep_folding_gradient_reduction_strategy, + mark_autoep_folding_partial_router_parameter, + mark_autoep_folding_router_parameter, + mark_autoep_folding_sp_sharded_layernorm_parameter, +) from deepspeed.module_inject.auto_ep_layer import AutoEPMoELayer from deepspeed.runtime.engine import DeepSpeedEngine from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer @@ -33,6 +42,10 @@ from deepspeed.module_inject.auto_ep_config import AutoEPConfig, validate_autoep_config +def _folding_spec(mp_mode="tp", tp_size=2): + return type("Spec", (), {"tp_size": tp_size, "mp_mode": mp_mode})() + + def test_zero_offload_paths_fail_fast_until_per_family_replica_groups_are_proven(): for kwargs in ({"zero_offload_optimizer": True}, {"zero_offload_param": True}): with pytest.raises(ValueError, match="offload"): @@ -74,15 +87,62 @@ def _folded_zero2_config(*, mixed_precision=True): return config -@pytest.mark.parametrize(("mp_mode", "expected_grad"), (("tp", 2.0), ("sp", 2.0), ("replicated", 1.0))) -def test_tp_replicated_gradient_reducer_respects_parallel_mode(monkeypatch, mp_mode, expected_grad): +def test_autoep_folding_gradient_strategy_uses_parameter_family(): + router = torch.nn.Parameter(torch.ones(2)) + mark_autoep_folding_router_parameter(router) + assert (autoep_folding_gradient_reduction_strategy(_folding_spec("tp"), + router) == AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE) + + partial_router = torch.nn.Parameter(torch.ones(2)) + mark_autoep_folding_partial_router_parameter(partial_router) + assert (autoep_folding_gradient_reduction_strategy(_folding_spec("tp"), + partial_router) == AUTOEP_FOLDING_GRAD_REDUCE_SUM) + assert (autoep_folding_gradient_reduction_strategy(_folding_spec("sp"), + partial_router) == AUTOEP_FOLDING_GRAD_REDUCE_SUM) + assert (autoep_folding_gradient_reduction_strategy(_folding_spec("replicated"), + router) == AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE) + + dense_or_layernorm = torch.nn.Parameter(torch.ones(2)) + assert (autoep_folding_gradient_reduction_strategy(_folding_spec("tp"), + dense_or_layernorm) == AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE) + + sp_layernorm = torch.nn.Parameter(torch.ones(2)) + mark_autoep_folding_sp_sharded_layernorm_parameter(sp_layernorm) + assert (autoep_folding_gradient_reduction_strategy(_folding_spec("sp"), + sp_layernorm) == AUTOEP_FOLDING_GRAD_REDUCE_SUM) + assert (autoep_folding_gradient_reduction_strategy(_folding_spec("tp"), + sp_layernorm) == AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE) + + expert = torch.nn.Parameter(torch.ones(2)) + expert.allreduce = False + assert autoep_folding_gradient_reduction_strategy(_folding_spec("tp"), expert) == AUTOEP_FOLDING_GRAD_REDUCE_SKIP + + model_parallel = torch.nn.Parameter(torch.ones(2)) + model_parallel.tensor_model_parallel = True + assert (autoep_folding_gradient_reduction_strategy(_folding_spec("tp"), + model_parallel) == AUTOEP_FOLDING_GRAD_REDUCE_SKIP) + + +@pytest.mark.parametrize( + ("param_name", "mark_router", "mp_mode", "expected_grad"), + ( + ("model.layers.0.mlp.router.gate.weight", True, "tp", 1.0), + ("model.layers.0.mlp.router.gate.weight", True, "sp", 1.0), + ("model.layers.0.mlp.router.gate.weight", True, "replicated", 1.0), + ("model.layers.0.input_layernorm.weight", False, "tp", 1.0), + ), +) +def test_tp_replicated_gradient_reducer_respects_param_family(monkeypatch, param_name, mark_router, mp_mode, + expected_grad): param = torch.nn.Parameter(torch.ones(2)) param.grad = torch.ones_like(param) + if mark_router: + mark_autoep_folding_router_parameter(param) engine = object.__new__(DeepSpeedEngine) - engine._autoep_folding_spec = type("Spec", (), {"tp_size": 2, "mp_mode": mp_mode})() + engine._autoep_folding_spec = _folding_spec(mp_mode) engine.__dict__["optimizer"] = None engine.__dict__["module"] = type("ModuleStub", (), - {"named_parameters": lambda self: iter([("dense.weight", param)])})() + {"named_parameters": lambda self: iter([(param_name, param)])})() monkeypatch.setattr(dist, "is_initialized", lambda: True) monkeypatch.setattr(groups, "get_tensor_model_parallel_group", lambda: object()) monkeypatch.setattr(dist, "get_world_size", lambda group=None: 2) @@ -104,7 +164,7 @@ def test_zero2_tp_gradient_reducer_skips_incomplete_ds_grad(monkeypatch): optimizer = object.__new__(DeepSpeedZeroOptimizer) optimizer.partition_gradients = True optimizer.autoep_folding_tp_group = object() - optimizer.autoep_folding_partitioned_grad_mode = True + optimizer.autoep_folding_spec = _folding_spec("tp") calls = [] def fake_all_reduce(tensor, group=None): @@ -119,13 +179,16 @@ def fake_all_reduce(tensor, group=None): torch.testing.assert_close(param.grad, torch.ones_like(param.grad)) -def test_zero2_tp_gradient_reducer_normalizes_partitioned_mode(monkeypatch): +@pytest.mark.parametrize(("mark_router", "expected_grad"), ((True, 1.0), (False, 1.0))) +def test_zero2_tp_gradient_reducer_uses_shared_param_family_strategy(monkeypatch, mark_router, expected_grad): param = torch.nn.Parameter(torch.ones(2)) param.grad = torch.ones_like(param) + if mark_router: + mark_autoep_folding_router_parameter(param) optimizer = object.__new__(DeepSpeedZeroOptimizer) optimizer.partition_gradients = True optimizer.autoep_folding_tp_group = object() - optimizer.autoep_folding_partitioned_grad_mode = True + optimizer.autoep_folding_spec = _folding_spec("tp") def fake_all_reduce(tensor, group=None): tensor.mul_(2) @@ -135,7 +198,7 @@ def fake_all_reduce(tensor, group=None): optimizer._maybe_reduce_autoep_folding_tp_gradient(param, param.grad) - torch.testing.assert_close(param.grad, torch.ones_like(param.grad)) + torch.testing.assert_close(param.grad, torch.full_like(param.grad, expected_grad)) def _folded_zero2_tp2_ep4_config(): @@ -145,6 +208,26 @@ def _folded_zero2_tp2_ep4_config(): return config +def _folded_zero0_tp2_ep4_config(): + config = make_autoep_config(zero_stage=0, ep_size=4, mixed_precision=False) + config["gradient_accumulation_steps"] = 2 + config["gradient_clipping"] = 0.0 + config["communication_data_type"] = "fp32" + config["optimizer"]["params"]["torch_adam"] = True + config["expert_parallel"]["autoep_size"] = 4 + config["tensor_parallel"] = { + "autotp_size": 2, + "partition_config": { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + }, + } + return config + + def _zero2_baseline_config(): config = { **{ @@ -159,6 +242,25 @@ def _zero2_baseline_config(): return config +def _zero0_baseline_config(): + config = { + **{ + key: value + for key, value in make_autoep_config(zero_stage=0, ep_size=1, mixed_precision=False).items() if key != "expert_parallel" + }, + "gradient_accumulation_steps": 2, + "gradient_clipping": 0.0, + } + config["communication_data_type"] = "fp32" + config["optimizer"]["params"]["torch_adam"] = True + return config + + +GATE_BASELINE = "model.layers.0.mlp.gate.weight" +GATE_FOLDED = "model.layers.0.mlp.router.gate.weight" +INPUT_LAYERNORM = "model.layers.0.input_layernorm.weight" + + def _router_grad_model(): return MockMoEOnlyTransformer(num_layers=1, num_experts=4, hidden_size=64, intermediate_size=128) @@ -195,6 +297,102 @@ def _full_grad_by_suffix(engine, suffix): raise AssertionError(f"Missing parameter ending with {suffix}") +def _cpu_folded_zero0_router_gate_and_layernorm_worker(rank, world_size, _shared_tmpdir): + seed = 1234 + tp_size = 2 + logical_dp_world_size = world_size // tp_size + logical_dp_rank = rank // tp_size + + seed_everything(seed) + reference_state = _router_grad_model().state_dict() + + baseline_model = _router_grad_model() + baseline_model.load_state_dict(reference_state) + baseline_engine, _, _, _ = deepspeed.initialize(model=baseline_model, config=_zero0_baseline_config()) + _run_router_grad_boundary(baseline_engine, + logical_dp_world_size=logical_dp_world_size, + logical_dp_rank=logical_dp_rank, + seed=seed) + baseline_gate = _full_grad_by_suffix(baseline_engine, GATE_BASELINE) + baseline_layernorm = _full_grad_by_suffix(baseline_engine, INPUT_LAYERNORM) + + folded_model = _router_grad_model() + folded_model.load_state_dict(reference_state) + folded_engine, _, _, _ = deepspeed.initialize(model=folded_model, config=_folded_zero0_tp2_ep4_config()) + _run_router_grad_boundary(folded_engine, + logical_dp_world_size=logical_dp_world_size, + logical_dp_rank=logical_dp_rank, + seed=seed) + folded_gate = _full_grad_by_suffix(folded_engine, GATE_FOLDED) + folded_layernorm = _full_grad_by_suffix(folded_engine, INPUT_LAYERNORM) + + metrics = { + "rank": rank, + "gate": _grad_parity_metrics(folded_gate, baseline_gate), + "layernorm": _grad_parity_metrics(folded_layernorm, baseline_layernorm), + } + if rank == 0: + print("FOLDED_ENGINE_ZERO0_ROUTER_GATE_LAYERNORM_GRAD_PARITY " + json.dumps(metrics, sort_keys=True)) + torch.testing.assert_close(folded_gate, + baseline_gate, + atol=1e-1, + rtol=5e-3, + msg=f"Folded zero_stage=0 router/gate grad must match baseline; metrics={metrics}") + torch.testing.assert_close(folded_layernorm, + baseline_layernorm, + atol=1e-1, + rtol=5e-3, + msg=f"Folded zero_stage=0 LayerNorm grad must match baseline; metrics={metrics}") + + +def _cpu_folded_zero2_router_gate_and_layernorm_worker(rank, world_size, _shared_tmpdir): + seed = 1234 + tp_size = 2 + logical_dp_world_size = world_size // tp_size + logical_dp_rank = rank // tp_size + + seed_everything(seed) + reference_state = _router_grad_model().state_dict() + + baseline_model = _router_grad_model() + baseline_model.load_state_dict(reference_state) + baseline_engine, _, _, _ = deepspeed.initialize(model=baseline_model, config=_zero2_baseline_config()) + _run_router_grad_boundary(baseline_engine, + logical_dp_world_size=logical_dp_world_size, + logical_dp_rank=logical_dp_rank, + seed=seed) + baseline_gate = _full_grad_by_suffix(baseline_engine, GATE_BASELINE) + baseline_layernorm = _full_grad_by_suffix(baseline_engine, INPUT_LAYERNORM) + + folded_model = _router_grad_model() + folded_model.load_state_dict(reference_state) + folded_engine, _, _, _ = deepspeed.initialize(model=folded_model, config=_folded_zero2_tp2_ep4_config()) + _run_router_grad_boundary(folded_engine, + logical_dp_world_size=logical_dp_world_size, + logical_dp_rank=logical_dp_rank, + seed=seed) + folded_gate = _full_grad_by_suffix(folded_engine, GATE_FOLDED) + folded_layernorm = _full_grad_by_suffix(folded_engine, INPUT_LAYERNORM) + + metrics = { + "rank": rank, + "gate": _grad_parity_metrics(folded_gate, baseline_gate), + "layernorm": _grad_parity_metrics(folded_layernorm, baseline_layernorm), + } + if rank == 0: + print("FOLDED_ZERO2_ROUTER_GATE_LAYERNORM_GRAD_PARITY " + json.dumps(metrics, sort_keys=True)) + torch.testing.assert_close(folded_gate, + baseline_gate, + atol=1e-1, + rtol=5e-3, + msg=f"Folded ZeRO-2 router/gate grad must match baseline; metrics={metrics}") + torch.testing.assert_close(folded_layernorm, + baseline_layernorm, + atol=1e-1, + rtol=5e-3, + msg=f"Folded ZeRO-2 LayerNorm grad must match baseline; metrics={metrics}") + + def _grad_parity_metrics(actual, expected): diff = actual - expected expected_norm_sq = expected.square().sum().item() @@ -213,6 +411,14 @@ def _grad_parity_metrics(actual, expected): } +def test_cpu_gloo_folded_zero0_router_gate_and_layernorm_grad_parity(tmpdir): + run_cpu_gloo_test(_cpu_folded_zero0_router_gate_and_layernorm_worker, tmpdir, world_size=8) + + +def test_cpu_gloo_folded_zero2_router_gate_and_layernorm_grad_parity(tmpdir): + run_cpu_gloo_test(_cpu_folded_zero2_router_gate_and_layernorm_worker, tmpdir, world_size=8) + + def _assert_zero_optimizer_folding_metadata(checkpoint_dir): optim_paths = sorted(glob.glob(os.path.join(str(checkpoint_dir), "folded-zero2", "*_optim_states.pt"))) assert optim_paths @@ -291,8 +497,8 @@ class TestH100FoldedRouterGateGradParityTP2EP4(DistributedTest): world_size = 8 reuse_dist_env = False - def test_folded_router_gate_grad_matches_nonfolded_zero2_baseline(self): - skip_unless_h100_tests_enabled("H100 folded router/gate gradient parity node") + def test_folded_router_gate_and_layernorm_grad_match_nonfolded_zero2_baseline(self): + skip_unless_h100_tests_enabled("H100 folded router/gate and LayerNorm gradient parity node") seed = 1234 tp_size = 2 @@ -308,7 +514,8 @@ def test_folded_router_gate_grad_matches_nonfolded_zero2_baseline(self): logical_dp_world_size=logical_dp_world_size, logical_dp_rank=logical_dp_rank, seed=seed) - baseline_grad = _full_grad_by_suffix(baseline_engine, "model.layers.0.mlp.gate.weight") + baseline_gate_grad = _full_grad_by_suffix(baseline_engine, GATE_BASELINE) + baseline_layernorm_grad = _full_grad_by_suffix(baseline_engine, INPUT_LAYERNORM) folded_model = _router_grad_model() folded_model.load_state_dict(reference_state) @@ -318,26 +525,28 @@ def test_folded_router_gate_grad_matches_nonfolded_zero2_baseline(self): logical_dp_rank=logical_dp_rank, seed=seed) - folded_grad = _full_grad_by_suffix(folded_engine, "model.layers.0.mlp.router.gate.weight") + folded_gate_grad = _full_grad_by_suffix(folded_engine, GATE_FOLDED) + folded_layernorm_grad = _full_grad_by_suffix(folded_engine, INPUT_LAYERNORM) metrics = { - **_grad_parity_metrics(folded_grad, baseline_grad), - "nodeid": - "tests/unit/v1/moe/test_autoep_autotp_grad_parity.py::" + "nodeid": "tests/unit/v1/moe/test_autoep_autotp_grad_parity.py::" "TestH100FoldedRouterGateGradParityTP2EP4::" - "test_folded_router_gate_grad_matches_nonfolded_zero2_baseline", - "rank": - dist.get_rank(), - "target_param": - "model.layers.0.mlp.gate.weight", - "folded_param": - "model.layers.0.mlp.router.gate.weight", + "test_folded_router_gate_and_layernorm_grad_match_nonfolded_zero2_baseline", + "rank": dist.get_rank(), + "gate": _grad_parity_metrics(folded_gate_grad, baseline_gate_grad), + "layernorm": _grad_parity_metrics(folded_layernorm_grad, baseline_layernorm_grad), } if dist.get_rank() == 0: - print("FOLDED_ROUTER_GATE_GRAD_PARITY " + json.dumps(metrics, sort_keys=True)) + print("FOLDED_ROUTER_GATE_LAYERNORM_GRAD_PARITY " + json.dumps(metrics, sort_keys=True)) - torch.testing.assert_close(folded_grad, - baseline_grad, + torch.testing.assert_close(folded_gate_grad, + baseline_gate_grad, atol=1e-1, rtol=5e-3, msg=("Folded TP2-EP4 router/gate grad must match the non-folded ZeRO-2 " f"baseline; metrics={metrics}")) + torch.testing.assert_close(folded_layernorm_grad, + baseline_layernorm_grad, + atol=1e-1, + rtol=5e-3, + msg=("Folded TP2-EP4 LayerNorm grad must match the non-folded ZeRO-2 " + f"baseline; metrics={metrics}")) From e5c1ba2319ff824f8ae86afa90e4bbdfdb848ad7 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 23 Jun 2026 12:02:27 -0700 Subject: [PATCH 05/16] Document AutoEP folding gradient-reduction rationale Comment-only; no behavior change. Make the folded TP/SP gradient SUM-vs-AVERAGE policy and its reasoning explicit in the code: - mark_autoep_folding_* contracts: router_gate_replicated is the only marker applied on the live forward path (AVERAGE); routed-token-partial and SP-sharded-LayerNorm are future-only SUM contracts pinned by unit tests. - _AllGatherVariableRows / restore_combined: document the tp_size factor the restore all-gather backward injects, which makes the folded router/gate arrive replicated (AVERAGE) and why SUM reproduces the 2.0x parity regression the CPU/Gloo tests guard. - partition_assignments: the drop-before-all-to-all keeps EP dispatch at 1x volume, and the reconstruction is what makes the gradient replicated. - Strategy classifier: the router-vs-LayerNorm asymmetry (router can be a tp/sp partial because it rides the dispatch all-to-all; LayerNorm only under true SP) and the underlying gathered=AVERAGE / sharded=SUM rule. Signed-off-by: Masahiro Tanaka --- deepspeed/module_inject/auto_ep_folding.py | 62 ++++++++++++++++++++++ deepspeed/module_inject/auto_ep_layer.py | 6 ++- deepspeed/moe/ep_tp_dispatch.py | 51 +++++++++++++++++- 3 files changed, 116 insertions(+), 3 deletions(-) diff --git a/deepspeed/module_inject/auto_ep_folding.py b/deepspeed/module_inject/auto_ep_folding.py index 467dc9f8d5cc..99ab2c2f3cc2 100644 --- a/deepspeed/module_inject/auto_ep_folding.py +++ b/deepspeed/module_inject/auto_ep_folding.py @@ -256,14 +256,51 @@ def validate_folding_global( def mark_autoep_folding_router_parameter(param) -> None: + """Tag a router/gate parameter as the *replicated* folded family (AVERAGE). + + This is the ONLY family marker applied on the live forward path today: + ``AutoEPMoELayer.__init__`` marks every ``router.*`` parameter with it. The + folded router runs redundantly on every TP peer (same tokens, same routing) + and its gradient is reconstructed into a replicated full view by the restore + all-gather (see ``deepspeed.moe.ep_tp_dispatch._AllGatherVariableRows`` and + ``restore_combined``). That all-gather backward scales each peer's slice by + ``tp_size``, so the extra TP reduction must AVERAGE (all_reduce then divide + by ``tp_size``); SUM would leave the ``tp_size`` factor, i.e. the 2.0x + parity regression the CPU/Gloo tests guard. + """ setattr(param, AUTOEP_FOLDING_PARAM_FAMILY_ATTR, AUTOEP_FOLDING_ROUTER_GATE_REPLICATED_PARAM) def mark_autoep_folding_partial_router_parameter(param) -> None: + """Tag a router/gate parameter as a *routed-token partial* family (SUM). + + Forward-looking contract; NOT used on the current forward path -- only the + unit tests in ``tests/unit/v1/moe/test_autoep_autotp_grad_parity.py`` set + it. Use it only for a future design where the router's per-token work is + genuinely partitioned across peers and the slices are NOT all-gathered back + into a replicated full view, so each peer holds a real partial gradient that + must be SUMed. Such a router is a SUM partial in any token-partitioned mode + (``mp_mode in {"tp", "sp"}``) because its partition can ride the existing + expert-dispatch all-to-all without changing the dense activation layout. + Prove the SUM with a parity test (like the existing router/gate cases) + before enabling it on a real forward path. + """ setattr(param, AUTOEP_FOLDING_PARAM_FAMILY_ATTR, AUTOEP_FOLDING_ROUTER_GATE_PARTIAL_PARAM) def mark_autoep_folding_sp_sharded_layernorm_parameter(param) -> None: + """Tag a LayerNorm parameter as *SP-sequence-sharded* family (SUM under SP). + + Forward-looking contract; NOT used on the current forward path -- only the + unit tests set it. Unlike the router, a LayerNorm has no adjacent dispatch + all-to-all to ride on, so the only way to token-partition it is to shard the + sequence dimension of the dense activations, which is Sequence Parallel by + definition. It therefore becomes a SUM partial only when ``mp_mode == "sp"`` + and otherwise falls back to the replicated AVERAGE. Today ``tp_size > 1`` + with sequence parallelism is rejected in ``validate_folding_global``; this + marker is the explicit contract for when that restriction is lifted, and + must be backed by a parity test before use. + """ setattr(param, AUTOEP_FOLDING_PARAM_FAMILY_ATTR, AUTOEP_FOLDING_SP_SHARDED_LAYERNORM_PARAM) @@ -276,6 +313,16 @@ def _is_model_parallel_param_marker(param) -> bool: def _autoep_folding_param_family(param, *, param_name: str | None = None) -> str | None: + """Resolve a parameter's folded reduction family. + + An explicit ``mark_autoep_folding_*`` tag always wins. The ``.router.`` name + match is only a redundant safety net: ``AutoEPMoELayer`` already tags router + params, so this fallback merely keeps the conservative *replicated* (AVERAGE) + classification if some router param ever reaches the reducer untagged. It + never returns a SUM family by name -- SUM families are opt-in via explicit + markers only, so any unrecognized replicated/dense/LayerNorm param falls + through to the AVERAGE default rather than being silently over-scaled. + """ family = getattr(param, AUTOEP_FOLDING_PARAM_FAMILY_ATTR, None) if family is not None: return family @@ -313,6 +360,21 @@ def autoep_folding_gradient_reduction_strategy( - Expert parameters and model-parallel parameters are SKIP because their EP/TP-specific paths own their reductions. + Underlying rule and mechanism: a folded parameter is replicated (AVERAGE) + when the forward reconstructs its partitioned work into an identical full + view inside the layer, and a genuine partial (SUM) only when the shard is + kept all the way to the loss. Today the router/gate is partitioned across + TP peers for dispatch but then all-gathered back by ``restore_combined`` + (see ``deepspeed.moe.ep_tp_dispatch``), whose backward scales each peer's + gradient by ``tp_size``; the TP all_reduce then yields ``tp_size * + full_grad`` and AVERAGE divides it out. Reducing with SUM would leave that + factor -- the 2.0x router/gate parity regression the CPU/Gloo tests guard. + The router can be a SUM partial in either ``tp`` or ``sp`` mode because its + token partition can ride the existing dispatch all-to-all, whereas a + LayerNorm becomes a partial only under true ``sp`` (sequence sharding): it + has no adjacent all-to-all, so partitioning it requires changing the dense + activation layout, which is Sequence Parallel by definition. + Both the DeepSpeedEngine path and the ZeRO-2 path call this helper so the policy cannot silently drift between optimizers. """ diff --git a/deepspeed/module_inject/auto_ep_layer.py b/deepspeed/module_inject/auto_ep_layer.py index 2b493264408c..8e70872cc426 100644 --- a/deepspeed/module_inject/auto_ep_layer.py +++ b/deepspeed/module_inject/auto_ep_layer.py @@ -489,7 +489,11 @@ def __init__( param.allreduce = False param.group_name = self.ep_group_name - # Mark shared expert and router params for global DP reduction + # Mark shared expert and router params for global DP reduction. + # The router runs redundantly on every TP peer and its gradient is + # rebuilt into a replicated full view by the restore all-gather, so it + # is tagged as the replicated family (AVERAGE TP reduction); a SUM would + # double it under tp_size=2. See mark_autoep_folding_router_parameter. for param in self.router.parameters(): param.allreduce = True mark_autoep_folding_router_parameter(param) diff --git a/deepspeed/moe/ep_tp_dispatch.py b/deepspeed/moe/ep_tp_dispatch.py index abde9ae821d9..c085ee7b1cc3 100644 --- a/deepspeed/moe/ep_tp_dispatch.py +++ b/deepspeed/moe/ep_tp_dispatch.py @@ -156,7 +156,16 @@ def partition_assignments( tp_rank: int, tp_size: int, ) -> tuple[RoutedAssignmentPayload, RestoreContext]: - """Partition routed assignments across TP peers by stable per-expert ordinal.""" + """Partition routed assignments across TP peers by stable per-expert ordinal. + + Each peer keeps only ``assignment_index % tp_size == tp_rank`` of the + (token, expert) assignments and drops the rest *before* the EP dispatch + all-to-all, so the dispatch carries the full token set exactly once (split + across peers) instead of ``tp_size`` redundant copies. The dropped work is + reconstructed afterwards by ``restore_combined``'s all-gather; that + reconstruction is what makes the folded router/gate gradient replicated + (AVERAGE) rather than a true SUM partial -- see ``_AllGatherVariableRows``. + """ active = ~payload.drop_mask & ~payload.pad_mask if tp_size <= 1: keep = active @@ -194,6 +203,31 @@ def _pad_rows(tensor: torch.Tensor, rows: int) -> torch.Tensor: class _AllGatherVariableRows(torch.autograd.Function): + """Differentiable all-gather of row-variable tensors across the TP folding group. + + Forward concatenates every TP peer's local rows into one tensor that is + identical on every peer: a replicated full view of the rows that + ``partition_assignments`` had split across peers before the EP dispatch. + + Backward is the matching reduce-scatter. Because the forward output is + consumed identically on every peer, each peer holds the same ``grad_output``; + summing those replicas with ``all_reduce`` and keeping this peer's own + row-slice is the correct vector-Jacobian product. + + Gradient-reduction consequence (important -- this is why the folded + router/gate uses AVERAGE, not SUM): the ``all_reduce`` in backward scales + each peer's slice gradient by ``tp_size``. A parameter whose gradient flows + through this restore all-gather -- the folded router/gate scores, see + ``restore_combined`` -- therefore reaches the optimizer's TP reducer carrying + ``tp_size`` times its own routed-token slice. The TP reducer all_reduce then + produces ``tp_size * full_grad``, and the AVERAGE strategy in + ``auto_ep_folding.autoep_folding_gradient_reduction_strategy`` divides by + ``tp_size`` to recover the true gradient. Reducing with SUM instead leaves + the uncancelled ``tp_size`` factor -- exactly the 2.0x router/gate gradient + regression the CPU/Gloo parity tests guard against. The partition is + reconstructed into a replicated full view here, so it is not a genuine SUM + partial; a future true-SP path that kept the shard to the loss would be. + """ @staticmethod def forward(ctx, tensor, group, counts, max_rows): @@ -218,6 +252,9 @@ def backward(ctx, grad_output): grad_padded = grad_output.new_zeros((ctx.max_rows, *grad_output.shape[1:])) if count: grad_padded[:count].copy_(chunk) + # grad_output is replicated across TP peers (the gathered full view + # is consumed identically), so this all_reduce sums tp_size copies + # and injects the tp_size factor documented in the class docstring. dist.all_reduce(grad_padded, group=ctx.group) reduced_chunks.append(grad_padded) grad_padded = reduced_chunks[ctx.group_rank] @@ -278,7 +315,17 @@ def _debug_validate_restore_coverage(payload: RoutedAssignmentPayload, ctx: Rest def restore_combined(local_combined: torch.Tensor, ctx: RestoreContext, *, tp_group) -> torch.Tensor: - """Gather TP-partitioned assignment outputs and combine back by token index.""" + """Gather TP-partitioned assignment outputs and combine back by token index. + + The all-gather rebuilds an identical full output on every TP peer, so all + downstream compute (and the router/gate score gradient) is replicated across + the folding group. Its differentiable backward injects a ``tp_size`` factor + (see ``_AllGatherVariableRows``) that the optimizer's TP gradient reducer + cancels with the AVERAGE strategy. A future true-SP path that kept + activations sequence-sharded instead of gathering them here would make those + parameters genuine SUM partials -- the reason the SUM family markers exist + in ``deepspeed.module_inject.auto_ep_folding``. + """ payload = ctx.original_payload local_token_indices = payload.token_indices.index_select(0, ctx.local_indices) local_expert_indices = payload.expert_indices.index_select(0, ctx.local_indices) From 13dceecf5d3b6d0a535d16c21f1ffc8315c40e1f Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 23 Jun 2026 16:02:52 -0700 Subject: [PATCH 06/16] Generalize AutoEP+AutoTP folding to cross-lane expert parallelism Lift the temporary fail-fast that forced expert parallelism to be a subset of data parallelism (expert_width = ep*etp <= dp and dp % expert_width == 0) and support cross-lane EP, where an EP group may span TP lanes and dense-DP ranks (e.g. world=4 TP4/EP4/dp1, or world=4 TP2/DP2/EP4). The folded group tables already lay EP groups across the tp-lane-major rank ordering, so only the validation gate and the routed-expert gradient reduction needed to change. Fix routed-expert gradient over-scaling under folding. The folded forward all-gathers expert outputs into a replicated full view in restore_combined, whose backward injects a tp_size factor (the same factor the replicated router cancels via AVERAGE). Routed experts were classified SKIP, so that factor survived and the expert weight gradient reached the optimizer tp_size times too large. This was invisible to scale-invariant Adam but real for non-adaptive optimizers and for gradient clipping (it inflates the expert contribution to the global grad norm), and it grows with tp_size in cross-lane shapes. Routed experts now use a dedicated EXPERT_TP_CANCEL reduction that divides the gradient by tp_size with no TP all_reduce; the division is linear and composes with the existing expert-data-parallel all_reduce in any order. The per-family gradient convention is otherwise unchanged and is now justified for the general cross-lane layout: each family's reduction is keyed to its replication structure, not the EP layout. Router/gate and dense/LayerNorm AVERAGE over the TP (token-replication) group; routed experts cancel the restore tp_size factor and reduce data-parallel over the EDP group. Add CPU/Gloo parity coverage for cross-lane shapes (router/gate and LayerNorm vs a non-folded baseline at world=4 TP2/EP4 and TP4/EP4/dp1) and SGD-based expert-weight parity (post-step weight equality vs the non-folded baseline) for the MVP shape and cross-lane shapes including edp>1. Adam is scale-invariant, so the expert test uses SGD to actually exercise the gradient magnitude. Update the folding config/group-table tests: cross-lane shapes are now accepted, and the cross-lane EP/EDP rank tables are asserted. Signed-off-by: Masahiro Tanaka --- deepspeed/module_inject/auto_ep_folding.py | 73 +++-- .../moe/test_autoep_autotp_folding_config.py | 42 ++- .../v1/moe/test_autoep_autotp_grad_parity.py | 261 +++++++++++++++++- 3 files changed, 353 insertions(+), 23 deletions(-) diff --git a/deepspeed/module_inject/auto_ep_folding.py b/deepspeed/module_inject/auto_ep_folding.py index 99ab2c2f3cc2..2c26542e331e 100644 --- a/deepspeed/module_inject/auto_ep_folding.py +++ b/deepspeed/module_inject/auto_ep_folding.py @@ -22,6 +22,15 @@ AUTOEP_FOLDING_GRAD_REDUCE_SKIP = "skip" AUTOEP_FOLDING_GRAD_REDUCE_SUM = "sum" AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE = "average" +# Divide by tp_size with NO TP all_reduce. Used for routed-expert parameters: the +# folded forward all-gathers expert outputs into a replicated full view in +# ``restore_combined``, whose backward injects a ``tp_size`` factor (same factor the +# replicated router cancels via AVERAGE). Routed experts are not TP-replicated, so +# they must not be TP all_reduced; they only need that spurious ``tp_size`` factor +# divided out. The remaining data-parallel reduction is owned by the expert-data +# -parallel (EDP) path, and ``/tp_size`` is linear so it composes with that EDP +# all_reduce in either order. +AUTOEP_FOLDING_GRAD_REDUCE_EXPERT_TP_CANCEL = "expert_tp_cancel" @dataclass(frozen=True) @@ -209,21 +218,19 @@ def validate_folding_global( "expert-internal tensor parallelism and is not supported yet. Use 1; ETP support " "is planned as follow-up work.") - expert_width = spec.ep_size * spec.etp_size - if spec.tp_size > 1 and expert_width > spec.dp_size: - raise ValueError("AutoEP+AutoTP folding does not yet support cross-lane expert-parallel groups where " - "expert_parallel.autoep_size * expert_parallel.expert_tensor_parallel_size exceeds " - f"the derived dense data-parallel size (ep * etp = {expert_width}, dp = {spec.dp_size}, " - f"stage_size = {spec.stage_size}). This is a temporary limitation; use a shape with " - "ep * etp <= dp or run a follow-up implementation for cross-lane EP groups.") - - if spec.tp_size > 1 and spec.dp_size % expert_width != 0: - raise ValueError("AutoEP+AutoTP folding requires the derived dense data-parallel size to be divisible by " - "expert_parallel.autoep_size * expert_parallel.expert_tensor_parallel_size so expert " - "groups stay within dense data-parallel lanes " - f"(dp = {spec.dp_size}, ep * etp = {expert_width}, stage_size = {spec.stage_size}). " - "Use a shape where dp % (ep * etp) == 0 or run a follow-up implementation for " - "cross-lane EP groups.") + # Cross-lane expert parallelism (expert_width = ep * etp need NOT be a subset of + # the dense data-parallel size) is supported: ``expected_folding_group_tables`` + # lays EP groups across the tp-lane-major rank ordering, so an EP group may span + # TP lanes and dense-DP ranks. The only structural requirement is that the + # expert width tiles the stage cleanly, which ``build_folding_spec`` already + # enforces (``stage_size % expert_width == 0``, so ``edp`` is integral). The + # gradient convention holds across the pool because each family's reduction is + # keyed to its replication structure, not the EP layout: router/gate and dense + # /LayerNorm AVERAGE over the TP (token-replication) group; routed experts cancel + # the restore ``tp_size`` factor (EXPERT_TP_CANCEL) and reduce data-parallel over + # the EDP group. The earlier ``expert_width <= dp`` / ``dp % expert_width == 0`` + # fail-fast limitation is therefore removed; only genuinely non-tiling shapes are + # rejected above (in ``build_folding_spec``). if tp_preset is not None and ep_preset is not None and tp_preset != ep_preset: raise ValueError("tensor_parallel.preset_model and expert_parallel.preset_model must match when both " @@ -357,8 +364,18 @@ def autoep_folding_gradient_reduction_strategy( LayerNorm parameters that way; the marker and strategy boundary exist so future SP support has an explicit contract instead of reusing the dense replicated default by accident. - - Expert parameters and model-parallel parameters are SKIP because their - EP/TP-specific paths own their reductions. + - Model-parallel (genuinely TP-sharded) parameters are SKIP because the + TP-specific path owns their reduction. + - Routed-expert parameters are EXPERT_TP_CANCEL: their data-parallel + reduction is owned by the EP/EDP path, but the folded forward all-gathers + their outputs into a replicated full view in ``restore_combined`` (whose + backward injects a ``tp_size`` factor), so the expert-weight gradient + reaches the optimizer ``tp_size`` times too large. Experts are not + TP-replicated, so the fix is a plain ``/tp_size`` (no TP all_reduce), which + is linear and composes with the EDP all_reduce in any order. Without this, + folded expert gradients are over-scaled by ``tp_size`` -- invisible to + scale-invariant Adam but real for SGD/Lion/Muon and for gradient clipping + (it inflates the expert contribution to the global grad norm). Underlying rule and mechanism: a folded parameter is replicated (AVERAGE) when the forward reconstructs its partitioned work into an identical full @@ -380,8 +397,18 @@ def autoep_folding_gradient_reduction_strategy( """ if folding_spec is None or getattr(folding_spec, "tp_size", 1) <= 1: return AUTOEP_FOLDING_GRAD_REDUCE_SKIP - if _is_moe_param_marker(param) or _is_model_parallel_param_marker(param): + if _is_model_parallel_param_marker(param): + # Genuinely TP-sharded (column/row-parallel) params: the TP-specific path + # owns their reduction. Not produced by the folded skip-partition MVP. return AUTOEP_FOLDING_GRAD_REDUCE_SKIP + if _is_moe_param_marker(param): + # Routed-expert params. Their EP/EDP data-parallel reduction is owned by + # the expert path, but the folded forward routes their outputs through the + # ``restore_combined`` all-gather, whose backward leaves a ``tp_size`` + # factor on the expert-weight gradient (the same factor the replicated + # router cancels with AVERAGE). Experts are NOT TP-replicated, so they must + # not be TP all_reduced; the factor is cancelled with a plain ``/tp_size``. + return AUTOEP_FOLDING_GRAD_REDUCE_EXPERT_TP_CANCEL family = _autoep_folding_param_family(param, param_name=param_name) mp_mode = getattr(folding_spec, "mp_mode", "tp") @@ -411,6 +438,16 @@ def reduce_autoep_folding_gradient( grad_data = grad.data tp_world_size = dist.get_world_size(group=tp_group) + + # Routed experts: cancel the ``tp_size`` factor the restore all-gather leaves, + # WITHOUT a TP all_reduce (experts are not TP-replicated; cross-TP summation of + # disjoint expert-token slices is owned by the EDP all_reduce). ``/tp_size`` is + # linear, so it composes with that EDP reduction in either order. + if strategy == AUTOEP_FOLDING_GRAD_REDUCE_EXPERT_TP_CANCEL: + if tp_world_size > 1: + grad_data.div_(tp_world_size) + return strategy + if grad_data.dtype != torch.float32: reduced = grad_data.float() dist.all_reduce(reduced, group=tp_group) diff --git a/tests/unit/v1/moe/test_autoep_autotp_folding_config.py b/tests/unit/v1/moe/test_autoep_autotp_folding_config.py index 6dee57889862..42d26884a755 100644 --- a/tests/unit/v1/moe/test_autoep_autotp_folding_config.py +++ b/tests/unit/v1/moe/test_autoep_autotp_folding_config.py @@ -131,9 +131,43 @@ def test_validation_rule_g11_zero_offload_rejected(offload_key): _assert_rejects("offload", **{offload_key: True}) -def test_validation_rule_g12_cross_lane_ep_groups_temporarily_rejected(): - _assert_rejects("temporary limitation", world_size=8, tp_size=4, ep_size=4) +@pytest.mark.parametrize( + "world_size,tp_size,ep_size,expected_dp,expected_edp", + [ + (4, 4, 4, 1, 1), # EP group == TP group == {0,1,2,3} + (4, 2, 4, 2, 1), # ep>dp AND dp % ep != 0; EP spans both TP lanes and both DP ranks + (8, 4, 4, 2, 2), # cross-lane with expert replication (edp>1) + ], +) +def test_cross_lane_ep_groups_accepted(world_size, tp_size, ep_size, expected_dp, expected_edp): + # Cross-lane EP (expert_width = ep*etp may exceed dp, and need not divide dp) is now + # supported: EP groups may span TP lanes. The earlier "temporary limitation" and + # "dp % (ep*etp) == 0" fail-fasts are removed; only non-tiling shapes are rejected. + config = AutoEPConfig(enabled=True, autoep_size=ep_size, expert_tensor_parallel_size=1) + validate_autoep_config(config, world_size=world_size, pp_size=1, tp_size=tp_size, sp_size=1) + spec = build_folding_spec(world_size=world_size, pp_size=1, tp_size=tp_size, ep_size=ep_size, etp_size=1) + assert spec.dp_size == expected_dp + assert spec.edp_size == expected_edp -def test_validation_rule_g13_expert_width_must_tile_dense_dp_lane(): - _assert_rejects("dp % \\(ep \\* etp\\) == 0", world_size=12, tp_size=3, ep_size=3) +def test_cross_lane_expected_folding_tables(): + # world=4 tp4 ep4 dp1: the EP group is the whole TP group; one expert per rank (edp=1). + spec_tp4 = build_folding_spec(world_size=4, pp_size=1, tp_size=4, ep_size=4, etp_size=1) + tables_tp4 = expected_folding_group_tables(spec_tp4) + assert tables_tp4.tp_groups == ((0, 1, 2, 3), ) + assert tables_tp4.ep_groups == ((0, 1, 2, 3), ) + assert tables_tp4.edp_groups == ((0, ), (1, ), (2, ), (3, )) + + # world=4 tp2 ep4: EP group spans both TP lanes (lane-major ordering 0,2,1,3). + spec_tp2 = build_folding_spec(world_size=4, pp_size=1, tp_size=2, ep_size=4, etp_size=1) + tables_tp2 = expected_folding_group_tables(spec_tp2) + assert tables_tp2.tp_groups == ((0, 1), (2, 3)) + assert tables_tp2.ep_groups == ((0, 2, 1, 3), ) + assert tables_tp2.edp_groups == ((0, ), (2, ), (1, ), (3, )) + + # world=8 tp4 ep4 (edp=2): two EP groups, each spanning TP lanes and DP ranks. + spec_w8 = build_folding_spec(world_size=8, pp_size=1, tp_size=4, ep_size=4, etp_size=1) + tables_w8 = expected_folding_group_tables(spec_w8) + assert tables_w8.tp_groups == ((0, 1, 2, 3), (4, 5, 6, 7)) + assert tables_w8.ep_groups == ((0, 4, 1, 5), (2, 6, 3, 7)) + assert tables_w8.edp_groups == ((0, 2), (4, 6), (1, 3), (5, 7)) diff --git a/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py b/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py index 34f8b1450eca..b01003ceb0c8 100644 --- a/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py +++ b/tests/unit/v1/moe/test_autoep_autotp_grad_parity.py @@ -17,6 +17,7 @@ from deepspeed.checkpoint.constants import FOLDING_FAMILY, FOLDING_METADATA_KEY, FOLDING_PARAM_FAMILIES from deepspeed.module_inject.auto_ep_folding import ( AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE, + AUTOEP_FOLDING_GRAD_REDUCE_EXPERT_TP_CANCEL, AUTOEP_FOLDING_GRAD_REDUCE_SKIP, AUTOEP_FOLDING_GRAD_REDUCE_SUM, autoep_folding_gradient_reduction_strategy, @@ -113,9 +114,15 @@ def test_autoep_folding_gradient_strategy_uses_parameter_family(): assert (autoep_folding_gradient_reduction_strategy(_folding_spec("tp"), sp_layernorm) == AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE) + # Routed experts cancel the restore all-gather tp_size factor (divide-by-tp, no + # TP all_reduce); their data-parallel reduction stays on the EP/EDP path. expert = torch.nn.Parameter(torch.ones(2)) expert.allreduce = False - assert autoep_folding_gradient_reduction_strategy(_folding_spec("tp"), expert) == AUTOEP_FOLDING_GRAD_REDUCE_SKIP + assert (autoep_folding_gradient_reduction_strategy(_folding_spec("tp"), + expert) == AUTOEP_FOLDING_GRAD_REDUCE_EXPERT_TP_CANCEL) + # With folding disabled (tp_size == 1) experts still SKIP. + assert (autoep_folding_gradient_reduction_strategy(_folding_spec("tp", tp_size=1), + expert) == AUTOEP_FOLDING_GRAD_REDUCE_SKIP) model_parallel = torch.nn.Parameter(torch.ones(2)) model_parallel.tensor_model_parallel = True @@ -419,6 +426,258 @@ def test_cpu_gloo_folded_zero2_router_gate_and_layernorm_grad_parity(tmpdir): run_cpu_gloo_test(_cpu_folded_zero2_router_gate_and_layernorm_worker, tmpdir, world_size=8) +# --------------------------------------------------------------------------- +# Cross-lane EP (expert parallel spanning TP lanes; expert_width = ep need NOT be a +# subset of dp). These reuse the tp2/ep4 workers at world_size=4 (where ep=4 > dp=2) +# and add a tp4/ep4/dp1 worker (EP group == TP group). The router/gate and LayerNorm +# AVERAGE-over-TP convention is unchanged because the dedup/restore key on the +# token-replication (TP) group, independent of the EP layout. +# --------------------------------------------------------------------------- + + +def _folded_zero0_tp4_ep4_config(): + config = make_autoep_config(zero_stage=0, ep_size=4, mixed_precision=False) + config["gradient_accumulation_steps"] = 2 + config["gradient_clipping"] = 0.0 + config["communication_data_type"] = "fp32" + config["optimizer"]["params"]["torch_adam"] = True + config["expert_parallel"]["autoep_size"] = 4 + config["tensor_parallel"] = { + "autotp_size": 4, + "partition_config": { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + }, + } + return config + + +def _cpu_folded_tp4_ep4_router_gate_and_layernorm_worker(rank, world_size, _shared_tmpdir): + seed = 1234 + tp_size = 4 + logical_dp_world_size = world_size // tp_size + logical_dp_rank = rank // tp_size + + seed_everything(seed) + reference_state = _router_grad_model().state_dict() + + baseline_model = _router_grad_model() + baseline_model.load_state_dict(reference_state) + baseline_engine, _, _, _ = deepspeed.initialize(model=baseline_model, config=_zero0_baseline_config()) + _run_router_grad_boundary(baseline_engine, + logical_dp_world_size=logical_dp_world_size, + logical_dp_rank=logical_dp_rank, + seed=seed) + baseline_gate = _full_grad_by_suffix(baseline_engine, GATE_BASELINE) + baseline_layernorm = _full_grad_by_suffix(baseline_engine, INPUT_LAYERNORM) + + folded_model = _router_grad_model() + folded_model.load_state_dict(reference_state) + folded_engine, _, _, _ = deepspeed.initialize(model=folded_model, config=_folded_zero0_tp4_ep4_config()) + _run_router_grad_boundary(folded_engine, + logical_dp_world_size=logical_dp_world_size, + logical_dp_rank=logical_dp_rank, + seed=seed) + folded_gate = _full_grad_by_suffix(folded_engine, GATE_FOLDED) + folded_layernorm = _full_grad_by_suffix(folded_engine, INPUT_LAYERNORM) + + metrics = { + "rank": rank, + "gate": _grad_parity_metrics(folded_gate, baseline_gate), + "layernorm": _grad_parity_metrics(folded_layernorm, baseline_layernorm) + } + if rank == 0: + print("FOLDED_CROSSLANE_TP4_EP4_GRAD_PARITY " + json.dumps(metrics, sort_keys=True)) + torch.testing.assert_close(folded_gate, + baseline_gate, + atol=1e-1, + rtol=5e-3, + msg=f"Cross-lane tp4/ep4 router/gate grad must match baseline; metrics={metrics}") + torch.testing.assert_close(folded_layernorm, + baseline_layernorm, + atol=1e-1, + rtol=5e-3, + msg=f"Cross-lane tp4/ep4 LayerNorm grad must match baseline; metrics={metrics}") + + +def test_cpu_gloo_crosslane_tp2_ep4_zero0_router_gate_and_layernorm_grad_parity(tmpdir): + # world=4: tp2/ep4 => ep=4 > dp=2 (cross-lane; EP group spans both TP lanes and DP ranks). + run_cpu_gloo_test(_cpu_folded_zero0_router_gate_and_layernorm_worker, tmpdir, world_size=4) + + +def test_cpu_gloo_crosslane_tp2_ep4_zero2_router_gate_and_layernorm_grad_parity(tmpdir): + run_cpu_gloo_test(_cpu_folded_zero2_router_gate_and_layernorm_worker, tmpdir, world_size=4) + + +def test_cpu_gloo_crosslane_tp4_ep4_dp1_zero0_router_gate_and_layernorm_grad_parity(tmpdir): + # world=4: tp4/ep4/dp1 => EP group == TP group == {0,1,2,3}. + run_cpu_gloo_test(_cpu_folded_tp4_ep4_router_gate_and_layernorm_worker, tmpdir, world_size=4) + + +# --------------------------------------------------------------------------- +# Expert-weight gradient parity. The folded forward all-gathers expert outputs into +# a replicated full view in ``restore_combined`` (backward injects a ``tp_size`` +# factor); routed experts must cancel it (EXPERT_TP_CANCEL ``/tp_size``) or their +# gradients reach the optimizer ``tp_size`` times too large. This is invisible to +# scale-invariant Adam, so this test uses SGD and compares the post-step expert +# weights against a non-folded baseline. It covers the MVP shape and cross-lane +# shapes including edp>1. +# --------------------------------------------------------------------------- + +EXPERTS_W1 = "experts.w1" # folded GroupedExperts gate half (num_local, ffn, hidden) +EXPERTS_W3 = "experts.w3" # folded GroupedExperts up half (num_local, ffn, hidden) +EXPERTS_W2 = "experts.w2" # folded GroupedExperts down (num_local, hidden, ffn) +GATE_UP_PROJ = "mlp.experts.gate_up_proj" # baseline fused gate||up (num_experts, 2*ffn, hidden) +DOWN_PROJ = "mlp.experts.down_proj" # baseline down (num_experts, hidden, ffn) + + +def _sgd_baseline_config(zero_stage=0): + config = { + key: value + for key, value in make_autoep_config(zero_stage=zero_stage, ep_size=1, mixed_precision=False).items() + if key != "expert_parallel" + } + config["gradient_accumulation_steps"] = 1 + config["gradient_clipping"] = 0.0 + config["communication_data_type"] = "fp32" + config["optimizer"] = {"type": "SGD", "params": {"lr": 1.0}} + config["zero_allow_untested_optimizer"] = True # SGD under ZeRO is "untested"; this is a grad-parity probe + return config + + +def _folded_sgd_config(tp_size, ep_size, zero_stage=0): + config = make_autoep_config(zero_stage=zero_stage, ep_size=ep_size, mixed_precision=False) + config["gradient_accumulation_steps"] = 1 + config["gradient_clipping"] = 0.0 + config["communication_data_type"] = "fp32" + config["optimizer"] = {"type": "SGD", "params": {"lr": 1.0}} + config["zero_allow_untested_optimizer"] = True # SGD under ZeRO is "untested"; this is a grad-parity probe + config["expert_parallel"]["autoep_size"] = ep_size + config["tensor_parallel"] = { + "autotp_size": tp_size, + "partition_config": { + "use_default_specs": False, + "layer_specs": [{ + "patterns": [".*\\.weight$"], + "partition_type": "skip", + }], + }, + } + return config + + +def _one_sgd_step(engine, *, tp_size, seed): + rank = dist.get_rank() + generator = torch.Generator().manual_seed(seed + (rank // tp_size)) + x = torch.randn((1, 4, 64), generator=generator, dtype=engine_input_dtype(engine)).to(engine.device) + loss = engine(x).float().mean() + engine.backward(loss) + engine.step() + return loss + + +def _local_param_by_suffix(module, suffix): + for name, param in module.named_parameters(): + if name.endswith(suffix): + return param.detach().float().cpu() + raise AssertionError(f"Missing parameter ending with {suffix}") + + +def _gather_full_experts(layer, suffix): + """All-gather a local routed-expert tensor over the EP group into the full + (num_experts, ...) tensor, ordered by ep_rank (expert_start = ep_rank * num_local).""" + local = None + for name, param in layer.named_parameters(): + if name.endswith(suffix): + local = param.detach().contiguous() + break + assert local is not None, f"Missing parameter ending with {suffix}" + ep_group = layer.ep_group + ep_world = dist.get_world_size(group=ep_group) + gathered = [torch.empty_like(local) for _ in range(ep_world)] + dist.all_gather(gathered, local, group=ep_group) + return torch.cat([chunk.float().cpu() for chunk in gathered], dim=0) + + +def _expert_weight_parity_worker(rank, world_size, tp_size, ep_size, zero_stage=0): + seed = 1234 + seed_everything(seed) + reference_state = _router_grad_model().state_dict() + + baseline_model = _router_grad_model() + baseline_model.load_state_dict(reference_state) + baseline_engine, _, _, _ = deepspeed.initialize(model=baseline_model, config=_sgd_baseline_config(zero_stage)) + _one_sgd_step(baseline_engine, tp_size=tp_size, seed=seed) + base_gate_up = _local_param_by_suffix(baseline_engine.module, GATE_UP_PROJ) # (E, 2*ffn, h) + base_down = _local_param_by_suffix(baseline_engine.module, DOWN_PROJ) # (E, h, ffn) + + folded_model = _router_grad_model() + folded_model.load_state_dict(reference_state) + folded_engine, _, _, _ = deepspeed.initialize(model=folded_model, + config=_folded_sgd_config(tp_size, ep_size, zero_stage)) + _one_sgd_step(folded_engine, tp_size=tp_size, seed=seed) + layer = folded_engine.module.model.layers[0].mlp + full_w1 = _gather_full_experts(layer, EXPERTS_W1) + full_w3 = _gather_full_experts(layer, EXPERTS_W3) + full_w2 = _gather_full_experts(layer, EXPERTS_W2) + folded_gate_up = torch.cat([full_w1, full_w3], dim=1) # (E, 2*ffn, h) + + if rank == 0: + gu_scale = folded_gate_up.mul(base_gate_up).sum().item() / base_gate_up.square().sum().item() + dn_scale = full_w2.mul(base_down).sum().item() / base_down.square().sum().item() + print(f"EXPERT_WEIGHT_PARITY tp{tp_size}_ep{ep_size}_w{world_size}_z{zero_stage} " + f"gate_up_post_step_scale={gu_scale:.6f} down_post_step_scale={dn_scale:.6f}") + # Post-step weight equality proves the *applied* expert update matches the baseline. + # A tp_size over-scaling would diverge these by ~lr*(tp-1)*grad (lr=1), far above tol. + torch.testing.assert_close( + folded_gate_up, + base_gate_up, + atol=1e-4, + rtol=1e-4, + msg="Folded routed-expert gate/up weights must match non-folded baseline after one SGD step") + torch.testing.assert_close( + full_w2, + base_down, + atol=1e-4, + rtol=1e-4, + msg="Folded routed-expert down weights must match non-folded baseline after one SGD step") + + +def _expert_weight_parity_mvp_tp2_ep4_z0(rank, world_size, _tmp): + _expert_weight_parity_worker(rank, world_size, tp_size=2, ep_size=4, zero_stage=0) + + +def _expert_weight_parity_tp4_ep4_z0(rank, world_size, _tmp): + _expert_weight_parity_worker(rank, world_size, tp_size=4, ep_size=4, zero_stage=0) + + +def _expert_weight_parity_tp4_ep4_z2(rank, world_size, _tmp): + _expert_weight_parity_worker(rank, world_size, tp_size=4, ep_size=4, zero_stage=2) + + +def test_cpu_gloo_expert_weight_parity_mvp_tp2_ep4(tmpdir): + # MVP shape (ep=4 <= dp=4, edp=2). Guards the pre-existing expert over-scaling fix. + run_cpu_gloo_test(_expert_weight_parity_mvp_tp2_ep4_z0, tmpdir, world_size=8) + + +def test_cpu_gloo_expert_weight_parity_crosslane_tp4_ep4_dp1(tmpdir): + # Cross-lane, edp=1: EP group == TP group == {0,1,2,3}. + run_cpu_gloo_test(_expert_weight_parity_tp4_ep4_z0, tmpdir, world_size=4) + + +def test_cpu_gloo_expert_weight_parity_crosslane_tp4_ep4_edp2(tmpdir): + # Cross-lane, edp=2: EP groups span TP lanes and DP ranks. + run_cpu_gloo_test(_expert_weight_parity_tp4_ep4_z0, tmpdir, world_size=8) + + +def test_cpu_gloo_expert_weight_parity_crosslane_tp4_ep4_dp1_zero2(tmpdir): + # Cross-lane expert fix must also apply on the ZeRO-2 reducer path (different hook). + run_cpu_gloo_test(_expert_weight_parity_tp4_ep4_z2, tmpdir, world_size=4) + + def _assert_zero_optimizer_folding_metadata(checkpoint_dir): optim_paths = sorted(glob.glob(os.path.join(str(checkpoint_dir), "folded-zero2", "*_optim_states.pt"))) assert optim_paths From 59fdca461060549fb4d7dad7d583efc3f7041065 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 26 Jun 2026 09:43:48 -0700 Subject: [PATCH 07/16] Gate AutoEP folding routing validation Signed-off-by: Masahiro Tanaka --- deepspeed/module_inject/auto_ep_config.py | 4 + deepspeed/module_inject/auto_ep_layer.py | 13 +- .../module_inject/auto_ep_presets/base.py | 1 + deepspeed/moe/ep_tp_dispatch.py | 112 ++++++++++++++++-- .../v1/moe/test_autoep_autotp_dispatch.py | 33 +++++- tests/unit/v1/moe/test_autoep_unit.py | 11 ++ 6 files changed, 160 insertions(+), 14 deletions(-) diff --git a/deepspeed/module_inject/auto_ep_config.py b/deepspeed/module_inject/auto_ep_config.py index 067c1da4219c..922cb1fc95ee 100644 --- a/deepspeed/module_inject/auto_ep_config.py +++ b/deepspeed/module_inject/auto_ep_config.py @@ -47,6 +47,7 @@ def parse_autoep_config(param_dict: dict) -> AutoEPConfig: config.enabled = param_dict.get("enabled", False) config.autoep_size = param_dict.get("autoep_size", 1) config.expert_tensor_parallel_size = param_dict.get("expert_tensor_parallel_size", 1) + config.validate_folding_routing = param_dict.get("validate_folding_routing", False) config.preset_model = param_dict.get("preset_model", None) config.moe_layer_pattern = param_dict.get("moe_layer_pattern", None) config.expert_pattern = param_dict.get("expert_pattern", None) @@ -109,6 +110,9 @@ def validate_autoep_config( if config.load_balance_coeff is not None: _raise_unsupported_load_balance_coeff(config.load_balance_coeff) + if not isinstance(config.validate_folding_routing, bool): + raise ValueError("expert_parallel.validate_folding_routing must be a boolean") + if not config.enabled: return diff --git a/deepspeed/module_inject/auto_ep_layer.py b/deepspeed/module_inject/auto_ep_layer.py index 8e70872cc426..487d06e748e6 100644 --- a/deepspeed/module_inject/auto_ep_layer.py +++ b/deepspeed/module_inject/auto_ep_layer.py @@ -402,6 +402,7 @@ def __init__( self.folding_group_handles = None self.tp_group = None resolved_config = resolve_autoep_config_defaults(config, spec.model_family) + self.validate_folding_routing = bool(resolved_config.validate_folding_routing) # Router: copy gate weights from source source_gate = getattr(source_module, spec.router_name) @@ -630,9 +631,10 @@ def forward( "num_tokens": torch.tensor(bsz * seqlen, device=hidden_states.device, dtype=torch.long), }, ) - assert_tp_payload_consistent(payload, - tp_group=self.tp_group, - tp_size=self.folding_group_handles.spec.tp_size) + if self.validate_folding_routing: + assert_tp_payload_consistent(payload, + tp_group=self.tp_group, + tp_size=self.folding_group_handles.spec.tp_size) tp_rank = dist.get_rank(group=self.tp_group) local_payload, restore_ctx = partition_assignments(payload, tp_group=self.tp_group, @@ -692,7 +694,10 @@ def forward( expert_output = _AllToAllV.apply(self.ep_group, expert_output, plan.output_splits, plan.input_splits) if folded_tp: - output = restore_combined(expert_output, restore_ctx, tp_group=self.tp_group).reshape(bsz, seqlen, hdim) + output = restore_combined(expert_output, + restore_ctx, + tp_group=self.tp_group, + validate_coverage=self.validate_folding_routing).reshape(bsz, seqlen, hdim) self._last_folding_dispatch_counters = dispatch_counters(restore_ctx) else: output = combine_from_routed( diff --git a/deepspeed/module_inject/auto_ep_presets/base.py b/deepspeed/module_inject/auto_ep_presets/base.py index c023498109c9..f7ec9a5e37ca 100644 --- a/deepspeed/module_inject/auto_ep_presets/base.py +++ b/deepspeed/module_inject/auto_ep_presets/base.py @@ -99,6 +99,7 @@ class AutoEPConfig: enabled: bool = False autoep_size: int = 1 expert_tensor_parallel_size: int = 1 + validate_folding_routing: bool = False preset_model: str | None = None moe_layer_pattern: str | None = None expert_pattern: str | None = None diff --git a/deepspeed/moe/ep_tp_dispatch.py b/deepspeed/moe/ep_tp_dispatch.py index c085ee7b1cc3..278946bda58e 100644 --- a/deepspeed/moe/ep_tp_dispatch.py +++ b/deepspeed/moe/ep_tp_dispatch.py @@ -7,6 +7,7 @@ from __future__ import annotations from dataclasses import dataclass +import os import torch import deepspeed.comm as dist @@ -135,6 +136,77 @@ def _payload_digest(payload: RoutedAssignmentPayload) -> torch.Tensor: return digest +def _payload_digest_components(payload: RoutedAssignmentPayload) -> dict[str, torch.Tensor]: + device = payload.token_indices.device + active = (~payload.drop_mask & ~payload.pad_mask).to(torch.long) + fields = { + "token_indices": payload.token_indices, + "expert_indices": payload.expert_indices, + "assignment_indices": payload.assignment_indices, + "capacity_slots": payload.capacity_slots, + "combine_weights": payload.combine_weights, + "drop_mask": payload.drop_mask, + "pad_mask": payload.pad_mask, + "active": active, + "destination_ranks": payload.extra.get("destination_ranks", torch.empty(0, device=device, dtype=torch.long)), + } + components: dict[str, torch.Tensor] = {} + for index, (name, field) in enumerate(fields.items(), start=1): + if not torch.is_tensor(field): + continue + words = _tensor_digest_words(field) + components[name] = torch.stack(( + torch.tensor(words.numel(), device=device, dtype=torch.long), + _digest_words(words, salt=17 * index, modulus=_FOLDING_DIGEST_MOD_A), + _digest_words(words, salt=31 * index, modulus=_FOLDING_DIGEST_MOD_B), + )) + return components + + +def _format_payload_debug(payload: RoutedAssignmentPayload, *, digest: torch.Tensor, max_digest: torch.Tensor, + min_digest: torch.Tensor, tp_group) -> str: + if os.environ.get("AUTOEP_FOLDING_DEBUG_PAYLOAD", "0") not in {"1", "true", "TRUE", "yes"}: + return "" + + differing_fields = [] + for name, component in _payload_digest_components(payload).items(): + component_max = component.clone() + component_min = component.clone() + dist.all_reduce(component_max, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(component_min, op=dist.ReduceOp.MIN, group=tp_group) + if not torch.equal(component_max, component_min): + differing_fields.append({ + "field": name, + "local": [int(value) for value in component.detach().cpu().tolist()], + "min": [int(value) for value in component_min.detach().cpu().tolist()], + "max": [int(value) for value in component_max.detach().cpu().tolist()], + }) + + sample_limit = int(os.environ.get("AUTOEP_FOLDING_DEBUG_SAMPLE_LIMIT", "12")) + samples = { + "token_indices": payload.token_indices[:sample_limit].detach().cpu().tolist(), + "expert_indices": payload.expert_indices[:sample_limit].detach().cpu().tolist(), + "assignment_indices": payload.assignment_indices[:sample_limit].detach().cpu().tolist(), + "capacity_slots": payload.capacity_slots[:sample_limit].detach().cpu().tolist(), + "combine_weights": payload.combine_weights[:sample_limit].detach().float().cpu().tolist(), + } + try: + tp_group_ranks = dist.get_all_ranks_from_group(tp_group) + except Exception: + tp_group_ranks = [] + details = { + "rank": dist.get_rank(), + "tp_rank": dist.get_rank(group=tp_group), + "tp_group_ranks": tp_group_ranks, + "digest": [int(value) for value in digest.detach().cpu().tolist()], + "digest_min": [int(value) for value in min_digest.detach().cpu().tolist()], + "digest_max": [int(value) for value in max_digest.detach().cpu().tolist()], + "differing_fields": differing_fields, + "samples": samples, + } + return f" Debug details: {details}" + + def assert_tp_payload_consistent(payload: RoutedAssignmentPayload, *, tp_group, tp_size: int) -> None: if tp_size <= 1 or not dist.is_initialized(): return @@ -145,8 +217,14 @@ def assert_tp_payload_consistent(payload: RoutedAssignmentPayload, *, tp_group, dist.all_reduce(max_digest, op=dist.ReduceOp.MAX, group=tp_group) dist.all_reduce(min_digest, op=dist.ReduceOp.MIN, group=tp_group) if not torch.equal(max_digest, min_digest): + debug_details = _format_payload_debug(payload, + digest=digest, + max_digest=max_digest, + min_digest=min_digest, + tp_group=tp_group) raise RuntimeError("AutoEP+AutoTP routing decisions differ across tensor-parallel lanes. " - "Folded dispatch requires identical routed-token payloads before TP partitioning.") + "Folded dispatch requires identical routed-token payloads before TP partitioning." + f"{debug_details}") def partition_assignments( @@ -285,18 +363,20 @@ def _all_gather_variable_rows(tensor: torch.Tensor, def _debug_validate_restore_coverage(payload: RoutedAssignmentPayload, ctx: RestoreContext, all_token_indices: torch.Tensor, all_expert_indices: torch.Tensor, - all_assignment_indices: torch.Tensor) -> None: + all_assignment_indices: torch.Tensor, all_capacity_slots: torch.Tensor) -> None: active = ~payload.drop_mask & ~payload.pad_mask expected_rows = torch.stack(( payload.token_indices[active].to(torch.long), payload.expert_indices[active].to(torch.long), payload.assignment_indices[active].to(torch.long), + payload.capacity_slots[active].to(torch.long), ), dim=1) observed_rows = torch.stack(( all_token_indices.to(torch.long), all_expert_indices.to(torch.long), all_assignment_indices.to(torch.long), + all_capacity_slots.to(torch.long), ), dim=1) if expected_rows.numel() == 0 and observed_rows.numel() == 0: @@ -314,7 +394,11 @@ def _debug_validate_restore_coverage(payload: RoutedAssignmentPayload, ctx: Rest f"missing={missing} unexpected={duplicate_or_stale}") -def restore_combined(local_combined: torch.Tensor, ctx: RestoreContext, *, tp_group) -> torch.Tensor: +def restore_combined(local_combined: torch.Tensor, + ctx: RestoreContext, + *, + tp_group, + validate_coverage: bool = False) -> torch.Tensor: """Gather TP-partitioned assignment outputs and combine back by token index. The all-gather rebuilds an identical full output on every TP peer, so all @@ -328,8 +412,7 @@ def restore_combined(local_combined: torch.Tensor, ctx: RestoreContext, *, tp_gr """ payload = ctx.original_payload local_token_indices = payload.token_indices.index_select(0, ctx.local_indices) - local_expert_indices = payload.expert_indices.index_select(0, ctx.local_indices) - local_assignment_indices = payload.assignment_indices.index_select(0, ctx.local_indices) + local_capacity_slots = payload.capacity_slots.index_select(0, ctx.local_indices) local_weights = payload.combine_weights.index_select(0, ctx.local_indices).to(local_combined.dtype) all_outputs = _all_gather_variable_rows(local_combined, @@ -337,19 +420,30 @@ def restore_combined(local_combined: torch.Tensor, ctx: RestoreContext, *, tp_gr ctx.tp_size, preserve_grad=local_combined.requires_grad) all_token_indices = _all_gather_variable_rows(local_token_indices, tp_group, ctx.tp_size).to(torch.long) - all_expert_indices = _all_gather_variable_rows(local_expert_indices, tp_group, ctx.tp_size).to(torch.long) - all_assignment_indices = _all_gather_variable_rows(local_assignment_indices, tp_group, ctx.tp_size).to(torch.long) + all_capacity_slots = _all_gather_variable_rows(local_capacity_slots, tp_group, ctx.tp_size).to(torch.long) all_weights = _all_gather_variable_rows(local_weights, tp_group, ctx.tp_size, preserve_grad=local_weights.requires_grad).to(local_combined.dtype) - _debug_validate_restore_coverage(payload, ctx, all_token_indices, all_expert_indices, all_assignment_indices) + if validate_coverage: + local_expert_indices = payload.expert_indices.index_select(0, ctx.local_indices) + local_assignment_indices = payload.assignment_indices.index_select(0, ctx.local_indices) + all_expert_indices = _all_gather_variable_rows(local_expert_indices, tp_group, ctx.tp_size).to(torch.long) + all_assignment_indices = _all_gather_variable_rows(local_assignment_indices, tp_group, ctx.tp_size).to(torch.long) + _debug_validate_restore_coverage(payload, ctx, all_token_indices, all_expert_indices, all_assignment_indices, + all_capacity_slots) if ctx.num_tokens <= 0: ctx.num_tokens = int(payload.token_indices.max().item()) + 1 if payload.token_indices.numel() else 0 output = local_combined.new_zeros((ctx.num_tokens, local_combined.shape[-1])) if all_outputs.numel() > 0: - output.index_add_(0, all_token_indices, all_outputs * all_weights.reshape(-1, 1)) + weight_shape = (-1, ) + (1, ) * (all_outputs.dim() - 1) + weighted_outputs = all_outputs * all_weights.reshape(weight_shape) + # Add one top-k slot at a time so token accumulation order stays stable + # without materializing a [tokens, top_k, hidden] buffer. + for slot in torch.unique(all_capacity_slots, sorted=True).tolist(): + rows = all_capacity_slots == int(slot) + output.index_add_(0, all_token_indices[rows], weighted_outputs[rows]) return output diff --git a/tests/unit/v1/moe/test_autoep_autotp_dispatch.py b/tests/unit/v1/moe/test_autoep_autotp_dispatch.py index 5a9d9c970498..31576ced4840 100644 --- a/tests/unit/v1/moe/test_autoep_autotp_dispatch.py +++ b/tests/unit/v1/moe/test_autoep_autotp_dispatch.py @@ -181,14 +181,45 @@ def test_restore_combined_tp_backward_matches_non_partitioned_combine(tmpdir): run_cpu_gloo_test(_restore_combined_backward_parity_worker, tmpdir, world_size=2) +def _restore_combined_topk_slot_order_worker(rank, world_size, _shared_tmpdir): + payload = RoutedAssignmentPayload( + token_indices=torch.tensor([0, 0, 0], dtype=torch.long), + expert_indices=torch.tensor([0, 0, 0], dtype=torch.long), + assignment_indices=torch.tensor([0, 1, 2], dtype=torch.long), + capacity_slots=torch.tensor([0, 1, 2], dtype=torch.long), + combine_weights=torch.ones(3, dtype=torch.float32), + drop_mask=torch.zeros(3, dtype=torch.bool), + pad_mask=torch.zeros(3, dtype=torch.bool), + input_splits=[3], + output_splits=[3], + extra={ + "destination_ranks": torch.zeros(3, dtype=torch.long), + "num_tokens": torch.tensor(1, dtype=torch.long), + }, + ) + tp_group = dist.get_world_group() + local, ctx = partition_assignments(payload, tp_group=tp_group, tp_rank=rank, tp_size=world_size) + full_values = torch.tensor([[1.0e20], [1.0], [-1.0e20]], dtype=torch.float32) + restored = restore_combined(full_values.index_select(0, ctx.local_indices), ctx, tp_group=tp_group) + + torch.testing.assert_close(restored, torch.zeros_like(restored), rtol=0.0, atol=0.0) + + +def test_restore_combined_tp_forward_uses_topk_slot_order(tmpdir): + run_cpu_gloo_test(_restore_combined_topk_slot_order_worker, tmpdir, world_size=2) + + def test_restore_coverage_assertion_detects_missing_assignment(): payload = _payload() local, ctx = partition_assignments(payload, tp_group=None, tp_rank=0, tp_size=1) ctx.local_indices = ctx.local_indices[:-1] values = torch.ones((local.token_indices.numel() - 1, 2), dtype=torch.float32) + restored = restore_combined(values, ctx, tp_group=None) + assert restored.shape == (5, 2) + with pytest.raises(RuntimeError, match="restore coverage mismatch"): - restore_combined(values, ctx, tp_group=None) + restore_combined(values, ctx, tp_group=None, validate_coverage=True) def test_tp_payload_consistency_detects_divergent_large_payload(monkeypatch): diff --git a/tests/unit/v1/moe/test_autoep_unit.py b/tests/unit/v1/moe/test_autoep_unit.py index 528531a57a2f..e4633fae7564 100644 --- a/tests/unit/v1/moe/test_autoep_unit.py +++ b/tests/unit/v1/moe/test_autoep_unit.py @@ -150,6 +150,7 @@ def test_parse_and_validate_enabled_size_contract(self): disabled = parse_autoep_config({}) assert disabled.enabled is False assert disabled.autoep_size == 1 + assert disabled.validate_folding_routing is False assert disabled.load_balance_coeff is None assert disabled._load_balance_coeff_explicit is False @@ -160,17 +161,27 @@ def test_parse_and_validate_enabled_size_contract(self): "load_balance_coeff": None, "score_apply": "pre", "route_scale": 2.0, + "validate_folding_routing": True, }) assert config.enabled is True assert config.autoep_size == 4 assert config.preset_model == "mixtral" + assert config.validate_folding_routing is True assert config.load_balance_coeff is None assert config._load_balance_coeff_explicit is True assert config.score_apply == "pre" assert config.route_scale == 2.0 validate_autoep_config(config, world_size=4, pp_size=1, tp_size=1, sp_size=1) + def test_validate_folding_routing_requires_boolean(self): + with pytest.raises(ValueError, match="validate_folding_routing"): + validate_autoep_config(AutoEPConfig(enabled=True, validate_folding_routing="true"), + world_size=1, + pp_size=1, + tp_size=1, + sp_size=1) + @pytest.mark.parametrize("value", UNSUPPORTED_LOAD_BALANCE_VALUES) def test_load_balance_coeff_rejected_at_parse(self, value): with pytest.raises(ValueError) as exc_info: From 7426b1214161ec4c8bacac13a21a2c3d7f7df22b Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 26 Jun 2026 09:47:18 -0700 Subject: [PATCH 08/16] Apply AutoEP folding yapf formatting Signed-off-by: Masahiro Tanaka --- deepspeed/moe/ep_tp_dispatch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/moe/ep_tp_dispatch.py b/deepspeed/moe/ep_tp_dispatch.py index 278946bda58e..7205bfb5f306 100644 --- a/deepspeed/moe/ep_tp_dispatch.py +++ b/deepspeed/moe/ep_tp_dispatch.py @@ -429,7 +429,8 @@ def restore_combined(local_combined: torch.Tensor, local_expert_indices = payload.expert_indices.index_select(0, ctx.local_indices) local_assignment_indices = payload.assignment_indices.index_select(0, ctx.local_indices) all_expert_indices = _all_gather_variable_rows(local_expert_indices, tp_group, ctx.tp_size).to(torch.long) - all_assignment_indices = _all_gather_variable_rows(local_assignment_indices, tp_group, ctx.tp_size).to(torch.long) + all_assignment_indices = _all_gather_variable_rows(local_assignment_indices, tp_group, + ctx.tp_size).to(torch.long) _debug_validate_restore_coverage(payload, ctx, all_token_indices, all_expert_indices, all_assignment_indices, all_capacity_slots) From e6239e4a147554e62bd3a2e8e41ec6496b4edf7e Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 1 Jul 2026 17:59:30 -0700 Subject: [PATCH 09/16] Add folded DeepCompile rejection tests Signed-off-by: Masahiro Tanaka --- .../test_autoep_autotp_deepcompile_reject.py | 51 +++++++++++++++++++ .../moe/test_autoep_autotp_folding_config.py | 21 ++++++++ 2 files changed, 72 insertions(+) create mode 100644 tests/unit/v1/moe/test_autoep_autotp_deepcompile_reject.py diff --git a/tests/unit/v1/moe/test_autoep_autotp_deepcompile_reject.py b/tests/unit/v1/moe/test_autoep_autotp_deepcompile_reject.py new file mode 100644 index 000000000000..82de0e562997 --- /dev/null +++ b/tests/unit/v1/moe/test_autoep_autotp_deepcompile_reject.py @@ -0,0 +1,51 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Engine-path DeepCompile validation for AutoEP + AutoTP folding.""" + +from types import SimpleNamespace + +import pytest + +from deepspeed.module_inject.auto_ep_config import AutoEPConfig +from deepspeed.runtime.engine import DeepSpeedEngine + + +def test_deepcompile_folded_rejected_before_autoep_process_groups(monkeypatch): + engine = object.__new__(DeepSpeedEngine) + engine.mpu = None + engine._config = SimpleNamespace( + expert_parallel_config=AutoEPConfig(enabled=True, autoep_size=2, expert_tensor_parallel_size=1), + compile_config=SimpleNamespace(deepcompile=True), + use_data_before_expert_parallel_=False, + tensor_parallel_config=SimpleNamespace(preset_model=None), + ) + engine.autotp_size = lambda: 2 + engine._autoep_sequence_parallel_world_size = lambda: 1 + engine.zero_optimization_stage = lambda: 0 + engine.zero_offload_optimizer = lambda: None + engine.zero_offload_param = lambda: None + + group_creations = [] + monkeypatch.setattr("deepspeed.runtime.engine.dist.get_world_size", lambda: 4) + monkeypatch.setattr("deepspeed.runtime.engine.dist.get_rank", lambda group=None: 0) + monkeypatch.setattr("deepspeed.runtime.engine.groups._get_sequence_parallel_world_size", lambda: 1) + monkeypatch.setattr("deepspeed.runtime.engine.groups._create_expert_and_data_parallel", + lambda **kwargs: group_creations.append(kwargs)) + monkeypatch.setattr("deepspeed.runtime.engine.groups._get_expert_parallel_group", lambda name: object()) + + class _AutoEPNoop: + + def __init__(self, model, config): + pass + + def ep_parser(self): + return [] + + monkeypatch.setattr("deepspeed.module_inject.auto_ep.AutoEP", _AutoEPNoop) + + with pytest.raises(ValueError, match="DeepCompile.*AutoEP\\+AutoTP folding"): + engine._configure_expert_parallel(object()) + + assert group_creations == [] diff --git a/tests/unit/v1/moe/test_autoep_autotp_folding_config.py b/tests/unit/v1/moe/test_autoep_autotp_folding_config.py index 42d26884a755..71293a0f080d 100644 --- a/tests/unit/v1/moe/test_autoep_autotp_folding_config.py +++ b/tests/unit/v1/moe/test_autoep_autotp_folding_config.py @@ -131,6 +131,27 @@ def test_validation_rule_g11_zero_offload_rejected(offload_key): _assert_rejects("offload", **{offload_key: True}) +def test_deepcompile_folded_rejected(): + config = AutoEPConfig(enabled=True, autoep_size=2, expert_tensor_parallel_size=1) + with pytest.raises(ValueError, match="DeepCompile.*AutoEP\\+AutoTP folding"): + validate_autoep_config(config, + world_size=4, + pp_size=1, + tp_size=2, + sp_size=1, + deepcompile_enabled=True) + + +def test_deepcompile_nonfolded_accepted(): + config = AutoEPConfig(enabled=True, autoep_size=2, expert_tensor_parallel_size=1) + validate_autoep_config(config, + world_size=4, + pp_size=1, + tp_size=1, + sp_size=1, + deepcompile_enabled=True) + + @pytest.mark.parametrize( "world_size,tp_size,ep_size,expected_dp,expected_edp", [ From 3c1e91a84d9c1e3b4062af7b2aaf0eff06032bb8 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 1 Jul 2026 18:00:06 -0700 Subject: [PATCH 10/16] Reject folded AutoEP DeepCompile configs Signed-off-by: Masahiro Tanaka --- deepspeed/module_inject/auto_ep_config.py | 2 ++ deepspeed/module_inject/auto_ep_folding.py | 5 +++++ deepspeed/runtime/engine.py | 1 + 3 files changed, 8 insertions(+) diff --git a/deepspeed/module_inject/auto_ep_config.py b/deepspeed/module_inject/auto_ep_config.py index 922cb1fc95ee..2f1a81dfa8b6 100644 --- a/deepspeed/module_inject/auto_ep_config.py +++ b/deepspeed/module_inject/auto_ep_config.py @@ -100,6 +100,7 @@ def validate_autoep_config( sp_size: int, *, zero_stage: int = 0, + deepcompile_enabled: bool = False, tp_preset_model: str | None = None, use_data_before_expert_parallel: bool = False, mpu=None, @@ -128,6 +129,7 @@ def validate_autoep_config( folding_spec, zero_stage=zero_stage, sp_size=sp_size, + deepcompile_enabled=deepcompile_enabled, use_data_before_expert_parallel=use_data_before_expert_parallel, mpu=mpu, autoep_enabled=config.enabled, diff --git a/deepspeed/module_inject/auto_ep_folding.py b/deepspeed/module_inject/auto_ep_folding.py index effddda89a6b..a22078a7b08d 100644 --- a/deepspeed/module_inject/auto_ep_folding.py +++ b/deepspeed/module_inject/auto_ep_folding.py @@ -193,6 +193,7 @@ def validate_folding_global( *, zero_stage: int = 0, sp_size: int = 1, + deepcompile_enabled: bool = False, use_data_before_expert_parallel: bool = False, mpu=None, autoep_enabled: bool = True, @@ -205,6 +206,10 @@ def validate_folding_global( if not autoep_enabled: return + if deepcompile_enabled and spec.tp_size > 1: + raise ValueError("DeepCompile with AutoEP+AutoTP folding is not supported; " + "disable compile.deepcompile or use non-folded AutoEP with tensor_parallel.autotp_size=1.") + if spec.tp_size > 1 and spec.pp_size > 1: raise ValueError("AutoEP+AutoTP folding currently supports pp_size=1 only; " f"got pp_size={spec.pp_size}. Pipeline-parallel validation is planned separately.") diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 642bec125f70..e507ddfb9261 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -564,6 +564,7 @@ def _configure_expert_parallel(self, model): folding_spec, zero_stage=self.zero_optimization_stage(), sp_size=sp_size, + deepcompile_enabled=self._config.compile_config.deepcompile, use_data_before_expert_parallel=self._config.use_data_before_expert_parallel_, mpu=self.mpu, autoep_enabled=autoep_config.enabled, From 3fcc33bd0dd37b404a7ccc4683c2f8e65500164e Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 1 Jul 2026 18:01:59 -0700 Subject: [PATCH 11/16] Add BF16 folded HP grad correction test Signed-off-by: Masahiro Tanaka --- .../test_autoep_autotp_bf16_folding_parity.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/unit/v1/moe/test_autoep_autotp_bf16_folding_parity.py diff --git a/tests/unit/v1/moe/test_autoep_autotp_bf16_folding_parity.py b/tests/unit/v1/moe/test_autoep_autotp_bf16_folding_parity.py new file mode 100644 index 000000000000..6bda462d97d8 --- /dev/null +++ b/tests/unit/v1/moe/test_autoep_autotp_bf16_folding_parity.py @@ -0,0 +1,76 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""BF16 HP-buffer coverage for AutoEP + AutoTP folding correction.""" + +from types import SimpleNamespace + +import torch + +from deepspeed.runtime import bf16_optimizer as bf16_mod +from deepspeed.runtime.bf16_optimizer import BF16_Optimizer + + +def _bf16_optimizer_stub(lp, hp_grad): + optimizer = object.__new__(BF16_Optimizer) + optimizer.autoep_folding_spec = SimpleNamespace(tp_size=2, mp_mode="tp") + optimizer.autoep_folding_tp_group = object() + optimizer.param_names = {lp: "model.layers.0.mlp.router.gate.weight"} + optimizer.fp32_groups_gradients = [[hp_grad]] + optimizer.fp32_groups_has_gradients = [[False]] + return optimizer + + +def test_bf16_hp_grad_update_uses_folded_correction_before_hp_buffer(monkeypatch): + lp = torch.nn.Parameter(torch.ones(2, dtype=torch.bfloat16)) + lp.grad = torch.full((2, ), 4.0, dtype=torch.bfloat16) + hp_grad = torch.zeros(2, dtype=torch.float32) + optimizer = _bf16_optimizer_stub(lp, hp_grad) + calls = [] + + def fake_apply_folding_correction(folding_spec, param, grad, *, tp_group, param_name=None): + calls.append({ + "folding_spec": folding_spec, + "param": param, + "tp_group": tp_group, + "param_name": param_name, + "grad_before": grad.detach().float().clone(), + }) + grad.data.mul_(0.5) + param.ds_autoep_folding_grad_corrected = True + return "average" + + monkeypatch.setattr(bf16_mod, "apply_folding_correction_to_grad_buffer", fake_apply_folding_correction, + raising=False) + + optimizer._update_hp_grad(lp, group_idx=0, param_idx=0, clear_lp_grads=False) + + assert len(calls) == 1 + assert calls[0]["param"] is lp + assert calls[0]["param_name"] == "model.layers.0.mlp.router.gate.weight" + torch.testing.assert_close(calls[0]["grad_before"], torch.full((2, ), 4.0)) + torch.testing.assert_close(hp_grad, torch.full((2, ), 2.0)) + assert optimizer.fp32_groups_has_gradients[0][0] is True + assert lp.ds_autoep_folding_grad_corrected is True + + +def test_bf16_immediate_grad_update_hook_reuses_corrected_hp_update(monkeypatch): + lp = torch.nn.Parameter(torch.ones(2, dtype=torch.bfloat16)) + lp.grad = torch.full((2, ), 6.0, dtype=torch.bfloat16) + hp_grad = torch.zeros(2, dtype=torch.float32) + optimizer = _bf16_optimizer_stub(lp, hp_grad) + optimizer.immediate_grad_update = True + + def fake_apply_folding_correction(_folding_spec, _param, grad, *, tp_group, param_name=None): + grad.data.div_(3.0) + _param.ds_autoep_folding_grad_corrected = True + return "expert_tp_cancel" + + monkeypatch.setattr(bf16_mod, "apply_folding_correction_to_grad_buffer", fake_apply_folding_correction, + raising=False) + + optimizer.accumulate_hp_grads_and_remove_lp(lp, group_idx=0, param_idx=0) + + torch.testing.assert_close(hp_grad, torch.full((2, ), 2.0)) + assert lp.ds_autoep_folding_grad_corrected is True From b485a75e68979d65e8df911a5ed8dfd74f22d966 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 1 Jul 2026 18:03:27 -0700 Subject: [PATCH 12/16] Apply folded correction before BF16 HP grad accumulation Signed-off-by: Masahiro Tanaka --- deepspeed/module_inject/auto_ep_folding.py | 27 ++++++++++++++++++++++ deepspeed/runtime/bf16_optimizer.py | 18 +++++++++++++++ deepspeed/runtime/engine.py | 7 +++++- 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/deepspeed/module_inject/auto_ep_folding.py b/deepspeed/module_inject/auto_ep_folding.py index a22078a7b08d..af8625f51992 100644 --- a/deepspeed/module_inject/auto_ep_folding.py +++ b/deepspeed/module_inject/auto_ep_folding.py @@ -19,6 +19,7 @@ AUTOEP_FOLDING_ROUTER_GATE_REPLICATED_PARAM = "router_gate_replicated" AUTOEP_FOLDING_ROUTER_GATE_PARTIAL_PARAM = "router_gate_partial" AUTOEP_FOLDING_SP_SHARDED_LAYERNORM_PARAM = "sp_sharded_layernorm" +AUTOEP_FOLDING_GRAD_CORRECTED_ATTR = "ds_autoep_folding_grad_corrected" AUTOEP_FOLDING_GRAD_REDUCE_SKIP = "skip" AUTOEP_FOLDING_GRAD_REDUCE_SUM = "sum" AUTOEP_FOLDING_GRAD_REDUCE_AVERAGE = "average" @@ -467,6 +468,32 @@ def reduce_autoep_folding_gradient( return strategy +def is_autoep_folding_gradient_corrected(param) -> bool: + return bool(getattr(param, AUTOEP_FOLDING_GRAD_CORRECTED_ATTR, False)) + + +def clear_autoep_folding_gradient_corrected(param) -> None: + if hasattr(param, AUTOEP_FOLDING_GRAD_CORRECTED_ATTR): + setattr(param, AUTOEP_FOLDING_GRAD_CORRECTED_ATTR, False) + + +def apply_folding_correction_to_grad_buffer( + folding_spec: ParallelFoldingSpec | None, + param, + grad, + *, + tp_group, + param_name: str | None = None, +) -> str: + if is_autoep_folding_gradient_corrected(param): + return AUTOEP_FOLDING_GRAD_REDUCE_SKIP + + strategy = reduce_autoep_folding_gradient(folding_spec, param, grad, tp_group=tp_group, param_name=param_name) + if strategy != AUTOEP_FOLDING_GRAD_REDUCE_SKIP: + setattr(param, AUTOEP_FOLDING_GRAD_CORRECTED_ATTR, True) + return strategy + + def _normalize_rank_groups(groups: Iterable[Iterable[int]]) -> set[tuple[int, ...]]: return {tuple(int(rank) for rank in group) for group in groups} diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 746618fb5bd9..7b26ed043f39 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -24,6 +24,7 @@ from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, PARAM_SLICE_MAPPINGS) +from deepspeed.module_inject.auto_ep_folding import apply_folding_correction_to_grad_buffer setattr(sys.modules[__name__], 'fragment_address', fragment_address) @@ -70,6 +71,8 @@ def __init__(self, self.clip_grad = clip_grad self.norm_type = norm_type self.mpu = mpu + self.autoep_folding_tp_group = None + self.autoep_folding_spec = None self.allgather_bucket_size = int(allgather_bucket_size) self.dp_process_group = dp_process_group self.dp_rank = dist.get_rank(group=self.dp_process_group) @@ -218,6 +221,14 @@ def _setup_for_real_optimizer(self): self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() + def configure_autoep_folding_tp_gradient_reduction(self, folding_spec): + if folding_spec is None or folding_spec.tp_size <= 1: + self.autoep_folding_tp_group = None + self.autoep_folding_spec = None + return + self.autoep_folding_tp_group = groups.get_tensor_model_parallel_group() + self.autoep_folding_spec = folding_spec + def _enable_universal_checkpoint(self): self._universal_checkpoint_info = None for lp_param_group in self.bf16_groups: @@ -346,6 +357,13 @@ def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads): if lp.grad is None: return + if self.autoep_folding_tp_group is not None and getattr(lp, "ds_grad_is_ready", True): + apply_folding_correction_to_grad_buffer(self.autoep_folding_spec, + lp, + lp.grad, + tp_group=self.autoep_folding_tp_group, + param_name=self.param_names.get(lp)) + hp_grad = self.fp32_groups_gradients[group_idx][param_idx] assert hp_grad is not None, \ f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{group_idx}][{param_idx}]' diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e507ddfb9261..e13a398f9283 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -43,7 +43,9 @@ from deepspeed.linear.optimized_linear import LoRAOptimizedLinear from deepspeed.module_inject.layers import GatherReplacedLayerParams, configure_tensor_parallel_runtime, collect_autotp_universal_checkpoint_info -from deepspeed.module_inject.auto_ep_folding import reduce_autoep_folding_gradient +from deepspeed.module_inject.auto_ep_folding import (clear_autoep_folding_gradient_corrected, + is_autoep_folding_gradient_corrected, + reduce_autoep_folding_gradient) from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \ @@ -2786,6 +2788,9 @@ def _reduce_autoep_folding_tp_replicated_gradients(self): for param_name, param in self.module.named_parameters(): if not param.requires_grad or param.grad is None: continue + if is_autoep_folding_gradient_corrected(param): + clear_autoep_folding_gradient_corrected(param) + continue reduce_autoep_folding_gradient(folding_spec, param, param.grad, tp_group=tp_group, param_name=param_name) def _backward_prologue(self): From 4906a3a61956e6957801d8ae7272950a1c2f6d05 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 1 Jul 2026 18:04:02 -0700 Subject: [PATCH 13/16] Add ZeRO-1 overlap folding correction tests Signed-off-by: Masahiro Tanaka --- ...est_autoep_autotp_zero1_overlap_folding.py | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 tests/unit/v1/moe/test_autoep_autotp_zero1_overlap_folding.py diff --git a/tests/unit/v1/moe/test_autoep_autotp_zero1_overlap_folding.py b/tests/unit/v1/moe/test_autoep_autotp_zero1_overlap_folding.py new file mode 100644 index 000000000000..f19063c2592e --- /dev/null +++ b/tests/unit/v1/moe/test_autoep_autotp_zero1_overlap_folding.py @@ -0,0 +1,85 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""ZeRO-1 overlap hook coverage for AutoEP + AutoTP folding correction.""" + +from types import SimpleNamespace + +import torch + +from deepspeed.runtime.zero import stage_1_and_2 as zero_mod +from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer + + +def _zero_optimizer_stub(*, partition_gradients, overlap_comm): + optimizer = object.__new__(DeepSpeedZeroOptimizer) + optimizer.partition_gradients = partition_gradients + optimizer.overlap_comm = overlap_comm + optimizer.autoep_folding_tp_group = object() + optimizer.autoep_folding_spec = SimpleNamespace(tp_size=2, mp_mode="tp") + return optimizer + + +def test_zero1_overlap_applies_folding_before_hook_consumes_grad(monkeypatch): + optimizer = _zero_optimizer_stub(partition_gradients=False, overlap_comm=True) + param = torch.nn.Parameter(torch.ones(2)) + grad = torch.full((2, ), 4.0) + calls = [] + + def fake_apply_folding_correction(folding_spec, param, grad, *, tp_group, param_name=None): + calls.append({ + "folding_spec": folding_spec, + "param": param, + "tp_group": tp_group, + "grad_before": grad.detach().clone(), + }) + grad.data.mul_(0.5) + param.ds_autoep_folding_grad_corrected = True + return "average" + + monkeypatch.setattr(zero_mod, "apply_folding_correction_to_grad_buffer", fake_apply_folding_correction, + raising=False) + + optimizer._maybe_reduce_autoep_folding_tp_gradient(param, grad) + + assert len(calls) == 1 + assert calls[0]["param"] is param + torch.testing.assert_close(calls[0]["grad_before"], torch.full((2, ), 4.0)) + torch.testing.assert_close(grad, torch.full((2, ), 2.0)) + assert param.ds_autoep_folding_grad_corrected is True + + +def test_zero1_nonoverlap_leaves_engine_boundary_sweep_owner(monkeypatch): + optimizer = _zero_optimizer_stub(partition_gradients=False, overlap_comm=False) + param = torch.nn.Parameter(torch.ones(2)) + grad = torch.full((2, ), 4.0) + calls = [] + monkeypatch.setattr(zero_mod, "apply_folding_correction_to_grad_buffer", + lambda *args, **kwargs: calls.append((args, kwargs)), + raising=False) + + optimizer._maybe_reduce_autoep_folding_tp_gradient(param, grad) + + assert calls == [] + torch.testing.assert_close(grad, torch.full((2, ), 4.0)) + + +def test_zero2_partitioned_path_still_applies_folding(monkeypatch): + optimizer = _zero_optimizer_stub(partition_gradients=True, overlap_comm=False) + param = torch.nn.Parameter(torch.ones(2)) + grad = torch.full((2, ), 4.0) + calls = [] + + def fake_apply_folding_correction(_folding_spec, _param, grad, *, tp_group, param_name=None): + calls.append(_param) + grad.data.div_(2.0) + return "expert_tp_cancel" + + monkeypatch.setattr(zero_mod, "apply_folding_correction_to_grad_buffer", fake_apply_folding_correction, + raising=False) + + optimizer._maybe_reduce_autoep_folding_tp_gradient(param, grad) + + assert calls == [param] + torch.testing.assert_close(grad, torch.full((2, ), 2.0)) From 42c4148de3c5af356825b58eb04fcccde53fd169 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 1 Jul 2026 18:04:22 -0700 Subject: [PATCH 14/16] Apply folded correction in ZeRO-1 overlap hooks Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/zero/stage_1_and_2.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index a74bcee59031..b73a80e87e5d 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -35,7 +35,7 @@ from deepspeed.runtime.constants import PIPE_REPLICATED from deepspeed.accelerator import get_accelerator from deepspeed.runtime.zero.muon.original_muon import muon_update -from deepspeed.module_inject.auto_ep_folding import reduce_autoep_folding_gradient +from deepspeed.module_inject.auto_ep_folding import apply_folding_correction_to_grad_buffer from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, @@ -1037,11 +1037,15 @@ def configure_autoep_folding_tp_gradient_reduction(self, folding_spec): self.autoep_folding_spec = folding_spec def _maybe_reduce_autoep_folding_tp_gradient(self, param, grad): - if not self.partition_gradients or self.autoep_folding_tp_group is None or grad is None: + if ((not self.partition_gradients and not self.overlap_comm) or self.autoep_folding_tp_group is None + or grad is None): return if not getattr(param, "ds_grad_is_ready", True): return - reduce_autoep_folding_gradient(self.autoep_folding_spec, param, grad, tp_group=self.autoep_folding_tp_group) + apply_folding_correction_to_grad_buffer(self.autoep_folding_spec, + param, + grad, + tp_group=self.autoep_folding_tp_group) def _fill_param_grad_accum_attribute(self, param): if param.grad is not None: From 96f3b3b720a685c0c10fcf2f9b418a635c213d5d Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 1 Jul 2026 20:42:48 -0700 Subject: [PATCH 15/16] Handle DeviceMesh split fallback for AutoTP groups Signed-off-by: Masahiro Tanaka --- deepspeed/utils/groups.py | 67 ++++++++++++++++++- .../model_parallelism/test_autotp_training.py | 57 ++++++++++++++++ 2 files changed, 122 insertions(+), 2 deletions(-) diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py index 7a0fc7eb6fa4..527672bf9cfd 100644 --- a/deepspeed/utils/groups.py +++ b/deepspeed/utils/groups.py @@ -25,6 +25,8 @@ For inference and other new scenarios, the code will be either reused or added to this file. """ +import os + from deepspeed import comm as dist from deepspeed.utils import log_dist from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size, bwc_pipeline_parallel_world_size @@ -52,6 +54,8 @@ mesh_device = None +_DEVICE_MESH_SPLIT_UNSUPPORTED = "No backend for the parent process group or its backend does not support splitting" + # Deprecated groups initialize function. def initialize(ep_size=1, mpu=None): @@ -81,6 +85,54 @@ def _ensure_divisibility(numerator, denominator): _MPU_TENSOR_MODEL_PARALLEL_RANK = None +def _init_tp_groups_with_new_group(tensor_model_parallel_size=1, data_parallel_size=None): + """Initialize TP/DP groups with explicit rank lists. + + This mirrors a 2D DeviceMesh shaped as (data_parallel, tensor_parallel), + while avoiding DeviceMesh's optimized split_group path. + """ + + global _DATA_PARALLEL_GROUP + global _MODEL_PARALLEL_GROUP + global _TENSOR_MODEL_PARALLEL_GROUP + + world_size = dist.get_world_size() + _ensure_divisibility(world_size, tensor_model_parallel_size) + + if data_parallel_size is None: + data_parallel_size = world_size // tensor_model_parallel_size + else: + assert data_parallel_size * tensor_model_parallel_size == world_size, ( + f"data_parallel_size ({data_parallel_size}) * tensor_model_parallel_size " + f"({tensor_model_parallel_size}) must equal world_size ({world_size})") + + rank = dist.get_rank() + data_parallel_group = None + tensor_model_parallel_group = None + + for tensor_rank in range(tensor_model_parallel_size): + ranks = list(range(tensor_rank, world_size, tensor_model_parallel_size)) + group = dist.new_group(ranks) + if rank in ranks: + data_parallel_group = group + + for data_rank in range(data_parallel_size): + start = data_rank * tensor_model_parallel_size + ranks = list(range(start, start + tensor_model_parallel_size)) + group = dist.new_group(ranks) + if rank in ranks: + tensor_model_parallel_group = group + + assert data_parallel_group is not None, 'data parallel group is not initialized' + assert tensor_model_parallel_group is not None, 'tensor parallel group is not initialized' + + _DATA_PARALLEL_GROUP = data_parallel_group + _TENSOR_MODEL_PARALLEL_GROUP = tensor_model_parallel_group + _MODEL_PARALLEL_GROUP = _TENSOR_MODEL_PARALLEL_GROUP + + return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP + + def _init_tp_mesh_device(tensor_model_parallel_size=1, data_parallel_size=None): """Initialize model data parallel groups.""" @@ -94,8 +146,19 @@ def _init_tp_mesh_device(tensor_model_parallel_size=1, data_parallel_size=None): if data_parallel_size is None: data_parallel_size = dist.get_world_size() // tensor_model_parallel_size - mesh_device = dist.initialize_mesh_device((data_parallel_size, tensor_model_parallel_size), - ("data_parallel", "tensor_parallel")) + if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "").upper() == "DETAIL": + log_dist("TORCH_DISTRIBUTED_DEBUG=DETAIL detected; initializing TP mesh groups with new_group", ranks=[0]) + return _init_tp_groups_with_new_group(tensor_model_parallel_size, data_parallel_size) + + try: + mesh_device = dist.initialize_mesh_device((data_parallel_size, tensor_model_parallel_size), + ("data_parallel", "tensor_parallel")) + except RuntimeError as exc: + if _DEVICE_MESH_SPLIT_UNSUPPORTED not in str(exc): + raise + log_dist("DeviceMesh process-group splitting is unsupported; falling back to new_group TP mesh groups", ranks=[0]) + return _init_tp_groups_with_new_group(tensor_model_parallel_size, data_parallel_size) + _TENSOR_MODEL_PARALLEL_GROUP = mesh_device.get_group(mesh_dim="tensor_parallel") _DATA_PARALLEL_GROUP = mesh_device.get_group(mesh_dim="data_parallel") diff --git a/tests/unit/model_parallelism/test_autotp_training.py b/tests/unit/model_parallelism/test_autotp_training.py index 64f0b1113b16..fd7298d83ca6 100644 --- a/tests/unit/model_parallelism/test_autotp_training.py +++ b/tests/unit/model_parallelism/test_autotp_training.py @@ -32,6 +32,63 @@ def reset_tp_model_init_state(): set_autotp_mode(training=False) +def _reset_tp_groups(monkeypatch): + monkeypatch.setattr(groups, "_DATA_PARALLEL_GROUP", None) + monkeypatch.setattr(groups, "_MODEL_PARALLEL_GROUP", None) + monkeypatch.setattr(groups, "_TENSOR_MODEL_PARALLEL_GROUP", None) + + +def _patch_tp_group_creation(monkeypatch, *, rank=2, initialize_mesh_device=None): + new_group_calls = [] + + def fake_new_group(ranks): + ranks = tuple(ranks) + new_group_calls.append(ranks) + return ranks + + _reset_tp_groups(monkeypatch) + monkeypatch.setattr(groups.dist, "get_world_size", lambda group=None: 4) + monkeypatch.setattr(groups.dist, "get_rank", lambda group=None: rank) + monkeypatch.setattr(groups.dist, "new_group", fake_new_group) + monkeypatch.setattr(groups, "log_dist", lambda *args, **kwargs: None) + if initialize_mesh_device is not None: + monkeypatch.setattr(groups.dist, "initialize_mesh_device", initialize_mesh_device) + + return new_group_calls + + +def test_init_tp_mesh_device_debug_detail_uses_explicit_groups(monkeypatch): + def fail_initialize_mesh_device(*args, **kwargs): + raise AssertionError("DeviceMesh should be skipped when TORCH_DISTRIBUTED_DEBUG=DETAIL") + + new_group_calls = _patch_tp_group_creation(monkeypatch, initialize_mesh_device=fail_initialize_mesh_device) + monkeypatch.setenv("TORCH_DISTRIBUTED_DEBUG", "DETAIL") + + data_parallel_group, tensor_parallel_group = groups._init_tp_mesh_device(tensor_model_parallel_size=2) + + assert new_group_calls == [(0, 2), (1, 3), (0, 1), (2, 3)] + assert data_parallel_group == (0, 2) + assert tensor_parallel_group == (2, 3) + assert groups.get_data_parallel_group() == (0, 2) + assert groups.get_tensor_model_parallel_group() == (2, 3) + + +def test_init_tp_mesh_device_split_error_falls_back_to_explicit_groups(monkeypatch): + def raise_split_error(*args, **kwargs): + raise RuntimeError(groups._DEVICE_MESH_SPLIT_UNSUPPORTED) + + new_group_calls = _patch_tp_group_creation(monkeypatch, initialize_mesh_device=raise_split_error) + monkeypatch.delenv("TORCH_DISTRIBUTED_DEBUG", raising=False) + + data_parallel_group, tensor_parallel_group = groups._init_tp_mesh_device(tensor_model_parallel_size=2) + + assert new_group_calls == [(0, 2), (1, 3), (0, 1), (2, 3)] + assert data_parallel_group == (0, 2) + assert tensor_parallel_group == (2, 3) + assert groups.get_data_parallel_group() == (0, 2) + assert groups.get_tensor_model_parallel_group() == (2, 3) + + class DummyMPU: def __init__(self, tp_world_size=1): From bc28dd8a6c0e4bb0cebf0c57e10a947f61137a1e Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 1 Jul 2026 22:34:17 -0700 Subject: [PATCH 16/16] Apply pre-commit formatting Signed-off-by: Masahiro Tanaka --- deepspeed/utils/groups.py | 3 ++- .../unit/model_parallelism/test_autotp_training.py | 2 ++ .../moe/test_autoep_autotp_bf16_folding_parity.py | 8 ++++++-- .../v1/moe/test_autoep_autotp_folding_config.py | 14 ++------------ .../test_autoep_autotp_zero1_overlap_folding.py | 11 ++++++++--- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py index 527672bf9cfd..389886d2ed98 100644 --- a/deepspeed/utils/groups.py +++ b/deepspeed/utils/groups.py @@ -156,7 +156,8 @@ def _init_tp_mesh_device(tensor_model_parallel_size=1, data_parallel_size=None): except RuntimeError as exc: if _DEVICE_MESH_SPLIT_UNSUPPORTED not in str(exc): raise - log_dist("DeviceMesh process-group splitting is unsupported; falling back to new_group TP mesh groups", ranks=[0]) + log_dist("DeviceMesh process-group splitting is unsupported; falling back to new_group TP mesh groups", + ranks=[0]) return _init_tp_groups_with_new_group(tensor_model_parallel_size, data_parallel_size) _TENSOR_MODEL_PARALLEL_GROUP = mesh_device.get_group(mesh_dim="tensor_parallel") diff --git a/tests/unit/model_parallelism/test_autotp_training.py b/tests/unit/model_parallelism/test_autotp_training.py index fd7298d83ca6..dcd90076d48f 100644 --- a/tests/unit/model_parallelism/test_autotp_training.py +++ b/tests/unit/model_parallelism/test_autotp_training.py @@ -58,6 +58,7 @@ def fake_new_group(ranks): def test_init_tp_mesh_device_debug_detail_uses_explicit_groups(monkeypatch): + def fail_initialize_mesh_device(*args, **kwargs): raise AssertionError("DeviceMesh should be skipped when TORCH_DISTRIBUTED_DEBUG=DETAIL") @@ -74,6 +75,7 @@ def fail_initialize_mesh_device(*args, **kwargs): def test_init_tp_mesh_device_split_error_falls_back_to_explicit_groups(monkeypatch): + def raise_split_error(*args, **kwargs): raise RuntimeError(groups._DEVICE_MESH_SPLIT_UNSUPPORTED) diff --git a/tests/unit/v1/moe/test_autoep_autotp_bf16_folding_parity.py b/tests/unit/v1/moe/test_autoep_autotp_bf16_folding_parity.py index 6bda462d97d8..fa02ce1cd16b 100644 --- a/tests/unit/v1/moe/test_autoep_autotp_bf16_folding_parity.py +++ b/tests/unit/v1/moe/test_autoep_autotp_bf16_folding_parity.py @@ -41,7 +41,9 @@ def fake_apply_folding_correction(folding_spec, param, grad, *, tp_group, param_ param.ds_autoep_folding_grad_corrected = True return "average" - monkeypatch.setattr(bf16_mod, "apply_folding_correction_to_grad_buffer", fake_apply_folding_correction, + monkeypatch.setattr(bf16_mod, + "apply_folding_correction_to_grad_buffer", + fake_apply_folding_correction, raising=False) optimizer._update_hp_grad(lp, group_idx=0, param_idx=0, clear_lp_grads=False) @@ -67,7 +69,9 @@ def fake_apply_folding_correction(_folding_spec, _param, grad, *, tp_group, para _param.ds_autoep_folding_grad_corrected = True return "expert_tp_cancel" - monkeypatch.setattr(bf16_mod, "apply_folding_correction_to_grad_buffer", fake_apply_folding_correction, + monkeypatch.setattr(bf16_mod, + "apply_folding_correction_to_grad_buffer", + fake_apply_folding_correction, raising=False) optimizer.accumulate_hp_grads_and_remove_lp(lp, group_idx=0, param_idx=0) diff --git a/tests/unit/v1/moe/test_autoep_autotp_folding_config.py b/tests/unit/v1/moe/test_autoep_autotp_folding_config.py index 71293a0f080d..e4fe8bd7b67a 100644 --- a/tests/unit/v1/moe/test_autoep_autotp_folding_config.py +++ b/tests/unit/v1/moe/test_autoep_autotp_folding_config.py @@ -134,22 +134,12 @@ def test_validation_rule_g11_zero_offload_rejected(offload_key): def test_deepcompile_folded_rejected(): config = AutoEPConfig(enabled=True, autoep_size=2, expert_tensor_parallel_size=1) with pytest.raises(ValueError, match="DeepCompile.*AutoEP\\+AutoTP folding"): - validate_autoep_config(config, - world_size=4, - pp_size=1, - tp_size=2, - sp_size=1, - deepcompile_enabled=True) + validate_autoep_config(config, world_size=4, pp_size=1, tp_size=2, sp_size=1, deepcompile_enabled=True) def test_deepcompile_nonfolded_accepted(): config = AutoEPConfig(enabled=True, autoep_size=2, expert_tensor_parallel_size=1) - validate_autoep_config(config, - world_size=4, - pp_size=1, - tp_size=1, - sp_size=1, - deepcompile_enabled=True) + validate_autoep_config(config, world_size=4, pp_size=1, tp_size=1, sp_size=1, deepcompile_enabled=True) @pytest.mark.parametrize( diff --git a/tests/unit/v1/moe/test_autoep_autotp_zero1_overlap_folding.py b/tests/unit/v1/moe/test_autoep_autotp_zero1_overlap_folding.py index f19063c2592e..1f4388b5115d 100644 --- a/tests/unit/v1/moe/test_autoep_autotp_zero1_overlap_folding.py +++ b/tests/unit/v1/moe/test_autoep_autotp_zero1_overlap_folding.py @@ -38,7 +38,9 @@ def fake_apply_folding_correction(folding_spec, param, grad, *, tp_group, param_ param.ds_autoep_folding_grad_corrected = True return "average" - monkeypatch.setattr(zero_mod, "apply_folding_correction_to_grad_buffer", fake_apply_folding_correction, + monkeypatch.setattr(zero_mod, + "apply_folding_correction_to_grad_buffer", + fake_apply_folding_correction, raising=False) optimizer._maybe_reduce_autoep_folding_tp_gradient(param, grad) @@ -55,7 +57,8 @@ def test_zero1_nonoverlap_leaves_engine_boundary_sweep_owner(monkeypatch): param = torch.nn.Parameter(torch.ones(2)) grad = torch.full((2, ), 4.0) calls = [] - monkeypatch.setattr(zero_mod, "apply_folding_correction_to_grad_buffer", + monkeypatch.setattr(zero_mod, + "apply_folding_correction_to_grad_buffer", lambda *args, **kwargs: calls.append((args, kwargs)), raising=False) @@ -76,7 +79,9 @@ def fake_apply_folding_correction(_folding_spec, _param, grad, *, tp_group, para grad.data.div_(2.0) return "expert_tp_cancel" - monkeypatch.setattr(zero_mod, "apply_folding_correction_to_grad_buffer", fake_apply_folding_correction, + monkeypatch.setattr(zero_mod, + "apply_folding_correction_to_grad_buffer", + fake_apply_folding_correction, raising=False) optimizer._maybe_reduce_autoep_folding_tp_gradient(param, grad)