diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index c6644d30e5..9a2e9b2be0 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -562,6 +562,23 @@ def _export_kineto_trace( default=False, help="Use manual seed for reproduction.", ) +@click.option( + "--device", + type=click.Choice(["cpu", "cuda"]), + default="cuda", + help="Device to run the benchmark on. Default is cuda.", +) +@click.option( + "--export-trace", + is_flag=True, + default=False, + help="Enable export of trace for profiling.", +) +@click.option( + "--trace-url", + type=str, + default="batched_dense_vec_jagged_2d_mul_trace_{ospid}.json", +) def batched_dense_vec_jagged_2d_mul( batch_size: int, h_dim: int, @@ -569,7 +586,13 @@ def batched_dense_vec_jagged_2d_mul( max_len: int, elem_type: str, manual_seed: bool, + device: str, + export_trace: bool, + trace_url: str, ) -> None: + if device == "cuda" and not torch.cuda.is_available(): + raise click.UsageError("CUDA requested but not available.") + # set manual seed for reproducibility if manual_seed: torch.manual_seed(42) @@ -586,7 +609,7 @@ def batched_dense_vec_jagged_2d_mul( # pyre-fixme[6]: For 1st param expected `int` but got `Union[bool, float, int]`. values_2d = torch.rand(total_lengths, h_dim * embedding_dim, dtype=dtype) dense = torch.rand(batch_size * h_dim, max_len, dtype=dtype) - if torch.cuda.is_available(): + if device == "cuda": offsets = offsets.cuda() values_2d = values_2d.cuda() dense = dense.cuda() @@ -605,6 +628,41 @@ def batched_dense_vec_jagged_2d_mul( f"batched_dense_vec_jagged_2d_mul {time} sec {num_flops / time / 1e9} GFLOP/s" ) + if export_trace: + is_cuda = device != "cpu" + + # pyre-fixme[53]: Captured variable `dense` is not annotated. + # pyre-fixme[53]: Captured variable `values_2d` is not annotated. + # pyre-fixme[53]: Captured variable `offsets` is not annotated. + def fn() -> torch.Tensor: + return torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul( + dense, values_2d, offsets + ) + + for _ in range(100): + fn() + if is_cuda: + torch.cuda.synchronize(device) + # pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`. + activities = [torch.profiler.ProfilerActivity.CPU] + if is_cuda: + # pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`. + activities.append(torch.profiler.ProfilerActivity.CUDA) + num_active = 5 + with profile( + activities=activities, + schedule=schedule(wait=0, warmup=0, active=num_active, repeat=1), + record_shapes=True, + on_trace_ready=lambda p: p.export_chrome_trace( + trace_url.format(ospid=os.getpid()) + ), + ) as prof: + for _ in range(num_active): + fn() + if is_cuda: + torch.cuda.synchronize(device) + prof.step() + @cli.command() @click.option("--batch-size", type=int, default=1024) @@ -615,12 +673,37 @@ def batched_dense_vec_jagged_2d_mul( default=False, help="Use manual seed for reproduction.", ) +@click.option( + "--device", + type=click.Choice(["cpu"]), + default="cpu", + help="Device to run the benchmark on. CPU-only (no CUDA kernel exists).", +) +@click.option( + "--export-trace", + is_flag=True, + default=False, + help="Enable export of trace for profiling.", +) +@click.option( + "--trace-url", + type=str, + default="jagged_1d_to_truncated_values_trace_{ospid}.json", +) def jagged_1d_to_truncated_values( batch_size: int, max_len: int, dtype: str, manual_seed: bool, + device: str, + export_trace: bool, + trace_url: str, ) -> None: + if max_len <= 0: + raise click.UsageError("max_len must be positive.") + if batch_size <= 0: + raise click.UsageError("batch_size must be positive.") + # set manual seed for reproducibility if manual_seed: torch.manual_seed(42) @@ -664,6 +747,32 @@ def ref(values: torch.Tensor, lengths: torch.Tensor, max_len: int) -> torch.Tens logging.info(f"reference {time_ref} sec {bytes / time_ref / 1e9} GB/s") logging.info(f"truncate_jagged_1d {time} sec {bytes / time / 1e9} GB/s") + if export_trace: + + # pyre-fixme[53]: Captured variable `values` is not annotated. + # pyre-fixme[53]: Captured variable `lengths` is not annotated. + def fn() -> torch.Tensor: + return torch.ops.fbgemm.jagged_1d_to_truncated_values( + values, lengths, max_len + ) + + for _ in range(100): + fn() + # pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`. + activities = [torch.profiler.ProfilerActivity.CPU] + num_active = 5 + with profile( + activities=activities, + schedule=schedule(wait=0, warmup=0, active=num_active, repeat=1), + record_shapes=True, + on_trace_ready=lambda p: p.export_chrome_trace( + trace_url.format(ospid=os.getpid()) + ), + ) as prof: + for _ in range(num_active): + fn() + prof.step() + @cli.command() @click.option("--batch-size", type=int, default=1024)