Skip to content

Commit 69f5e92

Browse files
author
Sanggyu Lee
committed
Update tico/utils/validate_args_kwargs.py
1 parent fbbd673 commit 69f5e92

File tree

2 files changed

+30
-15
lines changed

2 files changed

+30
-15
lines changed

tico/serialize/operators/op_attention.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from tico.serialize.operators.hashable_opcode import OpCode
2828
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
2929
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
30-
30+
from tico.utils.validate_args_kwargs import CircleAttentionArgs
3131

3232

3333
def llama_attention_forward_adapter(
@@ -78,20 +78,7 @@ def define_node(
7878
self,
7979
node: torch.fx.Node,
8080
) -> circle.Operator.OperatorT:
81-
(
82-
hidden_states,
83-
wq,
84-
wk,
85-
wv,
86-
wo,
87-
position_cos,
88-
position_sin,
89-
attention_mask,
90-
past_key,
91-
past_value,
92-
cache_position,
93-
) = node.args
94-
81+
args = CircleAttentionArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
9582
op_index = get_op_index(
9683
circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes
9784
)

tico/utils/validate_args_kwargs.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,34 @@ class CatArgs:
171171
dim: int = 0
172172

173173

174+
@enforce_type
175+
@dataclass
176+
class CircleAttentionArgs:
177+
"""
178+
For circle.BuiltinOperator.BuiltinOperator.RMS_NORM
179+
"""
180+
181+
182+
@enforce_type
183+
@dataclass
184+
class CircleAttentionArgs:
185+
"""
186+
For circle.BuiltinOperator.BuiltinOperator.ATTENTION
187+
"""
188+
189+
hidden_states: torch.fx.Node
190+
wq: torch.fx.Node
191+
wk: torch.fx.Node
192+
wv: torch.fx.Node
193+
wo: torch.fx.Node
194+
position_cos: torch.fx.Node
195+
position_sin: torch.fx.Node
196+
attention_mask: torch.fx.Node
197+
past_key: torch.fx.Node
198+
past_value: torch.fx.Node
199+
cache_position: torch.fx.Node
200+
201+
174202
@enforce_type
175203
@dataclass
176204
class CircleRMSNormArgs:

0 commit comments

Comments
 (0)