Skip to content

Commit 72a3617

Browse files
committed
[passes] Add ConvertMatmulToLinear pass
Let's add convert matmul to linear pass. This commit... refactors mm serialization logic and make convert_matmul_to_linear pass introduces new CompileConfig attribute convert_lhs/rhs_const_mm_to_fc. TICO-DCO-1.0-Signed-off-by: Dayoung Lee <dayoung.lee@samsung.com>
1 parent b9995ae commit 72a3617

File tree

6 files changed

+286
-109
lines changed

6 files changed

+286
-109
lines changed

test/modules/op/mm.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
import torch
1616

17+
from tico.config.v1 import CompileConfigV1
18+
1719
from test.modules.base import TestModuleBase
20+
from test.utils.tag import test_negative, use_onert
1821

1922

2023
class SimpleMatmul(TestModuleBase):
@@ -27,3 +30,64 @@ def forward(self, lhs, rhs):
2730

2831
def get_example_inputs(self):
2932
return (torch.randn(3, 4), torch.randn(4, 5)), {}
33+
34+
35+
class SimpleMatmulConstRhs(TestModuleBase):
36+
def __init__(self):
37+
super().__init__()
38+
self.weight = torch.randn(4, 5)
39+
40+
def forward(self, lhs):
41+
out = torch.mm(lhs, self.weight)
42+
return out
43+
44+
def get_example_inputs(self):
45+
return (torch.randn(3, 4),), {}
46+
47+
48+
@use_onert
49+
class SimpleMatmulConstRhsOnert(TestModuleBase):
50+
def __init__(self):
51+
super().__init__()
52+
self.weight = torch.randn(4, 5)
53+
54+
def forward(self, lhs):
55+
out = torch.mm(lhs, self.weight)
56+
return out
57+
58+
def get_example_inputs(self):
59+
return (torch.randn(3, 4),), {}
60+
61+
62+
@use_onert
63+
@test_negative(expected_err="NNFW_STATUS_ERROR")
64+
class SimpleMatmulConstLhsOnert(TestModuleBase):
65+
""" """
66+
67+
def __init__(self):
68+
super().__init__()
69+
self.weight = torch.randn(3, 4)
70+
71+
def forward(self, rhs):
72+
out = torch.mm(self.weight, rhs)
73+
return out
74+
75+
def get_example_inputs(self):
76+
return (torch.randn(4, 5),), {}
77+
78+
79+
@use_onert
80+
class SimpleMatmulConstLhsOnertWithLinearConversion(TestModuleBase):
81+
def __init__(self):
82+
super().__init__()
83+
self.weight = torch.randn(3, 4)
84+
85+
def forward(self, rhs):
86+
out = torch.mm(self.weight, rhs)
87+
return out
88+
89+
def get_example_inputs(self):
90+
return (torch.randn(4, 5),), {}
91+
92+
def get_compile_config(self):
93+
return CompileConfigV1(convert_lhs_const_mm_to_fc=True)

tico/config/v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
@dataclass
2121
class CompileConfigV1(CompileConfigBase):
2222
legalize_causal_mask_value: bool = False
23+
convert_lhs_const_mm_to_fc: bool = False
24+
convert_rhs_const_mm_to_fc: bool = True
2325

2426
def get(self, name: str):
2527
return super().get(name)
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Optional, TYPE_CHECKING
16+
17+
if TYPE_CHECKING:
18+
import torch.fx
19+
import torch
20+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
21+
from torch.export import ExportedProgram
22+
23+
from tico.utils import logging
24+
from tico.utils.graph import create_node
25+
from tico.utils.passes import PassBase, PassResult
26+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
27+
from tico.utils.validate_args_kwargs import MatmulArgs
28+
29+
30+
class Converter: # type: ignore[empty-body]
31+
def __init__(self):
32+
super().__init__()
33+
34+
def match(self, exported_program, node) -> bool: # type: ignore[empty-body]
35+
return False
36+
37+
def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empty-body]
38+
pass
39+
40+
41+
class ConvertRhsConstMatmulToLinear(Converter):
42+
def __init__(self):
43+
super().__init__()
44+
45+
def match(self, exported_program, node) -> bool:
46+
if not node.target == torch.ops.aten.mm.default:
47+
return False
48+
49+
mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
50+
51+
rhs = mm_args.other
52+
if isinstance(rhs, torch.fx.Node):
53+
if is_lifted_tensor_constant(exported_program, rhs):
54+
return True
55+
elif is_param(exported_program, rhs):
56+
return True
57+
elif is_buffer(exported_program, rhs):
58+
return True
59+
else:
60+
return False
61+
return False
62+
63+
def convert(self, exported_program, node) -> torch.fx.Node:
64+
graph_module = exported_program.graph_module
65+
graph = graph_module.graph
66+
67+
mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
68+
69+
lhs = mm_args.input
70+
rhs = mm_args.other
71+
72+
with graph.inserting_before(node):
73+
transpose_node = create_node(
74+
graph,
75+
torch.ops.aten.permute.default,
76+
args=(rhs, [1, 0]),
77+
)
78+
fc_node = create_node(
79+
graph,
80+
torch.ops.aten.linear.default,
81+
args=(lhs, transpose_node),
82+
)
83+
node.replace_all_uses_with(fc_node, propagate_meta=True)
84+
85+
return fc_node
86+
87+
88+
class ConvertLhsConstMatmulToLinear(Converter):
89+
def __init__(self):
90+
super().__init__()
91+
92+
def match(self, exported_program, node) -> bool:
93+
if not node.target == torch.ops.aten.mm.default:
94+
return False
95+
96+
mm_args = MatmulArgs(*node.args, **node.kwargs)
97+
lhs = mm_args.input
98+
if isinstance(lhs, torch.fx.Node):
99+
if is_lifted_tensor_constant(exported_program, lhs):
100+
return True
101+
elif is_param(exported_program, lhs):
102+
return True
103+
elif is_buffer(exported_program, lhs):
104+
return True
105+
else:
106+
return False
107+
108+
def convert(self, exported_program, node) -> torch.fx.Node:
109+
graph_module = exported_program.graph_module
110+
graph = graph_module.graph
111+
112+
mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
113+
114+
lhs = mm_args.input
115+
rhs = mm_args.other
116+
117+
with graph.inserting_before(node):
118+
transpose_node = create_node(
119+
graph,
120+
torch.ops.aten.permute.default,
121+
args=(rhs, [1, 0]),
122+
)
123+
fc_node = create_node(
124+
graph,
125+
torch.ops.aten.linear.default,
126+
args=(lhs, transpose_node),
127+
)
128+
node.replace_all_uses_with(fc_node, propagate_meta=True)
129+
130+
return fc_node
131+
132+
133+
@trace_graph_diff_on_pass
134+
class ConvertMatmulToLinear(PassBase):
135+
"""
136+
This pass converts matmul to linear selectively
137+
138+
How to select between `matmul` and `linear`?
139+
140+
* Linear has better quantization accuracy (NPU backend)
141+
Due to ONE compiler's quantization policy;
142+
FullyConnected(=Linear) uses per-channel quantization for weight and per-tensor for input.
143+
BatchMatmul(=matmul) uses per-tensor quantization for both rhs and lhs.
144+
145+
* Matmul to Linear requires Transpose, which may harm latency
146+
When RHS is constant, addtional transpose can be folded.
147+
148+
[RHS non-const case]
149+
Constant folding cannot be performed.
150+
151+
lhs rhs (non-const)
152+
| |
153+
| transpose
154+
| |
155+
-- linear --
156+
|
157+
out
158+
159+
[RHS const case]
160+
Constant folding can be performed to
161+
162+
lhs rhs (const) lh rhs (folded const)
163+
| | | |
164+
| transpose | |
165+
| | | |
166+
-- linear -- --> -- linear --
167+
| |
168+
out out
169+
170+
171+
enable_lhs_const: If true, also convert matmul where LHS is constant tensor. Default is False.
172+
enable_rhs_const: If true, also convert matmul where RHS is constant tensor. Default is True.
173+
"""
174+
175+
def __init__(
176+
self,
177+
enable_lhs_const: Optional[bool] = False,
178+
enable_rhs_const: Optional[bool] = True,
179+
):
180+
super().__init__()
181+
self.converters: List[Converter] = []
182+
if enable_lhs_const:
183+
self.converters.append(ConvertLhsConstMatmulToLinear())
184+
if enable_rhs_const:
185+
self.converters.append(ConvertRhsConstMatmulToLinear())
186+
187+
def call(self, exported_program: ExportedProgram) -> PassResult:
188+
logger = logging.getLogger(__name__)
189+
190+
graph_module = exported_program.graph_module
191+
graph = graph_module.graph
192+
modified = False
193+
for node in graph.nodes:
194+
if not node.op == "call_function":
195+
continue
196+
197+
for converter in self.converters:
198+
if not converter.match(exported_program, node):
199+
continue
200+
201+
new_node = converter.convert(exported_program, node)
202+
modified = True
203+
logger.debug(
204+
f"{node.name} is replaced with {new_node.name} operator (permute + linear)"
205+
)
206+
continue
207+
208+
graph.eliminate_dead_code()
209+
graph.lint()
210+
graph_module.recompile()
211+
212+
return PassResult(modified)

tico/passes/convert_to_relu6.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
172172
converter.convert(exported_program, node)
173173
modified = True
174174
logger.debug(f"{node.name} is replaced with ReLU6 operator")
175-
break
175+
continue
176176

177177
graph.eliminate_dead_code()
178178
graph.lint()

0 commit comments

Comments
 (0)