diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index 0ae7c5c527..56a27640b1 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -246,6 +246,9 @@ class InferenceMetrics: prefill_time_ms: float = 0.0 decode_time_ms: float = 0.0 + # Compilation metrics + num_recompilations: int = 0 + # Per-iteration metrics for variance analysis iteration_times: list[float] = field(default_factory=list) ttft_times: list[float] = field(default_factory=list) @@ -572,6 +575,8 @@ def run_benchmark(self) -> InferenceMetrics: self._calculate_aggregate_metrics(all_metrics) + self.metrics.num_recompilations = self._get_recompilation_count() + if torch.cuda.is_available(): self.metrics.memory_used_gb = torch.cuda.memory_allocated() / 1e9 self.metrics.peak_memory_gb = torch.cuda.max_memory_allocated() / 1e9 @@ -582,6 +587,20 @@ def run_benchmark(self) -> InferenceMetrics: return self.metrics + def _get_recompilation_count(self) -> int: + """Count Thunder cache misses (recompiles of the same jitted module), excluding the initial compile and Dynamo graph rebuilds.""" + if self.config.mode == "thunder": + backend = self.model._backend + total_misses = 0 + for subgraph_info in backend.subgraph_infos: + for thunder_fn in subgraph_info.thunder_compiled_fns: + total_misses += thunder.cache_misses(thunder_fn) - 1 + return total_misses + elif self.config.mode == "thunderjit": + return thunder.cache_misses(self.model) - 1 + else: + return 0 + def _calculate_aggregate_metrics(self, all_metrics: list[dict[str, Any]]): """Calculate aggregate metrics from individual iterations""" # Average throughput @@ -638,6 +657,13 @@ def print_results(self): print(f" Current Memory: {self.metrics.memory_used_gb:.2f} GB") print(f" Peak Memory: {self.metrics.peak_memory_gb:.2f} GB") + if self.config.mode in ("thunder", "thunderjit"): + print("\nCompilation Metrics:") + print( + " Number of Thunder module recompilations excluding initial compile: " + f"{self.metrics.num_recompilations}" + ) + if len(self.metrics.iteration_times) > 1: print("\nVariance Analysis:") print(f" Throughput Std Dev: {statistics.stdev(self.metrics.iteration_times):.2f} ms") @@ -659,6 +685,7 @@ def save_results(self, filename: str): "total_time_ms": self.metrics.total_time_ms, "memory_used_gb": self.metrics.memory_used_gb, "peak_memory_gb": self.metrics.peak_memory_gb, + "num_recompilations": self.metrics.num_recompilations, }, "detailed_metrics": { "iteration_times": self.metrics.iteration_times, diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 4416d7dbab..3cb13ea79b 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -315,6 +315,7 @@ def __init__( fp8_shard_intermediate_activation: bool = False, use_sdpa: bool = False, use_hf: bool = False, + thunder_cache: str | None = None, ): seed = 1337 torch.manual_seed(seed) @@ -350,6 +351,7 @@ def __init__( self.use_sdpa = use_sdpa self.use_hf = use_hf + self.thunder_cache = thunder_cache if self.use_sdpa and sdpa_available and self.compile not in ["eager", "inductor"]: warnings.warn( @@ -692,7 +694,7 @@ def setup_compile(self, model): transforms.insert(0, TransformerEngineTransform()) if "jit" in self.compile: - model = thunder.jit(model, executors=executors, transforms=transforms, **jit_options) + model = thunder.jit(model, executors=executors, transforms=transforms, cache=self.thunder_cache) else: if self.distributed_mode == "fsdp2": @@ -862,6 +864,7 @@ def train(self): self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iters) self.perf_metrics["saved_for_backward_tensor_size_mib"] = saved_tensors_size_in_mib self.perf_metrics["saved_for_backward_number_of_tensors"] = saved_tensors_len + self.perf_metrics["num_recompilations"] = self.compute_num_recompilations() def add_perf_metrics(self): if self.throughput: @@ -893,6 +896,22 @@ def add_model_info_to_metrics(self): self.perf_metrics["Sharding Size"] = None self.perf_metrics["compiler"] = self.compile + def compute_num_recompilations(self) -> int: + import thunder + + if "thunder" not in self.compile: + return 0 + + if "jit" in self.compile: + return thunder.cache_misses(self.model) - 1 + + # Compiled by ThunderFX + total_misses = 0 + for info in self.backend.subgraph_infos: + for thunder_fn in info.thunder_compiled_fns: + total_misses += thunder.cache_misses(thunder_fn) - 1 + return total_misses + class DummyDataset(IterableDataset): def __init__(self, max_seq_length: int, dynamic: bool): @@ -979,6 +998,11 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None print( f"Saved for backward number of tensors: {benchmark.perf_metrics['saved_for_backward_number_of_tensors']}" ) + if "thunder" in benchmark.compile: + print( + "Thunder module recompilations excluding initial compile: " + f"{benchmark.perf_metrics.get('num_recompilations', 0)}" + ) tokens_per_sec = benchmark.perf_metrics.get("tokens_per_sec") if tokens_per_sec: diff --git a/thunder/benchmarks/benchmark_peft.py b/thunder/benchmarks/benchmark_peft.py index 1c486e09a0..5f9bf5c321 100644 --- a/thunder/benchmarks/benchmark_peft.py +++ b/thunder/benchmarks/benchmark_peft.py @@ -232,7 +232,7 @@ def setup_fsdp2(model: torch.nn.Module) -> torch.nn.Module: return model -def setup_compilation(model, backend: str): +def setup_compilation(model, backend: str, thunder_cache: str | None = None): # TODO from thunder.executors.transformer_engineex import transformer_engine_ex """Apply compilation settings to the model.""" if backend in ("thunder", "inductor"): @@ -262,14 +262,15 @@ def setup_compilation(model, backend: str): if "jit" in backend: logger.info("Using thunder.jit") - model = thunder.jit(model, transforms=xforms, executors=executors) + model = thunder.jit(model, transforms=xforms, executors=executors, cache=thunder_cache) else: logger.info("Using ThunderFX") from thunder.dynamo import thunderfx # TODO get parameters out from thunderfx CompiledObject - compiled_object = thunderfx(model, transforms=xforms, executors=executors) + compiled_object = thunderfx(model, transforms=xforms, executors=executors, cache=thunder_cache) model = compiled_object._func + model._thunder_backend = compiled_object._backend return model @@ -302,6 +303,12 @@ def parse_args(): type=str.lower, choices=["eager", "inductor", "thunder", "thunder+jit"], ) + parser.add_argument( + "--thunder-cache", + type=str, + default=None, + help="Cache option: no caching, same input, constant values, symbolic values. See `cache` argument of `thunder.jit` for more details.", + ) parser.add_argument("--verbose", action="store_true", help="Enable verbose output including model wrapping details") parser.add_argument("--trust-remote-code", action="store_true") parser.add_argument( @@ -468,7 +475,7 @@ def main(args: argparse.Namespace): # Apply compilation if needed if args.compile != "eager": logger.info(f"Applying compilation: {args.compile} to model") - model = setup_compilation(model, args.compile) + model = setup_compilation(model, args.compile, thunder_cache=args.thunder_cache) logger.info("Compilation applied to model") # Verify only LoRA parameters are trainable @@ -583,6 +590,21 @@ def main(args: argparse.Namespace): # Print training summary total_time = time.time() - start_ts + + # Compute Thunder recompilations (cache misses minus initial compile) if applicable + num_recompilations = 0 + if "thunder" in args.compile: + import thunder + + if "jit" in args.compile: + num_recompilations = thunder.cache_misses(model) - 1 + else: + total_misses = 0 + for subgraph_info in model._thunder_backend.subgraph_infos: + for thunder_fn in subgraph_info.thunder_compiled_fns: + total_misses += thunder.cache_misses(thunder_fn) - 1 + num_recompilations = total_misses + print_training_summary( args, total_time, @@ -593,6 +615,7 @@ def main(args: argparse.Namespace): batches_processed, total_tokens_processed, WORLD_SIZE, + num_recompilations, ) # Clean up distributed environment if needed @@ -610,6 +633,7 @@ def print_training_summary( batches_processed: int, total_tokens_processed: int, WORLD_SIZE: int, + num_recompilations: int, ) -> None: """Print a comprehensive summary of the training run. @@ -650,6 +674,8 @@ def print_training_summary( logger.info(f"Maximum allocated memory: {max_allocated_memory / 1024**3:.2f} GB") logger.info(f"Total tokens processed: {total_tokens:,}") logger.info(f"Total iterations: {args.max_steps}") + if "thunder" in args.compile: + logger.info(f"Thunder module recompilations excluding initial compile: {num_recompilations}") # Verify batch processing across all ranks if WORLD_SIZE > 1: