Skip to content

[WIP] Make accelerate work end-to-end on AMD ROCm #4025

Draft
Abdennacer-Badaoui wants to merge 3 commits intohuggingface:mainfrom
Abdennacer-Badaoui:accelerate-on-amd
Draft

[WIP] Make accelerate work end-to-end on AMD ROCm #4025
Abdennacer-Badaoui wants to merge 3 commits intohuggingface:mainfrom
Abdennacer-Badaoui:accelerate-on-amd

Conversation

@Abdennacer-Badaoui
Copy link
Copy Markdown
Member

@Abdennacer-Badaoui Abdennacer-Badaoui commented Apr 30, 2026

Description

Validated accelerate on 8× MI300X (ROCm 7.1, PyTorch 2.8.0+rocm7.1.0) and patched the gaps. Changes are minimal and gated on ROCm where appropriate; CUDA paths are untouched.

ROCm detection helpers src/accelerate/utils/imports.py, __init__.py

  • Added is_rocm_available() and is_amdsmi_available() along with their exports. These helpers are used across the changes below.

NUMA affinity on ROCm src/accelerate/utils/environment.py

  • override_numa_affinity now branches on ROCm.
  • Uses amdsmi + sysfs (/sys/devices/system/node/node{N}/cpulist) to resolve a GPU's NUMA node.

Notebook launcher under ROCm src/accelerate/launchers.py, src/accelerate/test_utils/scripts/test_notebook.py

  • torch.cuda.is_available() initializes the HIP runtime in the parent process on ROCm, which breaks fork-based subprocesses.
  • notebook_launcher now uses spawn on ROCm (same handling as XPU).
  • Fault-tolerant test queue context updated to use spawn so worker processes can consume it correctly.

DeepSpeed bf16 silently producing NaNs src/accelerate/utils/dataclasses.py

  • On ROCm + DeepSpeed + bf16, training could silently produce all-NaN weights within a few steps.
  • Root cause: DeepSpeed’s bf16 path lacks NaN/Inf protection (unlike fp16 with loss scaling).
  • Fix: when this combination is detected, automatically inject communication_data_type="fp32" into the DeepSpeed config (logged at INFO level). User-defined values are respected (no override).

FSDP2 + tied weights src/accelerate/utils/fsdp_utils.py

  • fully_shard breaks the id-based deduplication used by PyTorch state_dict() for tied weights.
  • Result: meta-side state dict has extra keys (e.g. lm_head.weight vs deduped embed_tokens.weight).
  • Fix in fsdp2_load_full_state_dict:
    • Skip missing keys listed in model._tied_weights_keys on all ranks (keeps broadcast aligned).
    • Use strict=False only when such skips occur.
    • Caller’s tie_weights() restores correct parameter sharing post-load.

bitsandbytes (bnb) + tied weights src/accelerate/utils/bnb.py

  • load_and_quantize_model iterated over named_parameters() with remove_duplicate=True, skipping aliases of tied parameters.
  • Example (BLOOM):
    • transformer.word_embeddings.weight processed
    • lm_head.weight skipped → keep_in_fp32_modules=["lm_head"] ineffective
  • Fix: use remove_duplicate=False so all aliases are visited.

test_dynamo cross-device fix tests/test_utils.py

  • Model was on CPU while inputs were moved to torch_device.
  • Triggered an Inductor + autograd cross-device assertion on ROCm.
  • Fix: move the model to torch_device to match inputs.

Notes

  • The bf16 fix is broadly beneficial (including NVIDIA), but the failure mode was only reliably reproducible on ROCm in our testing, hence the gating.

Some related PRs :

@Abdennacer-Badaoui Abdennacer-Badaoui marked this pull request as draft April 30, 2026 08:16
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants