-
Notifications
You must be signed in to change notification settings - Fork 129
Expand file tree
/
Copy pathtrainer.py
More file actions
618 lines (515 loc) · 23.7 KB
/
trainer.py
File metadata and controls
618 lines (515 loc) · 23.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
import codecs
from naslib.search_spaces.core.graph import Graph
import time
import json
import logging
import os
import copy
import torch
import numpy as np
from fvcore.common.checkpoint import PeriodicCheckpointer
from naslib.search_spaces.core.query_metrics import Metric
from naslib.utils import utils
from naslib.utils.logging import log_every_n_seconds, log_first_n
from typing import Callable
from .additional_primitives import DropPathWrapper
logger = logging.getLogger(__name__)
class Trainer(object):
"""
Default implementation that handles dataloading and preparing batches, the
train loop, gathering statistics, checkpointing and doing the final
final evaluation.
If this does not fulfil your needs free do subclass it and implement your
required logic.
"""
def __init__(self, optimizer, config, lightweight_output=False):
"""
Initializes the trainer.
Args:
optimizer: A NASLib optimizer
config (AttrDict): The configuration loaded from a yaml file, e.g
via `utils.get_config_from_args()`
"""
self.optimizer = optimizer
self.config = config
self.epochs = self.config.search.epochs
self.lightweight_output = lightweight_output
# preparations
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# measuring stuff
self.train_top1 = utils.AverageMeter()
self.train_top5 = utils.AverageMeter()
self.train_loss = utils.AverageMeter()
self.val_top1 = utils.AverageMeter()
self.val_top5 = utils.AverageMeter()
self.val_loss = utils.AverageMeter()
n_parameters = optimizer.get_model_size()
# logger.info("param size = %fMB", n_parameters)
self.search_trajectory = utils.AttrDict(
{
"train_acc": [],
"train_loss": [],
"valid_acc": [],
"valid_loss": [],
"test_acc": [],
"test_loss": [],
"runtime": [],
"train_time": [],
"arch_eval": [],
"params": n_parameters,
}
)
def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int], None]=None, report_incumbent=True):
"""
Start the architecture search.
Generates a json file with training statistics.
Args:
resume_from (str): Checkpoint file to resume from. If not given then
train from scratch.
"""
logger.info("Beginning search")
np.random.seed(self.config.search.seed)
torch.manual_seed(self.config.search.seed)
checkpoint_freq = self.config.search.checkpoint_freq
if self.optimizer.using_step_function:
self.scheduler = self.build_search_scheduler(
self.optimizer.op_optimizer, self.config
)
start_epoch = self._setup_checkpointers(
resume_from, period=checkpoint_freq, scheduler=self.scheduler
)
else:
start_epoch = self._setup_checkpointers(resume_from, period=checkpoint_freq)
self.optimizer.before_training()
if self.optimizer.using_step_function:
self.train_queue, self.valid_queue, _ = self.build_search_dataloaders(
self.config
)
for e in range(start_epoch, self.epochs):
start_time = time.time()
self.optimizer.new_epoch(e)
if self.optimizer.using_step_function:
for step, data_train in enumerate(self.train_queue):
data_train = (
data_train[0].to(self.device),
data_train[1].to(self.device, non_blocking=True),
)
data_val = next(iter(self.valid_queue))
data_val = (
data_val[0].to(self.device),
data_val[1].to(self.device, non_blocking=True),
)
stats = self.optimizer.step(data_train, data_val)
logits_train, logits_val, train_loss, val_loss = stats
self._store_accuracies(logits_train, data_train[1], "train")
self._store_accuracies(logits_val, data_val[1], "val")
log_every_n_seconds(
logging.INFO,
"Epoch {}-{}, Train loss: {:.5f}, validation loss: {:.5f}, learning rate: {}".format(
e, step, train_loss, val_loss, self.scheduler.get_last_lr()
),
n=5,
)
if torch.cuda.is_available():
log_first_n(
logging.INFO,
"cuda consumption\n {}".format(torch.cuda.memory_summary()),
n=3,
)
self.train_loss.update(float(train_loss.detach().cpu()))
self.val_loss.update(float(val_loss.detach().cpu()))
# break
self.scheduler.step()
end_time = time.time()
self.search_trajectory.train_acc.append(self.train_top1.avg)
self.search_trajectory.train_loss.append(self.train_loss.avg)
self.search_trajectory.valid_acc.append(self.val_top1.avg)
self.search_trajectory.valid_loss.append(self.val_loss.avg)
self.search_trajectory.runtime.append(end_time - start_time)
else:
end_time = time.time()
# TODO: nasbench101 does not have train_loss, valid_loss, test_loss implemented, so this is a quick fix for now
# train_acc, train_loss, valid_acc, valid_loss, test_acc, test_loss = self.optimizer.train_statistics()
(
train_acc,
valid_acc,
test_acc,
train_time,
) = self.optimizer.train_statistics(report_incumbent)
train_loss, valid_loss, test_loss = -1, -1, -1
self.search_trajectory.train_acc.append(train_acc)
self.search_trajectory.train_loss.append(train_loss)
self.search_trajectory.valid_acc.append(valid_acc)
self.search_trajectory.valid_loss.append(valid_loss)
self.search_trajectory.test_acc.append(test_acc)
self.search_trajectory.test_loss.append(test_loss)
self.search_trajectory.runtime.append(end_time - start_time)
self.search_trajectory.train_time.append(train_time)
self.train_top1.avg = train_acc
self.val_top1.avg = valid_acc
# arch_weights = self.optimizer.get_checkpointables()["arch_weights"]
add_checkpointables = self.optimizer.get_checkpointables()
del add_checkpointables["model"]
self.periodic_checkpointer.step(e, **add_checkpointables)
anytime_results = self.optimizer.test_statistics()
# if anytime_results:
# record anytime performance
# self.search_trajectory.arch_eval.append(anytime_results)
# log_every_n_seconds(
# logging.INFO,
# "Epoch {}, Anytime results: {}".format(e, anytime_results),
# n=5,
# )
self._log_to_json()
self._log_and_reset_accuracies(e, summary_writer)
if after_epoch is not None:
after_epoch(e)
self.optimizer.after_training()
if summary_writer is not None:
summary_writer.close()
logger.info("Training finished")
def evaluate_oneshot(self, resume_from="", dataloader=None):
"""
Evaluate the one-shot model on the specified dataset.
Generates a json file with training statistics.
Args:
resume_from (str): Checkpoint file to resume from. If not given then
evaluate with the current one-shot weights.
"""
logger.info("Start one-shot evaluation")
self._setup_checkpointers(resume_from)
self.optimizer.before_training()
loss = torch.nn.CrossEntropyLoss()
if dataloader is None:
# load only the validation data
_, dataloader, _ = self.build_search_dataloaders(self.config)
self.optimizer.graph.eval()
with torch.no_grad():
start_time = time.time()
for step, data_val in enumerate(dataloader):
input_val = data_val[0].to(self.device)
target_val = data_val[1].to(self.device, non_blocking=True)
logits_val = self.optimizer.graph(input_val)
val_loss = loss(logits_val, target_val)
self._store_accuracies(logits_val, data_val[1], "val")
self.val_loss.update(float(val_loss.detach().cpu()))
end_time = time.time()
self.search_trajectory.valid_acc.append(self.val_top1.avg)
self.search_trajectory.valid_loss.append(self.val_loss.avg)
self.search_trajectory.runtime.append(end_time - start_time)
self._log_to_json()
logger.info("Evaluation finished")
return self.val_top1.avg
def evaluate(
self,
retrain:bool=True,
search_model:str="",
resume_from:str="",
best_arch:Graph=None,
dataset_api:object=None,
metric:Metric=None,
):
"""
Evaluate the final architecture as given from the optimizer.
If the search space has an interface to a benchmark then query that.
Otherwise train as defined in the config.
Args:
retrain (bool) : Reset the weights from the architecure search
search_model (str) : Path to checkpoint file that was created during search. If not provided,
then try to load 'model_final.pth' from search
resume_from (str) : Resume retraining from the given checkpoint file.
best_arch : Parsed model you want to directly evaluate and ignore the final model
from the optimizer.
dataset_api : Dataset API to use for querying model performance.
metric : Metric to query the benchmark for.
"""
logger.info("Start evaluation")
if not best_arch:
if not search_model:
search_model = os.path.join(
self.config.save, "search", "model_final.pth"
)
self._setup_checkpointers(search_model) # required to load the architecture
best_arch = self.optimizer.get_final_architecture()
logger.info(f"Final architecture hash: {best_arch.get_hash()}")
if best_arch.QUERYABLE and (not retrain):
if metric is None:
metric = Metric.TEST_ACCURACY
result = best_arch.query(
metric=metric, dataset=self.config.dataset, dataset_api=dataset_api
)
logger.info("Queried results ({}): {}".format(metric, result))
return result
else:
best_arch.to(self.device)
if retrain:
logger.info("Starting retraining from scratch")
best_arch.reset_weights(inplace=True)
(
self.train_queue,
self.valid_queue,
self.test_queue,
) = self.build_eval_dataloaders(self.config)
optim = self.build_eval_optimizer(best_arch.parameters(), self.config)
scheduler = self.build_eval_scheduler(optim, self.config)
start_epoch = self._setup_checkpointers(
resume_from,
search=False,
period=self.config.evaluation.checkpoint_freq,
model=best_arch, # checkpointables start here
optim=optim,
scheduler=scheduler,
)
grad_clip = self.config.evaluation.grad_clip
loss = torch.nn.CrossEntropyLoss()
self.train_top1.reset()
self.train_top5.reset()
self.val_top1.reset()
self.val_top5.reset()
# Enable drop path
best_arch.update_edges(
update_func=lambda edge: edge.data.set(
"op", DropPathWrapper(edge.data.op)
),
scope=best_arch.OPTIMIZER_SCOPE,
private_edge_data=True,
)
# train from scratch
epochs = self.config.evaluation.epochs
for e in range(start_epoch, epochs):
best_arch.train()
if torch.cuda.is_available():
log_first_n(
logging.INFO,
"cuda consumption\n {}".format(torch.cuda.memory_summary()),
n=20,
)
# update drop path probability
drop_path_prob = self.config.evaluation.drop_path_prob * e / epochs
best_arch.update_edges(
update_func=lambda edge: edge.data.set(
"drop_path_prob", drop_path_prob
),
scope=best_arch.OPTIMIZER_SCOPE,
private_edge_data=True,
)
# Train queue
for i, (input_train, target_train) in enumerate(self.train_queue):
input_train = input_train.to(self.device)
target_train = target_train.to(self.device, non_blocking=True)
optim.zero_grad()
logits_train = best_arch(input_train)
train_loss = loss(logits_train, target_train)
if hasattr(
best_arch, "auxilary_logits"
): # darts specific stuff
log_first_n(logging.INFO, "Auxiliary is used", n=10)
auxiliary_loss = loss(
best_arch.auxilary_logits(), target_train
)
train_loss += (
self.config.evaluation.auxiliary_weight * auxiliary_loss
)
train_loss.backward()
if grad_clip:
torch.nn.utils.clip_grad_norm_(
best_arch.parameters(), grad_clip
)
optim.step()
self._store_accuracies(logits_train, target_train, "train")
log_every_n_seconds(
logging.INFO,
"Epoch {}-{}, Train loss: {:.5}, learning rate: {}".format(
e, i, train_loss, scheduler.get_last_lr()
),
n=5,
)
# Validation queue
if self.valid_queue:
best_arch.eval()
for i, (input_valid, target_valid) in enumerate(
self.valid_queue
):
input_valid = input_valid.to(self.device).float()
target_valid = target_valid.to(self.device).float()
# just log the validation accuracy
with torch.no_grad():
logits_valid = best_arch(input_valid)
self._store_accuracies(
logits_valid, target_valid, "val"
)
arch_weights = self.optimizer.get_checkpointables()["arch_weights"]
scheduler.step()
self.periodic_checkpointer.step(iteration=e, arch_weights=arch_weights)
self._log_and_reset_accuracies(e)
# Disable drop path
best_arch.update_edges(
update_func=lambda edge: edge.data.set(
"op", edge.data.op.get_embedded_ops()
),
scope=best_arch.OPTIMIZER_SCOPE,
private_edge_data=True,
)
# measure final test accuracy
top1 = utils.AverageMeter()
top5 = utils.AverageMeter()
best_arch.eval()
for i, data_test in enumerate(self.test_queue):
input_test, target_test = data_test
input_test = input_test.to(self.device)
target_test = target_test.to(self.device, non_blocking=True)
n = input_test.size(0)
with torch.no_grad():
logits = best_arch(input_test)
prec1, prec5 = utils.accuracy(logits, target_test, topk=(1, 5))
top1.update(prec1.data.item(), n)
top5.update(prec5.data.item(), n)
log_every_n_seconds(
logging.INFO,
"Inference batch {} of {}.".format(i, len(self.test_queue)),
n=5,
)
logger.info(
"Evaluation finished. Test accuracies: top-1 = {:.5}, top-5 = {:.5}".format(
top1.avg, top5.avg
)
)
return top1.avg
@staticmethod
def build_search_dataloaders(config):
train_queue, valid_queue, test_queue, _, _ = utils.get_train_val_loaders(
config, mode="train"
)
return train_queue, valid_queue, _ # test_queue is not used in search currently
@staticmethod
def build_eval_dataloaders(config):
train_queue, valid_queue, test_queue, _, _ = utils.get_train_val_loaders(
config, mode="val"
)
return train_queue, valid_queue, test_queue
@staticmethod
def build_eval_optimizer(parameters, config):
return torch.optim.SGD(
parameters,
lr=config.evaluation.learning_rate,
momentum=config.evaluation.momentum,
weight_decay=config.evaluation.weight_decay,
)
@staticmethod
def build_search_scheduler(optimizer, config):
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=config.search.epochs,
eta_min=config.search.learning_rate_min,
)
@staticmethod
def build_eval_scheduler(optimizer, config):
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=config.evaluation.epochs,
eta_min=config.evaluation.learning_rate_min,
)
def _log_and_reset_accuracies(self, epoch, writer=None):
logger.info(
"Epoch {} done. Train accuracy: {:.5f}, Validation accuracy: {:.5f}".format(
epoch,
self.train_top1.avg,
self.val_top1.avg,
)
)
if writer is not None:
writer.add_scalar('Train accuracy (top 1)', self.train_top1.avg, epoch)
writer.add_scalar('Train accuracy (top 5)', self.train_top5.avg, epoch)
writer.add_scalar('Train loss', self.train_loss.avg, epoch)
writer.add_scalar('Validation accuracy (top 1)', self.val_top1.avg, epoch)
writer.add_scalar('Validation accuracy (top 5)', self.val_top5.avg, epoch)
writer.add_scalar('Validation loss', self.val_loss.avg, epoch)
self.train_top1.reset()
self.train_top5.reset()
self.train_loss.reset()
self.val_top1.reset()
self.val_top5.reset()
self.val_loss.reset()
def _store_accuracies(self, logits, target, split):
"""Update the accuracy counters"""
logits = logits.clone().detach().cpu()
target = target.clone().detach().cpu()
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = logits.size(0)
if split == "train":
self.train_top1.update(prec1.data.item(), n)
self.train_top5.update(prec5.data.item(), n)
elif split == "val":
self.val_top1.update(prec1.data.item(), n)
self.val_top5.update(prec5.data.item(), n)
else:
raise ValueError("Unknown split: {}. Expected either 'train' or 'val'")
def _prepare_dataloaders(self, config, mode="train"):
"""
Prepare train, validation, and test dataloaders with the splits defined
in the config.
Args:
config (AttrDict): config from config file.
"""
train_queue, valid_queue, test_queue, _, _ = utils.get_train_val_loaders(
config, mode
)
self.train_queue = train_queue
self.valid_queue = valid_queue
self.test_queue = test_queue
def _setup_checkpointers(
self, resume_from="", search=True, period=1, **add_checkpointables
):
"""
Sets up a periodic chechkpointer which can be used to save checkpoints
at every epoch. It will call optimizer's `get_checkpointables()` as objects
to store.
Args:
resume_from (str): A checkpoint file to resume the search or evaluation from.
search (bool): Whether search or evaluation phase is checkpointed. This is required
because the files are in different folders to not be overridden
add_checkpointables (object): Additional things to checkpoint together with the
optimizer's checkpointables.
"""
checkpointables = self.optimizer.get_checkpointables()
checkpointables.update(add_checkpointables)
checkpointer = utils.Checkpointer(
model=checkpointables.pop("model"),
save_dir=self.config.save + "/search"
if search
else self.config.save + "/eval",
# **checkpointables #NOTE: this is throwing an Error
)
self.periodic_checkpointer = PeriodicCheckpointer(
checkpointer,
period=period,
max_iter=self.config.search.epochs
if search
else self.config.evaluation.epochs,
)
if resume_from:
logger.info("loading model from file {}".format(resume_from))
# if resume=True starts from the last_checkpoint
# if resume=False starts from the path mentioned as resume_from
checkpoint = checkpointer.resume_or_load(resume_from, resume=False)
if checkpointer.has_checkpoint():
self.optimizer.set_checkpointables(checkpoint)
return checkpoint.get("iteration", -1) + 1
return 0
def _log_to_json(self):
"""log training statistics to json file"""
if not os.path.exists(self.config.save):
os.makedirs(self.config.save)
if not self.lightweight_output:
with codecs.open(
os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8"
) as file:
json.dump(self.search_trajectory, file, separators=(",", ":"))
else:
with codecs.open(
os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8"
) as file:
lightweight_dict = copy.deepcopy(self.search_trajectory)
for key in ["arch_eval", "train_loss", "valid_loss", "test_loss"]:
lightweight_dict.pop(key)
json.dump([self.config, lightweight_dict], file, separators=(",", ":"))