Skip to content

torchlib: add aten::bincount ONNX lowering and bool-mask aten::index_put fix with e2e coverage#2921

Open
Copilot wants to merge 10 commits into
mainfrom
copilot/update-torchlib
Open

torchlib: add aten::bincount ONNX lowering and bool-mask aten::index_put fix with e2e coverage#2921
Copilot wants to merge 10 commits into
mainfrom
copilot/update-torchlib

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented May 20, 2026

This issue asks us to absorb relevant exporter-side patches into torchlib where applicable (example: explicit TopK.sorted). In this PR, applicable missing/insufficient torchlib patches were implemented for aten::bincount and boolean-mask aten::index_put.

  • aten::bincount translation in torchlib

    • Registered aten::bincount as a trace_only torchlib op.
    • Implemented lowering using a scatter-add strategy: ScatterElements(reduction="add") of ones into a zero vector of length depth. This uses O(N + depth) memory instead of a dense one-hot.
    • Computes depth from ReduceMax (with an appended sentinel 0 so it is defined on empty input) and supports minlength by clamping the depth with Max.
    • Handles empty input correctly: bincount([]) returns an empty vector and bincount([], minlength=k) returns zeros(k), matching torch.
    • Widened the accepted input to integral types (cast to INT64 internally).
    • Explicitly rejects weighted bincount (weights != None); negative inputs are documented as unsupported.
  • Boolean-mask aten::index_put fix in torchlib

    • Updated _aten_index_put_bool single-mask path to use a NonZero + ScatterND strategy. values is broadcast to the selection shape [num_true, *self.shape[mask_rank:]] before scatter, so scalar/broadcastable, length-num_true, and row-mask full-shape values all work correctly (fixing a Gather out-of-bounds regression for scalar values such as the t[t < 0] = 0 pattern).
    • accumulate=True is now honored via ScatterND reduction="add" instead of being dropped.
    • The multi-mask path (multiple 1-D boolean masks) converts each mask via NonZero, builds a multi-column ScatterND index, and expands updates over the trailing dimensions; equal True-count across masks remains a known limitation.
  • Focused e2e coverage

    • Added torch-export e2e tests for torch.bincount(..., minlength=...), default minlength, and empty input.
    • Added torch-export e2e regression tests for boolean-mask torch.ops.aten.index_put, including scalar-value, row-mask scalar-value, and multi-mask cases.

Example of the aten::bincount lowering pattern:

@torch_op("aten::bincount", trace_only=True)
def aten_bincount(self: IntType, weights: Optional[TensorType] = None, minlength: int = 0):
    if weights is not None:
        raise NotImplementedError("aten::bincount with weights is not supported.")

    self = op.Cast(self, to=INT64.dtype)
    axis_0 = op.Constant(value_ints=[0])
    data_max = op.Unsqueeze(
        op.ReduceMax(op.Concat(self, op.Constant(value_ints=[0]), axis=0), keepdims=0),
        axis_0,
    )
    non_empty = op.Unsqueeze(
        op.Cast(op.Greater(op.Size(self), op.Constant(value_int=0)), to=INT64.dtype),
        axis_0,
    )
    depth = op.Mul(op.Add(data_max, op.Constant(value_ints=[1])), non_empty)
    if minlength > 0:
        depth = op.Max(depth, op.Constant(value_ints=[minlength]))

    zeros = op.Expand(op.Constant(value_int=0), depth)
    ones = op.Expand(op.Constant(value_int=1), op.Shape(self))
    return op.ScatterElements(zeros, self, ones, axis=0, reduction="add")

Copilot AI changed the title [WIP] Update torchlib to absorb patches and improvements torchlib: add aten::bincount ONNX lowering (with minlength) and e2e coverage May 20, 2026
Copilot AI requested a review from justinchuby May 20, 2026 16:05
@justinchuby justinchuby marked this pull request as ready for review May 20, 2026 19:16
@justinchuby justinchuby requested review from Copilot and xadupre May 20, 2026 19:16
@justinchuby
Copy link
Copy Markdown
Collaborator

@copilot what about other ops? Can you add them here too?

@codecov
Copy link
Copy Markdown

codecov Bot commented May 20, 2026

Codecov Report

❌ Patch coverage is 34.78261% with 30 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.58%. Comparing base (1fb43af) to head (d8df795).
⚠️ Report is 1 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 34.78% 28 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2921      +/-   ##
==========================================
- Coverage   72.64%   72.58%   -0.06%     
==========================================
  Files         259      259              
  Lines       31652    31687      +35     
  Branches     2980     2987       +7     
==========================================
+ Hits        22994    23001       +7     
- Misses       7649     7676      +27     
- Partials     1009     1010       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an ONNX lowering for aten::bincount in torchlib (including minlength) and an end-to-end export test to ensure the new translation is exercised through the torch export → ONNX → runtime path.

Changes:

  • Register aten::bincount as a trace_only torchlib op and implement lowering via ReduceMax + OneHot + ReduceSum (with minlength support).
  • Explicitly reject weighted bincount (weights != None) via NotImplementedError.
  • Add an e2e test covering torch.bincount(..., minlength=...).

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
onnxscript/function_libs/torch_lib/ops/core.py Implements aten::bincount ONNX lowering and minlength handling.
tests/function_libs/torch_lib/e2e_ops_tests.py Adds an e2e export test for torch.bincount(..., minlength=...).
Comments suppressed due to low confidence (1)

onnxscript/function_libs/torch_lib/ops/core.py:1273

  • depth is derived from ReduceMax(self). For an empty self, ONNX ReduceMax reduces over an empty set by returning the minimum value of the dtype (for INT64), so depth = min_int64 + 1 becomes a huge negative number and OneHot will be invalid. torch.bincount on empty input should return an empty result (or length minlength filled with zeros), so this lowering needs an explicit empty-input branch (e.g., if op.Size(self) == 0: ...) to return the correct zero-filled output without calling ReduceMax/OneHot.
    axis_0 = op.Constant(value_ints=[0])
    one = op.Constant(value_ints=[1])
    max_val = op.Unsqueeze(op.ReduceMax(self, keepdims=0), axis_0)
    depth = op.Add(max_val, one)
    if minlength > 0:
        depth = op.Max(depth, op.Constant(value_ints=[minlength]))

    one_hot = op.OneHot(self, depth, op.Constant(value_ints=[0, 1]), axis=-1)
    return op.ReduceSum(one_hot, axis_0, keepdims=0)

Comment on lines +1266 to +1270
one = op.Constant(value_ints=[1])
max_val = op.Unsqueeze(op.ReduceMax(self, keepdims=0), axis_0)
depth = op.Add(max_val, one)
if minlength > 0:
depth = op.Max(depth, op.Constant(value_ints=[minlength]))
Comment on lines +87 to +94
onnx_program = torch.onnx.export(
Model(),
(torch.tensor([0, 1, 1, 3, 5], dtype=torch.int64),),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

Copy link
Copy Markdown
Contributor Author

Copilot AI commented May 20, 2026

@copilot what about other ops? Can you add them here too?

Added another applicable op update in torchlib: I implemented the boolean-mask aten.index_put lowering using a cumsum/gather/where strategy and added an e2e regression test (test_index_put_bool_mask). This is in commit 64c7ff3.

Copilot AI changed the title torchlib: add aten::bincount ONNX lowering (with minlength) and e2e coverage torchlib: add aten::bincount ONNX lowering and bool-mask aten::index_put fix with e2e coverage May 20, 2026
def _aten_index_put_bool(
self: TReal,
indices: Sequence[BOOL],
indices: Sequence[Optional[Union[INT64, BOOL]]],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot should this have INT64? I don't think so?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — I removed INT64 from that bool helper signature, so it now expects Sequence[Optional[BOOL]] only. Fixed in 6ad7f14.

return op.Where(index, values, self)
del accumulate # Boolean masks index each position at most once.

bool_mask = indices[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot so indices with more than 1 elements is still not supported? What is the reference implementation in huggingface?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct — the bool path was effectively single-mask only. I made that explicit so we fail fast for multi-element indices instead of silently ignoring extras in ceb0881. I also couldn’t find a Hugging Face-specific lowering for this case; the reference we follow here is PyTorch’s torch.onnx.symbolic_opset11.index_put (already linked in the docstring).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Copilot AI commented May 20, 2026

@copilot what about other ops? Can you add them here too?

I added another applicable torchlib op in this PR (aten.index_put bool-mask lowering in 64c7ff3) in addition to aten::bincount. If you want more absorbed in this PR, share the specific op names and I’ll add them in follow-up commits.

@justinchuby
Copy link
Copy Markdown
Collaborator

@copilot huggingface/transformers@98445c3/src/transformers/exporters/exporter_onnx.py#L647 this is the reference. Please update index_put so that it is generalized to multiple masks

@justinchuby justinchuby marked this pull request as draft May 26, 2026 15:11
Copy link
Copy Markdown
Contributor Author

Copilot AI commented May 26, 2026

@copilot huggingface/transformers@98445c3/src/transformers/exporters/exporter_onnx.py#L647 this is the reference. Please update index_put so that it is generalized to multiple masks

Implemented in 7fbee18.

_aten_index_put_bool now handles multiple boolean masks (1-D masks) by converting each mask via NonZero, constructing a multi-column ScatterND index, and applying updates with proper trailing-dimension expansion. I also added e2e coverage in test_index_put_bool_multi_mask.

@titaiwangms titaiwangms self-requested a review June 1, 2026 15:57
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review summary

The added e2e tests pass (3/3) and common bool-mask index_put cases match torch eager. However, there's one verified regression plus several correctness gaps worth addressing.

🔴 Critical

C1. Single bool-mask index_put breaks scalar/broadcast values (verified with a repro).
The new cumsum→gather scheme assigns each selected element a flat value index 0..num_true-1, so when values is a scalar (or shorter-than-count broadcastable tensor), Gather goes out of bounds:

x = torch.zeros(2, 3); mask = torch.tensor([True, False]); v = torch.tensor(5.0)
torch.ops.aten.index_put(x, [mask], v)   # eager: row 0 -> 5
# exported ONNX in ORT: "Gather ... indices element out of data bounds, idx=1 must be within [-1,0]"

This is the very common t[t < 0] = 0 pattern, and a regression from the previous Where(mask, values, self) which broadcast correctly. Suggested fix: Expand values to the selected shape before scatter/where, or retain a Where+broadcast path when values isn't full-length. (Row-mask with full-shape values, and 1-D mask with length-N values, both work correctly.)

C2. bincount on empty input errors instead of returning zeros(minlength).
ReduceMax over an empty tensor throws in ORT, whereas torch.bincount([]) returns a length-minlength zero vector (length 0 when minlength == 0). Suggest gating on Size(self) == 0.

🟠 Major

  • accumulate=True silently dropped in the single-mask path (del accumulate). torch honors tensor[mask] += v. (Pre-existing in the old code, but the rewrite is a good opportunity — at minimum raise NotImplementedError when accumulate=True.)
  • Multi-mask path requires equal True counts. torch broadcasts a count-1 axis; here Concat shape-mismatches. Broadcast the position tensors, or raise a clear error.
  • bincount negative inputs silently miscount (ONNX OneHot wraps negative indices), whereas torch raises. Add a guard or document the limitation.
  • bincount memory is O(N · max(self)) via the dense OneHot, vs torch's O(N + depth). A scatter-add lowering (ScatterElements(reduction="add") into a [depth] zero vector) would avoid the blow-up on large inputs.
  • bincount typed INT64-only is narrower than torch (accepts int8/16/32/uint8). Consider widening to integral types and casting internally.

🟡 Minor

  • CumSum axis and Clip min are passed as rank-1 [..] tensors; the ONNX spec wants 0-D scalars (ORT is lenient and existing code does the same). For reference, OneHot depth as a rank-1 single-element tensor is spec-legal.
  • No 1-D rank check on bincount input; a mask of shape [1] on self [2,3] silently broadcasts where torch would raise a shape mismatch.
  • No TorchLibOpInfo("bincount", ...) entry in ops_test_data.py — the repo convention for parameterized correctness coverage against torch. The single happy-path e2e test (minlength=6) doesn't cover default minlength=0, empty input, or the unsupported weights path.

⚪ Readability

  • Add a 1–2 line comment for the cumsum-minus-one → gather trick and for the NonZeroScatterND index construction — the two hardest-to-decode blocks.
  • Naming: positions vs selected_positions collide in the multi-mask loop; minus_one could be neg_1 to match the file's existing convention; one is over-named.

✅ Confirmed correct

OneHot+ReduceSum int64 output dtype matches torch; the multi-mask ScatterND shape arithmetic is correct for the equal-count case; ScatterND reduction="add" is valid at opset-18; and the same_shape None-hardening is a genuine correctness improvement.

Bottom line: C1 (verified) and C2 hit common real-world patterns and should be resolved before merge. The remaining items are scope-appropriate hardening for an op-lowering PR.

Reviewed with a multi-model review team (correctness, adversarial, spec-adherence, readability) plus local runtime verification against onnxruntime.

@titaiwangms
Copy link
Copy Markdown
Contributor

Review summary

The added e2e tests pass (3/3) and common bool-mask index_put cases match torch eager. However, there's one verified regression plus several correctness gaps worth addressing.

🔴 Critical

C1. Single bool-mask index_put breaks scalar/broadcast values (verified with a repro). The new cumsum→gather scheme assigns each selected element a flat value index 0..num_true-1, so when values is a scalar (or shorter-than-count broadcastable tensor), Gather goes out of bounds:

x = torch.zeros(2, 3); mask = torch.tensor([True, False]); v = torch.tensor(5.0)
torch.ops.aten.index_put(x, [mask], v)   # eager: row 0 -> 5
# exported ONNX in ORT: "Gather ... indices element out of data bounds, idx=1 must be within [-1,0]"

This is the very common t[t < 0] = 0 pattern, and a regression from the previous Where(mask, values, self) which broadcast correctly. Suggested fix: Expand values to the selected shape before scatter/where, or retain a Where+broadcast path when values isn't full-length. (Row-mask with full-shape values, and 1-D mask with length-N values, both work correctly.)

C2. bincount on empty input errors instead of returning zeros(minlength). ReduceMax over an empty tensor throws in ORT, whereas torch.bincount([]) returns a length-minlength zero vector (length 0 when minlength == 0). Suggest gating on Size(self) == 0.

🟠 Major

  • accumulate=True silently dropped in the single-mask path (del accumulate). torch honors tensor[mask] += v. (Pre-existing in the old code, but the rewrite is a good opportunity — at minimum raise NotImplementedError when accumulate=True.)
  • Multi-mask path requires equal True counts. torch broadcasts a count-1 axis; here Concat shape-mismatches. Broadcast the position tensors, or raise a clear error.
  • bincount negative inputs silently miscount (ONNX OneHot wraps negative indices), whereas torch raises. Add a guard or document the limitation.
  • bincount memory is O(N · max(self)) via the dense OneHot, vs torch's O(N + depth). A scatter-add lowering (ScatterElements(reduction="add") into a [depth] zero vector) would avoid the blow-up on large inputs.
  • bincount typed INT64-only is narrower than torch (accepts int8/16/32/uint8). Consider widening to integral types and casting internally.

🟡 Minor

  • CumSum axis and Clip min are passed as rank-1 [..] tensors; the ONNX spec wants 0-D scalars (ORT is lenient and existing code does the same). For reference, OneHot depth as a rank-1 single-element tensor is spec-legal.
  • No 1-D rank check on bincount input; a mask of shape [1] on self [2,3] silently broadcasts where torch would raise a shape mismatch.
  • No TorchLibOpInfo("bincount", ...) entry in ops_test_data.py — the repo convention for parameterized correctness coverage against torch. The single happy-path e2e test (minlength=6) doesn't cover default minlength=0, empty input, or the unsupported weights path.

⚪ Readability

  • Add a 1–2 line comment for the cumsum-minus-one → gather trick and for the NonZeroScatterND index construction — the two hardest-to-decode blocks.
  • Naming: positions vs selected_positions collide in the multi-mask loop; minus_one could be neg_1 to match the file's existing convention; one is over-named.

✅ Confirmed correct

OneHot+ReduceSum int64 output dtype matches torch; the multi-mask ScatterND shape arithmetic is correct for the equal-count case; ScatterND reduction="add" is valid at opset-18; and the same_shape None-hardening is a genuine correctness improvement.

Bottom line: C1 (verified) and C2 hit common real-world patterns and should be resolved before merge. The remaining items are scope-appropriate hardening for an op-lowering PR.

Reviewed with a multi-model review team (correctness, adversarial, spec-adherence, readability) plus local runtime verification against onnxruntime.

@copilot Does this make sense? If so, can you address it?

Copy link
Copy Markdown
Contributor Author

Copilot AI commented Jun 1, 2026

Review summary

The added e2e tests pass (3/3) and common bool-mask index_put cases match torch eager. However, there's one verified regression plus several correctness gaps worth address...

Yes, this makes sense — addressed in the latest commit:

  • C1 (index_put scalar/broadcast values): the single-mask path now uses NonZero + ScatterND, broadcasting values to the selection shape [num_true, *self.shape[mask_rank:]] before scatter. Scalar, short-broadcast, length-num_true, and row-mask full-shape values all work now (verified against numpy/ORT).
  • C2 (bincount empty input): rewrote the lowering as ScatterElements(reduction="add") into a zero vector of length depth, with depth gated to 0 on empty input. bincount([]) -> [] and bincount([], minlength=k) -> zeros(k) now match torch (verified in ORT).
  • accumulate=True: now honored via ScatterND/ScatterElements reduction="add" instead of being silently dropped.
  • bincount memory: the scatter-add lowering is O(N+depth) instead of the dense O(N·max) one-hot.
  • bincount typed INT64-only: widened to integral inputs via an internal Cast.
  • negative inputs / weights: documented as unsupported in the docstring.

Added e2e coverage: scalar-value and row-mask bool index_put, plus default-minlength and empty-input bincount. The multi-mask path's equal-true-count broadcasting remains a known limitation.

Copilot AI requested a review from titaiwangms June 1, 2026 17:09
@titaiwangms titaiwangms closed this Jun 1, 2026
@github-project-automation github-project-automation Bot moved this from Todo to Done in ONNX Script Review Board Jun 1, 2026
@titaiwangms titaiwangms reopened this Jun 1, 2026
@titaiwangms titaiwangms marked this pull request as ready for review June 1, 2026 17:15
@titaiwangms
Copy link
Copy Markdown
Contributor

@copilot Can you fix the lint CI?

Copy link
Copy Markdown
Contributor Author

Copilot AI commented Jun 3, 2026

@copilot Can you fix the lint CI?

Fixed the lint CI — the failure was a RUFF-FORMAT diff on core.py. Applied ruff format and verified both format and ruff check pass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

Update torchlib

4 participants