Skip to content

Commit 9027726

Browse files
committed
Fix
1 parent 09efb88 commit 9027726

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tico/passes/convert_matmul_to_linear.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empt
3838
pass
3939

4040

41-
class ConvertMatmulToLinear(Converter):
41+
class MatmulToLinearConverter(Converter):
4242
def __init__(self):
4343
super().__init__()
4444

@@ -67,7 +67,7 @@ def convert(self, exported_program, node) -> torch.fx.Node:
6767
return fc_node
6868

6969

70-
class ConvertRhsConstMatmulToLinear(ConvertMatmulToLinear):
70+
class RhsConstMatmulToLinearConverter(MatmulToLinearConverter):
7171
def __init__(self):
7272
super().__init__()
7373

@@ -93,7 +93,7 @@ def convert(self, exported_program, node) -> torch.fx.Node:
9393
return super().convert(exported_program, node)
9494

9595

96-
class ConvertLhsConstMatmulToLinear(ConvertMatmulToLinear):
96+
class LhsConstMatmulToLinearConverter(MatmulToLinearConverter):
9797
def __init__(self):
9898
super().__init__()
9999

@@ -112,6 +112,7 @@ def match(self, exported_program, node) -> bool:
112112
return True
113113
else:
114114
return False
115+
return False
115116

116117
def convert(self, exported_program, node) -> torch.fx.Node:
117118
return super().convert(exported_program, node)
@@ -167,9 +168,9 @@ def __init__(
167168
super().__init__()
168169
self.converters: List[Converter] = []
169170
if enable_lhs_const:
170-
self.converters.append(ConvertLhsConstMatmulToLinear())
171+
self.converters.append(LhsConstMatmulToLinearConverter())
171172
if enable_rhs_const:
172-
self.converters.append(ConvertRhsConstMatmulToLinear())
173+
self.converters.append(RhsConstMatmulToLinearConverter())
173174

174175
def call(self, exported_program: ExportedProgram) -> PassResult:
175176
logger = logging.getLogger(__name__)

0 commit comments

Comments
 (0)