From 8e5c9a59a4abc5e992a0b8e3c1b15caf6238b462 Mon Sep 17 00:00:00 2001 From: Joey Yang Date: Fri, 17 Apr 2026 23:45:10 -0700 Subject: [PATCH] Support multi-dimensional runtime_meta in RES streaming buffers by lazy init (#5643) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5643 X-link: https://github.com/facebookresearch/FBGEMM/pull/2591 The `res_runtime_meta` buffer in `SplitTableBatchedEmbeddingBagsCodegen` was hardcoded to shape `(cache_size, 1)`. When `_hash_zch_runtime_meta` has dim > 1 (e.g., feature cache storing 2 cached features via `zch_custom_runtime_meta_dim=2`), the `.copy_()` in `raw_embedding_stream()` crashes with: `RuntimeError: output with shape [N, 1] doesn't match the broadcast shape [N, 2]` Full output P2274437578 This diff lazy resizes the buffer defaults to `(cache_size, 1, torch.long)` and auto-corrects on the first iteration when runtime_meta data arrives with a different shape or dtype. This is a one-time operation, after the first resize, dims match and no further reallocation occurs. Reviewed By: chouxi Differential Revision: D100944325 --- ...t_table_batched_embeddings_ops_training.py | 40 +++- .../raw_embedding_streamer.cpp | 2 +- .../training/store_prefetched_tensors_test.py | 222 ++++++++++++++++++ 3 files changed, 258 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 626495532f..bb1f7c175c 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1544,8 +1544,8 @@ def _register_res_buffers(self) -> None: self.enable_raw_embedding_streaming ), "Should not register res buffers when raw embedding streaming is not enabled" cache_size = self.lxu_cache_weights.size(0) + self.log(f"[RES] registering buffers: cache_size={cache_size}") if cache_size == 0: - self.log("Registering empty res buffers when there is no cache") self._register_empty_res_buffers() return self.register_buffer( @@ -1596,6 +1596,7 @@ def _register_res_buffers(self) -> None: (cache_size, 1), is_host_mapped=self.uvm_host_mapped, ), + persistent=False, # shape may change via lazy resize, exclude from checkpoints ) self.register_buffer( "res_count", @@ -1645,6 +1646,7 @@ def _register_empty_res_buffers(self) -> None: self.register_buffer( "res_runtime_meta", torch.zeros(0, 1, device=self.current_device, dtype=torch.long), + persistent=False, # shape may change via lazy resize, exclude from checkpoints ) self.register_buffer( "res_count", @@ -4404,10 +4406,38 @@ def raw_embedding_stream(self) -> None: : prefetched_info.hash_zch_identities.size(0) ].copy_(prefetched_info.hash_zch_identities) if prefetched_info.hash_zch_runtime_meta is not None: - # pyre-ignore[29]: `Union[...]` is not a function. - self.res_runtime_meta[ - : prefetched_info.hash_zch_runtime_meta.size(0) - ].copy_(prefetched_info.hash_zch_runtime_meta) + runtime_meta = prefetched_info.hash_zch_runtime_meta + if runtime_meta.dim() != 2: + self.log( + f"[RES] unexpected runtime_meta rank: {runtime_meta.dim()}, expected 2, skipping" + ) + else: + if ( + runtime_meta.shape[1] != self.res_runtime_meta.shape[1] + or runtime_meta.dtype != self.res_runtime_meta.dtype + ): + self.log( + f"[RES] lazy resize runtime_meta: {self.res_runtime_meta.shape} -> ({self.res_runtime_meta.shape[0]}, {runtime_meta.shape[1]}), dtype {self.res_runtime_meta.dtype} -> {runtime_meta.dtype}" + ) + # Lazy resize: runtime_meta shape/dtype is not known until + # the first data arrives from the MC module. Must use UVM + # (new_unified_tensor) because the C++ RawEmbeddingStreamer + # reads this buffer via raw CPU pointers in tensor_copy(). + self.register_buffer( + "res_runtime_meta", + torch.ops.fbgemm.new_unified_tensor( + torch.zeros( + 1, + device=self.current_device, + dtype=runtime_meta.dtype, + ), + (self.res_runtime_meta.shape[0], runtime_meta.shape[1]), + is_host_mapped=self.uvm_host_mapped, + ), + persistent=False, # shape may change via lazy resize, exclude from checkpoints + ) + # pyre-ignore[29]: `Union[...]` is not a function. + self.res_runtime_meta[: runtime_meta.size(0)].copy_(runtime_meta) self.res_copy_done.fill_(1) diff --git a/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp b/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp index b7328b1298..2674a1db2e 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp @@ -122,7 +122,7 @@ fbgemm_gpu::StreamQueueItem tensor_copy( }); } if (runtime_meta.has_value()) { - FBGEMM_DISPATCH_INTEGRAL_TYPES( + FBGEMM_DISPATCH_ALL_TYPES( runtime_meta->scalar_type(), "tensor_copy", [&] { using runtime_meta_t = scalar_t; auto runtime_meta_addr = diff --git a/fbgemm_gpu/test/tbe/training/store_prefetched_tensors_test.py b/fbgemm_gpu/test/tbe/training/store_prefetched_tensors_test.py index 9013edfc1f..170f987e3f 100644 --- a/fbgemm_gpu/test/tbe/training/store_prefetched_tensors_test.py +++ b/fbgemm_gpu/test/tbe/training/store_prefetched_tensors_test.py @@ -8,6 +8,7 @@ # pyre-strict import unittest +from unittest.mock import patch import torch from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( @@ -15,6 +16,7 @@ EmbeddingLocation, ) from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + RESParams, SplitTableBatchedEmbeddingBagsCodegen, ) @@ -627,6 +629,226 @@ def test_get_prefetched_info_with_neither(self) -> None: self.assertIsNone(prefetched_info.hash_zch_identities) self.assertIsNone(prefetched_info.hash_zch_runtime_meta) + @unittest.skipIf(*gpu_unavailable) + def test_register_res_buffers_default_dim(self) -> None: + """ + Test that RES buffers are registered with default dim=1. + """ + res_params = RESParams( + res_store_shards=1, + table_names=["table_0"], + table_offsets=[0, 100], + table_sizes=[100], + ) + with patch( + "fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer" + ): + tbe = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + (100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA), + ], + enable_raw_embedding_streaming=True, + res_params=res_params, + ) + cache_size = tbe.lxu_cache_weights.size(0) + self.assertGreater(cache_size, 0) + self.assertEqual(tbe.res_runtime_meta.shape, (cache_size, 1)) + + @unittest.skipIf(*gpu_unavailable) + def test_register_empty_res_buffers_default_dim(self) -> None: + """ + Test that empty RES buffers have dim=1 when streaming is disabled. + """ + tbe = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + (100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA), + ], + enable_raw_embedding_streaming=False, + ) + self.assertEqual(tbe.res_runtime_meta.shape[1], 1) + + @unittest.skipIf(*gpu_unavailable) + def test_lazy_resize_runtime_meta(self) -> None: + """ + Test that lazy resize in raw_embedding_stream() resizes res_runtime_meta + buffer when actual data has a different dim or dtype than the default. + """ + res_params = RESParams( + res_store_shards=1, + table_names=["table_0"], + table_offsets=[0, 100], + table_sizes=[100], + ) + with patch( + "fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer" + ): + tbe = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + (100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA), + ], + enable_raw_embedding_streaming=True, + res_params=res_params, + ) + cache_size = tbe.lxu_cache_weights.size(0) + # Initially dim=1 + self.assertEqual(tbe.res_runtime_meta.shape, (cache_size, 1)) + + # Simulate runtime_meta with dim=2 arriving via prefetch + n = 4 + runtime_meta_data = torch.tensor( + [[1, 10], [2, 20], [3, 30], [4, 40]], + device=torch.cuda.current_device(), + dtype=torch.int64, + ) + + # Manually trigger the resize logic + data = runtime_meta_data + if ( + data.shape[1] != tbe.res_runtime_meta.shape[1] + or data.dtype != tbe.res_runtime_meta.dtype + ): + tbe.register_buffer( + "res_runtime_meta", + torch.ops.fbgemm.new_unified_tensor( + torch.zeros(1, device=tbe.current_device, dtype=data.dtype), + (tbe.res_runtime_meta.shape[0], data.shape[1]), + is_host_mapped=tbe.uvm_host_mapped, + ), + persistent=False, + ) + + # After resize, dim should be 2 + self.assertEqual(tbe.res_runtime_meta.shape, (cache_size, 2)) + # Copy should succeed + tbe.res_runtime_meta[:n].copy_(runtime_meta_data) + self.assertEqual( + runtime_meta_data.tolist(), + tbe.res_runtime_meta[:n].tolist(), + ) + + @unittest.skipIf(*gpu_unavailable) + def test_res_runtime_meta_not_in_state_dict(self) -> None: + """ + Test that res_runtime_meta is registered with persistent=False and + does not appear in state_dict() (shape changes with runtime_meta_dim). + """ + res_params = RESParams( + res_store_shards=1, + table_names=["table_0"], + table_offsets=[0, 100], + table_sizes=[100], + ) + with patch( + "fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer" + ): + tbe = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + (100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA), + ], + enable_raw_embedding_streaming=True, + res_params=res_params, + ) + state_dict = tbe.state_dict() + self.assertNotIn( + "res_runtime_meta", + state_dict, + "res_runtime_meta should not be in state_dict", + ) + + @unittest.skipIf(*gpu_unavailable) + def test_prefetched_info_with_multi_dim_runtime_meta(self) -> None: + """ + Test that _get_prefetched_info preserves multi-dimensional runtime_meta. + When runtime_meta has shape [N, 2], output should also have dim=2. + """ + hash_zch_runtime_meta = torch.tensor( + [ + [1, 10], + [2, 20], + [3, 30], + [4, 40], + ], + device=torch.cuda.current_device(), + dtype=torch.int64, + ) + total_cache_hash_size = 100 + linear_cache_indices_merged = torch.tensor( + [54, 27, 43, 90], + device=torch.cuda.current_device(), + dtype=torch.int64, + ) + + prefetched_info = SplitTableBatchedEmbeddingBagsCodegen._get_prefetched_info( + linear_indices=linear_cache_indices_merged, + linear_cache_indices_merged=linear_cache_indices_merged, + total_cache_hash_size=total_cache_hash_size, + hash_zch_identities=None, + hash_zch_runtime_meta=hash_zch_runtime_meta, + max_indices_length=200, + ) + + assert prefetched_info.hash_zch_runtime_meta is not None + self.assertEqual(prefetched_info.hash_zch_runtime_meta.shape[1], 2) + self.assertEqual(prefetched_info.hash_zch_runtime_meta.shape[0], 4) + # Verify sorted order (by cache index: 27, 43, 54, 90) + self.assertEqual( + [ + [2, 20], # runtime meta for index 27 + [3, 30], # runtime meta for index 43 + [1, 10], # runtime meta for index 54 + [4, 40], # runtime meta for index 90 + ], + prefetched_info.hash_zch_runtime_meta.tolist(), + ) + + @unittest.skipIf(*gpu_unavailable) + def test_copy_runtime_meta_none_skipped(self) -> None: + """ + Test that when hash_zch_runtime_meta is None in prefetched_info, + the copy to res_runtime_meta is skipped without crashing. + """ + res_params = RESParams( + res_store_shards=1, + table_names=["table_0"], + table_offsets=[0, 100], + table_sizes=[100], + ) + with patch( + "fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer" + ): + tbe = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + (100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA), + ], + enable_raw_embedding_streaming=True, + res_params=res_params, + ) + + # Store a prefetched_info with runtime_meta=None + indices = torch.tensor( + [1, 2, 3], device=torch.cuda.current_device(), dtype=torch.int64 + ) + offsets = torch.tensor( + [0, 3], device=torch.cuda.current_device(), dtype=torch.int64 + ) + linear_cache_indices_merged = torch.tensor( + [1, 2, 3], device=torch.cuda.current_device(), dtype=torch.int64 + ) + + # This should not crash even though runtime_meta is None + tbe._store_prefetched_tensors( + indices=indices, + offsets=offsets, + vbe_metadata=None, + linear_cache_indices_merged=linear_cache_indices_merged, + final_lxu_cache_locations=torch.ones_like(linear_cache_indices_merged), + hash_zch_identities=None, + hash_zch_runtime_meta=None, + ) + + self.assertEqual(len(tbe.prefetched_info_list), 1) + self.assertIsNone(tbe.prefetched_info_list[0].hash_zch_runtime_meta) + if __name__ == "__main__": unittest.main()