File tree Expand file tree Collapse file tree 2 files changed +30
-15
lines changed
Expand file tree Collapse file tree 2 files changed +30
-15
lines changed Original file line number Diff line number Diff line change 2727from tico .serialize .operators .hashable_opcode import OpCode
2828from tico .serialize .operators .node_visitor import NodeVisitor , register_node_visitor
2929from tico .serialize .operators .utils import create_builtin_operator , get_op_index
30-
30+ from tico . utils . validate_args_kwargs import CircleAttentionArgs
3131
3232
3333def llama_attention_forward_adapter (
@@ -78,20 +78,7 @@ def define_node(
7878 self ,
7979 node : torch .fx .Node ,
8080 ) -> circle .Operator .OperatorT :
81- (
82- hidden_states ,
83- wq ,
84- wk ,
85- wv ,
86- wo ,
87- position_cos ,
88- position_sin ,
89- attention_mask ,
90- past_key ,
91- past_value ,
92- cache_position ,
93- ) = node .args
94-
81+ args = CircleAttentionArgs (* node .args , ** node .kwargs ) # type: ignore[arg-type]
9582 op_index = get_op_index (
9683 circle .BuiltinOperator .BuiltinOperator .ATTENTION , self ._op_codes
9784 )
Original file line number Diff line number Diff line change @@ -171,6 +171,34 @@ class CatArgs:
171171 dim : int = 0
172172
173173
174+ @enforce_type
175+ @dataclass
176+ class CircleAttentionArgs :
177+ """
178+ For circle.BuiltinOperator.BuiltinOperator.RMS_NORM
179+ """
180+
181+
182+ @enforce_type
183+ @dataclass
184+ class CircleAttentionArgs :
185+ """
186+ For circle.BuiltinOperator.BuiltinOperator.ATTENTION
187+ """
188+
189+ hidden_states : torch .fx .Node
190+ wq : torch .fx .Node
191+ wk : torch .fx .Node
192+ wv : torch .fx .Node
193+ wo : torch .fx .Node
194+ position_cos : torch .fx .Node
195+ position_sin : torch .fx .Node
196+ attention_mask : torch .fx .Node
197+ past_key : torch .fx .Node
198+ past_value : torch .fx .Node
199+ cache_position : torch .fx .Node
200+
201+
174202@enforce_type
175203@dataclass
176204class CircleRMSNormArgs :
You can’t perform that action at this time.
0 commit comments