Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from __future__ import annotations

import warnings
from typing import Sequence
from typing import Optional, Sequence

from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_opset import opset19
from onnxscript.onnx_types import FLOAT, INT64

_INT64_MAX = 0x7FFFFFFFFFFFFFFF
Expand Down Expand Up @@ -91,3 +93,36 @@ def torchvision_roi_pool(input, boxes, output_size: Sequence[int], spatial_scale
pooled_shape=(pooled_height, pooled_width),
spatial_scale=spatial_scale,
)


@torch_op("torchvision::deform_conv2d", trace_only=True)
def torchvision_deform_conv2d(
input: TFloat,
weight: TFloat,
offset: TFloat,
mask: Optional[TFloat],
bias: Optional[TFloat],
stride_h: int,
stride_w: int,
pad_h: int,
pad_w: int,
dilation_h: int,
dilation_w: int,
groups: int,
offset_groups: int,
use_mask: bool,
):
"""torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor? mask, Tensor? bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor"""

return opset19.DeformConv(
X=input,
W=weight,
offset=offset,
B=bias,
mask=mask if use_mask else None,
dilations=(dilation_h, dilation_w),
strides=(stride_h, stride_w),
pads=(pad_h, pad_w, pad_h, pad_w),
group=groups,
offset_group=offset_groups,
)
45 changes: 45 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo
import torch
import torchvision
from torch.onnx._internal.exporter import _testing


Expand Down Expand Up @@ -272,6 +273,50 @@ def forward(self, x):
)
_testing.assert_onnx_program(onnx_program)

def test_torchvision_deform_conv2d(self):
class Model(torch.nn.Module):
def forward(self, x, offset, weight, bias):
return torchvision.ops.deform_conv2d(x, offset, weight, bias=bias)

x = torch.randn(1, 2, 5, 5, dtype=torch.float32)
weight = torch.randn(3, 2, 3, 3, dtype=torch.float32)
offset = torch.randn(1, 18, 3, 3, dtype=torch.float32)
bias = torch.randn(3, dtype=torch.float32)

onnx_program = torch.onnx.export(
Model(),
(x, offset, weight, bias),
opset_version=19,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_torchvision_deform_conv2d_with_mask_groups_and_padding(self):
class Model(torch.nn.Module):
def forward(self, x, offset, weight, bias, mask):
return torchvision.ops.deform_conv2d(
x,
offset,
weight,
bias=bias,
padding=(1, 1),
mask=mask,
)

x = torch.randn(1, 4, 5, 5, dtype=torch.float32)
weight = torch.randn(4, 2, 3, 3, dtype=torch.float32)
offset = torch.randn(1, 36, 5, 5, dtype=torch.float32)
bias = torch.randn(4, dtype=torch.float32)
mask = torch.randn(1, 18, 5, 5, dtype=torch.float32)

onnx_program = torch.onnx.export(
Model(),
(x, offset, weight, bias, mask),
opset_version=19,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_dft_axis_promoted_from_attribute_to_input(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down
Loading