Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,17 @@ def apply_split_helper(
uvm_tensors_log: Optional[list[str]] = None,
uvm_host_mapped: bool = False,
make_persistent: bool = False,
dev_weight_init_device: Optional[torch.device] = None,
) -> None:
dev_device = (
dev_weight_init_device if dev_weight_init_device is not None else current_device
)
if dev_weight_init_device is not None and dev_weight_init_device != current_device:
logging.info(
f"[FBGEMM TBE] Allocating {prefix}_dev buffer on {dev_device} "
f"instead of {current_device} (dev_size={split.dev_size})"
)

set_attr_fn(f"{prefix}_physical_placements", split.placements)
set_attr_fn(f"{prefix}_physical_offsets", split.offsets)

Expand All @@ -308,7 +318,7 @@ def apply_split_helper(
if split.dev_size > 0:
dev_buffer = torch.zeros(
split.dev_size,
device=current_device,
device=dev_device,
# pyre-fixme[6]
dtype=dtype,
)
Expand Down Expand Up @@ -655,6 +665,13 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
True, defaults to RESParams().

is_qr_tbe (bool = False): Whether this is a QRSplitTableBatchedEmbeddingBagsCodegen.

weight_init_device (Optional[torch.device] = None): When set, allocate
the dev weight buffer on this device instead of `current_device`.
Metadata tensors (`weights_offsets`, `weights_placements`) remain on
`current_device`. The caller is responsible for moving weights to the
compute device before any forward pass. Only affects the `weights`
prefix `_apply_split` call, not optimizer states.
"""

embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]]
Expand Down Expand Up @@ -730,6 +747,7 @@ def __init__( # noqa C901
enable_raw_embedding_streaming: bool = False,
res_params: Optional[RESParams] = None,
is_qr_tbe: bool = False,
weight_init_device: Optional[torch.device] = None,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
self.uuid = str(uuid.uuid4())
Expand Down Expand Up @@ -1059,6 +1077,7 @@ def __init__( # noqa C901
make_dev_param=optimizer == OptimType.NONE,
dev_reshape=(-1, self.max_D) if optimizer == OptimType.NONE else None,
uvm_host_mapped=self.uvm_host_mapped,
dev_weight_init_device=weight_init_device,
)

assert optimizer not in (
Expand Down Expand Up @@ -3634,6 +3653,7 @@ def _apply_split(
make_dev_param: bool = False,
dev_reshape: Optional[tuple[int, ...]] = None,
uvm_host_mapped: bool = False,
dev_weight_init_device: Optional[torch.device] = None,
) -> None:
apply_split_helper(
self.register_buffer,
Expand All @@ -3651,6 +3671,7 @@ def _apply_split(
uvm_host_mapped=uvm_host_mapped,
# Only force persistent for Split TBE on MTIA, see D97971757 for details.
make_persistent=(self.use_mtia and not self.is_qr_tbe),
dev_weight_init_device=dev_weight_init_device,
)

def _apply_cache_state(
Expand Down Expand Up @@ -4638,7 +4659,15 @@ def __init__(
use_cpu: bool = False,
output_dtype: SparseType = SparseType.FP32,
use_mtia: bool = False,
) -> None: # noqa C901 # tuple of (rows, dims,)
weight_init_device: Optional[torch.device] = None,
) -> None: # noqa C901
"""
Args:
weight_init_device: When set, allocate the embedding weight tensor
on this device instead of the compute device. The caller is
responsible for moving weights to the compute device before
any forward pass.
"""
super(DenseTableBatchedEmbeddingBagsCodegen, self).__init__()
self.uuid = str(uuid.uuid4())

Expand Down Expand Up @@ -4720,10 +4749,20 @@ def __init__(
weights_offsets = [0] + list(
accumulate([row * dim for (row, dim) in embedding_specs])
)
weights_device = (
weight_init_device
if weight_init_device is not None
else self.current_device
)
if weight_init_device is not None and weight_init_device != self.current_device:
logging.info(
f"[FBGEMM DenseTBE] Allocating weights on {weights_device} "
f"instead of {self.current_device} (size={weights_offsets[-1]})"
)
self.weights = nn.Parameter(
torch.randn(
weights_offsets[-1],
device=self.current_device,
device=weights_device,
dtype=table_embedding_dtype,
)
)
Expand Down
167 changes: 167 additions & 0 deletions fbgemm_gpu/test/tbe/training/weight_init_device_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

# pyre-ignore-all-errors[56]

import unittest

import torch
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
EmbeddingLocation,
PoolingMode,
)
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
ComputeDevice,
DenseTableBatchedEmbeddingBagsCodegen,
SplitTableBatchedEmbeddingBagsCodegen,
)

from .. import common # noqa E402
from ..common import open_source

if open_source:
# pyre-ignore[21]
from test_utils import gpu_unavailable
else:
from fbgemm_gpu.test.test_utils import gpu_unavailable


class WeightInitDeviceTest(unittest.TestCase):
@unittest.skipIf(*gpu_unavailable)
def test_split_tbe_weight_init_device_cpu(self) -> None:
T = 3
E = 100
D = 32
cc = SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(E, D, EmbeddingLocation.DEVICE, ComputeDevice.CUDA) for _ in range(T)
],
weights_precision=SparseType.FP32,
optimizer=OptimType.NONE,
pooling_mode=PoolingMode.SUM,
weight_init_device=torch.device("cpu"),
)
# weights_dev should be on CPU
# pyre-ignore[29]: weights_dev is a Tensor
self.assertEqual(cc.weights_dev.device.type, "cpu")
# pyre-ignore[29]: weights_dev is a Tensor
self.assertEqual(cc.weights_dev.numel(), T * E * D)
# metadata should remain on CUDA
self.assertEqual(cc.weights_offsets.device.type, "cuda")
self.assertEqual(cc.weights_placements.device.type, "cuda")

@unittest.skipIf(*gpu_unavailable)
def test_split_tbe_weight_init_device_none(self) -> None:
T = 2
E = 64
D = 16
cc = SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(E, D, EmbeddingLocation.DEVICE, ComputeDevice.CUDA) for _ in range(T)
],
weights_precision=SparseType.FP32,
optimizer=OptimType.NONE,
pooling_mode=PoolingMode.SUM,
weight_init_device=None,
)
# default: weights_dev should be on CUDA
self.assertEqual(cc.weights_dev.device.type, "cuda")
self.assertEqual(cc.weights_offsets.device.type, "cuda")

@unittest.skipIf(*gpu_unavailable)
def test_split_tbe_weight_init_device_does_not_affect_optimizer_states(
self,
) -> None:
T = 2
E = 64
D = 16
cc = SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(E, D, EmbeddingLocation.DEVICE, ComputeDevice.CUDA) for _ in range(T)
],
weights_precision=SparseType.FP32,
optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
learning_rate=0.01,
pooling_mode=PoolingMode.SUM,
weight_init_device=torch.device("cpu"),
)
# weights_dev on CPU
self.assertEqual(cc.weights_dev.device.type, "cpu")
# optimizer state (momentum1) should remain on CUDA
self.assertEqual(cc.momentum1_dev.device.type, "cuda")

@unittest.skipIf(*gpu_unavailable)
def test_dense_tbe_weight_init_device_cpu(self) -> None:
E = 100
D = 32
T = 3
cc = DenseTableBatchedEmbeddingBagsCodegen(
embedding_specs=[(E, D) for _ in range(T)],
pooling_mode=PoolingMode.SUM,
weight_init_device=torch.device("cpu"),
)
# weights should be on CPU
self.assertEqual(cc.weights.device.type, "cpu")
self.assertEqual(cc.weights.numel(), T * E * D)
# metadata should remain on CUDA
self.assertEqual(cc.D_offsets.device.type, "cuda")
self.assertEqual(cc.hash_size_cumsum.device.type, "cuda")

@unittest.skipIf(*gpu_unavailable)
def test_dense_tbe_weight_init_device_none(self) -> None:
E = 64
D = 16
T = 2
cc = DenseTableBatchedEmbeddingBagsCodegen(
embedding_specs=[(E, D) for _ in range(T)],
pooling_mode=PoolingMode.SUM,
weight_init_device=None,
)
# default: weights on CUDA
self.assertEqual(cc.weights.device.type, "cuda")

@unittest.skipIf(*gpu_unavailable)
def test_split_tbe_weight_init_device_cpu_move_to_cuda(self) -> None:
T = 2
E = 64
D = 16
cc = SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(E, D, EmbeddingLocation.DEVICE, ComputeDevice.CUDA) for _ in range(T)
],
weights_precision=SparseType.FP32,
optimizer=OptimType.NONE,
pooling_mode=PoolingMode.SUM,
weight_init_device=torch.device("cpu"),
)
self.assertEqual(cc.weights_dev.device.type, "cpu")
# Simulate user moving weights to CUDA after init
# pyre-ignore[16, 6]: weights_dev is a Tensor
cc.weights_dev = torch.nn.Parameter(cc.weights_dev.data.cuda())
self.assertEqual(cc.weights_dev.device.type, "cuda")

@unittest.skipIf(*gpu_unavailable)
def test_dense_tbe_weight_init_device_cpu_move_to_cuda(self) -> None:
E = 64
D = 16
T = 2
cc = DenseTableBatchedEmbeddingBagsCodegen(
embedding_specs=[(E, D) for _ in range(T)],
pooling_mode=PoolingMode.SUM,
weight_init_device=torch.device("cpu"),
)
self.assertEqual(cc.weights.device.type, "cpu")
# Simulate user moving weights to CUDA after init
cc.weights = torch.nn.Parameter(cc.weights.cuda())
self.assertEqual(cc.weights.device.type, "cuda")


if __name__ == "__main__":
unittest.main()
Loading