Skip to content
100 changes: 96 additions & 4 deletions olive/passes/onnx/discrepancy_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import time
from typing import Optional

from olive.data.config import DataConfig
from olive.hardware import AcceleratorSpec
from olive.hardware.accelerator import Device
from olive.model import ONNXModelHandler
from olive.passes import Pass
from olive.passes.pass_config import BasePassConfig, PassConfigParam
Expand Down Expand Up @@ -61,6 +63,7 @@ class OnnxDiscrepancyCheck(Pass):
- Maximum absolute error (MaxAE)
- Number of elements where the absolute difference exceeds 0.1
- Number of elements where the absolute difference exceeds 0.01
- Inference speedup of ONNX over PyTorch on the target device (or CPU fallback)
- Longest common token sequence from the beginning between transformers
generate and ONNX Runtime GenAI generate (when enabled)

Expand Down Expand Up @@ -107,6 +110,16 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
"computes the longest common token sequence from the beginning of their outputs."
),
),
"warmup_iterations": PassConfigParam(
type_=int,
default_value=3,
description="Number of warmup iterations before timing inference for speedup measurement.",
),
"timing_iterations": PassConfigParam(
Comment thread
xadupre marked this conversation as resolved.
type_=int,
default_value=10,
description="Number of timed iterations to measure inference speedup (ONNX vs PyTorch).",
),
"generate_prompt": PassConfigParam(
type_=str,
default_value="The capital of France is",
Expand Down Expand Up @@ -177,8 +190,22 @@ def _run_for_config(
ref_model = AutoModelForCausalLM.from_pretrained(config.reference_model_path)
ref_model.eval()

# Prepare ONNX session
session = model.prepare_session()
# Prepare ONNX session on the target device (fallback to CPU)
device = self.accelerator_spec.accelerator_type if self.accelerator_spec else None
execution_provider = self.accelerator_spec.execution_provider if self.accelerator_spec else None
if device is None:
device = Device.CPU
Comment thread
xadupre marked this conversation as resolved.

# Determine the torch device matching the accelerator spec
torch_device = torch.device("cpu")
if device == Device.GPU and torch.cuda.is_available():
torch_device = torch.device("cuda")
ref_model = ref_model.to(torch_device)

session = model.prepare_session(
device=device,
execution_providers=[execution_provider] if execution_provider else None,
)
io_config = model.io_config

# Run inference on both and compare
Expand All @@ -194,9 +221,9 @@ def _run_for_config(

# Run PyTorch inference
if isinstance(input_data, dict):
torch_inputs = {k: v.clone() for k, v in input_data.items()}
torch_inputs = {k: v.clone().to(torch_device) for k, v in input_data.items()}
else:
torch_inputs = input_data
torch_inputs = input_data.to(torch_device)

torch_output = ref_model(**torch_inputs)
torch_logits = torch_output.logits.detach().cpu().numpy()
Expand Down Expand Up @@ -225,6 +252,11 @@ def _run_for_config(
total_elements,
)

# Measure inference speedup (ONNX vs PyTorch) on the target device
self._measure_speedup(
ref_model, session, dataloader, io_config, torch_device, config.warmup_iterations, config.timing_iterations
)

# Check thresholds
failures = []
if config.max_mae is not None and max_abs_error > config.max_mae:
Expand Down Expand Up @@ -254,6 +286,66 @@ def _run_for_config(
# Return the model unchanged
return model

def _measure_speedup(
self, ref_model, session, dataloader, io_config, torch_device, warmup_iterations, timing_iterations
):
"""Measure inference speedup of ONNX over PyTorch on the target device."""
import torch

from olive.common.utils import format_data

# Use the first batch for timing
first_batch = next(iter(dataloader))
input_data = first_batch[0] if isinstance(first_batch, (tuple, list)) else first_batch

if isinstance(input_data, dict):
torch_inputs = {k: v.clone().to(torch_device) for k, v in input_data.items()}
else:
torch_inputs = input_data.to(torch_device)

onnx_input_feed = format_data(input_data, io_config)
use_cuda_sync = torch_device.type == "cuda"

# Warmup PyTorch
with torch.no_grad():
for _ in range(warmup_iterations):
ref_model(**torch_inputs)
if use_cuda_sync:
torch.cuda.synchronize()

# Time PyTorch
with torch.no_grad():
if use_cuda_sync:
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(timing_iterations):
ref_model(**torch_inputs)
if use_cuda_sync:
torch.cuda.synchronize()
pytorch_time = (time.perf_counter() - start) / timing_iterations

# Warmup ONNX
for _ in range(warmup_iterations):
session.run(None, onnx_input_feed)

# Time ONNX
start = time.perf_counter()
for _ in range(timing_iterations):
session.run(None, onnx_input_feed)
onnx_time = (time.perf_counter() - start) / timing_iterations

speedup = pytorch_time / onnx_time if onnx_time > 0 else float("inf")

logger.info(
"OnnxDiscrepancyCheck speedup: pytorch_avg=%.4fs, onnx_avg=%.4fs, speedup=%.2fx (device=%s)",
pytorch_time,
onnx_time,
speedup,
torch_device,
)

return speedup

def compare_generation(self, config: type[BasePassConfig], ref_model) -> int:
Comment on lines +375 to 377

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already addressed in commit 595bffc. input_ids is moved to ref_model.device (line 389) and the output is moved back to CPU via .cpu() before .tolist() (line 398).

"""Run generation on both transformers and GenAI, return longest common token sequence length."""
try:
Expand Down
Loading