diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index 4639c5df4e..083b8dbcc0 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -156,7 +156,6 @@ class InferenceBenchmarkConfig: mode: str disable_moe_replacement: bool attn_implementation: str | None - profile: bool thunder_cache: str | None enable_thunder_cudagraph: bool @@ -493,10 +492,17 @@ def run_benchmark(self) -> InferenceMetrics: for _ in tqdm(range(self.config.num_iterations), disable=LOCAL_RANK != 0): past_key_values.reset() - if self.config.profile: + is_under_nsys = bool(os.environ.get("NSYS_PROFILING_SESSION_ID")) + # Wrap each non-warmup iteration with cudaProfilerStart() and + # cudaProfilerStop(). This allows the user to run + # ```shell + # nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat: ... + # ``` + # to record only the non-warmup iterations. + if is_under_nsys: torch.cuda.cudart().cudaProfilerStart() iter_metrics = self.measure_inference_step(input_ids, past_key_values, self.config.output_length) - if self.config.profile: + if is_under_nsys: torch.cuda.cudart().cudaProfilerStop() all_metrics.append(iter_metrics) @@ -680,11 +686,6 @@ def parse_args() -> argparse.Namespace: action="store_true", help="let nvfuser take care of linear and matmul, note that this might fail with distributed run. See: https://github.com/NVIDIA/Fuser/issues/4507", ) - parser.add_argument( - "--profile", - action="store_true", - help="Wrap each non-warmup iteration with cudaProfilerStart() and cudaProfilerStop(). This allows us to run `nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat: ... --profile` to record only the non-warmup iterations.", - ) parser.add_argument("--save-results", action="store_true", help="Save results to JSON file") parser.add_argument("--output-dir", type=str, default="./results", help="Directory to save results") @@ -728,7 +729,6 @@ def main(): enable_nv_linear=args.enable_nv_linear, disable_moe_replacement=args.disable_moe_replacement, attn_implementation=args.attn_implementation, - profile=args.profile, thunder_cache=args.thunder_cache, enable_thunder_cudagraph=args.enable_thunder_cudagraph, )