Add OnnxDiscrepancyCheck speedup metric with default timing updates#2502
Add OnnxDiscrepancyCheck speedup metric with default timing updates#2502xadupre wants to merge 8 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR enhances the OnnxDiscrepancyCheck pass by adding an inference speedup measurement (ONNX vs PyTorch) and introducing configurable warmup/timing iteration settings, with updated defaults and tests to validate the new behavior.
Changes:
- Added
warmup_iterationsandtiming_iterationsconfig parameters (defaultingtiming_iterationsto5) and implemented speedup measurement with an option to skip whentiming_iterations=0. - Updated session/device setup to target the configured accelerator (with CPU fallback) and run PyTorch on the matched torch device.
- Added unit tests to validate the new default and the skip behavior.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
olive/passes/onnx/discrepancy_check.py |
Adds speedup timing, new config params/defaults, and device-aware session/model handling. |
test/passes/onnx/test_discrepancy_check.py |
Adds tests for the new timing_iterations default and the “skip when 0” behavior. |
| return speedup | ||
|
|
||
| def compare_generation(self, config: type[BasePassConfig], ref_model) -> int: |
There was a problem hiding this comment.
This is already addressed in commit 595bffc. input_ids is moved to ref_model.device (line 389) and the output is moved back to CPU via .cpu() before .tolist() (line 398).
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
|
@copilot test/passes/onnx/test_discrepancy_check.py:105: in test_compare_generation_returns_common_prefix_length |
Fixed in 25c6ab4. Added |
Describe your changes
Added speedup measurement for
OnnxDiscrepancyCheckand updated behavior based on review feedback:timing_iterationsdefault from10to5.timing_iterationsis set to0, speedup measurement is skipped.deviceattribute forcompare_generationtests.Checklist before requesting a review
lintrunner -a