Skip to content
Open
54 changes: 39 additions & 15 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,12 +1254,23 @@ def aten_binary_cross_entropy_with_logits(
raise NotImplementedError()


@torch_op("aten::bincount", trace_only=True)
def aten_bincount(
self: TensorType, weights: Optional[TensorType] = None, minlength: int = 0
self: INT64, weights: Optional[TensorType] = None, minlength: int = 0
) -> TensorType:
"""bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor"""
if weights is not None:
raise NotImplementedError("aten::bincount with weights is not supported.")

raise NotImplementedError()
axis_0 = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
max_val = op.Unsqueeze(op.ReduceMax(self, keepdims=0), axis_0)
depth = op.Add(max_val, one)
if minlength > 0:
depth = op.Max(depth, op.Constant(value_ints=[minlength]))

one_hot = op.OneHot(self, depth, op.Constant(value_ints=[0, 1]), axis=-1)
return op.ReduceSum(one_hot, axis_0, keepdims=0)


def aten_binomial(
Expand Down Expand Up @@ -5071,24 +5082,37 @@ def same_shape(other_shape: ir.Shape) -> bool:

def _aten_index_put_bool(
self: TReal,
indices: Sequence[BOOL],
indices: Sequence[Optional[Union[INT64, BOOL]]],

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot should this have INT64? I don't think so?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — I removed INT64 from that bool helper signature, so it now expects Sequence[Optional[BOOL]] only. Fixed in 6ad7f14.

values: TReal,
accumulate: bool = False,
) -> TReal:
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""

# TODO: Support indices with more than 1 elements
index = indices[0]
# accumulate should be always False, True does not make sense but an assert would be great
# Reshape indices so it can be properly broadcasted
self_rank = len(self.shape)
index_rank = len(index.shape)
if self_rank > index_rank:
index_shape = op.Shape(index)
padding = op.Constant(value_ints=[1 for _ in range(self_rank - index_rank)])
padded_shape = op.Concat(index_shape, padding, axis=0)
index = op.Reshape(index, padded_shape)
return op.Where(index, values, self)
del accumulate # Boolean masks index each position at most once.

bool_mask = indices[0]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot so indices with more than 1 elements is still not supported? What is the reference implementation in huggingface?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct — the bool path was effectively single-mask only. I made that explicit so we fail fast for multi-element indices instead of silently ignoring extras in ceb0881. I also couldn’t find a Hugging Face-specific lowering for this case; the reference we follow here is PyTorch’s torch.onnx.symbolic_opset11.index_put (already linked in the docstring).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if bool_mask is None or bool_mask.dtype != BOOL.dtype:
raise NotImplementedError(
"Boolean index_put expects a boolean mask as the first index."
)

for _ in range(len(self.shape) - len(bool_mask.shape)):
bool_mask = op.Unsqueeze(bool_mask, op.Constant(value_ints=[-1]))

expanded_mask = op.Expand(bool_mask, op.Shape(self))
flat_mask = op.Reshape(expanded_mask, op.Constant(value_ints=[-1]))
flat_mask_int = op.Cast(flat_mask, to=INT64.dtype)
positions = op.Clip(
op.Sub(
op.CumSum(flat_mask_int, op.Constant(value_ints=[0])), op.Constant(value_ints=[1])
),
op.Constant(value_ints=[0]),
)
flat_values = op.Reshape(values, op.Constant(value_ints=[-1]))
gathered_values = op.Gather(flat_values, positions)
flat_self = op.Reshape(self, op.Constant(value_ints=[-1]))
result = op.Where(flat_mask, gathered_values, flat_self)
return op.Reshape(result, op.Shape(self))


def aten_index_reduce(
Expand Down
31 changes: 31 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
_testing.assert_onnx_program(onnx_program)

def test_bincount(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.bincount(x, minlength=6)

onnx_program = torch.onnx.export(
Model(),
(torch.tensor([0, 1, 1, 3, 5], dtype=torch.int64),),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

Comment on lines +87 to +94
def test_repeat_interleave_integer_1(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -902,6 +915,24 @@ def forward(self, x, index, update):
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_bool_mask(self):
class Model(torch.nn.Module):
def forward(self, x, mask, update):
return torch.ops.aten.index_put(x, [mask], update)

x = torch.zeros((2, 3), dtype=torch.float32)
mask = torch.tensor([[True, False, True], [False, True, False]], dtype=torch.bool)
update = torch.tensor([10.0, 20.0, 30.0], dtype=torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x, mask, update),
input_names=["x", "mask", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_std_mean(self):
"""Test torch.std_mean which will be decomposed into prims.sum."""

Expand Down
Loading