Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions thunder/benchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ class InferenceMetrics:
prefill_time_ms: float = 0.0
decode_time_ms: float = 0.0

# Compilation metrics
num_recompilations: int = 0

# Per-iteration metrics for variance analysis
iteration_times: list[float] = field(default_factory=list)
ttft_times: list[float] = field(default_factory=list)
Expand Down Expand Up @@ -572,6 +575,8 @@ def run_benchmark(self) -> InferenceMetrics:

self._calculate_aggregate_metrics(all_metrics)

self.metrics.num_recompilations = self._get_recompilation_count()

if torch.cuda.is_available():
self.metrics.memory_used_gb = torch.cuda.memory_allocated() / 1e9
self.metrics.peak_memory_gb = torch.cuda.max_memory_allocated() / 1e9
Expand All @@ -582,6 +587,20 @@ def run_benchmark(self) -> InferenceMetrics:

return self.metrics

def _get_recompilation_count(self) -> int:
"""Count Thunder cache misses (recompiles of the same jitted module), excluding the initial compile and Dynamo graph rebuilds."""
if self.config.mode == "thunder":
backend = self.model._backend
total_misses = 0
for subgraph_info in backend.subgraph_infos:
for thunder_fn in subgraph_info.thunder_compiled_fns:
total_misses += thunder.cache_misses(thunder_fn) - 1
return total_misses
elif self.config.mode == "thunderjit":
return thunder.cache_misses(self.model) - 1
else:
return 0

def _calculate_aggregate_metrics(self, all_metrics: list[dict[str, Any]]):
"""Calculate aggregate metrics from individual iterations"""
# Average throughput
Expand Down Expand Up @@ -638,6 +657,13 @@ def print_results(self):
print(f" Current Memory: {self.metrics.memory_used_gb:.2f} GB")
print(f" Peak Memory: {self.metrics.peak_memory_gb:.2f} GB")

if self.config.mode in ("thunder", "thunderjit"):
print("\nCompilation Metrics:")
print(
" Number of Thunder module recompilations excluding initial compile: "
f"{self.metrics.num_recompilations}"
)

if len(self.metrics.iteration_times) > 1:
print("\nVariance Analysis:")
print(f" Throughput Std Dev: {statistics.stdev(self.metrics.iteration_times):.2f} ms")
Expand All @@ -659,6 +685,7 @@ def save_results(self, filename: str):
"total_time_ms": self.metrics.total_time_ms,
"memory_used_gb": self.metrics.memory_used_gb,
"peak_memory_gb": self.metrics.peak_memory_gb,
"num_recompilations": self.metrics.num_recompilations,
},
"detailed_metrics": {
"iteration_times": self.metrics.iteration_times,
Expand Down
26 changes: 25 additions & 1 deletion thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def __init__(
fp8_shard_intermediate_activation: bool = False,
use_sdpa: bool = False,
use_hf: bool = False,
thunder_cache: str | None = None,
):
seed = 1337
torch.manual_seed(seed)
Expand Down Expand Up @@ -350,6 +351,7 @@ def __init__(

self.use_sdpa = use_sdpa
self.use_hf = use_hf
self.thunder_cache = thunder_cache

if self.use_sdpa and sdpa_available and self.compile not in ["eager", "inductor"]:
warnings.warn(
Expand Down Expand Up @@ -692,7 +694,7 @@ def setup_compile(self, model):
transforms.insert(0, TransformerEngineTransform())

if "jit" in self.compile:
model = thunder.jit(model, executors=executors, transforms=transforms, **jit_options)
model = thunder.jit(model, executors=executors, transforms=transforms, cache=self.thunder_cache)

else:
if self.distributed_mode == "fsdp2":
Expand Down Expand Up @@ -862,6 +864,7 @@ def train(self):
self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iters)
self.perf_metrics["saved_for_backward_tensor_size_mib"] = saved_tensors_size_in_mib
self.perf_metrics["saved_for_backward_number_of_tensors"] = saved_tensors_len
self.perf_metrics["num_recompilations"] = self.compute_num_recompilations()

def add_perf_metrics(self):
if self.throughput:
Expand Down Expand Up @@ -893,6 +896,22 @@ def add_model_info_to_metrics(self):
self.perf_metrics["Sharding Size"] = None
self.perf_metrics["compiler"] = self.compile

def compute_num_recompilations(self) -> int:
import thunder

if "thunder" not in self.compile:
return 0

if "jit" in self.compile:
return thunder.cache_misses(self.model) - 1

# Compiled by ThunderFX
total_misses = 0
for info in self.backend.subgraph_infos:
for thunder_fn in info.thunder_compiled_fns:
total_misses += thunder.cache_misses(thunder_fn) - 1
return total_misses


class DummyDataset(IterableDataset):
def __init__(self, max_seq_length: int, dynamic: bool):
Expand Down Expand Up @@ -979,6 +998,11 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
print(
f"Saved for backward number of tensors: {benchmark.perf_metrics['saved_for_backward_number_of_tensors']}"
)
if "thunder" in benchmark.compile:
print(
"Thunder module recompilations excluding initial compile: "
f"{benchmark.perf_metrics.get('num_recompilations', 0)}"
)

tokens_per_sec = benchmark.perf_metrics.get("tokens_per_sec")
if tokens_per_sec:
Expand Down
34 changes: 30 additions & 4 deletions thunder/benchmarks/benchmark_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def setup_fsdp2(model: torch.nn.Module) -> torch.nn.Module:
return model


def setup_compilation(model, backend: str):
def setup_compilation(model, backend: str, thunder_cache: str | None = None):
# TODO from thunder.executors.transformer_engineex import transformer_engine_ex
"""Apply compilation settings to the model."""
if backend in ("thunder", "inductor"):
Expand Down Expand Up @@ -262,14 +262,15 @@ def setup_compilation(model, backend: str):

if "jit" in backend:
logger.info("Using thunder.jit")
model = thunder.jit(model, transforms=xforms, executors=executors)
model = thunder.jit(model, transforms=xforms, executors=executors, cache=thunder_cache)
else:
logger.info("Using ThunderFX")
from thunder.dynamo import thunderfx

# TODO get parameters out from thunderfx CompiledObject
compiled_object = thunderfx(model, transforms=xforms, executors=executors)
compiled_object = thunderfx(model, transforms=xforms, executors=executors, cache=thunder_cache)
model = compiled_object._func
model._thunder_backend = compiled_object._backend

return model

Expand Down Expand Up @@ -302,6 +303,12 @@ def parse_args():
type=str.lower,
choices=["eager", "inductor", "thunder", "thunder+jit"],
)
parser.add_argument(
"--thunder-cache",
type=str,
default=None,
help="Cache option: no caching, same input, constant values, symbolic values. See `cache` argument of `thunder.jit` for more details.",
)
parser.add_argument("--verbose", action="store_true", help="Enable verbose output including model wrapping details")
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument(
Expand Down Expand Up @@ -468,7 +475,7 @@ def main(args: argparse.Namespace):
# Apply compilation if needed
if args.compile != "eager":
logger.info(f"Applying compilation: {args.compile} to model")
model = setup_compilation(model, args.compile)
model = setup_compilation(model, args.compile, thunder_cache=args.thunder_cache)
logger.info("Compilation applied to model")

# Verify only LoRA parameters are trainable
Expand Down Expand Up @@ -583,6 +590,21 @@ def main(args: argparse.Namespace):

# Print training summary
total_time = time.time() - start_ts

# Compute Thunder recompilations (cache misses minus initial compile) if applicable
num_recompilations = 0
if "thunder" in args.compile:
import thunder

if "jit" in args.compile:
num_recompilations = thunder.cache_misses(model) - 1
else:
total_misses = 0
for subgraph_info in model._thunder_backend.subgraph_infos:
for thunder_fn in subgraph_info.thunder_compiled_fns:
total_misses += thunder.cache_misses(thunder_fn) - 1
num_recompilations = total_misses

print_training_summary(
args,
total_time,
Expand All @@ -593,6 +615,7 @@ def main(args: argparse.Namespace):
batches_processed,
total_tokens_processed,
WORLD_SIZE,
num_recompilations,
)

# Clean up distributed environment if needed
Expand All @@ -610,6 +633,7 @@ def print_training_summary(
batches_processed: int,
total_tokens_processed: int,
WORLD_SIZE: int,
num_recompilations: int,
) -> None:
"""Print a comprehensive summary of the training run.

Expand Down Expand Up @@ -650,6 +674,8 @@ def print_training_summary(
logger.info(f"Maximum allocated memory: {max_allocated_memory / 1024**3:.2f} GB")
logger.info(f"Total tokens processed: {total_tokens:,}")
logger.info(f"Total iterations: {args.max_steps}")
if "thunder" in args.compile:
logger.info(f"Thunder module recompilations excluding initial compile: {num_recompilations}")

# Verify batch processing across all ranks
if WORLD_SIZE > 1:
Expand Down
Loading