From 889e2ebf6e4d3440fb61f5006d537f592c25a33f Mon Sep 17 00:00:00 2001 From: Seungho Henry Park Date: Tue, 31 Mar 2026 19:54:55 +0900 Subject: [PATCH] [circle-mlir/pass] Revise ConvTransposeOp pass This updates the ONNX `ConvTranspose` operation pass to support dynamic-shape inputs. ONE-DCO-1.0-Signed-off-by: Seungho Henry Park --- .../lib/pass/src/ops/ConvTransposeOp.h | 45 ++++++++++++------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/circle-mlir/circle-mlir/lib/pass/src/ops/ConvTransposeOp.h b/circle-mlir/circle-mlir/lib/pass/src/ops/ConvTransposeOp.h index 3d6516f75be..3f81a77af39 100644 --- a/circle-mlir/circle-mlir/lib/pass/src/ops/ConvTransposeOp.h +++ b/circle-mlir/circle-mlir/lib/pass/src/ops/ConvTransposeOp.h @@ -128,19 +128,30 @@ class ConvConvTranspose : public mlir::OpConversionPattern os_i32; + mlir::SmallVector os_i64; { - int32_t hin = static_cast(inshape[2]); - int32_t win = static_cast(inshape[3]); - int32_t hfs = static_cast(filtershape[2]); - int32_t wfs = static_cast(filtershape[3]); - int32_t hout = (hin - 1) * stride_h + dilation_h * (hfs - 1) + output_padding_h + 1; - int32_t wout = (win - 1) * stride_w + dilation_w * (wfs - 1) + output_padding_w + 1; - int32_t nin = static_cast(inshape[0]); - int32_t ofs = static_cast(filtershape[1]); - os_i32.push_back(nin); - os_i32.push_back(hout); - os_i32.push_back(wout); - os_i32.push_back(ofs); // from IOHW + int64_t dyn = mlir::ShapedType::kDynamic; + int64_t hin = inshape[2]; + int64_t win = inshape[3]; + int64_t hfs = filtershape[2]; + int64_t wfs = filtershape[3]; + int64_t hout = dyn; + int64_t wout = dyn; + int64_t nin = dyn; + int64_t ofs = filtershape[1]; + + if (!mlir::ShapedType::isDynamic(inshape[0])) + nin = inshape[0]; + if (!mlir::ShapedType::isDynamic(inshape[2])) + hout = (hin - 1) * stride_h + dilation_h * (hfs - 1) + output_padding_h + 1; + if (!mlir::ShapedType::isDynamic(inshape[3])) + wout = (win - 1) * stride_w + dilation_w * (wfs - 1) + output_padding_w + 1; + + os_i64 = {nin, hout, wout, ofs}; + os_i32.push_back(static_cast(nin)); + os_i32.push_back(static_cast(hout)); + os_i32.push_back(static_cast(wout)); + os_i32.push_back(static_cast(ofs)); // from IOHW mlir::Location shape_loc = mlir::NameLoc::get(rewriter.getStringAttr(op_name + "/shape")); mlir::Type i32 = rewriter.getI32Type(); @@ -148,7 +159,7 @@ class ConvConvTranspose : public mlir::OpConversionPattern(shape_loc, DenseIntElementsAttr::get(ostype, os_i32)); } - mlir::SmallVector trconv2d_shape({os_i32[0], os_i32[1], os_i32[2], os_i32[3]}); + mlir::SmallVector trconv2d_shape({os_i64[0], os_i64[1], os_i64[2], os_i64[3]}); auto trconv_output_type = mlir::RankedTensorType::get(trconv2d_shape, outtype.getElementType()); mlir::Value trconv2d = rewriter.create( opLoc, trconv_output_type, output_shape, filter_tran, pre_tran, bias, @@ -173,10 +184,10 @@ class ConvConvTranspose : public mlir::OpConversionPattern size_i32; - size_i32.push_back(os_i32[0]); - size_i32.push_back(os_i32[1] - 2 * padsValue[0]); - size_i32.push_back(os_i32[2] - 2 * padsValue[1]); - size_i32.push_back(os_i32[3]); + size_i32.push_back(static_cast(os_i64[0])); + size_i32.push_back(static_cast(os_i64[1]) - 2 * padsValue[0]); + size_i32.push_back(static_cast(os_i64[2]) - 2 * padsValue[1]); + size_i32.push_back(static_cast(os_i64[3])); auto sizeConst = rewriter.create(ss_loc, DenseIntElementsAttr::get(bstype, size_i32));