@@ -751,7 +751,7 @@ def set_log_level(log_level: str = "INFO", include_handlers: bool = False):
751751 set_logger_level (logging .getLogger ("NP" ), log_level , include_handlers )
752752
753753
754- def smooth_loss_and_suggest (lr_finder_results , window = 10 ):
754+ def smooth_loss_and_suggest (lr_finder , window = 10 ):
755755 """
756756 Smooth loss using a Hamming filter.
757757
@@ -769,10 +769,12 @@ def smooth_loss_and_suggest(lr_finder_results, window=10):
769769 suggested_lr: float
770770 Suggested learning rate based on gradient
771771 """
772+ lr_finder_results = lr_finder .results
772773 lr = lr_finder_results ["lr" ]
773774 loss = lr_finder_results ["loss" ]
774775 # Derive window size from num lr searches, ensure window is divisible by 2
775- half_window = math .ceil (round (len (loss ) * 0.1 ) / 2 )
776+ # half_window = math.ceil(round(len(loss) * 0.1) / 2)
777+ half_window = math .ceil (window / 2 )
776778 # Pad sequence and initialialize hamming filter
777779 loss = np .pad (np .array (loss ), pad_width = half_window , mode = "edge" )
778780 window = np .hamming (half_window * 2 )
@@ -798,7 +800,17 @@ def smooth_loss_and_suggest(lr_finder_results, window=10):
798800 "samples or manually set the learning rate."
799801 )
800802 raise
801- return (loss , lr , suggestion )
803+ suggestion_default = lr_finder .suggestion (skip_begin = 10 , skip_end = 3 )
804+ if suggestion is not None and suggestion_default is not None :
805+ log_suggestion_smooth = np .log (suggestion )
806+ log_suggestion_default = np .log (suggestion_default )
807+ lr_suggestion = np .exp ((log_suggestion_smooth + log_suggestion_default ) / 2 )
808+ elif suggestion is None and suggestion_default is None :
809+ log .error ("Automatic learning rate test failed. Please set manually the learning rate." )
810+ raise
811+ else :
812+ lr_suggestion = suggestion if suggestion is not None else suggestion_default
813+ return (loss , lr , lr_suggestion )
802814
803815
804816def _smooth_loss (loss , beta = 0.9 ):
0 commit comments