Skip to content
Open
15 changes: 13 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,12 +1254,23 @@ def aten_binary_cross_entropy_with_logits(
raise NotImplementedError()


@torch_op("aten::bincount", trace_only=True)
def aten_bincount(
self: TensorType, weights: Optional[TensorType] = None, minlength: int = 0
self: INT64, weights: Optional[TensorType] = None, minlength: int = 0
) -> TensorType:
"""bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor"""
if weights is not None:
raise NotImplementedError("aten::bincount with weights is not supported.")

raise NotImplementedError()
axis_0 = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
max_val = op.Unsqueeze(op.ReduceMax(self, keepdims=0), axis_0)
depth = op.Add(max_val, one)
if minlength > 0:
depth = op.Max(depth, op.Constant(value_ints=[minlength]))

one_hot = op.OneHot(self, depth, op.Constant(value_ints=[0, 1]), axis=-1)
return op.ReduceSum(one_hot, axis_0, keepdims=0)


def aten_binomial(
Expand Down
13 changes: 13 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
_testing.assert_onnx_program(onnx_program)

def test_bincount(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.bincount(x, minlength=6)

onnx_program = torch.onnx.export(
Model(),
(torch.tensor([0, 1, 1, 3, 5], dtype=torch.int64),),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

Comment on lines +87 to +94
def test_repeat_interleave_integer_1(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down
Loading