Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1544,8 +1544,8 @@ def _register_res_buffers(self) -> None:
self.enable_raw_embedding_streaming
), "Should not register res buffers when raw embedding streaming is not enabled"
cache_size = self.lxu_cache_weights.size(0)
self.log(f"[RES] registering buffers: cache_size={cache_size}")
if cache_size == 0:
self.log("Registering empty res buffers when there is no cache")
self._register_empty_res_buffers()
return
self.register_buffer(
Expand Down Expand Up @@ -1596,6 +1596,7 @@ def _register_res_buffers(self) -> None:
(cache_size, 1),
is_host_mapped=self.uvm_host_mapped,
),
persistent=False, # shape may change via lazy resize, exclude from checkpoints
)
self.register_buffer(
"res_count",
Expand Down Expand Up @@ -1645,6 +1646,7 @@ def _register_empty_res_buffers(self) -> None:
self.register_buffer(
"res_runtime_meta",
torch.zeros(0, 1, device=self.current_device, dtype=torch.long),
persistent=False, # shape may change via lazy resize, exclude from checkpoints
)
self.register_buffer(
"res_count",
Expand Down Expand Up @@ -4404,10 +4406,38 @@ def raw_embedding_stream(self) -> None:
: prefetched_info.hash_zch_identities.size(0)
].copy_(prefetched_info.hash_zch_identities)
if prefetched_info.hash_zch_runtime_meta is not None:
# pyre-ignore[29]: `Union[...]` is not a function.
self.res_runtime_meta[
: prefetched_info.hash_zch_runtime_meta.size(0)
].copy_(prefetched_info.hash_zch_runtime_meta)
runtime_meta = prefetched_info.hash_zch_runtime_meta
if runtime_meta.dim() != 2:
self.log(
f"[RES] unexpected runtime_meta rank: {runtime_meta.dim()}, expected 2, skipping"
)
else:
if (
runtime_meta.shape[1] != self.res_runtime_meta.shape[1]
or runtime_meta.dtype != self.res_runtime_meta.dtype
):
self.log(
f"[RES] lazy resize runtime_meta: {self.res_runtime_meta.shape} -> ({self.res_runtime_meta.shape[0]}, {runtime_meta.shape[1]}), dtype {self.res_runtime_meta.dtype} -> {runtime_meta.dtype}"
)
# Lazy resize: runtime_meta shape/dtype is not known until
# the first data arrives from the MC module. Must use UVM
# (new_unified_tensor) because the C++ RawEmbeddingStreamer
# reads this buffer via raw CPU pointers in tensor_copy().
self.register_buffer(
"res_runtime_meta",
torch.ops.fbgemm.new_unified_tensor(
torch.zeros(
1,
device=self.current_device,
dtype=runtime_meta.dtype,
),
(self.res_runtime_meta.shape[0], runtime_meta.shape[1]),
is_host_mapped=self.uvm_host_mapped,
),
persistent=False, # shape may change via lazy resize, exclude from checkpoints
)
# pyre-ignore[29]: `Union[...]` is not a function.
self.res_runtime_meta[: runtime_meta.size(0)].copy_(runtime_meta)

self.res_copy_done.fill_(1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ fbgemm_gpu::StreamQueueItem tensor_copy(
});
}
if (runtime_meta.has_value()) {
FBGEMM_DISPATCH_INTEGRAL_TYPES(
FBGEMM_DISPATCH_ALL_TYPES(
runtime_meta->scalar_type(), "tensor_copy", [&] {
using runtime_meta_t = scalar_t;
auto runtime_meta_addr =
Expand Down
222 changes: 222 additions & 0 deletions fbgemm_gpu/test/tbe/training/store_prefetched_tensors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
# pyre-strict

import unittest
from unittest.mock import patch

import torch
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
ComputeDevice,
EmbeddingLocation,
)
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
RESParams,
SplitTableBatchedEmbeddingBagsCodegen,
)

Expand Down Expand Up @@ -627,6 +629,226 @@ def test_get_prefetched_info_with_neither(self) -> None:
self.assertIsNone(prefetched_info.hash_zch_identities)
self.assertIsNone(prefetched_info.hash_zch_runtime_meta)

@unittest.skipIf(*gpu_unavailable)
def test_register_res_buffers_default_dim(self) -> None:
"""
Test that RES buffers are registered with default dim=1.
"""
res_params = RESParams(
res_store_shards=1,
table_names=["table_0"],
table_offsets=[0, 100],
table_sizes=[100],
)
with patch(
"fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer"
):
tbe = SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA),
],
enable_raw_embedding_streaming=True,
res_params=res_params,
)
cache_size = tbe.lxu_cache_weights.size(0)
self.assertGreater(cache_size, 0)
self.assertEqual(tbe.res_runtime_meta.shape, (cache_size, 1))

@unittest.skipIf(*gpu_unavailable)
def test_register_empty_res_buffers_default_dim(self) -> None:
"""
Test that empty RES buffers have dim=1 when streaming is disabled.
"""
tbe = SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA),
],
enable_raw_embedding_streaming=False,
)
self.assertEqual(tbe.res_runtime_meta.shape[1], 1)

@unittest.skipIf(*gpu_unavailable)
def test_lazy_resize_runtime_meta(self) -> None:
"""
Test that lazy resize in raw_embedding_stream() resizes res_runtime_meta
buffer when actual data has a different dim or dtype than the default.
"""
res_params = RESParams(
res_store_shards=1,
table_names=["table_0"],
table_offsets=[0, 100],
table_sizes=[100],
)
with patch(
"fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer"
):
tbe = SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA),
],
enable_raw_embedding_streaming=True,
res_params=res_params,
)
cache_size = tbe.lxu_cache_weights.size(0)
# Initially dim=1
self.assertEqual(tbe.res_runtime_meta.shape, (cache_size, 1))

# Simulate runtime_meta with dim=2 arriving via prefetch
n = 4
runtime_meta_data = torch.tensor(
[[1, 10], [2, 20], [3, 30], [4, 40]],
device=torch.cuda.current_device(),
dtype=torch.int64,
)

# Manually trigger the resize logic
data = runtime_meta_data
if (
data.shape[1] != tbe.res_runtime_meta.shape[1]
or data.dtype != tbe.res_runtime_meta.dtype
):
tbe.register_buffer(
"res_runtime_meta",
torch.ops.fbgemm.new_unified_tensor(
torch.zeros(1, device=tbe.current_device, dtype=data.dtype),
(tbe.res_runtime_meta.shape[0], data.shape[1]),
is_host_mapped=tbe.uvm_host_mapped,
),
persistent=False,
)

# After resize, dim should be 2
self.assertEqual(tbe.res_runtime_meta.shape, (cache_size, 2))
# Copy should succeed
tbe.res_runtime_meta[:n].copy_(runtime_meta_data)
self.assertEqual(
runtime_meta_data.tolist(),
tbe.res_runtime_meta[:n].tolist(),
)

@unittest.skipIf(*gpu_unavailable)
def test_res_runtime_meta_not_in_state_dict(self) -> None:
"""
Test that res_runtime_meta is registered with persistent=False and
does not appear in state_dict() (shape changes with runtime_meta_dim).
"""
res_params = RESParams(
res_store_shards=1,
table_names=["table_0"],
table_offsets=[0, 100],
table_sizes=[100],
)
with patch(
"fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer"
):
tbe = SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA),
],
enable_raw_embedding_streaming=True,
res_params=res_params,
)
state_dict = tbe.state_dict()
self.assertNotIn(
"res_runtime_meta",
state_dict,
"res_runtime_meta should not be in state_dict",
)

@unittest.skipIf(*gpu_unavailable)
def test_prefetched_info_with_multi_dim_runtime_meta(self) -> None:
"""
Test that _get_prefetched_info preserves multi-dimensional runtime_meta.
When runtime_meta has shape [N, 2], output should also have dim=2.
"""
hash_zch_runtime_meta = torch.tensor(
[
[1, 10],
[2, 20],
[3, 30],
[4, 40],
],
device=torch.cuda.current_device(),
dtype=torch.int64,
)
total_cache_hash_size = 100
linear_cache_indices_merged = torch.tensor(
[54, 27, 43, 90],
device=torch.cuda.current_device(),
dtype=torch.int64,
)

prefetched_info = SplitTableBatchedEmbeddingBagsCodegen._get_prefetched_info(
linear_indices=linear_cache_indices_merged,
linear_cache_indices_merged=linear_cache_indices_merged,
total_cache_hash_size=total_cache_hash_size,
hash_zch_identities=None,
hash_zch_runtime_meta=hash_zch_runtime_meta,
max_indices_length=200,
)

assert prefetched_info.hash_zch_runtime_meta is not None
self.assertEqual(prefetched_info.hash_zch_runtime_meta.shape[1], 2)
self.assertEqual(prefetched_info.hash_zch_runtime_meta.shape[0], 4)
# Verify sorted order (by cache index: 27, 43, 54, 90)
self.assertEqual(
[
[2, 20], # runtime meta for index 27
[3, 30], # runtime meta for index 43
[1, 10], # runtime meta for index 54
[4, 40], # runtime meta for index 90
],
prefetched_info.hash_zch_runtime_meta.tolist(),
)

@unittest.skipIf(*gpu_unavailable)
def test_copy_runtime_meta_none_skipped(self) -> None:
"""
Test that when hash_zch_runtime_meta is None in prefetched_info,
the copy to res_runtime_meta is skipped without crashing.
"""
res_params = RESParams(
res_store_shards=1,
table_names=["table_0"],
table_offsets=[0, 100],
table_sizes=[100],
)
with patch(
"fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer"
):
tbe = SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA),
],
enable_raw_embedding_streaming=True,
res_params=res_params,
)

# Store a prefetched_info with runtime_meta=None
indices = torch.tensor(
[1, 2, 3], device=torch.cuda.current_device(), dtype=torch.int64
)
offsets = torch.tensor(
[0, 3], device=torch.cuda.current_device(), dtype=torch.int64
)
linear_cache_indices_merged = torch.tensor(
[1, 2, 3], device=torch.cuda.current_device(), dtype=torch.int64
)

# This should not crash even though runtime_meta is None
tbe._store_prefetched_tensors(
indices=indices,
offsets=offsets,
vbe_metadata=None,
linear_cache_indices_merged=linear_cache_indices_merged,
final_lxu_cache_locations=torch.ones_like(linear_cache_indices_merged),
hash_zch_identities=None,
hash_zch_runtime_meta=None,
)

self.assertEqual(len(tbe.prefetched_info_list), 1)
self.assertIsNone(tbe.prefetched_info_list[0].hash_zch_runtime_meta)


if __name__ == "__main__":
unittest.main()
Loading