Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 126 additions & 5 deletions olive/passes/onnx/discrepancy_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import time
from typing import Optional

from olive.data.config import DataConfig
from olive.hardware import AcceleratorSpec
from olive.hardware.accelerator import Device
from olive.model import ONNXModelHandler
from olive.passes import Pass
from olive.passes.pass_config import BasePassConfig, PassConfigParam
Expand Down Expand Up @@ -61,6 +63,7 @@ class OnnxDiscrepancyCheck(Pass):
- Maximum absolute error (MaxAE)
- Number of elements where the absolute difference exceeds 0.1
- Number of elements where the absolute difference exceeds 0.01
- Inference speedup of ONNX over PyTorch on the target device (or CPU fallback)
- Longest common token sequence from the beginning between transformers
generate and ONNX Runtime GenAI generate (when enabled)

Expand Down Expand Up @@ -107,6 +110,19 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
"computes the longest common token sequence from the beginning of their outputs."
),
),
"warmup_iterations": PassConfigParam(
type_=int,
default_value=3,
description="Number of warmup iterations before timing inference for speedup measurement.",
),
"timing_iterations": PassConfigParam(
Comment thread
xadupre marked this conversation as resolved.
type_=int,
default_value=5,
description=(
"Number of timed iterations to measure inference speedup (ONNX vs PyTorch). "
"Set to 0 to disable speedup measurement."
),
),
"generate_prompt": PassConfigParam(
type_=str,
default_value="The capital of France is",
Expand Down Expand Up @@ -177,8 +193,28 @@ def _run_for_config(
ref_model = AutoModelForCausalLM.from_pretrained(config.reference_model_path)
ref_model.eval()

# Prepare ONNX session
session = model.prepare_session()
# Prepare ONNX session on the target device (fallback to CPU)
device = self.accelerator_spec.accelerator_type if self.accelerator_spec else None
execution_provider = self.accelerator_spec.execution_provider if self.accelerator_spec else None
if device is None:
device = Device.CPU
Comment thread
xadupre marked this conversation as resolved.
elif not isinstance(device, Device):
try:
device = Device(str(device).lower())
except ValueError:
logger.warning("Unknown accelerator_type=%s; falling back to CPU.", device)
device = Device.CPU

# Determine the torch device matching the accelerator spec
torch_device = torch.device("cpu")
if device == Device.GPU and torch.cuda.is_available():
torch_device = torch.device("cuda")
ref_model = ref_model.to(torch_device)

session = model.prepare_session(
device=device,
execution_providers=[execution_provider] if execution_provider else None,
)
io_config = model.io_config

# Run inference on both and compare
Expand All @@ -194,9 +230,9 @@ def _run_for_config(

# Run PyTorch inference
if isinstance(input_data, dict):
torch_inputs = {k: v.clone() for k, v in input_data.items()}
torch_inputs = {k: v.clone().to(torch_device) for k, v in input_data.items()}
else:
torch_inputs = input_data
torch_inputs = input_data.to(torch_device)

torch_output = ref_model(**torch_inputs)
torch_logits = torch_output.logits.detach().cpu().numpy()
Expand Down Expand Up @@ -225,6 +261,23 @@ def _run_for_config(
total_elements,
)

# Measure inference speedup (ONNX vs PyTorch) on the target device
if config.timing_iterations > 0:
self._measure_speedup(
ref_model,
session,
dataloader,
io_config,
torch_device,
config.warmup_iterations,
config.timing_iterations,
)
else:
logger.info(
"OnnxDiscrepancyCheck speedup measurement skipped because timing_iterations=%d.",
config.timing_iterations,
)

# Check thresholds
failures = []
if config.max_mae is not None and max_abs_error > config.max_mae:
Expand Down Expand Up @@ -254,6 +307,73 @@ def _run_for_config(
# Return the model unchanged
return model

def _measure_speedup(
self, ref_model, session, dataloader, io_config, torch_device, warmup_iterations, timing_iterations
):
"""Measure inference speedup of ONNX over PyTorch on the target device."""
if timing_iterations <= 0:
logger.info(
"OnnxDiscrepancyCheck speedup measurement skipped because timing_iterations=%d.",
timing_iterations,
)
return None

import torch

from olive.common.utils import format_data

# Use the first batch for timing
first_batch = next(iter(dataloader))
input_data = first_batch[0] if isinstance(first_batch, (tuple, list)) else first_batch

if isinstance(input_data, dict):
torch_inputs = {k: v.clone().to(torch_device) for k, v in input_data.items()}
else:
torch_inputs = input_data.to(torch_device)

onnx_input_feed = format_data(input_data, io_config)
use_cuda_sync = torch_device.type == "cuda"

# Warmup PyTorch
with torch.no_grad():
for _ in range(warmup_iterations):
ref_model(**torch_inputs)
if use_cuda_sync:
torch.cuda.synchronize()

# Time PyTorch
with torch.no_grad():
if use_cuda_sync:
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(timing_iterations):
ref_model(**torch_inputs)
if use_cuda_sync:
torch.cuda.synchronize()
pytorch_time = (time.perf_counter() - start) / timing_iterations

# Warmup ONNX
for _ in range(warmup_iterations):
session.run(None, onnx_input_feed)

# Time ONNX
start = time.perf_counter()
for _ in range(timing_iterations):
session.run(None, onnx_input_feed)
onnx_time = (time.perf_counter() - start) / timing_iterations

speedup = pytorch_time / onnx_time if onnx_time > 0 else float("inf")

logger.info(
"OnnxDiscrepancyCheck speedup: pytorch_avg=%.4fs, onnx_avg=%.4fs, speedup=%.2fx (device=%s)",
pytorch_time,
onnx_time,
speedup,
torch_device,
)

return speedup

def compare_generation(self, config: type[BasePassConfig], ref_model) -> int:
Comment on lines +375 to 377

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already addressed in commit 595bffc. input_ids is moved to ref_model.device (line 389) and the output is moved back to CPU via .cpu() before .tolist() (line 398).

"""Run generation on both transformers and GenAI, return longest common token sequence length."""
try:
Expand All @@ -266,6 +386,7 @@ def compare_generation(self, config: type[BasePassConfig], ref_model) -> int:

# Transformers generation
input_ids = tokenizer(config.generate_prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(ref_model.device)
import torch

with torch.no_grad():
Expand All @@ -274,7 +395,7 @@ def compare_generation(self, config: type[BasePassConfig], ref_model) -> int:
max_new_tokens=config.generate_max_new_tokens,
do_sample=False,
)
transformers_tokens = transformers_output[0].tolist()
transformers_tokens = transformers_output[0].cpu().tolist()

# ONNX Runtime GenAI generation
genai_model = og.Model(config.genai_model_path)
Expand Down
31 changes: 31 additions & 0 deletions test/passes/onnx/test_discrepancy_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=protected-access

import sys
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -159,3 +161,32 @@ def get_next_tokens_side_effect():
mock_generator.append_tokens.assert_called_once_with([[10, 20]])
# All 5 tokens match
assert result == 5


class TestSpeedupSettings:
def test_timing_iterations_default_is_5(self):
from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck

pass_config = OnnxDiscrepancyCheck._default_config(None)
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
assert pass_config["timing_iterations"].default_value == 5

def test_measure_speedup_skips_when_timing_iterations_is_zero(self):
from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck

pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck)
ref_model = MagicMock()
session = MagicMock()

result = pass_instance._measure_speedup(
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
ref_model=ref_model,
session=session,
dataloader=MagicMock(),
io_config=MagicMock(),
torch_device=MagicMock(),
warmup_iterations=3,
timing_iterations=0,
)

assert result is None
ref_model.assert_not_called()
session.run.assert_not_called()
Loading