Skip to content

Add weight_init_device support for both Split and Dense TBE kernels#5640

Open
TroyGarden wants to merge 1 commit intopytorch:mainfrom
TroyGarden:export-D101011345
Open

Add weight_init_device support for both Split and Dense TBE kernels#5640
TroyGarden wants to merge 1 commit intopytorch:mainfrom
TroyGarden:export-D101011345

Conversation

@TroyGarden
Copy link
Copy Markdown
Contributor

@TroyGarden TroyGarden commented Apr 15, 2026

Summary:

1. Context

The main purpose of this diff is to allow TBE weights to be initialized on a given device (CPU/meta) that differs from the originally assigned compute device (current_device, which is also used for metadata like D_offsets, hash_size_cumsum, weights_placements, weights_offsets, etc.). The user is responsible for handling the actual allocation and placement of the TBE weights after initialization — such as quantization, shared memory mapping, checkpoint loading, or any other post-processing before moving weights to the compute device for forward passes.

This enables eval workflows where temporary FP32 weights cause an HBM peak during model initialization, but are only needed briefly before being quantized, loaded from shared memory, or otherwise transformed. Example use cases:

  1. Checkpoint eval with quantization: Init weights on CPU, load FP32 checkpoint to CPU, quantize to NFP8 on CPU, move only the smaller quantized weights to GPU — FP32 never touches HBM.
  2. Shared memory eval: Init weights on meta device (zero allocation), then replace with a UVM tensor backed by POSIX shared memory (/dev/shm). Multiple eval workers share the same physical memory for read-only embedding lookups without per-process HBM copies.

2. Approach

  1. Unified weight_init_device parameter: Add weight_init_device: Optional[torch.device] to both SplitTableBatchedEmbeddingBagsCodegen and DenseTableBatchedEmbeddingBagsCodegen. When set, weight buffers are allocated on the specified device instead of current_device. Metadata tensors (D_offsets, hash_size_cumsum, weights_offsets, weights_placements) remain on current_device.

  2. Split TBE (dev_weight_init_device): In apply_split_helper, the new dev_weight_init_device parameter overrides the allocation device for dev_buffer only. Metadata tensors stay on current_device (GPU) so the CUDA forward kernel can read them once weights are moved back.

  3. Dense TBE: weight_init_device directly controls the device for self.weights = nn.Parameter(torch.randn(..., device=weights_device)).

TorchRec wiring (BatchedFusedEmbeddingBag, BatchedDenseEmbeddingBag, ShardingEnv, DistributedModelParallel) is in the dependent diff D99947880.

3. Analysis

  1. Backward compatibility: All new parameters default to None, preserving existing behavior.
  2. Device mismatch risk: When weight_init_device is set, the TBE weight buffers live on a different device than current_device. This creates a device mismatch that will cause runtime errors if the user attempts a forward pass before moving weights to the compute device. Downstream processes (checkpoint loading, state dict operations, TableBatchedEmbeddingSlice view creation) will also see weights on the init device. The user is responsible for ensuring weights are on the correct device before any operation that requires them on the compute device.
  3. Caller responsibility: weight_init_device is an init-only setting. The caller must handle moving weights to the compute device before forward passes — whether via quantization, shared memory registration, or explicit .to().
  4. Scope for Split TBE: dev_weight_init_device only affects the weights prefix _apply_split call, not optimizer state _apply_split calls.
  5. Metadata separation: All metadata tensors (D_offsets, hash_size_cumsum, weights_offsets, weights_placements, etc.) remain on current_device regardless of weight_init_device. Only the weight data buffers are placed on the init device.

4. Changes

  1. apply_split_helper: Added dev_weight_init_device parameter. When set, dev_buffer is allocated on this device instead of current_device. Metadata tensors ({prefix}_offsets, {prefix}_placements) are unaffected.
  2. SplitTableBatchedEmbeddingBagsCodegen._apply_split: Threads dev_weight_init_device to apply_split_helper.
  3. SplitTableBatchedEmbeddingBagsCodegen.__init__: Added weight_init_device parameter, passed to _apply_split for the weights prefix only (not optimizer states).
  4. DenseTableBatchedEmbeddingBagsCodegen.__init__: Added weight_init_device parameter. self.weights is allocated on weights_device instead of self.current_device.
  5. Added logging when device override is active for both Split and Dense TBE.

Differential Revision: D101011345

@meta-cla meta-cla Bot added the cla signed label Apr 15, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Apr 15, 2026

@TroyGarden has exported this pull request. If you are a Meta employee, you can view the originating Diff in D101011345.

@meta-codesync meta-codesync Bot changed the title Add UVM host-mapped memory support for dense TBE kernel Add UVM host-mapped memory support for dense TBE kernel (#5640) Apr 15, 2026
TroyGarden added a commit to TroyGarden/FBGEMM that referenced this pull request Apr 15, 2026
Summary:

X-link: facebookresearch/FBGEMM#2587

Add support for allocating dense TBE (Table Batched Embedding) weights in host-mapped UVM memory. This enables eval workflows where embedding tables reside in host DRAM (freeing HBM) while the TBE forward kernel still executes on GPU, reading embeddings over PCIe.

Changes:
- Add `CUDAHostRegisteredContext` in FBGEMM GPU memory utils: manages CUDA host registration lifetime for externally-owned memory (e.g. POSIX shared memory from `/dev/shm`). Unlike `CUDAHostMappedContext`, it does NOT free the underlying memory on destruction — only calls `cudaHostUnregister`.
- Add `cuda_register_host_memory()` op: registers an existing CPU tensor with CUDA via `cudaHostRegister` and returns a CUDA tensor backed by the same physical memory. The returned tensor's storage holds a reference to the input tensor's storage to keep the backing memory alive.
- Add `uvm_host_mapped` parameter to `DenseTableBatchedEmbeddingBagsCodegen`: when True, allocates embedding weights via `new_unified_tensor(is_host_mapped=True)` instead of regular device memory.
- Wire `uvm_host_mapped` through TorchRec's `BatchedDenseEmbeddingBag` via `fused_params`.

This is part of the eval workflow that uses data-parallel sharding with dense kernel on UVM, eliminating the need for input/output distribution and permutation across ranks.

Differential Revision: D101011345
TroyGarden added a commit to TroyGarden/FBGEMM that referenced this pull request Apr 15, 2026
Summary:

X-link: facebookresearch/FBGEMM#2587

Add support for allocating dense TBE (Table Batched Embedding) weights in host-mapped UVM memory. This enables eval workflows where embedding tables reside in host DRAM (freeing HBM) while the TBE forward kernel still executes on GPU, reading embeddings over PCIe.

Changes:
- Add `CUDAHostRegisteredContext` in FBGEMM GPU memory utils: manages CUDA host registration lifetime for externally-owned memory (e.g. POSIX shared memory from `/dev/shm`). Unlike `CUDAHostMappedContext`, it does NOT free the underlying memory on destruction — only calls `cudaHostUnregister`.
- Add `cuda_register_host_memory()` op: registers an existing CPU tensor with CUDA via `cudaHostRegister` and returns a CUDA tensor backed by the same physical memory. The returned tensor's storage holds a reference to the input tensor's storage to keep the backing memory alive.
- Add `uvm_host_mapped` parameter to `DenseTableBatchedEmbeddingBagsCodegen`: when True, allocates embedding weights via `new_unified_tensor(is_host_mapped=True)` instead of regular device memory.
- Wire `uvm_host_mapped` through TorchRec's `BatchedDenseEmbeddingBag` via `fused_params`.

This is part of the eval workflow that uses data-parallel sharding with dense kernel on UVM, eliminating the need for input/output distribution and permutation across ranks.

Differential Revision: D101011345
TroyGarden added a commit to TroyGarden/FBGEMM that referenced this pull request Apr 15, 2026
Summary:
Pull Request resolved: pytorch#5640

X-link: https://github.com/facebookresearch/FBGEMM/pull/2587

Add support for allocating dense TBE (Table Batched Embedding) weights in host-mapped UVM memory. This enables eval workflows where embedding tables reside in host DRAM (freeing HBM) while the TBE forward kernel still executes on GPU, reading embeddings over PCIe.

Changes:
- Add `CUDAHostRegisteredContext` in FBGEMM GPU memory utils: manages CUDA host registration lifetime for externally-owned memory (e.g. POSIX shared memory from `/dev/shm`). Unlike `CUDAHostMappedContext`, it does NOT free the underlying memory on destruction — only calls `cudaHostUnregister`.
- Add `cuda_register_host_memory()` op: registers an existing CPU tensor with CUDA via `cudaHostRegister` and returns a CUDA tensor backed by the same physical memory. The returned tensor's storage holds a reference to the input tensor's storage to keep the backing memory alive.
- Add `uvm_host_mapped` parameter to `DenseTableBatchedEmbeddingBagsCodegen`: when True, allocates embedding weights via `new_unified_tensor(is_host_mapped=True)` instead of regular device memory.
- Wire `uvm_host_mapped` through TorchRec's `BatchedDenseEmbeddingBag` via `fused_params`.

This is part of the eval workflow that uses data-parallel sharding with dense kernel on UVM, eliminating the need for input/output distribution and permutation across ranks.

Differential Revision: D101011345
TroyGarden added a commit to TroyGarden/FBGEMM that referenced this pull request Apr 15, 2026
Summary:
X-link: meta-pytorch/torchrec#4117

Pull Request resolved: pytorch#5640

X-link: https://github.com/facebookresearch/FBGEMM/pull/2587

Add support for allocating dense TBE (Table Batched Embedding) weights in host-mapped UVM memory. This enables eval workflows where embedding tables reside in host DRAM (freeing HBM) while the TBE forward kernel still executes on GPU, reading embeddings over PCIe.

Changes:
- Add `CUDAHostRegisteredContext` in FBGEMM GPU memory utils: manages CUDA host registration lifetime for externally-owned memory (e.g. POSIX shared memory from `/dev/shm`). Unlike `CUDAHostMappedContext`, it does NOT free the underlying memory on destruction — only calls `cudaHostUnregister`.
- Add `cuda_register_host_memory()` op: registers an existing CPU tensor with CUDA via `cudaHostRegister` and returns a CUDA tensor backed by the same physical memory. The returned tensor's storage holds a reference to the input tensor's storage to keep the backing memory alive.
- Add `uvm_host_mapped` parameter to `DenseTableBatchedEmbeddingBagsCodegen`: when True, allocates embedding weights via `new_unified_tensor(is_host_mapped=True)` instead of regular device memory.
- Wire `uvm_host_mapped` through TorchRec's `BatchedDenseEmbeddingBag` via `fused_params`.

This is part of the eval workflow that uses data-parallel sharding with dense kernel on UVM, eliminating the need for input/output distribution and permutation across ranks.

Differential Revision: D101011345
@TroyGarden TroyGarden force-pushed the export-D101011345 branch 2 times, most recently from 1857768 to ae3b6da Compare April 16, 2026 15:36
TroyGarden added a commit to TroyGarden/FBGEMM that referenced this pull request Apr 16, 2026
Summary:

X-link: facebookresearch/FBGEMM#2587

Add support for allocating dense TBE (Table Batched Embedding) weights in host-mapped UVM memory. This enables eval workflows where embedding tables reside in host DRAM (freeing HBM) while the TBE forward kernel still executes on GPU, reading embeddings over PCIe.

Changes:
- Add `CUDAHostRegisteredContext` in FBGEMM GPU memory utils: manages CUDA host registration lifetime for externally-owned memory (e.g. POSIX shared memory from `/dev/shm`). Unlike `CUDAHostMappedContext`, it does NOT free the underlying memory on destruction — only calls `cudaHostUnregister`.
- Add `cuda_register_host_memory()` op: registers an existing CPU tensor with CUDA via `cudaHostRegister` and returns a CUDA tensor backed by the same physical memory. The returned tensor's storage holds a reference to the input tensor's storage to keep the backing memory alive.
- Add `uvm_host_mapped` parameter to `DenseTableBatchedEmbeddingBagsCodegen`: when True, allocates embedding weights via `new_unified_tensor(is_host_mapped=True)` instead of regular device memory.
- Wire `uvm_host_mapped` through TorchRec's `BatchedDenseEmbeddingBag` via `fused_params`.

This is part of the eval workflow that uses data-parallel sharding with dense kernel on UVM, eliminating the need for input/output distribution and permutation across ranks.

Differential Revision: D101011345
@meta-codesync meta-codesync Bot changed the title Add UVM host-mapped memory support for dense TBE kernel (#5640) Add weight_init_device support for both Split and Dense TBE kernels Apr 23, 2026
Summary:
## 1. Context

The main purpose of this diff is to allow TBE weights to be initialized on a given device (CPU/meta) that differs from the originally assigned compute device (`current_device`, which is also used for metadata like `D_offsets`, `hash_size_cumsum`, `weights_placements`, `weights_offsets`, etc.). The user is responsible for handling the actual allocation and placement of the TBE weights after initialization — such as quantization, shared memory mapping, checkpoint loading, or any other post-processing before moving weights to the compute device for forward passes.

This enables eval workflows where temporary FP32 weights cause an HBM peak during model initialization, but are only needed briefly before being quantized, loaded from shared memory, or otherwise transformed. Example use cases:

1. **Checkpoint eval with quantization**: Init weights on CPU, load FP32 checkpoint to CPU, quantize to NFP8 on CPU, move only the smaller quantized weights to GPU — FP32 never touches HBM.
2. **Shared memory eval**: Init weights on meta device (zero allocation), then replace with a UVM tensor backed by POSIX shared memory (`/dev/shm`). Multiple eval workers share the same physical memory for read-only embedding lookups without per-process HBM copies.

## 2. Approach

1. **Unified `weight_init_device` parameter**: Add `weight_init_device: Optional[torch.device]` to both `SplitTableBatchedEmbeddingBagsCodegen` and `DenseTableBatchedEmbeddingBagsCodegen`. When set, weight buffers are allocated on the specified device instead of `current_device`. Metadata tensors (`D_offsets`, `hash_size_cumsum`, `weights_offsets`, `weights_placements`) remain on `current_device`.

2. **Split TBE (`dev_weight_init_device`)**: In `apply_split_helper`, the new `dev_weight_init_device` parameter overrides the allocation device for `dev_buffer` only. Metadata tensors stay on `current_device` (GPU) so the CUDA forward kernel can read them once weights are moved back.

3. **Dense TBE**: `weight_init_device` directly controls the device for `self.weights = nn.Parameter(torch.randn(..., device=weights_device))`.

TorchRec wiring (`BatchedFusedEmbeddingBag`, `BatchedDenseEmbeddingBag`, `ShardingEnv`, `DistributedModelParallel`) is in the dependent diff D99947880.

## 3. Analysis

1. **Backward compatibility**: All new parameters default to `None`, preserving existing behavior.
2. **Device mismatch risk**: When `weight_init_device` is set, the TBE weight buffers live on a different device than `current_device`. This creates a device mismatch that will cause runtime errors if the user attempts a forward pass before moving weights to the compute device. Downstream processes (checkpoint loading, state dict operations, `TableBatchedEmbeddingSlice` view creation) will also see weights on the init device. The user is responsible for ensuring weights are on the correct device before any operation that requires them on the compute device.
3. **Caller responsibility**: `weight_init_device` is an init-only setting. The caller must handle moving weights to the compute device before forward passes — whether via quantization, shared memory registration, or explicit `.to()`.
4. **Scope for Split TBE**: `dev_weight_init_device` only affects the `weights` prefix `_apply_split` call, not optimizer state `_apply_split` calls.
5. **Metadata separation**: All metadata tensors (`D_offsets`, `hash_size_cumsum`, `weights_offsets`, `weights_placements`, etc.) remain on `current_device` regardless of `weight_init_device`. Only the weight data buffers are placed on the init device.

## 4. Changes

1. **`apply_split_helper`**: Added `dev_weight_init_device` parameter. When set, `dev_buffer` is allocated on this device instead of `current_device`. Metadata tensors (`{prefix}_offsets`, `{prefix}_placements`) are unaffected.
2. **`SplitTableBatchedEmbeddingBagsCodegen._apply_split`**: Threads `dev_weight_init_device` to `apply_split_helper`.
3. **`SplitTableBatchedEmbeddingBagsCodegen.__init__`**: Added `weight_init_device` parameter, passed to `_apply_split` for the `weights` prefix only (not optimizer states).
4. **`DenseTableBatchedEmbeddingBagsCodegen.__init__`**: Added `weight_init_device` parameter. `self.weights` is allocated on `weights_device` instead of `self.current_device`.
5. Added logging when device override is active for both Split and Dense TBE.

Differential Revision: D101011345
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant