diff --git a/fbgemm_gpu/bench/merge_embeddings_benchmark.py b/fbgemm_gpu/bench/merge_embeddings_benchmark.py index 887b2a955a..b23799773b 100644 --- a/fbgemm_gpu/bench/merge_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/merge_embeddings_benchmark.py @@ -9,7 +9,9 @@ import logging +import os import signal +from contextlib import nullcontext import click import fbgemm_gpu @@ -288,6 +290,8 @@ def benchmark( # noqa C901 skip_dequantization: bool = False, num_of_embeddings: int = 10000, pooling_factor: int = 25, + export_trace: bool = False, + trace_url: str = "merge_embeddings_fwd_trace_{ospid}.json", ) -> str: assert torch.cuda.is_available() torch.cuda.set_device(dst_device) @@ -448,17 +452,33 @@ def pool_func_with_quantization( flush_gpu_cache_size_mb=0, iters=iters, ) - # pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`. - with profile(activities=[ProfilerActivity.CUDA]) as prof: - pool_func_with_quantization( - batch_indices, - include_quantization, - include_tbe, - fused_tbe, - skip_dequantization, - data_type, - ) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + def _kineto_trace_handler(p: profile) -> None: + url = trace_url.format(ospid=os.getpid()) + p.export_chrome_trace(url) + + if export_trace: + with profile(on_trace_ready=_kineto_trace_handler): + pool_func_with_quantization( + batch_indices, + include_quantization, + include_tbe, + fused_tbe, + skip_dequantization, + data_type, + ) + else: + # pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`. + with profile(activities=[ProfilerActivity.CUDA]) as prof: + pool_func_with_quantization( + batch_indices, + include_quantization, + include_tbe, + fused_tbe, + skip_dequantization, + data_type, + ) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) if isinstance(merged, Tensor): # all_to_one_only returns a list of tensors, @@ -519,6 +539,18 @@ def pool_func_with_quantization( default=False, help="Use manual seed for reproduction.", ) +@click.option( + "--export-trace", + is_flag=True, + default=False, + help="Enable export of Kineto trace for profiling.", +) +@click.option( + "--trace-url", + type=str, + default="merge_embeddings_fwd_trace_{ospid}.json", + help="Trace output file path template ({ospid} is replaced with PID).", +) def cli( all_to_one_only: bool, sum_reduce_to_one_only: bool, @@ -536,6 +568,8 @@ def cli( sweep: bool, use_pitched: bool, manual_seed: bool, + export_trace: bool, + trace_url: str, ) -> None: # set manual seed for reproducibility if manual_seed: @@ -584,6 +618,8 @@ def handler(signum, frame): skip_dequantization, num_of_embeddings, pooling_factor, + export_trace, + trace_url, ) results.append(result) except (TimeoutError, RuntimeError) as err: @@ -609,6 +645,8 @@ def handler(signum, frame): skip_dequantization, num_of_embeddings, pooling_factor, + export_trace, + trace_url, ) print(csv_header) print(result)