-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathoptimizer.py
More file actions
164 lines (139 loc) · 5.16 KB
/
Copy pathoptimizer.py
File metadata and controls
164 lines (139 loc) · 5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/python
# -*- encoding: utf-8 -*-
"""Custom Optimizer with Warmup + Polynomial LR Schedule Supports differential learning
rates (e.g., decoder x10) and integrates safely with AMP."""
import logging
from typing import Any, Dict, List
import torch
logger = logging.getLogger(__name__)
class Optimizer:
"""
Wrapper around SGD with:
- Linear warmup
- Polynomial decay
- Support for lr-multiplied parameter groups (e.g., decoder)
"""
def __init__(
self,
model: torch.nn.Module,
lr0: float,
momentum: float = 0.9,
wd: float = 1e-4,
warmup_steps: int = 0,
warmup_start_lr: float = 1e-5,
max_iter: int = 100000,
power: float = 0.9,
lr_multiplier: float = 10.0, # Instead of hardcoded x10
):
self.lr0 = lr0
self.momentum = momentum
self.wd = wd
self.warmup_steps = warmup_steps
self.warmup_start_lr = warmup_start_lr
self.max_iter = float(max_iter)
self.power = power
self.lr_multiplier = lr_multiplier
self.it = 0
# Extract parameter groups from model
try:
params = model.get_params()
if len(params) == 2:
# Encoder-only mode? No decoder-specific groups
wd_params, nowd_params = params
lr_mul_wd_params, lr_mul_nowd_params = [], []
logger.info(
"[Optimizer] Model returned 2 param groups (no decoder LR scaling)"
)
elif len(params) == 4:
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = params
logger.info(
f"[Optimizer] Using differential LR: x{lr_multiplier} for decoder"
)
else:
raise ValueError(f"Expected 2 or 4 param groups, got {len(params)}")
except AttributeError as e:
raise RuntimeError(
"Model must have .get_params() method returning param groups"
) from e
except Exception as e:
raise RuntimeError(f"Error parsing model parameters: {e}") from e
# Build parameter list
param_groups: List[Dict[str, Any]] = []
if wd_params:
param_groups.append({"params": wd_params, "weight_decay": wd})
if nowd_params:
param_groups.append({"params": nowd_params, "weight_decay": 0.0})
if lr_mul_wd_params:
param_groups.append(
{
"params": lr_mul_wd_params,
"weight_decay": wd,
"lr_scale": lr_multiplier,
}
)
if lr_mul_nowd_params:
param_groups.append(
{
"params": lr_mul_nowd_params,
"weight_decay": 0.0,
"lr_scale": lr_multiplier,
}
)
if len(param_groups) == 0:
raise ValueError("No parameters found in model!")
# Create base optimizer
self.optim = torch.optim.SGD(
param_groups, lr=lr0, momentum=momentum, weight_decay=0.0
) # WD handled per-group
# Warmup schedule
if warmup_steps > 0:
self.warmup_factor = (lr0 / warmup_start_lr) ** (1.0 / warmup_steps)
else:
self.warmup_factor = 1.0
logger.info(
f"[Optimizer] Initialized with LR={lr0}, WD={wd}, "
f"Warmup={warmup_steps} steps, Poly(power={power}), "
f"Max Iter={max_iter}"
)
def get_lr(self, group_idx: int, group: Dict[str, Any]) -> float:
"""Compute current LR for a given parameter group."""
base_lr = self.lr0
lr_scale: float = float(group.get("lr_scale", 1.0))
if self.it < self.warmup_steps:
# Linear or exponential warmup — use linear for stability
alpha = self.it / self.warmup_steps
lr = self.warmup_start_lr + alpha * (base_lr - self.warmup_start_lr)
else:
k = (self.it - self.warmup_steps) / (self.max_iter - self.warmup_steps)
k = max(k, 0.0) # Avoid negative
lr = base_lr * ((1 - k) ** self.power)
return lr * lr_scale
def step(self):
"""Update learning rates and take optimizer step."""
# Update LR for each group
for i, pg in enumerate(self.optim.param_groups):
pg["lr"] = self.get_lr(i, pg)
# Take step
self.optim.step()
# Logging
if self.it == self.warmup_steps:
logger.info(
f"==> Warmup completed at step {self.it}. "
f"Switching to poly({self.power}) LR schedule."
)
self.it += 1
def zero_grad(self):
"""Zero gradients."""
self.optim.zero_grad()
def state_dict(self):
"""Expose optimizer state."""
return self.optim.state_dict()
def load_state_dict(self, state):
"""Load optimizer state."""
self.optim.load_state_dict(state)
@property
def param_groups(self):
return self.optim.param_groups
@property
def defaults(self):
return self.optim.defaults