-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathresume_utils.py
More file actions
55 lines (46 loc) · 2.05 KB
/
resume_utils.py
File metadata and controls
55 lines (46 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from pathlib import Path
from typing import Any, Optional
from twinkle import get_logger
logger = get_logger()
def _build_model_kwargs(adapter_name: str) -> dict:
if not adapter_name:
return {}
return {'adapter_name': adapter_name}
def resume_from_checkpoint(
model: Any,
dataloader: Any,
checkpoint_path: Path,
*,
resume_only_model: bool,
ignore_data_skip: bool,
adapter_name: Optional[str] = None) -> int:
adapter_name = adapter_name or ''
checkpoint_dir = str(checkpoint_path)
model_kwargs = _build_model_kwargs(adapter_name)
if model_kwargs:
# Load adapter checkpoint.
model.load(
name=checkpoint_path.name,
output_dir=str(checkpoint_path.parent),
**model_kwargs,
)
if resume_only_model:
# Only load model weights, optionally skip data.
if ignore_data_skip:
logger.info('Resumed weights only and restarted progress from step 0.')
return 0
progress = model.read_training_progress(checkpoint_dir, **model_kwargs)
# Skip consumed samples in dataloader and move optimizer to the right step.
consumed_train_samples = int(progress['consumed_train_samples'])
dataloader.skip_consumed_samples(consumed_train_samples)
optimizer_group = model.optimizer_group[adapter_name]
optimizer_group.cur_step = progress['cur_step']
optimizer_group.gradient_accumulation_steps = progress['gradient_accumulation_steps']
logger.info(f'Skipped {consumed_train_samples} consumed samples.')
return consumed_train_samples
# Load full training state, including model weights, optimizer states, and training progress.
trainer_state = model.load_training_state(checkpoint_dir, **model_kwargs)
consumed_train_samples = int(trainer_state['consumed_train_samples'])
dataloader.skip_consumed_samples(consumed_train_samples)
logger.info(f'Restored full training state from step {trainer_state["cur_step"]}.')
return consumed_train_samples