Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions thunder/benchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
import os
import statistics
import time
import warnings
from typing import Any
from collections.abc import Callable
from looseversion import LooseVersion

import torch
Expand Down Expand Up @@ -52,6 +49,7 @@
from thunder.torch.custom_op import _register_custom_op, _register_nvfuser_translator

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any


Expand Down Expand Up @@ -217,6 +215,7 @@ class InferenceBenchmarkConfig:
num_layers: int | None
num_iterations: int
warmup_iterations: int
use_torchao_nvfp4: bool
enable_nvfp4: bool # Enable NVFP4 registration and quantize GroupedSwiGLU in MoE
fx_report_folder: str | None
enable_nv_linear: bool
Expand Down Expand Up @@ -345,7 +344,21 @@ def __init__(self, config: InferenceBenchmarkConfig):
self.vocab_size = model.vocab_size

if self.config.enable_nvfp4:
_quantize_llama4(model)
if not self.config.use_torchao_nvfp4:
_quantize_llama4(model)
else:
from torchao.quantization import quantize_
from torchao.prototype.mx_formats.inference_workflow import (
NVFP4InferenceConfig,
NVFP4MMConfig,
)

config = NVFP4InferenceConfig(
mm_config=NVFP4MMConfig.DYNAMIC,
use_triton_kernel=True,
use_dynamic_per_tensor_scale=True,
)
quantize_(model, config=config)
self.model = self._compile_model(model)

@property
Expand Down Expand Up @@ -737,6 +750,7 @@ def parse_args() -> argparse.Namespace:
help="Specify the folder for thunderfx_benchmark_report.",
)

parser.add_argument("--use-torchao-nvfp4", action="store_true", help="Use TorchAO for NVFP4")
parser.add_argument(
"--enable-nvfp4",
action="store_true",
Expand Down Expand Up @@ -796,6 +810,7 @@ def main():
warmup_iterations=args.warmup_iterations,
mode=args.mode,
enable_nvfp4=args.enable_nvfp4,
use_torchao_nvfp4=args.use_torchao_nvfp4,
fx_report_folder=args.fx_report_folder,
enable_nv_linear=args.enable_nv_linear,
disable_moe_replacement=args.disable_moe_replacement,
Expand Down
Loading