Add weight_init_device support for both Split and Dense TBE kernels#5640
Open
TroyGarden wants to merge 1 commit intopytorch:mainfrom
Open
Add weight_init_device support for both Split and Dense TBE kernels#5640TroyGarden wants to merge 1 commit intopytorch:mainfrom
TroyGarden wants to merge 1 commit intopytorch:mainfrom
Conversation
Contributor
|
@TroyGarden has exported this pull request. If you are a Meta employee, you can view the originating Diff in D101011345. |
TroyGarden
added a commit
to TroyGarden/FBGEMM
that referenced
this pull request
Apr 15, 2026
Summary: X-link: facebookresearch/FBGEMM#2587 Add support for allocating dense TBE (Table Batched Embedding) weights in host-mapped UVM memory. This enables eval workflows where embedding tables reside in host DRAM (freeing HBM) while the TBE forward kernel still executes on GPU, reading embeddings over PCIe. Changes: - Add `CUDAHostRegisteredContext` in FBGEMM GPU memory utils: manages CUDA host registration lifetime for externally-owned memory (e.g. POSIX shared memory from `/dev/shm`). Unlike `CUDAHostMappedContext`, it does NOT free the underlying memory on destruction — only calls `cudaHostUnregister`. - Add `cuda_register_host_memory()` op: registers an existing CPU tensor with CUDA via `cudaHostRegister` and returns a CUDA tensor backed by the same physical memory. The returned tensor's storage holds a reference to the input tensor's storage to keep the backing memory alive. - Add `uvm_host_mapped` parameter to `DenseTableBatchedEmbeddingBagsCodegen`: when True, allocates embedding weights via `new_unified_tensor(is_host_mapped=True)` instead of regular device memory. - Wire `uvm_host_mapped` through TorchRec's `BatchedDenseEmbeddingBag` via `fused_params`. This is part of the eval workflow that uses data-parallel sharding with dense kernel on UVM, eliminating the need for input/output distribution and permutation across ranks. Differential Revision: D101011345
6f9358b to
4368607
Compare
TroyGarden
added a commit
to TroyGarden/FBGEMM
that referenced
this pull request
Apr 15, 2026
Summary: X-link: facebookresearch/FBGEMM#2587 Add support for allocating dense TBE (Table Batched Embedding) weights in host-mapped UVM memory. This enables eval workflows where embedding tables reside in host DRAM (freeing HBM) while the TBE forward kernel still executes on GPU, reading embeddings over PCIe. Changes: - Add `CUDAHostRegisteredContext` in FBGEMM GPU memory utils: manages CUDA host registration lifetime for externally-owned memory (e.g. POSIX shared memory from `/dev/shm`). Unlike `CUDAHostMappedContext`, it does NOT free the underlying memory on destruction — only calls `cudaHostUnregister`. - Add `cuda_register_host_memory()` op: registers an existing CPU tensor with CUDA via `cudaHostRegister` and returns a CUDA tensor backed by the same physical memory. The returned tensor's storage holds a reference to the input tensor's storage to keep the backing memory alive. - Add `uvm_host_mapped` parameter to `DenseTableBatchedEmbeddingBagsCodegen`: when True, allocates embedding weights via `new_unified_tensor(is_host_mapped=True)` instead of regular device memory. - Wire `uvm_host_mapped` through TorchRec's `BatchedDenseEmbeddingBag` via `fused_params`. This is part of the eval workflow that uses data-parallel sharding with dense kernel on UVM, eliminating the need for input/output distribution and permutation across ranks. Differential Revision: D101011345
4368607 to
50362d8
Compare
TroyGarden
added a commit
to TroyGarden/FBGEMM
that referenced
this pull request
Apr 15, 2026
Summary: Pull Request resolved: pytorch#5640 X-link: https://github.com/facebookresearch/FBGEMM/pull/2587 Add support for allocating dense TBE (Table Batched Embedding) weights in host-mapped UVM memory. This enables eval workflows where embedding tables reside in host DRAM (freeing HBM) while the TBE forward kernel still executes on GPU, reading embeddings over PCIe. Changes: - Add `CUDAHostRegisteredContext` in FBGEMM GPU memory utils: manages CUDA host registration lifetime for externally-owned memory (e.g. POSIX shared memory from `/dev/shm`). Unlike `CUDAHostMappedContext`, it does NOT free the underlying memory on destruction — only calls `cudaHostUnregister`. - Add `cuda_register_host_memory()` op: registers an existing CPU tensor with CUDA via `cudaHostRegister` and returns a CUDA tensor backed by the same physical memory. The returned tensor's storage holds a reference to the input tensor's storage to keep the backing memory alive. - Add `uvm_host_mapped` parameter to `DenseTableBatchedEmbeddingBagsCodegen`: when True, allocates embedding weights via `new_unified_tensor(is_host_mapped=True)` instead of regular device memory. - Wire `uvm_host_mapped` through TorchRec's `BatchedDenseEmbeddingBag` via `fused_params`. This is part of the eval workflow that uses data-parallel sharding with dense kernel on UVM, eliminating the need for input/output distribution and permutation across ranks. Differential Revision: D101011345
50362d8 to
fbff767
Compare
TroyGarden
added a commit
to TroyGarden/FBGEMM
that referenced
this pull request
Apr 15, 2026
Summary: X-link: meta-pytorch/torchrec#4117 Pull Request resolved: pytorch#5640 X-link: https://github.com/facebookresearch/FBGEMM/pull/2587 Add support for allocating dense TBE (Table Batched Embedding) weights in host-mapped UVM memory. This enables eval workflows where embedding tables reside in host DRAM (freeing HBM) while the TBE forward kernel still executes on GPU, reading embeddings over PCIe. Changes: - Add `CUDAHostRegisteredContext` in FBGEMM GPU memory utils: manages CUDA host registration lifetime for externally-owned memory (e.g. POSIX shared memory from `/dev/shm`). Unlike `CUDAHostMappedContext`, it does NOT free the underlying memory on destruction — only calls `cudaHostUnregister`. - Add `cuda_register_host_memory()` op: registers an existing CPU tensor with CUDA via `cudaHostRegister` and returns a CUDA tensor backed by the same physical memory. The returned tensor's storage holds a reference to the input tensor's storage to keep the backing memory alive. - Add `uvm_host_mapped` parameter to `DenseTableBatchedEmbeddingBagsCodegen`: when True, allocates embedding weights via `new_unified_tensor(is_host_mapped=True)` instead of regular device memory. - Wire `uvm_host_mapped` through TorchRec's `BatchedDenseEmbeddingBag` via `fused_params`. This is part of the eval workflow that uses data-parallel sharding with dense kernel on UVM, eliminating the need for input/output distribution and permutation across ranks. Differential Revision: D101011345
1857768 to
ae3b6da
Compare
TroyGarden
added a commit
to TroyGarden/FBGEMM
that referenced
this pull request
Apr 16, 2026
Summary: X-link: facebookresearch/FBGEMM#2587 Add support for allocating dense TBE (Table Batched Embedding) weights in host-mapped UVM memory. This enables eval workflows where embedding tables reside in host DRAM (freeing HBM) while the TBE forward kernel still executes on GPU, reading embeddings over PCIe. Changes: - Add `CUDAHostRegisteredContext` in FBGEMM GPU memory utils: manages CUDA host registration lifetime for externally-owned memory (e.g. POSIX shared memory from `/dev/shm`). Unlike `CUDAHostMappedContext`, it does NOT free the underlying memory on destruction — only calls `cudaHostUnregister`. - Add `cuda_register_host_memory()` op: registers an existing CPU tensor with CUDA via `cudaHostRegister` and returns a CUDA tensor backed by the same physical memory. The returned tensor's storage holds a reference to the input tensor's storage to keep the backing memory alive. - Add `uvm_host_mapped` parameter to `DenseTableBatchedEmbeddingBagsCodegen`: when True, allocates embedding weights via `new_unified_tensor(is_host_mapped=True)` instead of regular device memory. - Wire `uvm_host_mapped` through TorchRec's `BatchedDenseEmbeddingBag` via `fused_params`. This is part of the eval workflow that uses data-parallel sharding with dense kernel on UVM, eliminating the need for input/output distribution and permutation across ranks. Differential Revision: D101011345
ae3b6da to
2bcaca8
Compare
Summary:
## 1. Context
The main purpose of this diff is to allow TBE weights to be initialized on a given device (CPU/meta) that differs from the originally assigned compute device (`current_device`, which is also used for metadata like `D_offsets`, `hash_size_cumsum`, `weights_placements`, `weights_offsets`, etc.). The user is responsible for handling the actual allocation and placement of the TBE weights after initialization — such as quantization, shared memory mapping, checkpoint loading, or any other post-processing before moving weights to the compute device for forward passes.
This enables eval workflows where temporary FP32 weights cause an HBM peak during model initialization, but are only needed briefly before being quantized, loaded from shared memory, or otherwise transformed. Example use cases:
1. **Checkpoint eval with quantization**: Init weights on CPU, load FP32 checkpoint to CPU, quantize to NFP8 on CPU, move only the smaller quantized weights to GPU — FP32 never touches HBM.
2. **Shared memory eval**: Init weights on meta device (zero allocation), then replace with a UVM tensor backed by POSIX shared memory (`/dev/shm`). Multiple eval workers share the same physical memory for read-only embedding lookups without per-process HBM copies.
## 2. Approach
1. **Unified `weight_init_device` parameter**: Add `weight_init_device: Optional[torch.device]` to both `SplitTableBatchedEmbeddingBagsCodegen` and `DenseTableBatchedEmbeddingBagsCodegen`. When set, weight buffers are allocated on the specified device instead of `current_device`. Metadata tensors (`D_offsets`, `hash_size_cumsum`, `weights_offsets`, `weights_placements`) remain on `current_device`.
2. **Split TBE (`dev_weight_init_device`)**: In `apply_split_helper`, the new `dev_weight_init_device` parameter overrides the allocation device for `dev_buffer` only. Metadata tensors stay on `current_device` (GPU) so the CUDA forward kernel can read them once weights are moved back.
3. **Dense TBE**: `weight_init_device` directly controls the device for `self.weights = nn.Parameter(torch.randn(..., device=weights_device))`.
TorchRec wiring (`BatchedFusedEmbeddingBag`, `BatchedDenseEmbeddingBag`, `ShardingEnv`, `DistributedModelParallel`) is in the dependent diff D99947880.
## 3. Analysis
1. **Backward compatibility**: All new parameters default to `None`, preserving existing behavior.
2. **Device mismatch risk**: When `weight_init_device` is set, the TBE weight buffers live on a different device than `current_device`. This creates a device mismatch that will cause runtime errors if the user attempts a forward pass before moving weights to the compute device. Downstream processes (checkpoint loading, state dict operations, `TableBatchedEmbeddingSlice` view creation) will also see weights on the init device. The user is responsible for ensuring weights are on the correct device before any operation that requires them on the compute device.
3. **Caller responsibility**: `weight_init_device` is an init-only setting. The caller must handle moving weights to the compute device before forward passes — whether via quantization, shared memory registration, or explicit `.to()`.
4. **Scope for Split TBE**: `dev_weight_init_device` only affects the `weights` prefix `_apply_split` call, not optimizer state `_apply_split` calls.
5. **Metadata separation**: All metadata tensors (`D_offsets`, `hash_size_cumsum`, `weights_offsets`, `weights_placements`, etc.) remain on `current_device` regardless of `weight_init_device`. Only the weight data buffers are placed on the init device.
## 4. Changes
1. **`apply_split_helper`**: Added `dev_weight_init_device` parameter. When set, `dev_buffer` is allocated on this device instead of `current_device`. Metadata tensors (`{prefix}_offsets`, `{prefix}_placements`) are unaffected.
2. **`SplitTableBatchedEmbeddingBagsCodegen._apply_split`**: Threads `dev_weight_init_device` to `apply_split_helper`.
3. **`SplitTableBatchedEmbeddingBagsCodegen.__init__`**: Added `weight_init_device` parameter, passed to `_apply_split` for the `weights` prefix only (not optimizer states).
4. **`DenseTableBatchedEmbeddingBagsCodegen.__init__`**: Added `weight_init_device` parameter. `self.weights` is allocated on `weights_device` instead of `self.current_device`.
5. Added logging when device override is active for both Split and Dense TBE.
Differential Revision: D101011345
2bcaca8 to
9bed063
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
1. Context
The main purpose of this diff is to allow TBE weights to be initialized on a given device (CPU/meta) that differs from the originally assigned compute device (
current_device, which is also used for metadata likeD_offsets,hash_size_cumsum,weights_placements,weights_offsets, etc.). The user is responsible for handling the actual allocation and placement of the TBE weights after initialization — such as quantization, shared memory mapping, checkpoint loading, or any other post-processing before moving weights to the compute device for forward passes.This enables eval workflows where temporary FP32 weights cause an HBM peak during model initialization, but are only needed briefly before being quantized, loaded from shared memory, or otherwise transformed. Example use cases:
/dev/shm). Multiple eval workers share the same physical memory for read-only embedding lookups without per-process HBM copies.2. Approach
Unified
weight_init_deviceparameter: Addweight_init_device: Optional[torch.device]to bothSplitTableBatchedEmbeddingBagsCodegenandDenseTableBatchedEmbeddingBagsCodegen. When set, weight buffers are allocated on the specified device instead ofcurrent_device. Metadata tensors (D_offsets,hash_size_cumsum,weights_offsets,weights_placements) remain oncurrent_device.Split TBE (
dev_weight_init_device): Inapply_split_helper, the newdev_weight_init_deviceparameter overrides the allocation device fordev_bufferonly. Metadata tensors stay oncurrent_device(GPU) so the CUDA forward kernel can read them once weights are moved back.Dense TBE:
weight_init_devicedirectly controls the device forself.weights = nn.Parameter(torch.randn(..., device=weights_device)).TorchRec wiring (
BatchedFusedEmbeddingBag,BatchedDenseEmbeddingBag,ShardingEnv,DistributedModelParallel) is in the dependent diff D99947880.3. Analysis
None, preserving existing behavior.weight_init_deviceis set, the TBE weight buffers live on a different device thancurrent_device. This creates a device mismatch that will cause runtime errors if the user attempts a forward pass before moving weights to the compute device. Downstream processes (checkpoint loading, state dict operations,TableBatchedEmbeddingSliceview creation) will also see weights on the init device. The user is responsible for ensuring weights are on the correct device before any operation that requires them on the compute device.weight_init_deviceis an init-only setting. The caller must handle moving weights to the compute device before forward passes — whether via quantization, shared memory registration, or explicit.to().dev_weight_init_deviceonly affects theweightsprefix_apply_splitcall, not optimizer state_apply_splitcalls.D_offsets,hash_size_cumsum,weights_offsets,weights_placements, etc.) remain oncurrent_deviceregardless ofweight_init_device. Only the weight data buffers are placed on the init device.4. Changes
apply_split_helper: Addeddev_weight_init_deviceparameter. When set,dev_bufferis allocated on this device instead ofcurrent_device. Metadata tensors ({prefix}_offsets,{prefix}_placements) are unaffected.SplitTableBatchedEmbeddingBagsCodegen._apply_split: Threadsdev_weight_init_devicetoapply_split_helper.SplitTableBatchedEmbeddingBagsCodegen.__init__: Addedweight_init_deviceparameter, passed to_apply_splitfor theweightsprefix only (not optimizer states).DenseTableBatchedEmbeddingBagsCodegen.__init__: Addedweight_init_deviceparameter.self.weightsis allocated onweights_deviceinstead ofself.current_device.Differential Revision: D101011345