Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 49 additions & 11 deletions fbgemm_gpu/bench/merge_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


import logging
import os
import signal
from contextlib import nullcontext

Check failure on line 14 in fbgemm_gpu/bench/merge_embeddings_benchmark.py

View workflow job for this annotation

GitHub Actions / run-lint (3.14)

F401 'contextlib.nullcontext' imported but unused

import click
import fbgemm_gpu
Expand Down Expand Up @@ -288,6 +290,8 @@
skip_dequantization: bool = False,
num_of_embeddings: int = 10000,
pooling_factor: int = 25,
export_trace: bool = False,
trace_url: str = "merge_embeddings_fwd_trace_{ospid}.json",
) -> str:
assert torch.cuda.is_available()
torch.cuda.set_device(dst_device)
Expand Down Expand Up @@ -448,17 +452,33 @@
flush_gpu_cache_size_mb=0,
iters=iters,
)
# pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`.
with profile(activities=[ProfilerActivity.CUDA]) as prof:
pool_func_with_quantization(
batch_indices,
include_quantization,
include_tbe,
fused_tbe,
skip_dequantization,
data_type,
)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

def _kineto_trace_handler(p: profile) -> None:
url = trace_url.format(ospid=os.getpid())
p.export_chrome_trace(url)

if export_trace:
with profile(on_trace_ready=_kineto_trace_handler):
pool_func_with_quantization(
batch_indices,
include_quantization,
include_tbe,
fused_tbe,
skip_dequantization,
data_type,
)
else:
# pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`.
with profile(activities=[ProfilerActivity.CUDA]) as prof:
pool_func_with_quantization(
batch_indices,
include_quantization,
include_tbe,
fused_tbe,
skip_dequantization,
data_type,
)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

if isinstance(merged, Tensor):
# all_to_one_only returns a list of tensors,
Expand Down Expand Up @@ -519,6 +539,18 @@
default=False,
help="Use manual seed for reproduction.",
)
@click.option(
"--export-trace",
is_flag=True,
default=False,
help="Enable export of Kineto trace for profiling.",
)
@click.option(
"--trace-url",
type=str,
default="merge_embeddings_fwd_trace_{ospid}.json",
help="Trace output file path template ({ospid} is replaced with PID).",
)
def cli(
all_to_one_only: bool,
sum_reduce_to_one_only: bool,
Expand All @@ -536,6 +568,8 @@
sweep: bool,
use_pitched: bool,
manual_seed: bool,
export_trace: bool,
trace_url: str,
) -> None:
# set manual seed for reproducibility
if manual_seed:
Expand Down Expand Up @@ -584,6 +618,8 @@
skip_dequantization,
num_of_embeddings,
pooling_factor,
export_trace,
trace_url,
)
results.append(result)
except (TimeoutError, RuntimeError) as err:
Expand All @@ -609,6 +645,8 @@
skip_dequantization,
num_of_embeddings,
pooling_factor,
export_trace,
trace_url,
)
print(csv_header)
print(result)
Expand Down
Loading