torchlib: add aten::bincount ONNX lowering and bool-mask aten::index_put fix with e2e coverage#2921
torchlib: add aten::bincount ONNX lowering and bool-mask aten::index_put fix with e2e coverage#2921Copilot wants to merge 10 commits into
aten::bincount ONNX lowering and bool-mask aten::index_put fix with e2e coverage#2921Conversation
Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/bbc26d89-9521-4afa-9885-df627c50aecf Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
aten::bincount ONNX lowering (with minlength) and e2e coverage
|
@copilot what about other ops? Can you add them here too? |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2921 +/- ##
==========================================
- Coverage 72.64% 72.58% -0.06%
==========================================
Files 259 259
Lines 31652 31687 +35
Branches 2980 2987 +7
==========================================
+ Hits 22994 23001 +7
- Misses 7649 7676 +27
- Partials 1009 1010 +1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Pull request overview
Adds an ONNX lowering for aten::bincount in torchlib (including minlength) and an end-to-end export test to ensure the new translation is exercised through the torch export → ONNX → runtime path.
Changes:
- Register
aten::bincountas atrace_onlytorchlib op and implement lowering viaReduceMax+OneHot+ReduceSum(withminlengthsupport). - Explicitly reject weighted bincount (
weights != None) viaNotImplementedError. - Add an e2e test covering
torch.bincount(..., minlength=...).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
onnxscript/function_libs/torch_lib/ops/core.py |
Implements aten::bincount ONNX lowering and minlength handling. |
tests/function_libs/torch_lib/e2e_ops_tests.py |
Adds an e2e export test for torch.bincount(..., minlength=...). |
Comments suppressed due to low confidence (1)
onnxscript/function_libs/torch_lib/ops/core.py:1273
depthis derived fromReduceMax(self). For an emptyself, ONNXReduceMaxreduces over an empty set by returning the minimum value of the dtype (for INT64), sodepth = min_int64 + 1becomes a huge negative number andOneHotwill be invalid.torch.bincounton empty input should return an empty result (or lengthminlengthfilled with zeros), so this lowering needs an explicit empty-input branch (e.g.,if op.Size(self) == 0: ...) to return the correct zero-filled output without callingReduceMax/OneHot.
axis_0 = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
max_val = op.Unsqueeze(op.ReduceMax(self, keepdims=0), axis_0)
depth = op.Add(max_val, one)
if minlength > 0:
depth = op.Max(depth, op.Constant(value_ints=[minlength]))
one_hot = op.OneHot(self, depth, op.Constant(value_ints=[0, 1]), axis=-1)
return op.ReduceSum(one_hot, axis_0, keepdims=0)
| one = op.Constant(value_ints=[1]) | ||
| max_val = op.Unsqueeze(op.ReduceMax(self, keepdims=0), axis_0) | ||
| depth = op.Add(max_val, one) | ||
| if minlength > 0: | ||
| depth = op.Max(depth, op.Constant(value_ints=[minlength])) |
| onnx_program = torch.onnx.export( | ||
| Model(), | ||
| (torch.tensor([0, 1, 1, 3, 5], dtype=torch.int64),), | ||
| dynamo=True, | ||
| optimize=False, | ||
| ) | ||
| _testing.assert_onnx_program(onnx_program) | ||
|
|
Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/5ff8e3c5-f1b7-4d47-bbc0-4b7043a2bc61 Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Added another applicable op update in torchlib: I implemented the boolean-mask |
aten::bincount ONNX lowering (with minlength) and e2e coverageaten::bincount ONNX lowering and bool-mask aten::index_put fix with e2e coverage
| def _aten_index_put_bool( | ||
| self: TReal, | ||
| indices: Sequence[BOOL], | ||
| indices: Sequence[Optional[Union[INT64, BOOL]]], |
There was a problem hiding this comment.
@copilot should this have INT64? I don't think so?
There was a problem hiding this comment.
Good catch — I removed INT64 from that bool helper signature, so it now expects Sequence[Optional[BOOL]] only. Fixed in 6ad7f14.
| return op.Where(index, values, self) | ||
| del accumulate # Boolean masks index each position at most once. | ||
|
|
||
| bool_mask = indices[0] |
There was a problem hiding this comment.
@copilot so indices with more than 1 elements is still not supported? What is the reference implementation in huggingface?
There was a problem hiding this comment.
Correct — the bool path was effectively single-mask only. I made that explicit so we fail fast for multi-element indices instead of silently ignoring extras in ceb0881. I also couldn’t find a Hugging Face-specific lowering for this case; the reference we follow here is PyTorch’s torch.onnx.symbolic_opset11.index_put (already linked in the docstring).
There was a problem hiding this comment.
Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/95a124b2-bf9a-4a54-bb80-8e10a16a6b53 Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/9b023bd8-af00-478c-bc75-2369c734b5c2 Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
I added another applicable torchlib op in this PR ( |
|
@copilot huggingface/transformers@98445c3/src/transformers/exporters/exporter_onnx.py#L647 this is the reference. Please update index_put so that it is generalized to multiple masks |
Implemented in
|
titaiwangms
left a comment
There was a problem hiding this comment.
Review summary
The added e2e tests pass (3/3) and common bool-mask index_put cases match torch eager. However, there's one verified regression plus several correctness gaps worth addressing.
🔴 Critical
C1. Single bool-mask index_put breaks scalar/broadcast values (verified with a repro).
The new cumsum→gather scheme assigns each selected element a flat value index 0..num_true-1, so when values is a scalar (or shorter-than-count broadcastable tensor), Gather goes out of bounds:
x = torch.zeros(2, 3); mask = torch.tensor([True, False]); v = torch.tensor(5.0)
torch.ops.aten.index_put(x, [mask], v) # eager: row 0 -> 5
# exported ONNX in ORT: "Gather ... indices element out of data bounds, idx=1 must be within [-1,0]"This is the very common t[t < 0] = 0 pattern, and a regression from the previous Where(mask, values, self) which broadcast correctly. Suggested fix: Expand values to the selected shape before scatter/where, or retain a Where+broadcast path when values isn't full-length. (Row-mask with full-shape values, and 1-D mask with length-N values, both work correctly.)
C2. bincount on empty input errors instead of returning zeros(minlength).
ReduceMax over an empty tensor throws in ORT, whereas torch.bincount([]) returns a length-minlength zero vector (length 0 when minlength == 0). Suggest gating on Size(self) == 0.
🟠 Major
accumulate=Truesilently dropped in the single-mask path (del accumulate). torch honorstensor[mask] += v. (Pre-existing in the old code, but the rewrite is a good opportunity — at minimumraise NotImplementedErrorwhenaccumulate=True.)- Multi-mask path requires equal True counts. torch broadcasts a count-1 axis; here
Concatshape-mismatches. Broadcast the position tensors, or raise a clear error. bincountnegative inputs silently miscount (ONNXOneHotwraps negative indices), whereas torch raises. Add a guard or document the limitation.bincountmemory isO(N · max(self))via the denseOneHot, vs torch'sO(N + depth). A scatter-add lowering (ScatterElements(reduction="add")into a[depth]zero vector) would avoid the blow-up on large inputs.bincounttypedINT64-only is narrower than torch (accepts int8/16/32/uint8). Consider widening to integral types and casting internally.
🟡 Minor
CumSumaxisandClipminare passed as rank-1[..]tensors; the ONNX spec wants 0-D scalars (ORT is lenient and existing code does the same). For reference,OneHotdepthas a rank-1 single-element tensor is spec-legal.- No 1-D rank check on
bincountinput; a mask of shape[1]onself[2,3]silently broadcasts where torch would raise a shape mismatch. - No
TorchLibOpInfo("bincount", ...)entry inops_test_data.py— the repo convention for parameterized correctness coverage against torch. The single happy-path e2e test (minlength=6) doesn't cover defaultminlength=0, empty input, or the unsupportedweightspath.
⚪ Readability
- Add a 1–2 line comment for the cumsum-minus-one → gather trick and for the
NonZero→ScatterNDindex construction — the two hardest-to-decode blocks. - Naming:
positionsvsselected_positionscollide in the multi-mask loop;minus_onecould beneg_1to match the file's existing convention;oneis over-named.
✅ Confirmed correct
OneHot+ReduceSum int64 output dtype matches torch; the multi-mask ScatterND shape arithmetic is correct for the equal-count case; ScatterND reduction="add" is valid at opset-18; and the same_shape None-hardening is a genuine correctness improvement.
Bottom line: C1 (verified) and C2 hit common real-world patterns and should be resolved before merge. The remaining items are scope-appropriate hardening for an op-lowering PR.
Reviewed with a multi-model review team (correctness, adversarial, spec-adherence, readability) plus local runtime verification against onnxruntime.
@copilot Does this make sense? If so, can you address it? |
…alar values, accumulate)
Yes, this makes sense — addressed in the latest commit:
Added e2e coverage: scalar-value and row-mask bool |
|
@copilot Can you fix the lint CI? |
Fixed the lint CI — the failure was a RUFF-FORMAT diff on |
This issue asks us to absorb relevant exporter-side patches into torchlib where applicable (example: explicit
TopK.sorted). In this PR, applicable missing/insufficient torchlib patches were implemented foraten::bincountand boolean-maskaten::index_put.aten::bincounttranslation in torchlibaten::bincountas atrace_onlytorchlib op.ScatterElements(reduction="add")of ones into a zero vector of lengthdepth. This uses O(N + depth) memory instead of a dense one-hot.depthfromReduceMax(with an appended sentinel0so it is defined on empty input) and supportsminlengthby clamping the depth withMax.bincount([])returns an empty vector andbincount([], minlength=k)returnszeros(k), matching torch.INT64internally).weights != None); negative inputs are documented as unsupported.Boolean-mask
aten::index_putfix in torchlib_aten_index_put_boolsingle-mask path to use aNonZero+ScatterNDstrategy.valuesis broadcast to the selection shape[num_true, *self.shape[mask_rank:]]before scatter, so scalar/broadcastable, length-num_true, and row-mask full-shape values all work correctly (fixing aGatherout-of-bounds regression for scalar values such as thet[t < 0] = 0pattern).accumulate=Trueis now honored viaScatterNDreduction="add"instead of being dropped.NonZero, builds a multi-columnScatterNDindex, and expands updates over the trailing dimensions; equal True-count across masks remains a known limitation.Focused e2e coverage
torch.bincount(..., minlength=...), defaultminlength, and empty input.torch.ops.aten.index_put, including scalar-value, row-mask scalar-value, and multi-mask cases.Example of the
aten::bincountlowering pattern: