diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d478af87fc..95c62a70cb 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1254,12 +1254,41 @@ def aten_binary_cross_entropy_with_logits( raise NotImplementedError() +@torch_op("aten::bincount", trace_only=True) def aten_bincount( - self: TensorType, weights: Optional[TensorType] = None, minlength: int = 0 + self: IntType, weights: Optional[TensorType] = None, minlength: int = 0 ) -> TensorType: - """bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor""" + """bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor - raise NotImplementedError() + ``weights`` is not supported. Negative inputs are rejected by torch and are not + handled here (ONNX integer ops would wrap them around). + """ + if weights is not None: + raise NotImplementedError("aten::bincount with weights is not supported.") + + self = op.Cast(self, to=INT64.dtype) + axis_0 = op.Constant(value_ints=[0]) + # Append a 0 so ReduceMax is defined even when ``self`` is empty. It only sizes the + # output and never contributes to the counts (the scatter below uses ``self``). + data_max = op.Unsqueeze( + op.ReduceMax(op.Concat(self, op.Constant(value_ints=[0]), axis=0), keepdims=0), + axis_0, + ) + # An empty input yields depth 0, so the output is empty unless ``minlength`` applies. + non_empty = op.Unsqueeze( + op.Cast(op.Greater(op.Size(self), op.Constant(value_int=0)), to=INT64.dtype), + axis_0, + ) + depth = op.Mul(op.Add(data_max, op.Constant(value_ints=[1])), non_empty) + if minlength > 0: + depth = op.Max(depth, op.Constant(value_ints=[minlength])) + + # Scatter-add 1 for each value into a zero vector of length ``depth``. This uses + # O(N + depth) memory instead of the dense O(N * depth) one-hot, and behaves + # correctly for empty inputs. + zeros = op.Expand(op.Constant(value_int=0), depth) + ones = op.Expand(op.Constant(value_int=1), op.Shape(self)) + return op.ScatterElements(zeros, self, ones, axis=0, reduction="add") def aten_binomial( @@ -4976,8 +5005,14 @@ def is_advanced_index(index): # will invalidate equality-based check. first_shape = indices[advanced_indices[0]].shape - def same_shape(other_shape: ir.Shape) -> bool: - return (not any(d is None for d in other_shape)) and other_shape == first_shape + def same_shape(other_shape: Optional[ir.Shape]) -> bool: + return ( + first_shape is not None + and other_shape is not None + and not any(d is None for d in first_shape) + and not any(d is None for d in other_shape) + and other_shape == first_shape + ) all_same_shape = all(same_shape(indices[i].shape) for i in advanced_indices) if not all_same_shape: @@ -5071,24 +5106,70 @@ def same_shape(other_shape: ir.Shape) -> bool: def _aten_index_put_bool( self: TReal, - indices: Sequence[BOOL], + indices: Sequence[Optional[Union[INT64, BOOL]]], values: TReal, accumulate: bool = False, ) -> TReal: """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" - # TODO: Support indices with more than 1 elements - index = indices[0] - # accumulate should be always False, True does not make sense but an assert would be great - # Reshape indices so it can be properly broadcasted + bool_mask = indices[0] + if len(indices) > 1: + if any(index is None for index in indices): + raise NotImplementedError( + "Boolean index_put with multiple indices does not support None indices." + ) + + advanced_indices = [] + selected_positions = [] + minus_one = op.Constant(value_ints=[-1]) + for index in indices: + if index.dtype != BOOL.dtype or len(index.shape) != 1: + raise NotImplementedError( + "Boolean index_put with multiple indices supports only 1-D boolean masks." + ) + positions = op.Reshape(op.Transpose(op.NonZero(index), perm=[1, 0]), minus_one) + selected_positions.append(positions) + advanced_indices.append(op.Unsqueeze(positions, minus_one)) + onnx_index = op.Concat(*advanced_indices, axis=-1) + target_shape = op.Concat( + op.Shape(selected_positions[0]), + op.Slice(op.Shape(self), starts=[len(indices)], ends=[len(self.shape)], axes=[0]), + axis=0, + ) + expanded_values = op.Expand(values, target_shape) + return op.ScatterND( + self, onnx_index, expanded_values, reduction="add" if accumulate else None + ) + + if bool_mask is None or bool_mask.dtype != BOOL.dtype: + raise NotImplementedError( + "Boolean index_put expects a boolean mask as the first index." + ) + + neg_1 = op.Constant(value_ints=[-1]) self_rank = len(self.shape) - index_rank = len(index.shape) - if self_rank > index_rank: - index_shape = op.Shape(index) - padding = op.Constant(value_ints=[1 for _ in range(self_rank - index_rank)]) - padded_shape = op.Concat(index_shape, padding, axis=0) - index = op.Reshape(index, padded_shape) - return op.Where(index, values, self) + mask_rank = len(bool_mask.shape) + + # Expand a lower-rank mask (e.g. a row mask) across the trailing dimensions of self + # so it selects whole slices, then collect the coordinates of every selected element. + # NonZero returns them in row-major order. + expanded_mask = bool_mask + for _ in range(self_rank - mask_rank): + expanded_mask = op.Unsqueeze(expanded_mask, neg_1) + expanded_mask = op.Expand(expanded_mask, op.Shape(self)) + selected_indices = op.Transpose(op.NonZero(expanded_mask), perm=[1, 0]) + + # Broadcast ``values`` to the selection shape ``[num_true, *self.shape[mask_rank:]]`` + # and flatten it to one update per selected element. This keeps scalar and + # broadcastable ``values`` working, matching ``self[mask] = values`` semantics. + num_true = op.ReduceSum(op.Cast(op.Reshape(bool_mask, neg_1), to=INT64.dtype), keepdims=1) + trailing_shape = op.Slice(op.Shape(self), starts=[mask_rank], ends=[self_rank], axes=[0]) + selection_shape = op.Concat(num_true, trailing_shape, axis=0) + flat_values = op.Reshape(op.Expand(values, selection_shape), neg_1) + + return op.ScatterND( + self, selected_indices, flat_values, reduction="add" if accumulate else None + ) def aten_index_reduce( diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 019e6f7fe5..ecaf13335a 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -79,6 +79,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) _testing.assert_onnx_program(onnx_program) + def test_bincount(self): + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.bincount(x, minlength=6) + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor([0, 1, 1, 3, 5], dtype=torch.int64),), + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_bincount_default_minlength(self): + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.bincount(x) + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor([2, 2, 2], dtype=torch.int64),), + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_bincount_empty_input(self): + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.bincount(x, minlength=4) + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor([], dtype=torch.int64),), + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_integer_1(self): class Model(torch.nn.Module): def forward(self, x): @@ -902,6 +941,79 @@ def forward(self, x, index, update): ) _testing.assert_onnx_program(onnx_program) + def test_index_put_bool_mask(self): + class Model(torch.nn.Module): + def forward(self, x, mask, update): + return torch.ops.aten.index_put(x, [mask], update) + + x = torch.zeros((2, 3), dtype=torch.float32) + mask = torch.tensor([[True, False, True], [False, True, False]], dtype=torch.bool) + update = torch.tensor([10.0, 20.0, 30.0], dtype=torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, mask, update), + input_names=["x", "mask", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_bool_mask_scalar_value(self): + class Model(torch.nn.Module): + def forward(self, x, mask, update): + return torch.ops.aten.index_put(x, [mask], update) + + x = torch.arange(6, dtype=torch.float32).reshape((2, 3)) + mask = torch.tensor([[True, False, True], [False, True, False]], dtype=torch.bool) + update = torch.tensor(5.0, dtype=torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, mask, update), + input_names=["x", "mask", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_bool_row_mask_scalar_value(self): + class Model(torch.nn.Module): + def forward(self, x, mask, update): + return torch.ops.aten.index_put(x, [mask], update) + + x = torch.arange(6, dtype=torch.float32).reshape((2, 3)) + mask = torch.tensor([True, False], dtype=torch.bool) + update = torch.tensor(7.0, dtype=torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, mask, update), + input_names=["x", "mask", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_bool_multi_mask(self): + class Model(torch.nn.Module): + def forward(self, x, mask0, mask1, update): + return torch.ops.aten.index_put(x, [mask0, mask1], update) + + x = torch.zeros((3, 4), dtype=torch.float32) + mask0 = torch.tensor([True, False, True], dtype=torch.bool) + mask1 = torch.tensor([True, False, True, False], dtype=torch.bool) + update = torch.tensor([10.0, 20.0], dtype=torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, mask0, mask1, update), + input_names=["x", "mask0", "mask1", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + def test_std_mean(self): """Test torch.std_mean which will be decomposed into prims.sum."""