diff --git a/warmup_scheduler/scheduler.py b/warmup_scheduler/scheduler.py index 2ebdc70..97cfbf5 100644 --- a/warmup_scheduler/scheduler.py +++ b/warmup_scheduler/scheduler.py @@ -1,7 +1,6 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import ReduceLROnPlateau - class GradualWarmupScheduler(_LRScheduler): """ Gradually warm-up(increasing) learning rate in optimizer. Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. @@ -28,20 +27,39 @@ def get_lr(self): if not self.finished: self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] self.finished = True - return self.after_scheduler.get_last_lr() + return self.after_scheduler.get_lr() return [base_lr * self.multiplier for base_lr in self.base_lrs] if self.multiplier == 1.0: return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] else: - return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in + self.base_lrs] + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + result = {key: value for key, value in self.__dict__.items() if key != 'optimizer' or key != "after_scheduler"} + if self.after_scheduler: + result.update({"after_scheduler": self.after_scheduler.state_dict()}) + return result + + def load_state_dict(self, state_dict): + after_scheduler_state = state_dict.pop("after_scheduler", None) + self.__dict__.update(state_dict) + if after_scheduler_state: + self.after_scheduler.load_state_dict(after_scheduler_state) def step_ReduceLROnPlateau(self, metrics, epoch=None): if epoch is None: epoch = self.last_epoch + 1 self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning if self.last_epoch <= self.total_epoch: - warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] + warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in + self.base_lrs] for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): param_group['lr'] = lr else: @@ -57,8 +75,9 @@ def step(self, epoch=None, metrics=None): self.after_scheduler.step(None) else: self.after_scheduler.step(epoch - self.total_epoch) - self._last_lr = self.after_scheduler.get_last_lr() + self._last_lr = self.after_scheduler.get_lr() else: return super(GradualWarmupScheduler, self).step(epoch) else: self.step_ReduceLROnPlateau(metrics, epoch) +