Skip to content

Commit 9573413

Browse files
author
Sanggyu Lee
committed
[serialize] prefer fc over bmm on const lhs
If lhs of matmul is const, emit fullyconnected instead of bmm. TICO-DCO-1.0-Signed-off-by: Sanggyu Lee <[email protected]>
1 parent 07f96a3 commit 9573413

File tree

1 file changed

+90
-13
lines changed

1 file changed

+90
-13
lines changed

tico/serialize/operators/op_bmm.py

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,32 @@ class BatchMatmulVisitor(NodeVisitor):
3434
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
3535
super().__init__(op_codes, graph)
3636

37+
def define_fc_node(self, inputs, outputs) -> circle.Operator.OperatorT:
38+
def set_fc_option(operator):
39+
operator.builtinOptionsType = (
40+
circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions
41+
)
42+
option = circle.FullyConnectedOptions.FullyConnectedOptionsT()
43+
44+
option.fusedActivationFunction = (
45+
circle.ActivationFunctionType.ActivationFunctionType.NONE
46+
)
47+
option.weightsFormat = (
48+
circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT
49+
)
50+
option.keepNumDims = False
51+
option.asymmetricQuantizeInputs = False
52+
option.quantizedBiasType = circle.TensorType.TensorType.FLOAT32
53+
54+
operator.builtinOptions = option
55+
56+
fc_op_index = get_op_index(
57+
circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes
58+
)
59+
operator = create_builtin_operator(self.graph, fc_op_index, inputs, outputs)
60+
set_fc_option(operator)
61+
return operator
62+
3763
def define_node(
3864
self,
3965
node: torch.fx.Node,
@@ -42,21 +68,72 @@ def define_node(
4268
input = args.input
4369
mat2 = args.mat2
4470

45-
op_index = get_op_index(
46-
circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
71+
is_const_tensor = lambda n: (
72+
n.op == "get_attr"
73+
or (
74+
n.op == "placeholder"
75+
and isinstance(n.meta.get("val", None), torch.Tensor)
76+
and not n.meta["val"].requires_grad
77+
)
4778
)
4879

49-
inputs = [input, mat2]
50-
outputs = [node]
80+
lhs, rhs = input, mat2
81+
is_const_lhs = is_const_tensor(lhs)
5182

52-
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
83+
if is_const_lhs:
84+
fc_index = get_op_index(
85+
circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED,
86+
self._op_codes,
87+
)
5388

54-
# Op-specific option
55-
operator.builtinOptionsType = (
56-
circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
57-
)
58-
option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
59-
option.adjointLhs, option.adjointRhs = False, False
60-
operator.builtinOptions = option
89+
rhs_tid = self.graph.get_tid_registered(rhs)
90+
rhs_tensor: circle.Tensor.TensorT = self.graph.tensors[rhs_tid]
91+
rhs_shape = list(rhs_tensor.shape) # [..., batch, in_features]
92+
rhs_dtype = rhs_tensor.type
6193

62-
return operator
94+
# lhs : weight, shape = [..., out_features, in_features]
95+
lhs_tid = self.graph.get_tid_registered(lhs)
96+
lhs_tensor: circle.Tensor.TensorT = self.graph.tensors[lhs_tid]
97+
lhs_shape = list(lhs_tensor.shape)
98+
out_features = lhs_shape[-2]
99+
fc_out_shape = rhs_shape[:-1] + [out_features]
100+
fc_bias = self.graph.add_const_tensor(data=[0.0], source_node=node)
101+
fc_out = self.graph.add_tensor_from_scratch(
102+
prefix=f"{node.name}_fc_out",
103+
shape=fc_out_shape,
104+
shape_signature=fc_out_shape,
105+
dtype=rhs_dtype,
106+
)
107+
108+
fc_inputs = [rhs, lhs, fc_bias] # order: [input, weight]
109+
fc_outputs = [fc_out]
110+
fc_op = self.define_fc_node(fc_inputs, fc_outputs)
111+
self.graph.add_operator(fc_op)
112+
113+
trs_index = get_op_index(
114+
circle.BuiltinOperator.BuiltinOperator.TRANSPOSE,
115+
self._op_codes,
116+
)
117+
118+
perm = list(range(len(fc_out.shape)))
119+
perm[-2], perm[-1] = perm[-1], perm[-2]
120+
perm_tensor = self.graph.add_const_tensor(
121+
data=torch.tensor(perm, dtype=torch.int32), # to prevent int64
122+
)
123+
124+
trs_inputs = [fc_out, perm_tensor]
125+
trs_outputs = [node]
126+
trs_op = create_builtin_operator(
127+
self.graph, trs_index, trs_inputs, trs_outputs
128+
)
129+
130+
return trs_op
131+
132+
bmm_index = get_op_index(
133+
circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL,
134+
self._op_codes,
135+
)
136+
inputs = [lhs, rhs]
137+
outputs = [node]
138+
op = create_builtin_operator(self.graph, bmm_index, inputs, outputs)
139+
return op

0 commit comments

Comments
 (0)