Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 19 additions & 27 deletions fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@
def _generate_batch_sizes(
tbe_data_config: TBEDataConfig,
) -> tuple[list[int], list[list[int]] | None]:
logging.info(
f"DEBUG_TBE: [_generate_batch_sizes] VBE tbe_data_config.variable_B()={tbe_data_config.variable_B()}"
)
logging.debug(f"[_generate_batch_sizes] variable_B={tbe_data_config.variable_B()}")
if tbe_data_config.variable_B():
assert (
tbe_data_config.batch_params.vbe_num_ranks is not None
Expand Down Expand Up @@ -121,7 +119,7 @@ def _generate_indices(
start_offset = L_offsets_list[it * total_B]
end_offset = L_offsets_list[(it + 1) * total_B]

logging.info(f"DEBUG_TBE: _generate_indices E={tbe_data_config.E}")
logging.debug(f"_generate_indices E={tbe_data_config.E}")

indices_list.append(
torch.ops.fbgemm.tbe_generate_indices_from_distribution(
Expand Down Expand Up @@ -215,15 +213,17 @@ def generate_requests(
# Generate batch sizes
if batch_size_per_feature_per_rank:
Bs = tbe_data_config.batch_params.Bs
Bs_feature_rank = batch_size_per_feature_per_rank
else:
Bs, _ = _generate_batch_sizes(tbe_data_config)

logging.info(
f"DEBUG_TBE: VBE [generate_requests] batch_size_per_feature_per_rank={batch_size_per_feature_per_rank} Bs={Bs}"
)
Bs, Bs_feature_rank = _generate_batch_sizes(tbe_data_config)

assert Bs is not None, "Batch sizes (Bs) must be set"

logging.debug(
f"[generate_requests] Bs_len={len(Bs)}"
f" Bs_feature_rank={Bs_feature_rank is not None}"
)

# Generate pooling info
L_offsets = _generate_pooling_info(tbe_data_config, iters, Bs)

Expand All @@ -233,24 +233,14 @@ def generate_requests(

# Build TBE requests
if tbe_data_config.variable_B() or tbe_data_config.variable_L():
if batch_size_per_feature_per_rank:
return _build_requests_jagged(
tbe_data_config,
iters,
Bs,
batch_size_per_feature_per_rank,
L_offsets,
all_indices,
)
else:
return _build_requests_jagged(
tbe_data_config,
iters,
Bs,
batch_size_per_feature_per_rank,
L_offsets,
all_indices,
)
return _build_requests_jagged(
tbe_data_config,
iters,
Bs,
Bs_feature_rank,
L_offsets,
all_indices,
)
else:
return _build_requests_dense(tbe_data_config, iters, all_indices)

Expand Down Expand Up @@ -309,6 +299,8 @@ def generate_requests_with_Llist(
all_indices = all_indices.to(get_device())

# Build TBE requests
# batch_size_per_feature_per_rank is passed through directly
# from the caller (no _generate_batch_sizes call here).
if tbe_data_config.variable_B() or tbe_data_config.variable_L():
return _build_requests_jagged(
tbe_data_config,
Expand Down
Loading