Skip to content
Open
92 changes: 75 additions & 17 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,12 +1254,23 @@ def aten_binary_cross_entropy_with_logits(
raise NotImplementedError()


@torch_op("aten::bincount", trace_only=True)
def aten_bincount(
self: TensorType, weights: Optional[TensorType] = None, minlength: int = 0
self: INT64, weights: Optional[TensorType] = None, minlength: int = 0
) -> TensorType:
"""bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor"""
if weights is not None:
raise NotImplementedError("aten::bincount with weights is not supported.")

raise NotImplementedError()
axis_0 = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
max_val = op.Unsqueeze(op.ReduceMax(self, keepdims=0), axis_0)
depth = op.Add(max_val, one)
if minlength > 0:
depth = op.Max(depth, op.Constant(value_ints=[minlength]))

one_hot = op.OneHot(self, depth, op.Constant(value_ints=[0, 1]), axis=-1)
return op.ReduceSum(one_hot, axis_0, keepdims=0)


def aten_binomial(
Expand Down Expand Up @@ -4976,8 +4987,14 @@ def is_advanced_index(index):
# will invalidate equality-based check.
first_shape = indices[advanced_indices[0]].shape

def same_shape(other_shape: ir.Shape) -> bool:
return (not any(d is None for d in other_shape)) and other_shape == first_shape
def same_shape(other_shape: Optional[ir.Shape]) -> bool:
return (
first_shape is not None
and other_shape is not None
and not any(d is None for d in first_shape)
and not any(d is None for d in other_shape)
and other_shape == first_shape
)

all_same_shape = all(same_shape(indices[i].shape) for i in advanced_indices)
if not all_same_shape:
Expand Down Expand Up @@ -5071,24 +5088,65 @@ def same_shape(other_shape: ir.Shape) -> bool:

def _aten_index_put_bool(
self: TReal,
indices: Sequence[BOOL],
indices: Sequence[Optional[Union[INT64, BOOL]]],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot should this have INT64? I don't think so?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — I removed INT64 from that bool helper signature, so it now expects Sequence[Optional[BOOL]] only. Fixed in 6ad7f14.

values: TReal,
accumulate: bool = False,
) -> TReal:
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""

# TODO: Support indices with more than 1 elements
index = indices[0]
# accumulate should be always False, True does not make sense but an assert would be great
# Reshape indices so it can be properly broadcasted
self_rank = len(self.shape)
index_rank = len(index.shape)
if self_rank > index_rank:
index_shape = op.Shape(index)
padding = op.Constant(value_ints=[1 for _ in range(self_rank - index_rank)])
padded_shape = op.Concat(index_shape, padding, axis=0)
index = op.Reshape(index, padded_shape)
return op.Where(index, values, self)
bool_mask = indices[0]
if len(indices) > 1:
if any(index is None for index in indices):
raise NotImplementedError(
"Boolean index_put with multiple indices does not support None indices."
)

advanced_indices = []
selected_positions = []
minus_one = op.Constant(value_ints=[-1])
for index in indices:
if index.dtype != BOOL.dtype or len(index.shape) != 1:
raise NotImplementedError(
"Boolean index_put with multiple indices supports only 1-D boolean masks."
)
positions = op.Reshape(op.Transpose(op.NonZero(index), perm=[1, 0]), minus_one)
selected_positions.append(positions)
advanced_indices.append(op.Unsqueeze(positions, minus_one))
onnx_index = op.Concat(*advanced_indices, axis=-1)
target_shape = op.Concat(
op.Shape(selected_positions[0]),
op.Slice(op.Shape(self), starts=[len(indices)], ends=[len(self.shape)], axes=[0]),
axis=0,
)
expanded_values = op.Expand(values, target_shape)
return op.ScatterND(
self, onnx_index, expanded_values, reduction="add" if accumulate else None
)

del accumulate # Boolean masks index each position at most once.

if bool_mask is None or bool_mask.dtype != BOOL.dtype:
raise NotImplementedError(
"Boolean index_put expects a boolean mask as the first index."
)

for _ in range(len(self.shape) - len(bool_mask.shape)):
bool_mask = op.Unsqueeze(bool_mask, op.Constant(value_ints=[-1]))

expanded_mask = op.Expand(bool_mask, op.Shape(self))
flat_mask = op.Reshape(expanded_mask, op.Constant(value_ints=[-1]))
flat_mask_int = op.Cast(flat_mask, to=INT64.dtype)
positions = op.Clip(
op.Sub(
op.CumSum(flat_mask_int, op.Constant(value_ints=[0])), op.Constant(value_ints=[1])
),
op.Constant(value_ints=[0]),
)
flat_values = op.Reshape(values, op.Constant(value_ints=[-1]))
gathered_values = op.Gather(flat_values, positions)
flat_self = op.Reshape(self, op.Constant(value_ints=[-1]))
result = op.Where(flat_mask, gathered_values, flat_self)
return op.Reshape(result, op.Shape(self))


def aten_index_reduce(
Expand Down
50 changes: 50 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
_testing.assert_onnx_program(onnx_program)

def test_bincount(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.bincount(x, minlength=6)

onnx_program = torch.onnx.export(
Model(),
(torch.tensor([0, 1, 1, 3, 5], dtype=torch.int64),),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

Comment on lines +87 to +94
def test_repeat_interleave_integer_1(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -902,6 +915,43 @@ def forward(self, x, index, update):
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_bool_mask(self):
class Model(torch.nn.Module):
def forward(self, x, mask, update):
return torch.ops.aten.index_put(x, [mask], update)

x = torch.zeros((2, 3), dtype=torch.float32)
mask = torch.tensor([[True, False, True], [False, True, False]], dtype=torch.bool)
update = torch.tensor([10.0, 20.0, 30.0], dtype=torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x, mask, update),
input_names=["x", "mask", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_bool_multi_mask(self):
class Model(torch.nn.Module):
def forward(self, x, mask0, mask1, update):
return torch.ops.aten.index_put(x, [mask0, mask1], update)

x = torch.zeros((3, 4), dtype=torch.float32)
mask0 = torch.tensor([True, False, True], dtype=torch.bool)
mask1 = torch.tensor([True, False, True, False], dtype=torch.bool)
update = torch.tensor([10.0, 20.0], dtype=torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x, mask0, mask1, update),
input_names=["x", "mask0", "mask1", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_std_mean(self):
"""Test torch.std_mean which will be decomposed into prims.sum."""

Expand Down