Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 118 additions & 4 deletions olive/passes/onnx/discrepancy_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import time
from typing import Optional

from olive.data.config import DataConfig
from olive.hardware import AcceleratorSpec
from olive.hardware.accelerator import Device
from olive.model import ONNXModelHandler
from olive.passes import Pass
from olive.passes.pass_config import BasePassConfig, PassConfigParam
Expand Down Expand Up @@ -61,6 +63,7 @@ class OnnxDiscrepancyCheck(Pass):
- Maximum absolute error (MaxAE)
- Number of elements where the absolute difference exceeds 0.1
- Number of elements where the absolute difference exceeds 0.01
- Inference speedup of ONNX over PyTorch on the target device (or CPU fallback)
- Longest common token sequence from the beginning between transformers
generate and ONNX Runtime GenAI generate (when enabled)

Expand Down Expand Up @@ -107,6 +110,19 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
"computes the longest common token sequence from the beginning of their outputs."
),
),
"warmup_iterations": PassConfigParam(
type_=int,
default_value=3,
description="Number of warmup iterations before timing inference for speedup measurement.",
),
"timing_iterations": PassConfigParam(
Comment thread
xadupre marked this conversation as resolved.
type_=int,
default_value=5,
description=(
"Number of timed iterations to measure inference speedup (ONNX vs PyTorch). "
"Set to 0 to disable speedup measurement."
),
),
"generate_prompt": PassConfigParam(
type_=str,
default_value="The capital of France is",
Expand Down Expand Up @@ -177,8 +193,22 @@ def _run_for_config(
ref_model = AutoModelForCausalLM.from_pretrained(config.reference_model_path)
ref_model.eval()

# Prepare ONNX session
session = model.prepare_session()
# Prepare ONNX session on the target device (fallback to CPU)
device = self.accelerator_spec.accelerator_type if self.accelerator_spec else None
execution_provider = self.accelerator_spec.execution_provider if self.accelerator_spec else None
if device is None:
device = Device.CPU
Comment thread
xadupre marked this conversation as resolved.

# Determine the torch device matching the accelerator spec
torch_device = torch.device("cpu")
if device == Device.GPU and torch.cuda.is_available():
torch_device = torch.device("cuda")
ref_model = ref_model.to(torch_device)

session = model.prepare_session(
device=device,
execution_providers=[execution_provider] if execution_provider else None,
)
io_config = model.io_config

# Run inference on both and compare
Expand All @@ -194,9 +224,9 @@ def _run_for_config(

# Run PyTorch inference
if isinstance(input_data, dict):
torch_inputs = {k: v.clone() for k, v in input_data.items()}
torch_inputs = {k: v.clone().to(torch_device) for k, v in input_data.items()}
else:
torch_inputs = input_data
torch_inputs = input_data.to(torch_device)

torch_output = ref_model(**torch_inputs)
torch_logits = torch_output.logits.detach().cpu().numpy()
Expand Down Expand Up @@ -225,6 +255,23 @@ def _run_for_config(
total_elements,
)

# Measure inference speedup (ONNX vs PyTorch) on the target device
if config.timing_iterations > 0:
self._measure_speedup(
ref_model,
session,
dataloader,
io_config,
torch_device,
config.warmup_iterations,
config.timing_iterations,
)
else:
logger.info(
"OnnxDiscrepancyCheck speedup measurement skipped because timing_iterations=%d.",
config.timing_iterations,
)

# Check thresholds
failures = []
if config.max_mae is not None and max_abs_error > config.max_mae:
Expand Down Expand Up @@ -254,6 +301,73 @@ def _run_for_config(
# Return the model unchanged
return model

def _measure_speedup(
self, ref_model, session, dataloader, io_config, torch_device, warmup_iterations, timing_iterations
):
"""Measure inference speedup of ONNX over PyTorch on the target device."""
if timing_iterations <= 0:
logger.info(
"OnnxDiscrepancyCheck speedup measurement skipped because timing_iterations=%d.",
timing_iterations,
)
return None

import torch

from olive.common.utils import format_data

# Use the first batch for timing
first_batch = next(iter(dataloader))
input_data = first_batch[0] if isinstance(first_batch, (tuple, list)) else first_batch

if isinstance(input_data, dict):
torch_inputs = {k: v.clone().to(torch_device) for k, v in input_data.items()}
else:
torch_inputs = input_data.to(torch_device)

onnx_input_feed = format_data(input_data, io_config)
use_cuda_sync = torch_device.type == "cuda"

# Warmup PyTorch
with torch.no_grad():
for _ in range(warmup_iterations):
ref_model(**torch_inputs)
if use_cuda_sync:
torch.cuda.synchronize()

# Time PyTorch
with torch.no_grad():
if use_cuda_sync:
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(timing_iterations):
ref_model(**torch_inputs)
if use_cuda_sync:
torch.cuda.synchronize()
pytorch_time = (time.perf_counter() - start) / timing_iterations

# Warmup ONNX
for _ in range(warmup_iterations):
session.run(None, onnx_input_feed)

# Time ONNX
start = time.perf_counter()
for _ in range(timing_iterations):
session.run(None, onnx_input_feed)
onnx_time = (time.perf_counter() - start) / timing_iterations

speedup = pytorch_time / onnx_time if onnx_time > 0 else float("inf")

logger.info(
"OnnxDiscrepancyCheck speedup: pytorch_avg=%.4fs, onnx_avg=%.4fs, speedup=%.2fx (device=%s)",
pytorch_time,
onnx_time,
speedup,
torch_device,
)

return speedup

def compare_generation(self, config: type[BasePassConfig], ref_model) -> int:
Comment thread
xadupre marked this conversation as resolved.
"""Run generation on both transformers and GenAI, return longest common token sequence length."""
try:
Expand Down
29 changes: 29 additions & 0 deletions test/passes/onnx/test_discrepancy_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,32 @@
mock_generator.append_tokens.assert_called_once_with([[10, 20]])
# All 5 tokens match
assert result == 5


class TestSpeedupSettings:
def test_timing_iterations_default_is_5(self):
from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck

pass_config = OnnxDiscrepancyCheck._default_config(None)

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning test

Access to a protected member _default_config of a client class (protected-access)
See protected-access.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
assert pass_config["timing_iterations"].default_value == 5

def test_measure_speedup_skips_when_timing_iterations_is_zero(self):
from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck

pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck)
ref_model = MagicMock()
session = MagicMock()

result = pass_instance._measure_speedup(

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning test

Access to a protected member _measure_speedup of a client class (protected-access)
See protected-access.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
ref_model=ref_model,
session=session,
dataloader=MagicMock(),
io_config=MagicMock(),
torch_device=MagicMock(),
warmup_iterations=3,
timing_iterations=0,
)

assert result is None
ref_model.assert_not_called()
session.run.assert_not_called()
Loading