Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,30 @@

namespace fbgemm_gpu {

// Validate that every element in total_num_blocks is divisible by my_size.
inline void check_total_num_blocks_divisibility(
const at::Tensor& total_num_blocks,
int64_t my_size) {
const auto tnb = total_num_blocks.cpu();
AT_DISPATCH_INDEX_TYPES(
tnb.scalar_type(),
"block_bucketize_sparse_features_total_num_blocks_check",
[&] {
const auto* tnb_data = tnb.const_data_ptr<index_t>();
for (const auto t : c10::irange(tnb.numel())) {
TORCH_CHECK(
tnb_data[t] % my_size == 0,
"block_bucketize_sparse_features: total_num_blocks[",
t,
"] = ",
tnb_data[t],
" must be a multiple of my_size (",
my_size,
")");
}
});
}

/// @defgroup sparse-data-cuda Sparse Data CUDA Operators
/// The following are CUDA operators
///
Expand Down
5 changes: 5 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,11 @@ _block_bucketize_sparse_features_cuda(

CUDA_DEVICE_GUARD(lengths);

if (total_num_blocks.has_value()) {
fbgemm_gpu::check_total_num_blocks_divisibility(
total_num_blocks.value(), my_size);
}

// allocate tensors and buffers
const auto lengths_size = lengths.numel();
const auto T = block_sizes.numel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,11 @@ _block_bucketize_sparse_features_2d_weights_cuda(

CUDA_DEVICE_GUARD(lengths);

if (total_num_blocks.has_value()) {
fbgemm_gpu::check_total_num_blocks_divisibility(
total_num_blocks.value(), my_size);
}

// allocate tensors and buffers
const auto lengths_size = lengths.numel();
const auto T = block_sizes.numel();
Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,11 @@ _block_bucketize_sparse_features_cpu(
const bool return_bucket_mapping,
const bool keep_orig_idx,
const std::optional<Tensor>& keep_orig_idx_per_feature = std::nullopt) {
if (total_num_blocks.has_value()) {
fbgemm_gpu::check_total_num_blocks_divisibility(
total_num_blocks.value(), my_size);
}

const auto lengths_size = lengths.numel();
const auto new_lengths_size = lengths_size * my_size;
auto new_lengths = at::zeros({new_lengths_size}, lengths.options());
Expand Down Expand Up @@ -1619,6 +1624,11 @@ _block_bucketize_sparse_features_2d_weights_cpu(
const bool return_bucket_mapping,
const bool keep_orig_idx,
const std::optional<Tensor>& keep_orig_idx_per_feature = std::nullopt) {
if (total_num_blocks.has_value()) {
fbgemm_gpu::check_total_num_blocks_divisibility(
total_num_blocks.value(), my_size);
}

const auto lengths_size = lengths.numel();
const auto new_lengths_size = lengths_size * my_size;
auto new_lengths = at::zeros({new_lengths_size}, lengths.options());
Expand Down
16 changes: 16 additions & 0 deletions fbgemm_gpu/test/sparse/block_bucketize_2d_weights_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,22 @@ def test_block_bucketize_sparse_features_2d_weights_with_variable_batch_sizes(
new_indices_ref, new_indices_gpu.cpu(), new_lengths_ref
)

def test_block_bucketize_sparse_features_2d_weights_total_num_blocks_not_divisible(
self,
) -> None:
indices = torch.tensor([1, 2, 10, 4, 16, 6, 7, 18, 19, 10, 0], dtype=torch.int)
with self.assertRaisesRegex(RuntimeError, "must be a multiple of my_size"):
torch.ops.fbgemm.block_bucketize_sparse_features_2d_weights(
torch.tensor([0, 3, 2, 0, 1, 5], dtype=torch.int),
indices,
False,
False,
torch.tensor([2, 3, 4], dtype=torch.int),
3,
torch.rand(indices.numel(), 3),
total_num_blocks=torch.tensor([7, 6, 6], dtype=torch.int),
)


extend_test_class(BlockBucketize2DWeightsTest)

Expand Down
14 changes: 14 additions & 0 deletions fbgemm_gpu/test/sparse/block_bucketize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2111,6 +2111,20 @@ def test_block_bucketize_sparse_features_float64_weights(
new_pos_ref, new_pos_gpu.cpu(), new_lengths_ref
)

def test_block_bucketize_sparse_features_total_num_blocks_not_divisible(
self,
) -> None:
with self.assertRaisesRegex(RuntimeError, "must be a multiple of my_size"):
torch.ops.fbgemm.block_bucketize_sparse_features(
torch.tensor([0, 3, 2, 0, 1, 5], dtype=torch.int),
torch.tensor([1, 2, 10, 4, 16, 6, 7, 18, 19, 10, 0], dtype=torch.int),
False,
False,
torch.tensor([2, 3, 4], dtype=torch.int),
3,
total_num_blocks=torch.tensor([7, 6, 6], dtype=torch.int),
)


extend_test_class(BlockBucketizeTest)

Expand Down
Loading