diff --git a/olive/passes/onnx/discrepancy_check.py b/olive/passes/onnx/discrepancy_check.py index 711d91643..6f509d08b 100644 --- a/olive/passes/onnx/discrepancy_check.py +++ b/olive/passes/onnx/discrepancy_check.py @@ -3,10 +3,12 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging +import time from typing import Optional from olive.data.config import DataConfig from olive.hardware import AcceleratorSpec +from olive.hardware.accelerator import Device from olive.model import ONNXModelHandler from olive.passes import Pass from olive.passes.pass_config import BasePassConfig, PassConfigParam @@ -61,6 +63,7 @@ class OnnxDiscrepancyCheck(Pass): - Maximum absolute error (MaxAE) - Number of elements where the absolute difference exceeds 0.1 - Number of elements where the absolute difference exceeds 0.01 + - Inference speedup of ONNX over PyTorch on the target device (or CPU fallback) - Longest common token sequence from the beginning between transformers generate and ONNX Runtime GenAI generate (when enabled) @@ -107,6 +110,19 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon "computes the longest common token sequence from the beginning of their outputs." ), ), + "warmup_iterations": PassConfigParam( + type_=int, + default_value=3, + description="Number of warmup iterations before timing inference for speedup measurement.", + ), + "timing_iterations": PassConfigParam( + type_=int, + default_value=5, + description=( + "Number of timed iterations to measure inference speedup (ONNX vs PyTorch). " + "Set to 0 to disable speedup measurement." + ), + ), "generate_prompt": PassConfigParam( type_=str, default_value="The capital of France is", @@ -177,8 +193,28 @@ def _run_for_config( ref_model = AutoModelForCausalLM.from_pretrained(config.reference_model_path) ref_model.eval() - # Prepare ONNX session - session = model.prepare_session() + # Prepare ONNX session on the target device (fallback to CPU) + device = self.accelerator_spec.accelerator_type if self.accelerator_spec else None + execution_provider = self.accelerator_spec.execution_provider if self.accelerator_spec else None + if device is None: + device = Device.CPU + elif not isinstance(device, Device): + try: + device = Device(str(device).lower()) + except ValueError: + logger.warning("Unknown accelerator_type=%s; falling back to CPU.", device) + device = Device.CPU + + # Determine the torch device matching the accelerator spec + torch_device = torch.device("cpu") + if device == Device.GPU and torch.cuda.is_available(): + torch_device = torch.device("cuda") + ref_model = ref_model.to(torch_device) + + session = model.prepare_session( + device=device, + execution_providers=[execution_provider] if execution_provider else None, + ) io_config = model.io_config # Run inference on both and compare @@ -194,9 +230,9 @@ def _run_for_config( # Run PyTorch inference if isinstance(input_data, dict): - torch_inputs = {k: v.clone() for k, v in input_data.items()} + torch_inputs = {k: v.clone().to(torch_device) for k, v in input_data.items()} else: - torch_inputs = input_data + torch_inputs = input_data.to(torch_device) torch_output = ref_model(**torch_inputs) torch_logits = torch_output.logits.detach().cpu().numpy() @@ -225,6 +261,23 @@ def _run_for_config( total_elements, ) + # Measure inference speedup (ONNX vs PyTorch) on the target device + if config.timing_iterations > 0: + self._measure_speedup( + ref_model, + session, + dataloader, + io_config, + torch_device, + config.warmup_iterations, + config.timing_iterations, + ) + else: + logger.info( + "OnnxDiscrepancyCheck speedup measurement skipped because timing_iterations=%d.", + config.timing_iterations, + ) + # Check thresholds failures = [] if config.max_mae is not None and max_abs_error > config.max_mae: @@ -254,6 +307,73 @@ def _run_for_config( # Return the model unchanged return model + def _measure_speedup( + self, ref_model, session, dataloader, io_config, torch_device, warmup_iterations, timing_iterations + ): + """Measure inference speedup of ONNX over PyTorch on the target device.""" + if timing_iterations <= 0: + logger.info( + "OnnxDiscrepancyCheck speedup measurement skipped because timing_iterations=%d.", + timing_iterations, + ) + return None + + import torch + + from olive.common.utils import format_data + + # Use the first batch for timing + first_batch = next(iter(dataloader)) + input_data = first_batch[0] if isinstance(first_batch, (tuple, list)) else first_batch + + if isinstance(input_data, dict): + torch_inputs = {k: v.clone().to(torch_device) for k, v in input_data.items()} + else: + torch_inputs = input_data.to(torch_device) + + onnx_input_feed = format_data(input_data, io_config) + use_cuda_sync = torch_device.type == "cuda" + + # Warmup PyTorch + with torch.no_grad(): + for _ in range(warmup_iterations): + ref_model(**torch_inputs) + if use_cuda_sync: + torch.cuda.synchronize() + + # Time PyTorch + with torch.no_grad(): + if use_cuda_sync: + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(timing_iterations): + ref_model(**torch_inputs) + if use_cuda_sync: + torch.cuda.synchronize() + pytorch_time = (time.perf_counter() - start) / timing_iterations + + # Warmup ONNX + for _ in range(warmup_iterations): + session.run(None, onnx_input_feed) + + # Time ONNX + start = time.perf_counter() + for _ in range(timing_iterations): + session.run(None, onnx_input_feed) + onnx_time = (time.perf_counter() - start) / timing_iterations + + speedup = pytorch_time / onnx_time if onnx_time > 0 else float("inf") + + logger.info( + "OnnxDiscrepancyCheck speedup: pytorch_avg=%.4fs, onnx_avg=%.4fs, speedup=%.2fx (device=%s)", + pytorch_time, + onnx_time, + speedup, + torch_device, + ) + + return speedup + def compare_generation(self, config: type[BasePassConfig], ref_model) -> int: """Run generation on both transformers and GenAI, return longest common token sequence length.""" try: @@ -266,6 +386,7 @@ def compare_generation(self, config: type[BasePassConfig], ref_model) -> int: # Transformers generation input_ids = tokenizer(config.generate_prompt, return_tensors="pt").input_ids + input_ids = input_ids.to(ref_model.device) import torch with torch.no_grad(): @@ -274,7 +395,7 @@ def compare_generation(self, config: type[BasePassConfig], ref_model) -> int: max_new_tokens=config.generate_max_new_tokens, do_sample=False, ) - transformers_tokens = transformers_output[0].tolist() + transformers_tokens = transformers_output[0].cpu().tolist() # ONNX Runtime GenAI generation genai_model = og.Model(config.genai_model_path) diff --git a/test/passes/onnx/test_discrepancy_check.py b/test/passes/onnx/test_discrepancy_check.py index 8fe9c677f..c6a2f83eb 100644 --- a/test/passes/onnx/test_discrepancy_check.py +++ b/test/passes/onnx/test_discrepancy_check.py @@ -2,6 +2,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +# pylint: disable=protected-access + import sys from unittest.mock import MagicMock, patch @@ -65,6 +67,7 @@ def test_compare_generation_returns_common_prefix_length(self): # Transformers generates [1, 2, 3, 10, 11, 12, 13] mock_ref_model = MagicMock() + mock_ref_model.device = torch.device("cpu") mock_ref_model.generate.return_value = torch.tensor([[1, 2, 3, 10, 11, 12, 13]]) # GenAI generates [1, 2, 3, 10, 11, 99, 99] (diverges at index 5) @@ -122,6 +125,7 @@ def test_compare_generation_fully_matching(self): mock_tokenizer.return_value = MagicMock(input_ids=torch.tensor([[10, 20]])) mock_ref_model = MagicMock() + mock_ref_model.device = torch.device("cpu") mock_ref_model.generate.return_value = torch.tensor([[10, 20, 30, 40, 50]]) mock_og = MagicMock() @@ -159,3 +163,32 @@ def get_next_tokens_side_effect(): mock_generator.append_tokens.assert_called_once_with([[10, 20]]) # All 5 tokens match assert result == 5 + + +class TestSpeedupSettings: + def test_timing_iterations_default_is_5(self): + from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck + + pass_config = OnnxDiscrepancyCheck._default_config(None) + assert pass_config["timing_iterations"].default_value == 5 + + def test_measure_speedup_skips_when_timing_iterations_is_zero(self): + from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck + + pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck) + ref_model = MagicMock() + session = MagicMock() + + result = pass_instance._measure_speedup( + ref_model=ref_model, + session=session, + dataloader=MagicMock(), + io_config=MagicMock(), + torch_device=MagicMock(), + warmup_iterations=3, + timing_iterations=0, + ) + + assert result is None + ref_model.assert_not_called() + session.run.assert_not_called()