Skip to content

Commit d712ce9

Browse files
committed
feat: Add Qwen3.5 support to Megatron backend
1 parent 0211221 commit d712ce9

7 files changed

Lines changed: 1552 additions & 320 deletions

File tree

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
import math
2+
3+
from megatron.bridge.models.conversion.model_bridge import MegatronWeightTuple
4+
from megatron.bridge.models.conversion.peft_bridge import AdapterWeight
5+
from megatron.core.transformer.module import MegatronModule
6+
from megatron.core.transformer.transformer_layer import TransformerLayer
7+
import torch
8+
9+
from art.megatron.lora import (
10+
GatedDeltaNetInProjLoRA,
11+
LoRA,
12+
MLPExpertsLinearFC1LoRA,
13+
MLPExpertsLinearFC2LoRA,
14+
SelfAttentionLinearProjLoRA,
15+
SelfAttentionLinearQKVLoRA,
16+
SharedExpertsLinearFC1LoRA,
17+
SharedExpertsLinearFC2LoRA,
18+
)
19+
20+
21+
def _is_language_transformer_layer_name(module_name: str) -> bool:
22+
while module_name.startswith("module."):
23+
module_name = module_name.removeprefix("module.")
24+
return module_name.startswith(("decoder.layers.", "language_model.decoder.layers."))
25+
26+
27+
def _adapter_alpha_dim(lora: LoRA) -> tuple[int, int]:
28+
dim = int(lora.A_T.shape[-1])
29+
alpha = float(lora.scale) * dim
30+
rounded_alpha = round(alpha)
31+
assert math.isclose(alpha, rounded_alpha)
32+
return rounded_alpha, dim
33+
34+
35+
def _adapter_tensors(lora: LoRA, expert_idx: int | None = None) -> tuple[torch.Tensor, torch.Tensor]:
36+
a_t = lora.A_T if expert_idx is None else lora.A_T[expert_idx]
37+
b_t = lora.B_T if expert_idx is None else lora.B_T[expert_idx]
38+
return a_t.transpose(-1, -2).contiguous(), b_t.transpose(-1, -2).contiguous()
39+
40+
41+
def _adapter_param_prefix(base_prefix: str, adapter_key: str | None) -> str:
42+
if adapter_key is None:
43+
return f"{base_prefix}.adapter"
44+
return f"{base_prefix}.adapter.{adapter_key}"
45+
46+
47+
def _adapter_weight(
48+
*,
49+
base_prefix: str,
50+
adapter_key: str | None,
51+
alpha: int,
52+
dim: int,
53+
linear_in: torch.Tensor,
54+
linear_out: torch.Tensor,
55+
) -> AdapterWeight:
56+
param_prefix = _adapter_param_prefix(base_prefix, adapter_key)
57+
return AdapterWeight(
58+
global_base_prefix=base_prefix,
59+
adapter_key=adapter_key,
60+
alpha=alpha,
61+
dim=dim,
62+
linear_in_weight=MegatronWeightTuple(
63+
param_name=f"{param_prefix}.linear_in.weight",
64+
weight=linear_in,
65+
vp_stage=0,
66+
),
67+
linear_out_weight=MegatronWeightTuple(
68+
param_name=f"{param_prefix}.linear_out.weight",
69+
weight=linear_out,
70+
vp_stage=0,
71+
),
72+
)
73+
74+
75+
def _simple_adapter_weight(
76+
base_prefix: str,
77+
lora: LoRA,
78+
*,
79+
adapter_key: str | None = None,
80+
expert_idx: int | None = None,
81+
) -> AdapterWeight:
82+
alpha, dim = _adapter_alpha_dim(lora)
83+
linear_in, linear_out = _adapter_tensors(lora, expert_idx)
84+
return _adapter_weight(
85+
base_prefix=base_prefix,
86+
adapter_key=adapter_key,
87+
alpha=alpha,
88+
dim=dim,
89+
linear_in=linear_in,
90+
linear_out=linear_out,
91+
)
92+
93+
94+
def _fused_gdn_adapter_weight(
95+
base_prefix: str,
96+
handler: GatedDeltaNetInProjLoRA,
97+
) -> AdapterWeight:
98+
qkv_linear_in, qkv_linear_out = _adapter_tensors(handler.qkv_lora)
99+
z_linear_in, z_linear_out = _adapter_tensors(handler.z_lora)
100+
assert math.isclose(float(handler.qkv_lora.scale), float(handler.z_lora.scale))
101+
total_dim = int(qkv_linear_in.shape[0] + z_linear_in.shape[0])
102+
alpha = round(float(handler.qkv_lora.scale) * total_dim)
103+
104+
qkv_rank = int(qkv_linear_in.shape[0])
105+
z_rank = int(z_linear_in.shape[0])
106+
qkv_out = int(qkv_linear_out.shape[0])
107+
z_out = int(z_linear_out.shape[0])
108+
beta_alpha_out = int(handler.num_value_heads_per_partition)
109+
110+
qkv_padding = qkv_linear_out.new_zeros((qkv_out, z_rank))
111+
z_padding = z_linear_out.new_zeros((z_out, qkv_rank))
112+
zeros = qkv_linear_out.new_zeros((beta_alpha_out, total_dim))
113+
114+
return _adapter_weight(
115+
base_prefix=base_prefix,
116+
adapter_key=None,
117+
alpha=alpha,
118+
dim=total_dim,
119+
linear_in=torch.cat([qkv_linear_in, z_linear_in], dim=0),
120+
linear_out=torch.cat(
121+
[
122+
torch.cat([qkv_linear_out, qkv_padding], dim=1),
123+
torch.cat([z_padding, z_linear_out], dim=1),
124+
zeros,
125+
zeros.clone(),
126+
],
127+
dim=0,
128+
),
129+
)
130+
131+
132+
def _fused_pair_adapter_weight(
133+
base_prefix: str,
134+
first_lora: LoRA,
135+
second_lora: LoRA,
136+
*,
137+
first_expert_idx: int | None = None,
138+
second_expert_idx: int | None = None,
139+
) -> AdapterWeight:
140+
first_linear_in, first_linear_out = _adapter_tensors(first_lora, first_expert_idx)
141+
second_linear_in, second_linear_out = _adapter_tensors(second_lora, second_expert_idx)
142+
assert math.isclose(float(first_lora.scale), float(second_lora.scale))
143+
total_dim = int(first_linear_in.shape[0] + second_linear_in.shape[0])
144+
alpha = round(float(first_lora.scale) * total_dim)
145+
146+
first_rank = int(first_linear_in.shape[0])
147+
second_rank = int(second_linear_in.shape[0])
148+
first_out = int(first_linear_out.shape[0])
149+
second_out = int(second_linear_out.shape[0])
150+
151+
first_padding = first_linear_out.new_zeros((first_out, second_rank))
152+
second_padding = second_linear_out.new_zeros((second_out, first_rank))
153+
154+
return _adapter_weight(
155+
base_prefix=base_prefix,
156+
adapter_key=None,
157+
alpha=alpha,
158+
dim=total_dim,
159+
linear_in=torch.cat([first_linear_in, second_linear_in], dim=0),
160+
linear_out=torch.cat(
161+
[
162+
torch.cat([first_linear_out, first_padding], dim=1),
163+
torch.cat([second_padding, second_linear_out], dim=1),
164+
],
165+
dim=0,
166+
),
167+
)
168+
169+
170+
def build_adapter_weights_by_base(
171+
model_chunks: list[MegatronModule],
172+
) -> dict[str, list[AdapterWeight]]:
173+
adapter_weights_by_base: dict[str, list[AdapterWeight]] = {}
174+
for chunk in model_chunks:
175+
for module_name, module in chunk.named_modules():
176+
if not isinstance(module, TransformerLayer):
177+
continue
178+
if not _is_language_transformer_layer_name(module_name):
179+
continue
180+
181+
layer_prefix = f"language_model.decoder.layers.{module.layer_number - 1}"
182+
self_attention = module.self_attention
183+
184+
linear_proj = getattr(self_attention, "linear_proj", None)
185+
if isinstance(linear_proj, SelfAttentionLinearProjLoRA):
186+
base_prefix = f"{layer_prefix}.self_attention.linear_proj"
187+
adapter_weights_by_base[f"{base_prefix}.weight"] = [
188+
_simple_adapter_weight(base_prefix, linear_proj.lora)
189+
]
190+
191+
linear_qkv = getattr(self_attention, "linear_qkv", None)
192+
if isinstance(linear_qkv, SelfAttentionLinearQKVLoRA):
193+
base_prefix = f"{layer_prefix}.self_attention.linear_qkv"
194+
adapter_weights_by_base[f"{base_prefix}.weight"] = [
195+
_simple_adapter_weight(base_prefix, linear_qkv.q_proj_lora, adapter_key="adapter_q"),
196+
_simple_adapter_weight(base_prefix, linear_qkv.k_proj_lora, adapter_key="adapter_k"),
197+
_simple_adapter_weight(base_prefix, linear_qkv.v_proj_lora, adapter_key="adapter_v"),
198+
]
199+
200+
out_proj = getattr(self_attention, "out_proj", None)
201+
if isinstance(out_proj, SelfAttentionLinearProjLoRA):
202+
base_prefix = f"{layer_prefix}.self_attention.out_proj"
203+
adapter_weights_by_base[f"{base_prefix}.weight"] = [
204+
_simple_adapter_weight(base_prefix, out_proj.lora)
205+
]
206+
207+
in_proj = getattr(self_attention, "in_proj", None)
208+
if isinstance(in_proj, GatedDeltaNetInProjLoRA):
209+
base_prefix = f"{layer_prefix}.self_attention.in_proj"
210+
adapter_weights_by_base[f"{base_prefix}.weight"] = [
211+
_fused_gdn_adapter_weight(base_prefix, in_proj)
212+
]
213+
214+
experts = getattr(module.mlp, "experts", None)
215+
if experts is not None:
216+
if isinstance(experts.linear_fc1, MLPExpertsLinearFC1LoRA):
217+
base_prefix = f"{layer_prefix}.mlp.experts.linear_fc1"
218+
for local_expert_idx in range(experts.linear_fc1.gate_lora.num_local_experts):
219+
global_expert_idx = local_expert_idx + experts.linear_fc1.gate_lora._expert_offset
220+
adapter_weights_by_base[f"{base_prefix}.weight{global_expert_idx}"] = [
221+
_fused_pair_adapter_weight(
222+
base_prefix,
223+
experts.linear_fc1.gate_lora,
224+
experts.linear_fc1.up_lora,
225+
first_expert_idx=local_expert_idx,
226+
second_expert_idx=local_expert_idx,
227+
)
228+
]
229+
if isinstance(experts.linear_fc2, MLPExpertsLinearFC2LoRA):
230+
base_prefix = f"{layer_prefix}.mlp.experts.linear_fc2"
231+
for local_expert_idx in range(experts.linear_fc2.lora.num_local_experts):
232+
global_expert_idx = local_expert_idx + experts.linear_fc2.lora._expert_offset
233+
adapter_weights_by_base[f"{base_prefix}.weight{global_expert_idx}"] = [
234+
_simple_adapter_weight(
235+
base_prefix,
236+
experts.linear_fc2.lora,
237+
expert_idx=local_expert_idx,
238+
)
239+
]
240+
else:
241+
linear_fc1 = getattr(module.mlp, "linear_fc1", None)
242+
if isinstance(linear_fc1, SharedExpertsLinearFC1LoRA):
243+
base_prefix = f"{layer_prefix}.mlp.linear_fc1"
244+
adapter_weights_by_base[f"{base_prefix}.weight"] = [
245+
_simple_adapter_weight(base_prefix, linear_fc1.gate_lora, adapter_key="adapter_gate"),
246+
_simple_adapter_weight(base_prefix, linear_fc1.up_lora, adapter_key="adapter_up"),
247+
]
248+
linear_fc2 = getattr(module.mlp, "linear_fc2", None)
249+
if isinstance(linear_fc2, SharedExpertsLinearFC2LoRA):
250+
base_prefix = f"{layer_prefix}.mlp.linear_fc2"
251+
adapter_weights_by_base[f"{base_prefix}.weight"] = [
252+
_simple_adapter_weight(base_prefix, linear_fc2.row_parallel_lora.lora)
253+
]
254+
255+
shared_experts = getattr(module.mlp, "shared_experts", None)
256+
if shared_experts is not None:
257+
if isinstance(shared_experts.linear_fc1, SharedExpertsLinearFC1LoRA):
258+
base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc1"
259+
adapter_weights_by_base[f"{base_prefix}.weight"] = [
260+
_simple_adapter_weight(
261+
base_prefix,
262+
shared_experts.linear_fc1.gate_lora,
263+
adapter_key="adapter_gate",
264+
),
265+
_simple_adapter_weight(
266+
base_prefix,
267+
shared_experts.linear_fc1.up_lora,
268+
adapter_key="adapter_up",
269+
),
270+
]
271+
if isinstance(shared_experts.linear_fc2, SharedExpertsLinearFC2LoRA):
272+
base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc2"
273+
adapter_weights_by_base[f"{base_prefix}.weight"] = [
274+
_simple_adapter_weight(
275+
base_prefix,
276+
shared_experts.linear_fc2.row_parallel_lora.lora,
277+
)
278+
]
279+
return adapter_weights_by_base

0 commit comments

Comments
 (0)