|
| 1 | +import math |
| 2 | + |
| 3 | +from megatron.bridge.models.conversion.model_bridge import MegatronWeightTuple |
| 4 | +from megatron.bridge.models.conversion.peft_bridge import AdapterWeight |
| 5 | +from megatron.core.transformer.module import MegatronModule |
| 6 | +from megatron.core.transformer.transformer_layer import TransformerLayer |
| 7 | +import torch |
| 8 | + |
| 9 | +from art.megatron.lora import ( |
| 10 | + GatedDeltaNetInProjLoRA, |
| 11 | + LoRA, |
| 12 | + MLPExpertsLinearFC1LoRA, |
| 13 | + MLPExpertsLinearFC2LoRA, |
| 14 | + SelfAttentionLinearProjLoRA, |
| 15 | + SelfAttentionLinearQKVLoRA, |
| 16 | + SharedExpertsLinearFC1LoRA, |
| 17 | + SharedExpertsLinearFC2LoRA, |
| 18 | +) |
| 19 | + |
| 20 | + |
| 21 | +def _is_language_transformer_layer_name(module_name: str) -> bool: |
| 22 | + while module_name.startswith("module."): |
| 23 | + module_name = module_name.removeprefix("module.") |
| 24 | + return module_name.startswith(("decoder.layers.", "language_model.decoder.layers.")) |
| 25 | + |
| 26 | + |
| 27 | +def _adapter_alpha_dim(lora: LoRA) -> tuple[int, int]: |
| 28 | + dim = int(lora.A_T.shape[-1]) |
| 29 | + alpha = float(lora.scale) * dim |
| 30 | + rounded_alpha = round(alpha) |
| 31 | + assert math.isclose(alpha, rounded_alpha) |
| 32 | + return rounded_alpha, dim |
| 33 | + |
| 34 | + |
| 35 | +def _adapter_tensors(lora: LoRA, expert_idx: int | None = None) -> tuple[torch.Tensor, torch.Tensor]: |
| 36 | + a_t = lora.A_T if expert_idx is None else lora.A_T[expert_idx] |
| 37 | + b_t = lora.B_T if expert_idx is None else lora.B_T[expert_idx] |
| 38 | + return a_t.transpose(-1, -2).contiguous(), b_t.transpose(-1, -2).contiguous() |
| 39 | + |
| 40 | + |
| 41 | +def _adapter_param_prefix(base_prefix: str, adapter_key: str | None) -> str: |
| 42 | + if adapter_key is None: |
| 43 | + return f"{base_prefix}.adapter" |
| 44 | + return f"{base_prefix}.adapter.{adapter_key}" |
| 45 | + |
| 46 | + |
| 47 | +def _adapter_weight( |
| 48 | + *, |
| 49 | + base_prefix: str, |
| 50 | + adapter_key: str | None, |
| 51 | + alpha: int, |
| 52 | + dim: int, |
| 53 | + linear_in: torch.Tensor, |
| 54 | + linear_out: torch.Tensor, |
| 55 | +) -> AdapterWeight: |
| 56 | + param_prefix = _adapter_param_prefix(base_prefix, adapter_key) |
| 57 | + return AdapterWeight( |
| 58 | + global_base_prefix=base_prefix, |
| 59 | + adapter_key=adapter_key, |
| 60 | + alpha=alpha, |
| 61 | + dim=dim, |
| 62 | + linear_in_weight=MegatronWeightTuple( |
| 63 | + param_name=f"{param_prefix}.linear_in.weight", |
| 64 | + weight=linear_in, |
| 65 | + vp_stage=0, |
| 66 | + ), |
| 67 | + linear_out_weight=MegatronWeightTuple( |
| 68 | + param_name=f"{param_prefix}.linear_out.weight", |
| 69 | + weight=linear_out, |
| 70 | + vp_stage=0, |
| 71 | + ), |
| 72 | + ) |
| 73 | + |
| 74 | + |
| 75 | +def _simple_adapter_weight( |
| 76 | + base_prefix: str, |
| 77 | + lora: LoRA, |
| 78 | + *, |
| 79 | + adapter_key: str | None = None, |
| 80 | + expert_idx: int | None = None, |
| 81 | +) -> AdapterWeight: |
| 82 | + alpha, dim = _adapter_alpha_dim(lora) |
| 83 | + linear_in, linear_out = _adapter_tensors(lora, expert_idx) |
| 84 | + return _adapter_weight( |
| 85 | + base_prefix=base_prefix, |
| 86 | + adapter_key=adapter_key, |
| 87 | + alpha=alpha, |
| 88 | + dim=dim, |
| 89 | + linear_in=linear_in, |
| 90 | + linear_out=linear_out, |
| 91 | + ) |
| 92 | + |
| 93 | + |
| 94 | +def _fused_gdn_adapter_weight( |
| 95 | + base_prefix: str, |
| 96 | + handler: GatedDeltaNetInProjLoRA, |
| 97 | +) -> AdapterWeight: |
| 98 | + qkv_linear_in, qkv_linear_out = _adapter_tensors(handler.qkv_lora) |
| 99 | + z_linear_in, z_linear_out = _adapter_tensors(handler.z_lora) |
| 100 | + assert math.isclose(float(handler.qkv_lora.scale), float(handler.z_lora.scale)) |
| 101 | + total_dim = int(qkv_linear_in.shape[0] + z_linear_in.shape[0]) |
| 102 | + alpha = round(float(handler.qkv_lora.scale) * total_dim) |
| 103 | + |
| 104 | + qkv_rank = int(qkv_linear_in.shape[0]) |
| 105 | + z_rank = int(z_linear_in.shape[0]) |
| 106 | + qkv_out = int(qkv_linear_out.shape[0]) |
| 107 | + z_out = int(z_linear_out.shape[0]) |
| 108 | + beta_alpha_out = int(handler.num_value_heads_per_partition) |
| 109 | + |
| 110 | + qkv_padding = qkv_linear_out.new_zeros((qkv_out, z_rank)) |
| 111 | + z_padding = z_linear_out.new_zeros((z_out, qkv_rank)) |
| 112 | + zeros = qkv_linear_out.new_zeros((beta_alpha_out, total_dim)) |
| 113 | + |
| 114 | + return _adapter_weight( |
| 115 | + base_prefix=base_prefix, |
| 116 | + adapter_key=None, |
| 117 | + alpha=alpha, |
| 118 | + dim=total_dim, |
| 119 | + linear_in=torch.cat([qkv_linear_in, z_linear_in], dim=0), |
| 120 | + linear_out=torch.cat( |
| 121 | + [ |
| 122 | + torch.cat([qkv_linear_out, qkv_padding], dim=1), |
| 123 | + torch.cat([z_padding, z_linear_out], dim=1), |
| 124 | + zeros, |
| 125 | + zeros.clone(), |
| 126 | + ], |
| 127 | + dim=0, |
| 128 | + ), |
| 129 | + ) |
| 130 | + |
| 131 | + |
| 132 | +def _fused_pair_adapter_weight( |
| 133 | + base_prefix: str, |
| 134 | + first_lora: LoRA, |
| 135 | + second_lora: LoRA, |
| 136 | + *, |
| 137 | + first_expert_idx: int | None = None, |
| 138 | + second_expert_idx: int | None = None, |
| 139 | +) -> AdapterWeight: |
| 140 | + first_linear_in, first_linear_out = _adapter_tensors(first_lora, first_expert_idx) |
| 141 | + second_linear_in, second_linear_out = _adapter_tensors(second_lora, second_expert_idx) |
| 142 | + assert math.isclose(float(first_lora.scale), float(second_lora.scale)) |
| 143 | + total_dim = int(first_linear_in.shape[0] + second_linear_in.shape[0]) |
| 144 | + alpha = round(float(first_lora.scale) * total_dim) |
| 145 | + |
| 146 | + first_rank = int(first_linear_in.shape[0]) |
| 147 | + second_rank = int(second_linear_in.shape[0]) |
| 148 | + first_out = int(first_linear_out.shape[0]) |
| 149 | + second_out = int(second_linear_out.shape[0]) |
| 150 | + |
| 151 | + first_padding = first_linear_out.new_zeros((first_out, second_rank)) |
| 152 | + second_padding = second_linear_out.new_zeros((second_out, first_rank)) |
| 153 | + |
| 154 | + return _adapter_weight( |
| 155 | + base_prefix=base_prefix, |
| 156 | + adapter_key=None, |
| 157 | + alpha=alpha, |
| 158 | + dim=total_dim, |
| 159 | + linear_in=torch.cat([first_linear_in, second_linear_in], dim=0), |
| 160 | + linear_out=torch.cat( |
| 161 | + [ |
| 162 | + torch.cat([first_linear_out, first_padding], dim=1), |
| 163 | + torch.cat([second_padding, second_linear_out], dim=1), |
| 164 | + ], |
| 165 | + dim=0, |
| 166 | + ), |
| 167 | + ) |
| 168 | + |
| 169 | + |
| 170 | +def build_adapter_weights_by_base( |
| 171 | + model_chunks: list[MegatronModule], |
| 172 | +) -> dict[str, list[AdapterWeight]]: |
| 173 | + adapter_weights_by_base: dict[str, list[AdapterWeight]] = {} |
| 174 | + for chunk in model_chunks: |
| 175 | + for module_name, module in chunk.named_modules(): |
| 176 | + if not isinstance(module, TransformerLayer): |
| 177 | + continue |
| 178 | + if not _is_language_transformer_layer_name(module_name): |
| 179 | + continue |
| 180 | + |
| 181 | + layer_prefix = f"language_model.decoder.layers.{module.layer_number - 1}" |
| 182 | + self_attention = module.self_attention |
| 183 | + |
| 184 | + linear_proj = getattr(self_attention, "linear_proj", None) |
| 185 | + if isinstance(linear_proj, SelfAttentionLinearProjLoRA): |
| 186 | + base_prefix = f"{layer_prefix}.self_attention.linear_proj" |
| 187 | + adapter_weights_by_base[f"{base_prefix}.weight"] = [ |
| 188 | + _simple_adapter_weight(base_prefix, linear_proj.lora) |
| 189 | + ] |
| 190 | + |
| 191 | + linear_qkv = getattr(self_attention, "linear_qkv", None) |
| 192 | + if isinstance(linear_qkv, SelfAttentionLinearQKVLoRA): |
| 193 | + base_prefix = f"{layer_prefix}.self_attention.linear_qkv" |
| 194 | + adapter_weights_by_base[f"{base_prefix}.weight"] = [ |
| 195 | + _simple_adapter_weight(base_prefix, linear_qkv.q_proj_lora, adapter_key="adapter_q"), |
| 196 | + _simple_adapter_weight(base_prefix, linear_qkv.k_proj_lora, adapter_key="adapter_k"), |
| 197 | + _simple_adapter_weight(base_prefix, linear_qkv.v_proj_lora, adapter_key="adapter_v"), |
| 198 | + ] |
| 199 | + |
| 200 | + out_proj = getattr(self_attention, "out_proj", None) |
| 201 | + if isinstance(out_proj, SelfAttentionLinearProjLoRA): |
| 202 | + base_prefix = f"{layer_prefix}.self_attention.out_proj" |
| 203 | + adapter_weights_by_base[f"{base_prefix}.weight"] = [ |
| 204 | + _simple_adapter_weight(base_prefix, out_proj.lora) |
| 205 | + ] |
| 206 | + |
| 207 | + in_proj = getattr(self_attention, "in_proj", None) |
| 208 | + if isinstance(in_proj, GatedDeltaNetInProjLoRA): |
| 209 | + base_prefix = f"{layer_prefix}.self_attention.in_proj" |
| 210 | + adapter_weights_by_base[f"{base_prefix}.weight"] = [ |
| 211 | + _fused_gdn_adapter_weight(base_prefix, in_proj) |
| 212 | + ] |
| 213 | + |
| 214 | + experts = getattr(module.mlp, "experts", None) |
| 215 | + if experts is not None: |
| 216 | + if isinstance(experts.linear_fc1, MLPExpertsLinearFC1LoRA): |
| 217 | + base_prefix = f"{layer_prefix}.mlp.experts.linear_fc1" |
| 218 | + for local_expert_idx in range(experts.linear_fc1.gate_lora.num_local_experts): |
| 219 | + global_expert_idx = local_expert_idx + experts.linear_fc1.gate_lora._expert_offset |
| 220 | + adapter_weights_by_base[f"{base_prefix}.weight{global_expert_idx}"] = [ |
| 221 | + _fused_pair_adapter_weight( |
| 222 | + base_prefix, |
| 223 | + experts.linear_fc1.gate_lora, |
| 224 | + experts.linear_fc1.up_lora, |
| 225 | + first_expert_idx=local_expert_idx, |
| 226 | + second_expert_idx=local_expert_idx, |
| 227 | + ) |
| 228 | + ] |
| 229 | + if isinstance(experts.linear_fc2, MLPExpertsLinearFC2LoRA): |
| 230 | + base_prefix = f"{layer_prefix}.mlp.experts.linear_fc2" |
| 231 | + for local_expert_idx in range(experts.linear_fc2.lora.num_local_experts): |
| 232 | + global_expert_idx = local_expert_idx + experts.linear_fc2.lora._expert_offset |
| 233 | + adapter_weights_by_base[f"{base_prefix}.weight{global_expert_idx}"] = [ |
| 234 | + _simple_adapter_weight( |
| 235 | + base_prefix, |
| 236 | + experts.linear_fc2.lora, |
| 237 | + expert_idx=local_expert_idx, |
| 238 | + ) |
| 239 | + ] |
| 240 | + else: |
| 241 | + linear_fc1 = getattr(module.mlp, "linear_fc1", None) |
| 242 | + if isinstance(linear_fc1, SharedExpertsLinearFC1LoRA): |
| 243 | + base_prefix = f"{layer_prefix}.mlp.linear_fc1" |
| 244 | + adapter_weights_by_base[f"{base_prefix}.weight"] = [ |
| 245 | + _simple_adapter_weight(base_prefix, linear_fc1.gate_lora, adapter_key="adapter_gate"), |
| 246 | + _simple_adapter_weight(base_prefix, linear_fc1.up_lora, adapter_key="adapter_up"), |
| 247 | + ] |
| 248 | + linear_fc2 = getattr(module.mlp, "linear_fc2", None) |
| 249 | + if isinstance(linear_fc2, SharedExpertsLinearFC2LoRA): |
| 250 | + base_prefix = f"{layer_prefix}.mlp.linear_fc2" |
| 251 | + adapter_weights_by_base[f"{base_prefix}.weight"] = [ |
| 252 | + _simple_adapter_weight(base_prefix, linear_fc2.row_parallel_lora.lora) |
| 253 | + ] |
| 254 | + |
| 255 | + shared_experts = getattr(module.mlp, "shared_experts", None) |
| 256 | + if shared_experts is not None: |
| 257 | + if isinstance(shared_experts.linear_fc1, SharedExpertsLinearFC1LoRA): |
| 258 | + base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc1" |
| 259 | + adapter_weights_by_base[f"{base_prefix}.weight"] = [ |
| 260 | + _simple_adapter_weight( |
| 261 | + base_prefix, |
| 262 | + shared_experts.linear_fc1.gate_lora, |
| 263 | + adapter_key="adapter_gate", |
| 264 | + ), |
| 265 | + _simple_adapter_weight( |
| 266 | + base_prefix, |
| 267 | + shared_experts.linear_fc1.up_lora, |
| 268 | + adapter_key="adapter_up", |
| 269 | + ), |
| 270 | + ] |
| 271 | + if isinstance(shared_experts.linear_fc2, SharedExpertsLinearFC2LoRA): |
| 272 | + base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc2" |
| 273 | + adapter_weights_by_base[f"{base_prefix}.weight"] = [ |
| 274 | + _simple_adapter_weight( |
| 275 | + base_prefix, |
| 276 | + shared_experts.linear_fc2.row_parallel_lora.lora, |
| 277 | + ) |
| 278 | + ] |
| 279 | + return adapter_weights_by_base |
0 commit comments