diff --git a/circle-mlir/circle-mlir/lib/pass/src/ops/ConvTransposeOp.h b/circle-mlir/circle-mlir/lib/pass/src/ops/ConvTransposeOp.h index 3d6516f75be..3f81a77af39 100644 --- a/circle-mlir/circle-mlir/lib/pass/src/ops/ConvTransposeOp.h +++ b/circle-mlir/circle-mlir/lib/pass/src/ops/ConvTransposeOp.h @@ -128,19 +128,30 @@ class ConvConvTranspose : public mlir::OpConversionPattern os_i32; + mlir::SmallVector os_i64; { - int32_t hin = static_cast(inshape[2]); - int32_t win = static_cast(inshape[3]); - int32_t hfs = static_cast(filtershape[2]); - int32_t wfs = static_cast(filtershape[3]); - int32_t hout = (hin - 1) * stride_h + dilation_h * (hfs - 1) + output_padding_h + 1; - int32_t wout = (win - 1) * stride_w + dilation_w * (wfs - 1) + output_padding_w + 1; - int32_t nin = static_cast(inshape[0]); - int32_t ofs = static_cast(filtershape[1]); - os_i32.push_back(nin); - os_i32.push_back(hout); - os_i32.push_back(wout); - os_i32.push_back(ofs); // from IOHW + int64_t dyn = mlir::ShapedType::kDynamic; + int64_t hin = inshape[2]; + int64_t win = inshape[3]; + int64_t hfs = filtershape[2]; + int64_t wfs = filtershape[3]; + int64_t hout = dyn; + int64_t wout = dyn; + int64_t nin = dyn; + int64_t ofs = filtershape[1]; + + if (!mlir::ShapedType::isDynamic(inshape[0])) + nin = inshape[0]; + if (!mlir::ShapedType::isDynamic(inshape[2])) + hout = (hin - 1) * stride_h + dilation_h * (hfs - 1) + output_padding_h + 1; + if (!mlir::ShapedType::isDynamic(inshape[3])) + wout = (win - 1) * stride_w + dilation_w * (wfs - 1) + output_padding_w + 1; + + os_i64 = {nin, hout, wout, ofs}; + os_i32.push_back(static_cast(nin)); + os_i32.push_back(static_cast(hout)); + os_i32.push_back(static_cast(wout)); + os_i32.push_back(static_cast(ofs)); // from IOHW mlir::Location shape_loc = mlir::NameLoc::get(rewriter.getStringAttr(op_name + "/shape")); mlir::Type i32 = rewriter.getI32Type(); @@ -148,7 +159,7 @@ class ConvConvTranspose : public mlir::OpConversionPattern(shape_loc, DenseIntElementsAttr::get(ostype, os_i32)); } - mlir::SmallVector trconv2d_shape({os_i32[0], os_i32[1], os_i32[2], os_i32[3]}); + mlir::SmallVector trconv2d_shape({os_i64[0], os_i64[1], os_i64[2], os_i64[3]}); auto trconv_output_type = mlir::RankedTensorType::get(trconv2d_shape, outtype.getElementType()); mlir::Value trconv2d = rewriter.create( opLoc, trconv_output_type, output_shape, filter_tran, pre_tran, bias, @@ -173,10 +184,10 @@ class ConvConvTranspose : public mlir::OpConversionPattern size_i32; - size_i32.push_back(os_i32[0]); - size_i32.push_back(os_i32[1] - 2 * padsValue[0]); - size_i32.push_back(os_i32[2] - 2 * padsValue[1]); - size_i32.push_back(os_i32[3]); + size_i32.push_back(static_cast(os_i64[0])); + size_i32.push_back(static_cast(os_i64[1]) - 2 * padsValue[0]); + size_i32.push_back(static_cast(os_i64[2]) - 2 * padsValue[1]); + size_i32.push_back(static_cast(os_i64[3])); auto sizeConst = rewriter.create(ss_loc, DenseIntElementsAttr::get(bstype, size_i32)); diff --git a/circle-mlir/circle-mlir/tools-test/circle-impexp-test/test.lst b/circle-mlir/circle-mlir/tools-test/circle-impexp-test/test.lst index afeb1d01aec..8db39811f13 100644 --- a/circle-mlir/circle-mlir/tools-test/circle-impexp-test/test.lst +++ b/circle-mlir/circle-mlir/tools-test/circle-impexp-test/test.lst @@ -48,6 +48,8 @@ AddModel(ConvTranspose2d_F32_R4_op01) AddModel(ConvTranspose2d_F32_R4_p10) AddModel(ConvTranspose2d_F32_R4_p11) AddModel(ConvTranspose2d_F32_R4_p11_nobias) +# AddModel(ConvTranspose2d_F32_R4_unk_bh) --> Does't support dynamic shape output +# AddModel(ConvTranspose2d_F32_R4_unk_bw) --> Does't support dynamic shape output AddModel(Cos_F32_R4) AddModel(CumSum_F32_R4_1) AddModel(CumSum_F32_R4_2) diff --git a/circle-mlir/circle-mlir/tools-test/onnx2circle-models/test.lst b/circle-mlir/circle-mlir/tools-test/onnx2circle-models/test.lst index 05c3f9131ec..a0299208184 100644 --- a/circle-mlir/circle-mlir/tools-test/onnx2circle-models/test.lst +++ b/circle-mlir/circle-mlir/tools-test/onnx2circle-models/test.lst @@ -50,6 +50,8 @@ AddModel(ConvTranspose2d_F32_R4_op01) AddModel(ConvTranspose2d_F32_R4_p10) AddModel(ConvTranspose2d_F32_R4_p11) AddModel(ConvTranspose2d_F32_R4_p11_nobias) +AddModel(ConvTranspose2d_F32_R4_unk_bh) +AddModel(ConvTranspose2d_F32_R4_unk_bw) AddModel(Cos_F32_R4) AddModel(CumSum_F32_R4_1) AddModel(CumSum_F32_R4_2) diff --git a/circle-mlir/circle-mlir/tools-test/onnx2circle-value-test/test.lst b/circle-mlir/circle-mlir/tools-test/onnx2circle-value-test/test.lst index d7feef42a74..e42a1f5f9b2 100644 --- a/circle-mlir/circle-mlir/tools-test/onnx2circle-value-test/test.lst +++ b/circle-mlir/circle-mlir/tools-test/onnx2circle-value-test/test.lst @@ -50,6 +50,8 @@ AddModel(ConvTranspose2d_F32_R4_op01) AddModel(ConvTranspose2d_F32_R4_p10) AddModel(ConvTranspose2d_F32_R4_p11) AddModel(ConvTranspose2d_F32_R4_p11_nobias) +# AddModel(ConvTranspose2d_F32_R4_unk_bh) --> Does't support dynamic shape output +# AddModel(ConvTranspose2d_F32_R4_unk_bw) --> Does't support dynamic shape output AddModel(Cos_F32_R4) AddModel(CumSum_F32_R4_1) AddModel(CumSum_F32_R4_2) diff --git a/circle-mlir/models/unit/ConvTranspose2d_F32_R4_unk_bh/__init__.py b/circle-mlir/models/unit/ConvTranspose2d_F32_R4_unk_bh/__init__.py new file mode 100644 index 00000000000..3a2f97581fb --- /dev/null +++ b/circle-mlir/models/unit/ConvTranspose2d_F32_R4_unk_bh/__init__.py @@ -0,0 +1,35 @@ +import torch + + +# Generate ConvTranspose2d operator with Float32, Rank-4, unknown +# input : [N, 4, H, 10] +# output : [N, 3, H, 10+7] +# dynamic axes: N, H +class net_ConvTranspose2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.op = torch.nn.ConvTranspose2d( + in_channels=4, + out_channels=3, + kernel_size=(1, 8), + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + bias=True, + ) + + def forward(self, input): + return self.op(input) + + def onnx_opset_version(self): + # TODO set to appropriate value + return 14 + + +_model_ = net_ConvTranspose2d() + +_inputs_ = (torch.Tensor(1, 4, 1, 1)) + +_io_names_ = [['input'], ['output']] +_dynamic_axes_ = {"input": {0: "?", 2: "?"}, "output": {0: "?", 2: "?"}} diff --git a/circle-mlir/models/unit/ConvTranspose2d_F32_R4_unk_bw/__init__.py b/circle-mlir/models/unit/ConvTranspose2d_F32_R4_unk_bw/__init__.py new file mode 100644 index 00000000000..ff4d82e020a --- /dev/null +++ b/circle-mlir/models/unit/ConvTranspose2d_F32_R4_unk_bw/__init__.py @@ -0,0 +1,35 @@ +import torch + + +# Generate ConvTranspose2d operator with Float32, Rank-4, unknown +# input : [N, 4, 1, W] +# output : [N, 3, 1, W+7] +# dynamic axes: N, W +class net_ConvTranspose2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.op = torch.nn.ConvTranspose2d( + in_channels=4, + out_channels=3, + kernel_size=(1, 8), + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + bias=True, + ) + + def forward(self, input): + return self.op(input) + + def onnx_opset_version(self): + # TODO set to appropriate value + return 14 + + +_model_ = net_ConvTranspose2d() + +_inputs_ = (torch.Tensor(1, 4, 1, 1)) + +_io_names_ = [['input'], ['output']] +_dynamic_axes_ = {"input": {0: "?", 3: "?"}, "output": {0: "?", 3: "?"}}