Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/accelerate/launchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
get_current_device_type,
get_gpu_info,
is_mps_available,
is_rocm_available,
is_torch_version,
patch_environment,
)
Expand Down Expand Up @@ -206,8 +207,9 @@ def train(*args):
# First dummy launch
# Determine device type without initializing any device (which would break fork)
device_type, distributed_type = get_current_device_type()
# XPU requires spawn instead of fork
start_method = "spawn" if device_type == "xpu" else "fork"
# XPU and ROCm require spawn instead of fork (HIP/XPU runtime is initialized in the parent,
# which breaks fork-based subprocesses).
start_method = "spawn" if device_type == "xpu" or is_rocm_available() else "fork"
if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
launcher = PrepareForLaunch(test_launch, distributed_type=distributed_type)
try:
Expand Down
7 changes: 4 additions & 3 deletions src/accelerate/test_utils/scripts/test_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from accelerate import PartialState, notebook_launcher
from accelerate.test_utils import require_bnb
from accelerate.utils import is_bnb_available, is_xpu_available
from accelerate.utils import is_bnb_available, is_rocm_available, is_xpu_available


def basic_function():
Expand Down Expand Up @@ -72,8 +72,9 @@ def test_fault_tolerant(max_restarts: int = 3):
# Use torch.multiprocessing to get the right context for the current device
import torch.multiprocessing as mp

# Get appropriate context - 'spawn' for XPU, 'fork' for others
if is_xpu_available():
# Get appropriate context - 'spawn' for XPU/ROCm (matches notebook_launcher's start_method),
# 'fork' for others
if is_xpu_available() or is_rocm_available():
ctx = mp.get_context("spawn")
else:
ctx = mp.get_context("fork")
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
is_4bit_bnb_available,
is_8bit_bnb_available,
is_aim_available,
is_amdsmi_available,
is_bf16_available,
is_bitsandbytes_multi_backend_available,
is_bnb_available,
Expand Down Expand Up @@ -121,6 +122,7 @@
is_pynvml_available,
is_pytest_available,
is_rich_available,
is_rocm_available,
is_sagemaker_available,
is_schedulefree_available,
is_sdaa_available,
Expand Down
5 changes: 4 additions & 1 deletion src/accelerate/utils/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,11 @@ def load_and_quantize_model(
)
model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
# convert param to the right dtype
# remove_duplicate=False so tied params (e.g. BLOOM's lm_head.weight tied to
# word_embeddings.weight) are visited under every alias — the keep_in_fp32 cast
# would otherwise be skipped if the tied alias came first under a different name.
dtype = bnb_quantization_config.torch_dtype
for name, param in model.named_parameters():
for name, param in model.named_parameters(remove_duplicate=False):
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
param.data = param.data.to(torch.float32)
elif torch.is_floating_point(param):
Expand Down
12 changes: 12 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
is_msamp_available,
is_musa_available,
is_npu_available,
is_rocm_available,
is_torchao_available,
is_transformer_engine_available,
is_xpu_available,
Expand Down Expand Up @@ -1441,6 +1442,17 @@ def set_mixed_precision(self, mixed_precision):
self.fill_match("fp16.enabled", must_match=False, **kwargs)
self.fill_match("bf16.enabled", must_match=False, **kwargs)

# On ROCm, bf16 DeepSpeed training can silently produce NaN weights because
# bf16 has no NaN/Inf safety net (unlike fp16 loss scaling). Accumulating
# gradients in fp32 for the collective avoids the overflow path.
if mixed_precision in ("bf16", "fp8") and is_rocm_available() and "communication_data_type" not in ds_config:
ds_config["communication_data_type"] = "fp32"
logger.info(
"ROCm + DeepSpeed + bf16 detected: setting "
"`communication_data_type='fp32'` to avoid bf16 overflow corrupting "
"weights. Set it explicitly in your DeepSpeed config to override."
)

def set_deepspeed_weakref(self):
from .imports import is_transformers_available

Expand Down
79 changes: 60 additions & 19 deletions src/accelerate/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,21 @@ def get_cpu_distributed_information() -> CPUInformation:
return CPUInformation(**information)


def _parse_cpu_list(cpu_list_str: str) -> list[int]:
"""Parse a Linux sysfs-style CPU list (e.g. "0-7,16-23") into a list of ints."""
cpus = []
for part in cpu_list_str.split(","):
part = part.strip()
if not part:
continue
if "-" in part:
start, end = part.split("-")
cpus.extend(range(int(start), int(end) + 1))
else:
cpus.append(int(part))
return cpus


def override_numa_affinity(local_process_index: int, verbose: Optional[bool] = None) -> None:
"""
Overrides whatever NUMA affinity is set for the current process. This is very taxing and requires recalculating the
Expand All @@ -297,25 +312,51 @@ def override_numa_affinity(local_process_index: int, verbose: Optional[bool] = N
if verbose is None:
verbose = parse_flag_from_env("ACCELERATE_DEBUG_MODE", False)
if torch.cuda.is_available():
from accelerate.utils import is_pynvml_available

if not is_pynvml_available():
raise ImportError(
"To set CPU affinity on CUDA GPUs the `nvidia-ml-py` package must be available. (`pip install nvidia-ml-py`)"
)
import pynvml as nvml

# The below code is based on https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/BERT/gpu_affinity.py
nvml.nvmlInit()
num_elements = math.ceil(os.cpu_count() / 64)
handle = nvml.nvmlDeviceGetHandleByIndex(local_process_index)
affinity_string = ""
for j in nvml.nvmlDeviceGetCpuAffinity(handle, num_elements):
# assume nvml returns list of 64 bit ints
affinity_string = f"{j:064b}{affinity_string}"
affinity_list = [int(x) for x in affinity_string]
affinity_list.reverse() # so core 0 is the 0th element
affinity_to_set = [i for i, e in enumerate(affinity_list) if e != 0]
from accelerate.utils import is_amdsmi_available, is_pynvml_available, is_rocm_available

affinity_to_set = None

if is_rocm_available():
if not is_amdsmi_available():
raise ImportError(
"To set CPU affinity on ROCm GPUs the `amdsmi` package must be available. "
"It ships with ROCm; ensure the ROCm Python bindings are on PYTHONPATH."
)
import amdsmi

amdsmi.amdsmi_init()
try:
handles = amdsmi.amdsmi_get_processor_handles()
handle = handles[local_process_index]
numa_node = amdsmi.amdsmi_topo_get_numa_node_number(handle)
if numa_node is None or numa_node < 0:
# GPU is not bound to a NUMA node; fall back to all online CPUs
affinity_to_set = list(os.sched_getaffinity(0))
else:
with open(f"/sys/devices/system/node/node{numa_node}/cpulist") as f:
cpu_list_str = f.read().strip()
affinity_to_set = _parse_cpu_list(cpu_list_str)
finally:
amdsmi.amdsmi_shut_down()
else:
if not is_pynvml_available():
raise ImportError(
"To set CPU affinity on CUDA GPUs the `nvidia-ml-py` package must be available. (`pip install nvidia-ml-py`)"
)
import pynvml as nvml

# The below code is based on https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/BERT/gpu_affinity.py
nvml.nvmlInit()
num_elements = math.ceil(os.cpu_count() / 64)
handle = nvml.nvmlDeviceGetHandleByIndex(local_process_index)
affinity_string = ""
for j in nvml.nvmlDeviceGetCpuAffinity(handle, num_elements):
# assume nvml returns list of 64 bit ints
affinity_string = f"{j:064b}{affinity_string}"
affinity_list = [int(x) for x in affinity_string]
affinity_list.reverse() # so core 0 is the 0th element
affinity_to_set = [i for i, e in enumerate(affinity_list) if e != 0]

os.sched_setaffinity(0, affinity_to_set)
if verbose:
cpu_cores = os.sched_getaffinity(0)
Expand Down
23 changes: 21 additions & 2 deletions src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,12 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
meta_sharded_sd = model.state_dict()
sharded_sd = {}

# `fully_shard` wraps each parameter in its own DTensor, breaking the `id`-equality
# that PyTorch's `state_dict()` uses to dedupe tied weights. So the meta-side dict
# may contain keys (e.g. `lm_head.weight`) whose source `full_sd` deduped away.
# These keys will be re-tied by the caller after loading, so we can skip them.
tied_keys = set(getattr(model, "_tied_weights_keys", None) or [])

# Rank 0 distributes the full state dict to other ranks
def _infer_parameter_dtype(model, param_name, empty_param):
try:
Expand Down Expand Up @@ -513,6 +519,8 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
if accelerator.is_main_process:
for param_name, sharded_param in meta_sharded_sd.items():
if param_name not in full_sd:
if param_name in tied_keys:
continue # will be re-tied to its source key after loading
raise KeyError(
f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict. "
f"Full state dict has {len(full_sd)} keys, sharded has {len(meta_sharded_sd)} keys."
Expand Down Expand Up @@ -540,6 +548,8 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
# We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
else:
for param_name, sharded_param in meta_sharded_sd.items():
if param_name in tied_keys and param_name not in full_sd:
continue # mirror the rank-0 skip so broadcast counts stay aligned
device_mesh = sharded_param.device_mesh
full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype)
dist.broadcast(full_tensor, src=0, group=dist.group.WORLD)
Expand All @@ -555,8 +565,17 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
sharded_tensor = sharded_tensor.to("cpu")
sharded_sd[param_name] = sharded_tensor

# we set `assign=True` because our params are on meta device
model.load_state_dict(sharded_sd, assign=True)
# We set `assign=True` because our params are on meta device. We use `strict=False`
# only when the missing keys are tied-weight keys (skipped above): the caller will
# call `tie_weights()` right after, repointing them at their already-loaded source.
skipped_tied = [k for k in meta_sharded_sd if k in tied_keys and k not in full_sd]
incompatible = model.load_state_dict(sharded_sd, assign=True, strict=not skipped_tied)
if skipped_tied:
unexpected_missing = set(incompatible.missing_keys) - set(skipped_tied)
if unexpected_missing:
raise RuntimeError(
f"Unexpected missing keys when loading sharded state dict: {sorted(unexpected_missing)}"
)
return model


Expand Down
8 changes: 8 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ def is_pynvml_available():
return _is_package_available("pynvml") or _is_package_available("pynvml", "nvidia-ml-py")


def is_amdsmi_available():
return _is_package_available("amdsmi")


def is_rocm_available():
return torch.cuda.is_available() and torch.version.hip is not None


def is_pytest_available():
return _is_package_available("pytest")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors, tied_weights
@require_non_cpu
@require_huggingface_suite
def test_nested_hook(self):
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers import PretrainedConfig, PreTrainedModel

class MyLinear(torch.nn.Module):
def __init__(self, device=None, dtype=None):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_can_undo_fp16_conversion(self):
@require_triton
@require_non_cpu
def test_dynamo(self):
model = RegressionModel()
model = RegressionModel().to(torch_device)
model._original_forward = model.forward
model.forward = torch.autocast(device_type=torch_device, dtype=torch.float16)(model.forward)
model.forward = convert_outputs_to_fp32(model.forward)
Expand Down
Loading