@@ -38,7 +38,7 @@ def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empt
3838 pass
3939
4040
41- class ConvertMatmulToLinear (Converter ):
41+ class MatmulToLinearConverter (Converter ):
4242 def __init__ (self ):
4343 super ().__init__ ()
4444
@@ -67,7 +67,7 @@ def convert(self, exported_program, node) -> torch.fx.Node:
6767 return fc_node
6868
6969
70- class ConvertRhsConstMatmulToLinear ( ConvertMatmulToLinear ):
70+ class RhsConstMatmulToLinearConverter ( MatmulToLinearConverter ):
7171 def __init__ (self ):
7272 super ().__init__ ()
7373
@@ -93,7 +93,7 @@ def convert(self, exported_program, node) -> torch.fx.Node:
9393 return super ().convert (exported_program , node )
9494
9595
96- class ConvertLhsConstMatmulToLinear ( ConvertMatmulToLinear ):
96+ class LhsConstMatmulToLinearConverter ( MatmulToLinearConverter ):
9797 def __init__ (self ):
9898 super ().__init__ ()
9999
@@ -112,6 +112,7 @@ def match(self, exported_program, node) -> bool:
112112 return True
113113 else :
114114 return False
115+ return False
115116
116117 def convert (self , exported_program , node ) -> torch .fx .Node :
117118 return super ().convert (exported_program , node )
@@ -167,9 +168,9 @@ def __init__(
167168 super ().__init__ ()
168169 self .converters : List [Converter ] = []
169170 if enable_lhs_const :
170- self .converters .append (ConvertLhsConstMatmulToLinear ())
171+ self .converters .append (LhsConstMatmulToLinearConverter ())
171172 if enable_rhs_const :
172- self .converters .append (ConvertRhsConstMatmulToLinear ())
173+ self .converters .append (RhsConstMatmulToLinearConverter ())
173174
174175 def call (self , exported_program : ExportedProgram ) -> PassResult :
175176 logger = logging .getLogger (__name__ )
0 commit comments