Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions warmup_scheduler/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau


class GradualWarmupScheduler(_LRScheduler):
""" Gradually warm-up(increasing) learning rate in optimizer.
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
Expand All @@ -28,20 +27,39 @@ def get_lr(self):
if not self.finished:
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
self.finished = True
return self.after_scheduler.get_last_lr()
return self.after_scheduler.get_lr()
return [base_lr * self.multiplier for base_lr in self.base_lrs]

if self.multiplier == 1.0:
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
else:
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in
self.base_lrs]

def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.

It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
result = {key: value for key, value in self.__dict__.items() if key != 'optimizer' or key != "after_scheduler"}
if self.after_scheduler:
result.update({"after_scheduler": self.after_scheduler.state_dict()})
return result

def load_state_dict(self, state_dict):
after_scheduler_state = state_dict.pop("after_scheduler", None)
self.__dict__.update(state_dict)
if after_scheduler_state:
self.after_scheduler.load_state_dict(after_scheduler_state)

def step_ReduceLROnPlateau(self, metrics, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
if self.last_epoch <= self.total_epoch:
warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in
self.base_lrs]
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
param_group['lr'] = lr
else:
Expand All @@ -57,8 +75,9 @@ def step(self, epoch=None, metrics=None):
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.total_epoch)
self._last_lr = self.after_scheduler.get_last_lr()
self._last_lr = self.after_scheduler.get_lr()
else:
return super(GradualWarmupScheduler, self).step(epoch)
else:
self.step_ReduceLROnPlateau(metrics, epoch)