From ad34109ec535c5b622fd3a70f6e365539a5c8534 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 15:45:35 +0000 Subject: [PATCH 1/8] Initial plan From db985c750a7d2bfe6d0b9580bf3fa4d2a05290ce Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 16:02:12 +0000 Subject: [PATCH 2/8] feat(torchlib): add aten::bincount translation and e2e test Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/bbc26d89-9521-4afa-9885-df627c50aecf Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/core.py | 15 +++++++++++++-- tests/function_libs/torch_lib/e2e_ops_tests.py | 13 +++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 877a83a403..b3789be7e1 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1254,12 +1254,23 @@ def aten_binary_cross_entropy_with_logits( raise NotImplementedError() +@torch_op("aten::bincount", trace_only=True) def aten_bincount( - self: TensorType, weights: Optional[TensorType] = None, minlength: int = 0 + self: INT64, weights: Optional[TensorType] = None, minlength: int = 0 ) -> TensorType: """bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor""" + if weights is not None: + raise NotImplementedError("aten::bincount with weights is not supported.") - raise NotImplementedError() + axis_0 = op.Constant(value_ints=[0]) + one = op.Constant(value_ints=[1]) + max_val = op.Unsqueeze(op.ReduceMax(self, keepdims=0), axis_0) + depth = op.Add(max_val, one) + if minlength > 0: + depth = op.Max(depth, op.Constant(value_ints=[minlength])) + + one_hot = op.OneHot(self, depth, op.Constant(value_ints=[0, 1]), axis=-1) + return op.ReduceSum(one_hot, axis_0, keepdims=0) def aten_binomial( diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 019e6f7fe5..fa2339b0bc 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -79,6 +79,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) _testing.assert_onnx_program(onnx_program) + def test_bincount(self): + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.bincount(x, minlength=6) + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor([0, 1, 1, 3, 5], dtype=torch.int64),), + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_integer_1(self): class Model(torch.nn.Module): def forward(self, x): From 64c7ff3db261289136f2ed25443b0a3c2f60c5bc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 19:25:24 +0000 Subject: [PATCH 3/8] fix(torchlib): add bool-mask index_put lowering and e2e test Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/5ff8e3c5-f1b7-4d47-bbc0-4b7043a2bc61 Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- .../function_libs/torch_lib/ops/core.py | 39 ++++++++++++------- .../function_libs/torch_lib/e2e_ops_tests.py | 18 +++++++++ 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b3789be7e1..af01b4e4dc 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5082,24 +5082,37 @@ def same_shape(other_shape: ir.Shape) -> bool: def _aten_index_put_bool( self: TReal, - indices: Sequence[BOOL], + indices: Sequence[Optional[Union[INT64, BOOL]]], values: TReal, accumulate: bool = False, ) -> TReal: """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" - # TODO: Support indices with more than 1 elements - index = indices[0] - # accumulate should be always False, True does not make sense but an assert would be great - # Reshape indices so it can be properly broadcasted - self_rank = len(self.shape) - index_rank = len(index.shape) - if self_rank > index_rank: - index_shape = op.Shape(index) - padding = op.Constant(value_ints=[1 for _ in range(self_rank - index_rank)]) - padded_shape = op.Concat(index_shape, padding, axis=0) - index = op.Reshape(index, padded_shape) - return op.Where(index, values, self) + del accumulate # Boolean masks index each position at most once. + + bool_mask = indices[0] + if bool_mask is None or bool_mask.dtype != BOOL.dtype: + raise NotImplementedError( + "Boolean index_put expects a boolean mask as the first index." + ) + + for _ in range(len(self.shape) - len(bool_mask.shape)): + bool_mask = op.Unsqueeze(bool_mask, op.Constant(value_ints=[-1])) + + expanded_mask = op.Expand(bool_mask, op.Shape(self)) + flat_mask = op.Reshape(expanded_mask, op.Constant(value_ints=[-1])) + flat_mask_int = op.Cast(flat_mask, to=INT64.dtype) + positions = op.Clip( + op.Sub( + op.CumSum(flat_mask_int, op.Constant(value_ints=[0])), op.Constant(value_ints=[1]) + ), + op.Constant(value_ints=[0]), + ) + flat_values = op.Reshape(values, op.Constant(value_ints=[-1])) + gathered_values = op.Gather(flat_values, positions) + flat_self = op.Reshape(self, op.Constant(value_ints=[-1])) + result = op.Where(flat_mask, gathered_values, flat_self) + return op.Reshape(result, op.Shape(self)) def aten_index_reduce( diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index fa2339b0bc..8c197a4191 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -915,6 +915,24 @@ def forward(self, x, index, update): ) _testing.assert_onnx_program(onnx_program) + def test_index_put_bool_mask(self): + class Model(torch.nn.Module): + def forward(self, x, mask, update): + return torch.ops.aten.index_put(x, [mask], update) + + x = torch.zeros((2, 3), dtype=torch.float32) + mask = torch.tensor([[True, False, True], [False, True, False]], dtype=torch.bool) + update = torch.tensor([10.0, 20.0, 30.0], dtype=torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, mask, update), + input_names=["x", "mask", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + def test_std_mean(self): """Test torch.std_mean which will be decomposed into prims.sum.""" From 6ad7f146370a7c1feb6d7b25abcc5e09882deff5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 19:36:07 +0000 Subject: [PATCH 4/8] chore(torchlib): tighten bool index_put index type annotation Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/95a124b2-bf9a-4a54-bb80-8e10a16a6b53 Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index af01b4e4dc..8c985b89bf 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5082,7 +5082,7 @@ def same_shape(other_shape: ir.Shape) -> bool: def _aten_index_put_bool( self: TReal, - indices: Sequence[Optional[Union[INT64, BOOL]]], + indices: Sequence[Optional[BOOL]], values: TReal, accumulate: bool = False, ) -> TReal: From ceb08814d36f435c5697243b8a93ad63d7ed70d0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 May 2026 19:43:56 +0000 Subject: [PATCH 5/8] fix(torchlib): explicitly reject multi-index bool index_put Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/9b023bd8-af00-478c-bc75-2369c734b5c2 Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8c985b89bf..132205f386 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5090,6 +5090,9 @@ def _aten_index_put_bool( del accumulate # Boolean masks index each position at most once. + if len(indices) != 1: + raise NotImplementedError("Boolean index_put with multiple indices is not supported.") + bool_mask = indices[0] if bool_mask is None or bool_mask.dtype != BOOL.dtype: raise NotImplementedError( From 7fbee18930bc60d803885e6d8a957cfd74ef8011 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 26 May 2026 15:22:46 +0000 Subject: [PATCH 6/8] fix(torchlib): generalize bool index_put to multi-mask --- .../function_libs/torch_lib/ops/core.py | 45 ++++++++++++++++--- .../function_libs/torch_lib/e2e_ops_tests.py | 19 ++++++++ 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 132205f386..48409815df 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4987,8 +4987,14 @@ def is_advanced_index(index): # will invalidate equality-based check. first_shape = indices[advanced_indices[0]].shape - def same_shape(other_shape: ir.Shape) -> bool: - return (not any(d is None for d in other_shape)) and other_shape == first_shape + def same_shape(other_shape: Optional[ir.Shape]) -> bool: + return ( + first_shape is not None + and other_shape is not None + and not any(d is None for d in first_shape) + and not any(d is None for d in other_shape) + and other_shape == first_shape + ) all_same_shape = all(same_shape(indices[i].shape) for i in advanced_indices) if not all_same_shape: @@ -5082,18 +5088,43 @@ def same_shape(other_shape: ir.Shape) -> bool: def _aten_index_put_bool( self: TReal, - indices: Sequence[Optional[BOOL]], + indices: Sequence[Optional[Union[INT64, BOOL]]], values: TReal, accumulate: bool = False, ) -> TReal: """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" - del accumulate # Boolean masks index each position at most once. + bool_mask = indices[0] + if len(indices) > 1: + if any(index is None for index in indices): + raise NotImplementedError( + "Boolean index_put with multiple indices does not support None indices." + ) - if len(indices) != 1: - raise NotImplementedError("Boolean index_put with multiple indices is not supported.") + advanced_indices = [] + selected_positions = [] + minus_one = op.Constant(value_ints=[-1]) + for index in indices: + if index.dtype != BOOL.dtype or len(index.shape) != 1: + raise NotImplementedError( + "Boolean index_put with multiple indices supports only 1-D boolean masks." + ) + positions = op.Reshape(op.Transpose(op.NonZero(index), perm=[1, 0]), minus_one) + selected_positions.append(positions) + advanced_indices.append(op.Unsqueeze(positions, minus_one)) + onnx_index = op.Concat(*advanced_indices, axis=-1) + target_shape = op.Concat( + op.Shape(selected_positions[0]), + op.Slice(op.Shape(self), starts=[len(indices)], ends=[len(self.shape)], axes=[0]), + axis=0, + ) + expanded_values = op.Expand(values, target_shape) + return op.ScatterND( + self, onnx_index, expanded_values, reduction="add" if accumulate else None + ) + + del accumulate # Boolean masks index each position at most once. - bool_mask = indices[0] if bool_mask is None or bool_mask.dtype != BOOL.dtype: raise NotImplementedError( "Boolean index_put expects a boolean mask as the first index." diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 8c197a4191..6e69402714 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -933,6 +933,25 @@ def forward(self, x, mask, update): ) _testing.assert_onnx_program(onnx_program) + def test_index_put_bool_multi_mask(self): + class Model(torch.nn.Module): + def forward(self, x, mask0, mask1, update): + return torch.ops.aten.index_put(x, [mask0, mask1], update) + + x = torch.zeros((3, 4), dtype=torch.float32) + mask0 = torch.tensor([True, False, True], dtype=torch.bool) + mask1 = torch.tensor([True, False, True, False], dtype=torch.bool) + update = torch.tensor([10.0, 20.0], dtype=torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, mask0, mask1, update), + input_names=["x", "mask0", "mask1", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + def test_std_mean(self): """Test torch.std_mean which will be decomposed into prims.sum.""" From 4aff8b77e5d804c806b6105cc209b5e0570e2340 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:06:59 +0000 Subject: [PATCH 7/8] fix(torchlib): harden bincount (empty/scatter) and bool index_put (scalar values, accumulate) --- .../function_libs/torch_lib/ops/core.py | 75 +++++++++++++------ .../function_libs/torch_lib/e2e_ops_tests.py | 62 +++++++++++++++ 2 files changed, 113 insertions(+), 24 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 48409815df..ccea7a2679 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1256,21 +1256,39 @@ def aten_binary_cross_entropy_with_logits( @torch_op("aten::bincount", trace_only=True) def aten_bincount( - self: INT64, weights: Optional[TensorType] = None, minlength: int = 0 + self: IntType, weights: Optional[TensorType] = None, minlength: int = 0 ) -> TensorType: - """bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor""" + """bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor + + ``weights`` is not supported. Negative inputs are rejected by torch and are not + handled here (ONNX integer ops would wrap them around). + """ if weights is not None: raise NotImplementedError("aten::bincount with weights is not supported.") + self = op.Cast(self, to=INT64.dtype) axis_0 = op.Constant(value_ints=[0]) - one = op.Constant(value_ints=[1]) - max_val = op.Unsqueeze(op.ReduceMax(self, keepdims=0), axis_0) - depth = op.Add(max_val, one) + # Append a 0 so ReduceMax is defined even when ``self`` is empty. It only sizes the + # output and never contributes to the counts (the scatter below uses ``self``). + data_max = op.Unsqueeze( + op.ReduceMax(op.Concat(self, op.Constant(value_ints=[0]), axis=0), keepdims=0), + axis_0, + ) + # An empty input yields depth 0, so the output is empty unless ``minlength`` applies. + non_empty = op.Unsqueeze( + op.Cast(op.Greater(op.Size(self), op.Constant(value_int=0)), to=INT64.dtype), + axis_0, + ) + depth = op.Mul(op.Add(data_max, op.Constant(value_ints=[1])), non_empty) if minlength > 0: depth = op.Max(depth, op.Constant(value_ints=[minlength])) - one_hot = op.OneHot(self, depth, op.Constant(value_ints=[0, 1]), axis=-1) - return op.ReduceSum(one_hot, axis_0, keepdims=0) + # Scatter-add 1 for each value into a zero vector of length ``depth``. This uses + # O(N + depth) memory instead of the dense O(N * depth) one-hot, and behaves + # correctly for empty inputs. + zeros = op.Expand(op.Constant(value_int=0), depth) + ones = op.Expand(op.Constant(value_int=1), op.Shape(self)) + return op.ScatterElements(zeros, self, ones, axis=0, reduction="add") def aten_binomial( @@ -5123,30 +5141,39 @@ def _aten_index_put_bool( self, onnx_index, expanded_values, reduction="add" if accumulate else None ) - del accumulate # Boolean masks index each position at most once. - if bool_mask is None or bool_mask.dtype != BOOL.dtype: raise NotImplementedError( "Boolean index_put expects a boolean mask as the first index." ) - for _ in range(len(self.shape) - len(bool_mask.shape)): - bool_mask = op.Unsqueeze(bool_mask, op.Constant(value_ints=[-1])) + neg_1 = op.Constant(value_ints=[-1]) + self_rank = len(self.shape) + mask_rank = len(bool_mask.shape) + + # Expand a lower-rank mask (e.g. a row mask) across the trailing dimensions of self + # so it selects whole slices, then collect the coordinates of every selected element. + # NonZero returns them in row-major order. + expanded_mask = bool_mask + for _ in range(self_rank - mask_rank): + expanded_mask = op.Unsqueeze(expanded_mask, neg_1) + expanded_mask = op.Expand(expanded_mask, op.Shape(self)) + selected_indices = op.Transpose(op.NonZero(expanded_mask), perm=[1, 0]) + + # Broadcast ``values`` to the selection shape ``[num_true, *self.shape[mask_rank:]]`` + # and flatten it to one update per selected element. This keeps scalar and + # broadcastable ``values`` working, matching ``self[mask] = values`` semantics. + num_true = op.ReduceSum( + op.Cast(op.Reshape(bool_mask, neg_1), to=INT64.dtype), keepdims=1 + ) + trailing_shape = op.Slice( + op.Shape(self), starts=[mask_rank], ends=[self_rank], axes=[0] + ) + selection_shape = op.Concat(num_true, trailing_shape, axis=0) + flat_values = op.Reshape(op.Expand(values, selection_shape), neg_1) - expanded_mask = op.Expand(bool_mask, op.Shape(self)) - flat_mask = op.Reshape(expanded_mask, op.Constant(value_ints=[-1])) - flat_mask_int = op.Cast(flat_mask, to=INT64.dtype) - positions = op.Clip( - op.Sub( - op.CumSum(flat_mask_int, op.Constant(value_ints=[0])), op.Constant(value_ints=[1]) - ), - op.Constant(value_ints=[0]), + return op.ScatterND( + self, selected_indices, flat_values, reduction="add" if accumulate else None ) - flat_values = op.Reshape(values, op.Constant(value_ints=[-1])) - gathered_values = op.Gather(flat_values, positions) - flat_self = op.Reshape(self, op.Constant(value_ints=[-1])) - result = op.Where(flat_mask, gathered_values, flat_self) - return op.Reshape(result, op.Shape(self)) def aten_index_reduce( diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 6e69402714..ecaf13335a 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -92,6 +92,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) _testing.assert_onnx_program(onnx_program) + def test_bincount_default_minlength(self): + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.bincount(x) + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor([2, 2, 2], dtype=torch.int64),), + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_bincount_empty_input(self): + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.bincount(x, minlength=4) + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor([], dtype=torch.int64),), + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_integer_1(self): class Model(torch.nn.Module): def forward(self, x): @@ -933,6 +959,42 @@ def forward(self, x, mask, update): ) _testing.assert_onnx_program(onnx_program) + def test_index_put_bool_mask_scalar_value(self): + class Model(torch.nn.Module): + def forward(self, x, mask, update): + return torch.ops.aten.index_put(x, [mask], update) + + x = torch.arange(6, dtype=torch.float32).reshape((2, 3)) + mask = torch.tensor([[True, False, True], [False, True, False]], dtype=torch.bool) + update = torch.tensor(5.0, dtype=torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, mask, update), + input_names=["x", "mask", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_bool_row_mask_scalar_value(self): + class Model(torch.nn.Module): + def forward(self, x, mask, update): + return torch.ops.aten.index_put(x, [mask], update) + + x = torch.arange(6, dtype=torch.float32).reshape((2, 3)) + mask = torch.tensor([True, False], dtype=torch.bool) + update = torch.tensor(7.0, dtype=torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, mask, update), + input_names=["x", "mask", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + def test_index_put_bool_multi_mask(self): class Model(torch.nn.Module): def forward(self, x, mask0, mask1, update): From d8df7950584a018b87ef49c84c63414e36d4f38c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 3 Jun 2026 17:15:14 +0000 Subject: [PATCH 8/8] Apply ruff format to core.py --- onnxscript/function_libs/torch_lib/ops/core.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a10b1d6e2c..95c62a70cb 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5162,12 +5162,8 @@ def _aten_index_put_bool( # Broadcast ``values`` to the selection shape ``[num_true, *self.shape[mask_rank:]]`` # and flatten it to one update per selected element. This keeps scalar and # broadcastable ``values`` working, matching ``self[mask] = values`` semantics. - num_true = op.ReduceSum( - op.Cast(op.Reshape(bool_mask, neg_1), to=INT64.dtype), keepdims=1 - ) - trailing_shape = op.Slice( - op.Shape(self), starts=[mask_rank], ends=[self_rank], axes=[0] - ) + num_true = op.ReduceSum(op.Cast(op.Reshape(bool_mask, neg_1), to=INT64.dtype), keepdims=1) + trailing_shape = op.Slice(op.Shape(self), starts=[mask_rank], ends=[self_rank], axes=[0]) selection_shape = op.Concat(num_true, trailing_shape, axis=0) flat_values = op.Reshape(op.Expand(values, selection_shape), neg_1)