@@ -34,6 +34,32 @@ class BatchMatmulVisitor(NodeVisitor):
3434 def __init__ (self , op_codes : Dict [OpCode , int ], graph : CircleSubgraph ):
3535 super ().__init__ (op_codes , graph )
3636
37+ def define_fc_node (self , inputs , outputs ) -> circle .Operator .OperatorT :
38+ def set_fc_option (operator ):
39+ operator .builtinOptionsType = (
40+ circle .BuiltinOptions .BuiltinOptions .FullyConnectedOptions
41+ )
42+ option = circle .FullyConnectedOptions .FullyConnectedOptionsT ()
43+
44+ option .fusedActivationFunction = (
45+ circle .ActivationFunctionType .ActivationFunctionType .NONE
46+ )
47+ option .weightsFormat = (
48+ circle .FullyConnectedOptionsWeightsFormat .FullyConnectedOptionsWeightsFormat .DEFAULT
49+ )
50+ option .keepNumDims = False
51+ option .asymmetricQuantizeInputs = False
52+ option .quantizedBiasType = circle .TensorType .TensorType .FLOAT32
53+
54+ operator .builtinOptions = option
55+
56+ fc_op_index = get_op_index (
57+ circle .BuiltinOperator .BuiltinOperator .FULLY_CONNECTED , self ._op_codes
58+ )
59+ operator = create_builtin_operator (self .graph , fc_op_index , inputs , outputs )
60+ set_fc_option (operator )
61+ return operator
62+
3763 def define_node (
3864 self ,
3965 node : torch .fx .Node ,
@@ -42,21 +68,72 @@ def define_node(
4268 input = args .input
4369 mat2 = args .mat2
4470
45- op_index = get_op_index (
46- circle .BuiltinOperator .BuiltinOperator .BATCH_MATMUL , self ._op_codes
71+ is_const_tensor = lambda n : (
72+ n .op == "get_attr"
73+ or (
74+ n .op == "placeholder"
75+ and isinstance (n .meta .get ("val" , None ), torch .Tensor )
76+ and not n .meta ["val" ].requires_grad
77+ )
4778 )
4879
49- inputs = [ input , mat2 ]
50- outputs = [ node ]
80+ lhs , rhs = input , mat2
81+ is_const_lhs = is_const_tensor ( lhs )
5182
52- operator = create_builtin_operator (self .graph , op_index , inputs , outputs )
83+ if is_const_lhs :
84+ fc_index = get_op_index (
85+ circle .BuiltinOperator .BuiltinOperator .FULLY_CONNECTED ,
86+ self ._op_codes ,
87+ )
5388
54- # Op-specific option
55- operator .builtinOptionsType = (
56- circle .BuiltinOptions .BuiltinOptions .BatchMatMulOptions
57- )
58- option = circle .BatchMatMulOptions .BatchMatMulOptionsT ()
59- option .adjointLhs , option .adjointRhs = False , False
60- operator .builtinOptions = option
89+ rhs_tid = self .graph .get_tid_registered (rhs )
90+ rhs_tensor : circle .Tensor .TensorT = self .graph .tensors [rhs_tid ]
91+ rhs_shape = list (rhs_tensor .shape ) # [..., batch, in_features]
92+ rhs_dtype = rhs_tensor .type
6193
62- return operator
94+ # lhs : weight, shape = [..., out_features, in_features]
95+ lhs_tid = self .graph .get_tid_registered (lhs )
96+ lhs_tensor : circle .Tensor .TensorT = self .graph .tensors [lhs_tid ]
97+ lhs_shape = list (lhs_tensor .shape )
98+ out_features = lhs_shape [- 2 ]
99+ fc_out_shape = rhs_shape [:- 1 ] + [out_features ]
100+ fc_bias = self .graph .add_const_tensor (data = [0.0 ], source_node = node )
101+ fc_out = self .graph .add_tensor_from_scratch (
102+ prefix = f"{ node .name } _fc_out" ,
103+ shape = fc_out_shape ,
104+ shape_signature = fc_out_shape ,
105+ dtype = rhs_dtype ,
106+ )
107+
108+ fc_inputs = [rhs , lhs , fc_bias ] # order: [input, weight]
109+ fc_outputs = [fc_out ]
110+ fc_op = self .define_fc_node (fc_inputs , fc_outputs )
111+ self .graph .add_operator (fc_op )
112+
113+ trs_index = get_op_index (
114+ circle .BuiltinOperator .BuiltinOperator .TRANSPOSE ,
115+ self ._op_codes ,
116+ )
117+
118+ perm = list (range (len (fc_out .shape )))
119+ perm [- 2 ], perm [- 1 ] = perm [- 1 ], perm [- 2 ]
120+ perm_tensor = self .graph .add_const_tensor (
121+ data = torch .tensor (perm , dtype = torch .int32 ), # to prevent int64
122+ )
123+
124+ trs_inputs = [fc_out , perm_tensor ]
125+ trs_outputs = [node ]
126+ trs_op = create_builtin_operator (
127+ self .graph , trs_index , trs_inputs , trs_outputs
128+ )
129+
130+ return trs_op
131+
132+ bmm_index = get_op_index (
133+ circle .BuiltinOperator .BuiltinOperator .BATCH_MATMUL ,
134+ self ._op_codes ,
135+ )
136+ inputs = [lhs , rhs ]
137+ outputs = [node ]
138+ op = create_builtin_operator (self .graph , bmm_index , inputs , outputs )
139+ return op
0 commit comments