File size: 38,258 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import dataclasses
import inspect
import math
import warnings
from functools import partial
from typing import Any, Dict, Optional, Union

import hydra
import torch.optim as optim
import torch.optim.lr_scheduler as pt_scheduler
import torch.utils.data.dataloader as dataloader
from omegaconf import DictConfig, OmegaConf
from torch.optim.lr_scheduler import _LRScheduler

from nemo.core.config import SchedulerParams, get_scheduler_config, register_scheduler_params
from nemo.utils import logging
from nemo.utils.model_utils import maybe_update_config_version


class WarmupPolicy(_LRScheduler):
    """Adds warmup kwargs and warmup logic to lr policy.
    All arguments should be passed as kwargs for clarity,
    Args:
        warmup_steps: Number of training steps in warmup stage
        warmup_ratio: Ratio of warmup steps to total steps
        max_steps: Total number of steps while training or `None` for
            infinite training
    """

    def __init__(self, optimizer, *, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1):
        assert not (
            warmup_steps is not None and warmup_ratio is not None
        ), "Either use particular number of step or ratio"
        assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps"

        # It is necessary to assign all attributes *before* __init__,
        # as class is wrapped by an inner class.
        self.max_steps = max_steps
        if warmup_steps is not None:
            self.warmup_steps = warmup_steps
        elif warmup_ratio is not None:
            self.warmup_steps = int(warmup_ratio * max_steps)
        else:
            self.warmup_steps = 0

        self.min_lr = min_lr
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
            )

        step = self.last_epoch

        if step <= self.warmup_steps and self.warmup_steps > 0:
            return self._get_warmup_lr(step)

        if (self.max_steps is not None) and (step > self.max_steps):
            return [self.min_lr for _ in self.base_lrs]

        return self._get_lr(step)

    def _get_warmup_lr(self, step):
        lr_val = (step + 1) / (self.warmup_steps + 1)
        return [initial_lr * lr_val for initial_lr in self.base_lrs]

    def _get_lr(self, step):
        """Simple const lr policy"""
        return self.base_lrs


class SquareRootConstantPolicy(_LRScheduler):
    """Adds warmup kwargs and warmup logic to lr policy.
    All arguments should be passed as kwargs for clarity,
    Args:
        warmup_steps: Number of training steps in warmup stage
        warmup_ratio: Ratio of warmup steps to total steps
        max_steps: Total number of steps while training or `None` for
            infinite training
    """

    def __init__(
        self, optimizer, *, constant_steps=None, constant_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1
    ):
        assert not (
            constant_steps is not None and constant_ratio is not None
        ), "Either use particular number of step or ratio"
        assert constant_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps"

        # It is necessary to assign all attributes *before* __init__,
        # as class is wrapped by an inner class.
        self.max_steps = max_steps
        if constant_steps is not None:
            self.constant_steps = constant_steps
        elif constant_ratio is not None:
            self.constant_steps = int(constant_ratio * max_steps)
        else:
            self.constant_steps = 0

        self.constant_lr = 1 / (constant_steps ** 0.5)
        self.min_lr = min_lr
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
            )

        step = self.last_epoch

        if step <= self.constant_steps:
            return [self.constant_lr for _ in self.base_lrs]

        if step > self.max_steps:
            return [self.min_lr for _ in self.base_lrs]

        return self._get_lr(step)

    def _get_lr(self, step):
        """Simple const lr policy"""
        return self.base_lrs


class WarmupHoldPolicy(WarmupPolicy):
    """Variant of WarmupPolicy which maintains high learning rate for a defined number of steps.
    All arguments should be passed as kwargs for clarity,
    Args:
        warmup_steps: Number of training steps in warmup stage
        warmup_ratio: Ratio of warmup steps to total steps
        hold_steps: Number of training steps to hold the learning rate after warm up
        hold_ratio: Ratio of hold steps to total steps
        max_steps: Total number of steps while training or `None` for
            infinite training
    """

    def __init__(
        self,
        optimizer,
        *,
        warmup_steps=None,
        warmup_ratio=None,
        hold_steps=None,
        hold_ratio=None,
        max_steps=None,
        min_lr=0.0,
        last_epoch=-1,
    ):
        assert not (hold_steps is not None and hold_ratio is not None), "Either use particular number of step or ratio"
        assert hold_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps"

        self.min_lr = min_lr
        self._last_warmup_lr = 0.0

        # Necessary to duplicate as class attributes are hidden in inner class
        self.max_steps = max_steps
        if warmup_steps is not None:
            self.warmup_steps = warmup_steps
        elif warmup_ratio is not None:
            self.warmup_steps = int(warmup_ratio * max_steps)
        else:
            self.warmup_steps = 0

        if hold_steps is not None:
            self.hold_steps = hold_steps + self.warmup_steps
        elif hold_ratio is not None:
            self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
        else:
            self.hold_steps = 0

        super().__init__(
            optimizer,
            warmup_steps=warmup_steps,
            warmup_ratio=warmup_ratio,
            max_steps=max_steps,
            last_epoch=last_epoch,
            min_lr=min_lr,
        )

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning
            )

        step = self.last_epoch

        # Warmup phase
        if step <= self.warmup_steps and self.warmup_steps > 0:
            return self._get_warmup_lr(step)

        # Hold phase
        if (step >= self.warmup_steps) and (step < self.hold_steps):
            return self.base_lrs

        if step > self.max_steps:
            return [self.min_lr for _ in self.base_lrs]

        return self._get_lr(step)


class WarmupAnnealHoldPolicy(_LRScheduler):
    """Adds warmup kwargs and warmup logic to lr policy.
    All arguments should be passed as kwargs for clarity,
    Args:
        warmup_steps: Number of training steps in warmup stage
        warmup_ratio: Ratio of warmup steps to total steps
        max_steps: Total number of steps while training or `None` for
            infinite training
        min_lr: Minimum lr to hold the learning rate after decay at.
        constant_steps: Number of steps to keep lr constant at.
        constant_ratio: Ratio of steps to keep lr constant.
    """

    def __init__(
        self,
        optimizer,
        *,
        warmup_steps=None,
        warmup_ratio=None,
        constant_steps=None,
        constant_ratio=None,
        max_steps=None,
        min_lr=0.0,
        last_epoch=-1,
    ):
        assert not (
            warmup_steps is not None and warmup_ratio is not None
        ), "Either use particular number of step or ratio"
        assert not (
            constant_steps is not None and constant_ratio is not None
        ), "Either use constant_steps or constant_ratio"
        assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps"

        # It is necessary to assign all attributes *before* __init__,
        # as class is wrapped by an inner class.
        self.max_steps = max_steps

        if warmup_steps is not None:
            self.warmup_steps = warmup_steps
        elif warmup_ratio is not None:
            self.warmup_steps = int(warmup_ratio * max_steps)
        else:
            self.warmup_steps = 0

        if constant_steps is not None:
            self.constant_steps = constant_steps
        elif constant_ratio is not None:
            self.constant_steps = int(constant_ratio * max_steps)
        else:
            self.constant_steps = 0

        self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps)

        self.min_lr = min_lr
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
            )

        step = self.last_epoch

        # Warmup steps
        if self.warmup_steps > 0 and step <= self.warmup_steps:
            return self._get_warmup_lr(step)

        # Constant steps after warmup and decay
        if self.constant_steps > 0 and (self.warmup_steps + self.decay_steps) < step <= self.max_steps:
            return self._get_constant_lr(step)

        # Min lr after max steps of updates
        if step > self.max_steps:
            return [self.min_lr for _ in self.base_lrs]

        return self._get_lr(step)

    def _get_warmup_lr(self, step):
        lr_val = (step + 1) / (self.warmup_steps + 1)
        return [initial_lr * lr_val for initial_lr in self.base_lrs]

    def _get_constant_lr(self, step):
        return [self.min_lr for _ in self.base_lrs]

    def _get_lr(self, step):
        """Simple const lr policy"""
        return self.base_lrs


def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
    mult = ((max_steps - step) / max_steps) ** 0.5
    out_lr = initial_lr * mult
    out_lr = max(out_lr, min_lr)
    return out_lr


def _square_annealing(initial_lr, step, max_steps, min_lr):
    mult = ((max_steps - step) / max_steps) ** 2
    out_lr = initial_lr * mult
    out_lr = max(out_lr, min_lr)
    return out_lr


def _cosine_annealing(initial_lr, step, max_steps, min_lr):
    mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
    out_lr = (initial_lr - min_lr) * mult + min_lr
    return out_lr


def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step, decay_steps, min_lr):

    assert max_lr > min_lr
    # Use linear warmup for the initial part.
    if warmup_steps > 0 and step <= warmup_steps:
        return max_lr * float(step) / float(warmup_steps)

    # For any steps larger than `decay_steps`, use `min_lr`.
    if step > warmup_steps + decay_steps:
        return min_lr

    # If we are done with the warmup period, use the decay style.
    num_steps_ = step - warmup_steps
    decay_steps_ = decay_steps
    decay_ratio = float(num_steps_) / float(decay_steps_)
    assert decay_ratio >= 0.0
    assert decay_ratio <= 1.0
    delta_lr = max_lr - min_lr

    coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)

    return min_lr + coeff * delta_lr


def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
    if cycle:
        multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
        decay_steps *= multiplier
    else:
        step = min(step, decay_steps)
    p = step / decay_steps
    lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
    lr += min_lr
    return lr


def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps, decay_rate, min_lr):
    # hold_steps = total number of steps to hold the LR, not the warmup + hold steps.
    T_warmup_decay = max(1, warmup_steps ** decay_rate)
    T_hold_decay = max(1, (step - hold_steps) ** decay_rate)
    lr = (initial_lr * T_warmup_decay) / T_hold_decay
    lr = max(lr, min_lr)
    return lr


class SquareAnnealing(WarmupPolicy):
    def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, **kwargs):
        super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)

    def _get_lr(self, step):
        new_lrs = [
            _square_annealing(
                initial_lr=initial_lr,
                step=step - self.warmup_steps,
                max_steps=self.max_steps - self.warmup_steps,
                min_lr=self.min_lr,
            )
            for initial_lr in self.base_lrs
        ]
        return new_lrs


class SquareRootAnnealing(WarmupPolicy):
    def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
        super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)

    def _get_lr(self, step):
        new_lrs = [
            _squareroot_annealing(initial_lr=initial_lr, step=step, max_steps=self.max_steps, min_lr=self.min_lr)
            for initial_lr in self.base_lrs
        ]
        return new_lrs


class CosineAnnealing(WarmupAnnealHoldPolicy):
    def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
        super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)

    def _get_lr(self, step):
        for initial_lr in self.base_lrs:
            if initial_lr < self.min_lr:
                raise ValueError(
                    f"{self} received an initial learning rate that was lower than the minimum learning rate."
                )

        if self.constant_steps is None or self.constant_steps == 0:
            new_lrs = [
                _cosine_annealing(
                    initial_lr=initial_lr,
                    step=step - self.warmup_steps,
                    max_steps=self.max_steps - self.warmup_steps,
                    min_lr=self.min_lr,
                )
                for initial_lr in self.base_lrs
            ]
        else:
            new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step)
        return new_lrs

    def _get_warmup_lr(self, step):
        if self.constant_steps is None or self.constant_steps == 0:
            return super()._get_warmup_lr(step)
        else:
            # Use linear warmup for the initial part.
            return self._get_linear_warmup_with_cosine_annealing_lr(step)

    def _get_constant_lr(self, step):
        # Only called when `constant_steps` > 0.
        return self._get_linear_warmup_with_cosine_annealing_lr(step)

    def _get_linear_warmup_with_cosine_annealing_lr(self, step):
        # Cosine Schedule for Megatron LM, slightly different warmup schedule + constant LR at the end.
        new_lrs = [
            _linear_warmup_with_cosine_annealing(
                max_lr=self.base_lrs[0],
                warmup_steps=self.warmup_steps,
                step=step,
                decay_steps=self.decay_steps,
                min_lr=self.min_lr,
            )
            for _ in self.base_lrs
        ]
        return new_lrs


class NoamAnnealing(_LRScheduler):
    def __init__(
        self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1
    ):
        self._normalize = d_model ** (-0.5)
        assert not (
            warmup_steps is not None and warmup_ratio is not None
        ), "Either use particular number of step or ratio"
        assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps"

        # It is necessary to assign all attributes *before* __init__,
        # as class is wrapped by an inner class.
        self.max_steps = max_steps
        if warmup_steps is not None:
            self.warmup_steps = warmup_steps
        elif warmup_ratio is not None:
            self.warmup_steps = int(warmup_ratio * max_steps)
        else:
            self.warmup_steps = 0

        self.min_lr = min_lr
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
            )

        step = max(1, self.last_epoch)

        for initial_lr in self.base_lrs:
            if initial_lr < self.min_lr:
                raise ValueError(
                    f"{self} received an initial learning rate that was lower than the minimum learning rate."
                )

        new_lrs = [self._noam_annealing(initial_lr=initial_lr, step=step) for initial_lr in self.base_lrs]
        return new_lrs

    def _noam_annealing(self, initial_lr, step):
        if self.warmup_steps > 0:
            mult = self._normalize * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5)))
        else:
            mult = self._normalize * step ** (-0.5)

        out_lr = initial_lr * mult
        if step > self.warmup_steps:
            out_lr = max(out_lr, self.min_lr)
        return out_lr


class NoamHoldAnnealing(WarmupHoldPolicy):
    def __init__(self, optimizer, *, max_steps, decay_rate=0.5, min_lr=0.0, last_epoch=-1, **kwargs):
        """
        Implementation of the Noam Hold Annealing policy from the SqueezeFormer paper.

        Unlike NoamAnnealing, the peak learning rate can be explicitly set for this scheduler.
        The schedule first performs linear warmup, then holds the peak LR, then decays with some schedule for
        the remainder of the steps. Therefore the min-lr is still dependent on the hyper parameters selected.

        It's schedule is determined by three factors-

        Warmup Steps: Initial stage, where linear warmup occurs uptil the peak LR is reached. Unlike NoamAnnealing,
            the peak LR is explicitly stated here instead of a scaling factor.

        Hold Steps: Intermediate stage, where the peak LR is maintained for some number of steps. In this region,
            the high peak LR allows the model to converge faster if training is stable. However the high LR
            may also cause instability during training. Should usually be a significant fraction of training
            steps (around 30-40% of the entire training steps).

        Decay Steps: Final stage, where the LR rapidly decays with some scaling rate (set by decay rate).
            To attain Noam decay, use 0.5, for Squeezeformer recommended decay, use 1.0. The fast decay after
            prolonged high LR during hold phase allows for rapid convergence.

        References:
            - [Squeezeformer: An Efficient Transformer for Automatic Speech Recognition](https://arxiv.org/abs/2206.00888)

        Args:
            optimizer: Pytorch compatible Optimizer object.
            warmup_steps: Number of training steps in warmup stage
            warmup_ratio: Ratio of warmup steps to total steps
            hold_steps: Number of training steps to hold the learning rate after warm up
            hold_ratio: Ratio of hold steps to total steps
            max_steps: Total number of steps while training or `None` for
                infinite training
            decay_rate: Float value describing the polynomial decay after the hold period. Default value
                of 0.5 corresponds to Noam decay.
            min_lr: Minimum learning rate.
        """
        self.decay_rate = decay_rate
        super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)

    def _get_lr(self, step):
        if self.warmup_steps is None or self.warmup_steps == 0:
            raise ValueError("Noam scheduler cannot be used without warmup steps")

        if self.hold_steps > 0:
            hold_steps = self.hold_steps - self.warmup_steps
        else:
            hold_steps = 0

        new_lrs = [
            _noam_hold_annealing(
                initial_lr,
                step=step,
                warmup_steps=self.warmup_steps,
                hold_steps=hold_steps,
                decay_rate=self.decay_rate,
                min_lr=self.min_lr,
            )
            for initial_lr in self.base_lrs
        ]
        return new_lrs


class WarmupAnnealing(WarmupPolicy):
    def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs):
        super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)

    def _get_lr(self, step):
        delta_lr = self.base_lrs[0] - self.min_lr
        mult = (step - self.warmup_steps) / (self.max_steps - self.warmup_steps)
        out_lr = [self.min_lr + (1 - mult) * delta_lr for _ in self.base_lrs]
        return out_lr


class InverseSquareRootAnnealing(WarmupPolicy):
    def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs):
        super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr)

    def _get_lr(self, step):
        denom = ((step + 1) / (self.warmup_steps + 1)) ** 0.5
        out_lr = [initial_lr / denom for initial_lr in self.base_lrs]
        return out_lr


class T5InverseSquareRootAnnealing(SquareRootConstantPolicy):
    def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs):
        super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr)

    def _get_lr(self, step):
        return [1 / (step ** 0.5) for _ in self.base_lrs]


class PolynomialDecayAnnealing(WarmupPolicy):
    def __init__(self, optimizer, *, max_steps, min_lr=0.0, power=1.0, cycle=False, last_epoch=-1, **kwargs):
        self.power = power
        self.cycle = cycle

        super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)

    def _get_lr(self, step):
        new_lrs = [
            _poly_decay(
                initial_lr,
                step=step - self.warmup_steps,
                decay_steps=self.max_steps - self.warmup_steps,
                power=self.power,
                min_lr=self.min_lr,
                cycle=self.cycle,
            )
            for initial_lr in self.base_lrs
        ]
        return new_lrs


class PolynomialHoldDecayAnnealing(WarmupHoldPolicy):
    def __init__(self, optimizer, *, max_steps, min_lr=0.0, power=1.0, cycle=False, last_epoch=-1, **kwargs):
        self.power = power
        self.cycle = cycle

        super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)

    def _get_lr(self, step):
        new_lrs = [
            _poly_decay(
                initial_lr,
                step=step - self.hold_steps,
                decay_steps=self.max_steps - max(self.warmup_steps, self.hold_steps),
                power=self.power,
                min_lr=self.min_lr,
                cycle=self.cycle,
            )
            for initial_lr in self.base_lrs
        ]
        return new_lrs


def register_scheduler(name: str, scheduler: _LRScheduler, scheduler_params: SchedulerParams):
    """
    Checks if the scheduler name exists in the registry, and if it doesnt, adds it.

    This allows custom schedulers to be added and called by name during instantiation.

    Args:
        name: Name of the optimizer. Will be used as key to retrieve the optimizer.
        scheduler: Scheduler class (inherits from _LRScheduler)
        scheduler_params: The parameters as a dataclass of the scheduler
    """
    if name in AVAILABLE_SCHEDULERS:
        raise ValueError(f"Cannot override pre-existing schedulers. Conflicting scheduler name = {name}")

    AVAILABLE_SCHEDULERS[name] = scheduler

    sched_name = "{}_params".format(scheduler.__name__)
    register_scheduler_params(name=sched_name, scheduler_params=scheduler_params)


def get_scheduler(name: str, **kwargs: Optional[Dict[str, Any]]) -> _LRScheduler:
    """
    Convenience method to obtain an _LRScheduler class and partially instantiate it with optimizer kwargs.

    Args:
        name: Name of the scheduler in the registry.
        kwargs: Optional kwargs of the scheduler used during instantiation.

    Returns:
        a partially instantiated _LRScheduler
    """
    if name not in AVAILABLE_SCHEDULERS:
        raise ValueError(
            f"Cannot resolve scheduler{name}'. Available optimizers are : " f"{AVAILABLE_SCHEDULERS.keys()}"
        )

    scheduler_cls = AVAILABLE_SCHEDULERS[name]
    # Pop 'max_steps' if it's not required by the scheduler
    if 'max_steps' in kwargs and 'max_steps' not in inspect.signature(scheduler_cls).parameters:
        kwargs.pop('max_steps')
    scheduler = partial(scheduler_cls, **kwargs)
    return scheduler


def prepare_lr_scheduler(
    optimizer: optim.Optimizer,
    scheduler_config: Union[Dict[str, Any], DictConfig],
    train_dataloader: Optional[dataloader.DataLoader] = None,
) -> Optional[Dict[str, Any]]:
    """
    Constructs an LR Scheduler (optionally) for a given optimizer, based on a config with the following schema

    optim:
      name: <name of optimizer>
      lr: <maximal learning rate>

      # <additional optimizer arguments>
      args:
        name: auto  # special keyword, resolves to correct optimizer config for given optimizer name
        # cls: nemo.core.config.optimizers.NovogradParams  # explicit instantiation by class path
        params:  # optional override parameters for the optimizer config
          betas: [0.8, 0.5]
          weight_decay: 0.001

      # scheduler setup
      sched:
        name: <name of scheduler>
        iters_per_batch: null # computed at runtime; mandatory to have
        max_steps: -1 # computed at runtime or explicitly set here; mandatory to have

        # pytorch lightning args <mandatory>
        monitor: val_loss
        reduce_on_plateau: false

        # <scheduler config override>
        args:
          name: auto  # special keyword, resolves to correct optimizer config for given optimizer name
          # cls: nemo.core.config.schedulers.CosineAnnealingParams  # explicit instantiation by class path
          params:  # optional override parameters for the optimizer config
            warmup_steps: null
            warmup_ratio: null
            min_lr: 0.0
            last_epoch: -1

    Args:
        optimizer: An instantiated Optimizer.
        scheduler_config: A dictionary / config dict which follows the above schema.
        train_dataloader: Optional requirement, must be passed if "iters_per_batch" is defined
            instead of "max_steps". Used to compute effective "max_steps".

    Returns:
        A dictionary containing the LR Scheduler implementation if the config was successfully parsed
        along with other parameters required by Pytorch Lightning, otherwise None.
    """
    if scheduler_config is not None:
        scheduler_config = maybe_update_config_version(scheduler_config)

    # Build nested dictionary for convenience out of structured objects
    if isinstance(scheduler_config, DictConfig):
        scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True)

    elif dataclasses.is_dataclass(scheduler_config):
        # Recursively transform data classes to basic dictionaries
        scheduler_config = OmegaConf.create(scheduler_config)
        scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True)

    # Test to see if config follows above schema
    interval = 'step'
    if scheduler_config is not None:
        if 'args' in scheduler_config:
            scheduler_args = scheduler_config.pop('args')
        else:
            scheduler_args = copy.deepcopy(scheduler_config)

            # Remove extra parameters from scheduler_args nest
            # Assume all other parameters are to be passed into scheduler constructor
            scheduler_args.pop('name', None)
            scheduler_args.pop('t_max_epochs', None)
            scheduler_args.pop('t_accumulate_grad_batches', None)
            scheduler_args.pop('t_limit_train_batches', None)
            scheduler_args.pop('t_num_workers', None)
            scheduler_args.pop('monitor', None)
            scheduler_args.pop('reduce_on_plateau', None)

        if 'name' in scheduler_config and scheduler_config['name'] in EPOCH_SCHEDULERS:
            interval = 'epoch'

    else:
        # Return gracefully in case `sched` was not supplied; inform user
        logging.info('Scheduler not initialized as no `sched` config supplied to setup_optimizer()')
        return None

    # Try instantiation of scheduler params from config class path
    if '_target_' in scheduler_args:
        scheduler_args_cfg = OmegaConf.create(scheduler_args)
        scheduler_conf = hydra.utils.instantiate(scheduler_args_cfg)
        scheduler_args = vars(scheduler_conf)

        # Get name of the scheduler
        scheduler_name = scheduler_conf.__class__.__name__

        if 'Params' in scheduler_name:
            scheduler_name = scheduler_name.replace('Params', '')

    else:
        # Class path instantiation failed; try resolving "name" component

        # Get name of the scheduler
        if 'name' in scheduler_config:
            scheduler_name = scheduler_config['name']
        else:
            logging.warning(
                "Could not resolve classpath for Scheduler Config, and `name` "
                "was not provided either. \n"
                "Scheduler cannot be instantiated !"
            )
            return None

        # If class path was not provided, perhaps `name` is provided for resolution
        if 'name' in scheduler_args:
            # If `auto` is passed as name for resolution of optimizer name,
            # then lookup optimizer name and resolve its parameter config
            if scheduler_args['name'] == 'auto':
                scheduler_params_name = "{}Params".format(scheduler_name)
            else:
                scheduler_params_name = scheduler_args['name']

            # Get override arguments provided in the config yaml file / Dict Config
            scheduler_params_override = scheduler_args.get('params', {})

            # If params is itself a dict config object provided explicitly in Dict Config
            # Resolve to dictionary for convenience
            if isinstance(scheduler_params_override, DictConfig):
                scheduler_params_override = OmegaConf.to_container(scheduler_params_override, resolve=True)

            # Get and instantiate the Config dataclass for this scheduler
            scheduler_params_cls = get_scheduler_config(scheduler_params_name, **scheduler_params_override)
            scheduler_params = scheduler_params_cls()  # instantiate the parameters object
            scheduler_args = vars(scheduler_params)  # extract just the dictionary from the Config object

        else:
            # assume the input dictionary is schedular args (from dataclasses / omegaconf)
            pass

    # Extract value to monitor in losses, if provided.
    if 'monitor' in scheduler_config:
        monitor = scheduler_config.get('monitor')
    else:
        # Default to train loss
        monitor = 'loss'

    # Store exact max_steps if it is provided
    if 'max_steps' in scheduler_config and scheduler_config['max_steps'] is not None:
        max_steps = scheduler_config['max_steps']

    elif 't_max_epochs' in scheduler_config:
        # Compute effective max_steps if t_max_epochs is provided
        if train_dataloader is None:
            logging.warning(
                'As `t_max_epochs` is provided/computed, it is required to pass the train dataloader in order\n'
                'to compute effective maximum number of steps.\n'
                'Scheduler will not be instantiated !'
            )
            return None

        # Raise exception if neither `max_steps` nor `t_max_epochs` is provided
        if scheduler_config.get('t_max_epochs', None) is None:
            logging.warning(
                "`t_max_epochs` cannot be None when `max_steps` is not not provided.\n"
                "This can occur when `train dataloader` is not available to correctly "
                "prepare the scheduler.\n"
                "Scheduler will not be instantiated !"
            )
            return None

        # Get iters_per_batch
        max_epochs = scheduler_config.get('t_max_epochs')
        accumulate_grad_batches = scheduler_config.get('t_accumulate_grad_batches')
        limit_train_batches = scheduler_config.get('t_limit_train_batches')
        num_workers = scheduler_config.get('t_num_workers')

        # Compute effective num max_steps
        num_samples = len(train_dataloader.dataset)
        # TODO: not sure if this will be the correct LR schedule for Megatron
        # we may need to override ModelPT setup_optimization
        if train_dataloader.batch_size is not None:
            batch_size = train_dataloader.batch_size
        elif hasattr(train_dataloader, 'batch_sampler') and train_dataloader.batch_sampler is not None:
            if train_dataloader.batch_sampler.micro_batch_size is not None:
                batch_size = train_dataloader.batch_sampler.micro_batch_size
            else:
                raise ValueError(f'Could not find batch_size from batch_sampler: {train_dataloader.batch_sampler}')
        else:
            raise ValueError(f'Could not find batch_size from train_dataloader: {train_dataloader}')
        drop_last = train_dataloader.drop_last

        max_steps = compute_max_steps(
            max_epochs=max_epochs,
            accumulate_grad_batches=accumulate_grad_batches,
            limit_train_batches=limit_train_batches,
            num_workers=num_workers,
            num_samples=num_samples,
            batch_size=batch_size,
            drop_last=drop_last,
        )

    else:
        logging.warning(
            "Neither `max_steps` nor `iters_per_batch` were provided to `optim.sched`, "
            "cannot compute effective `max_steps` !\n"
            "Scheduler will not be instantiated !"
        )
        return None

    # Inject max_steps (effective or provided) into the scheduler config
    scheduler_args['max_steps'] = max_steps

    # Get the scheduler class from the config
    scheduler_cls = get_scheduler(scheduler_name, **scheduler_args)

    # Pop 'max_steps' if it's not required by the scheduler
    if 'max_steps' not in inspect.signature(scheduler_cls).parameters:
        scheduler_args.pop('max_steps')

    # Instantiate the LR schedule
    schedule = scheduler_cls(optimizer, **scheduler_args)

    logging.info(
        'Scheduler "%s" \nwill be used during training (effective maximum steps = %d) - \nParameters : \n(%s)',
        str(schedule),
        max_steps,
        OmegaConf.to_yaml(OmegaConf.create(scheduler_args)),
    )

    # Wrap the schedule in PTL arguments to perform stepwise computation
    # Rather than epoch level computation
    if isinstance(schedule, optim.lr_scheduler.ReduceLROnPlateau):
        reduce_lr_on_plateau = True
    else:
        reduce_lr_on_plateau = False

    schedule_dict = {
        'scheduler': schedule,
        'interval': interval,
        'frequency': 1,
        'monitor': monitor,
        'reduce_on_plateau': reduce_lr_on_plateau,
    }
    return schedule_dict


def compute_max_steps(
    max_epochs, accumulate_grad_batches, limit_train_batches, num_workers, num_samples, batch_size, drop_last
):
    _round = math.floor if drop_last else math.ceil

    sampler_num_samples = math.ceil(num_samples / max(1, num_workers))

    if drop_last and num_workers > 1:
        logging.warning(
            "Please note that drop_last is broken in pytorch 1.6.0. We will fix when pytorch 1.7.0 is released"
        )
        # TODO: Master version, not in pytorch 1.6.0
        # sampler_num_samples = math.ceil((num_samples - num_workers)/ num_workers)

    steps_per_epoch = _round(sampler_num_samples / batch_size)
    if isinstance(limit_train_batches, int) or limit_train_batches == 0.0:
        steps_per_epoch = min(steps_per_epoch, int(limit_train_batches))
    elif steps_per_epoch != float('inf'):
        # limit_train_batches is a percentage of batches per epoch
        steps_per_epoch = int(steps_per_epoch * limit_train_batches)

    return math.ceil(steps_per_epoch / accumulate_grad_batches) * max_epochs


AVAILABLE_SCHEDULERS = {
    'WarmupPolicy': WarmupPolicy,
    'WarmupHoldPolicy': WarmupHoldPolicy,
    'SquareAnnealing': SquareAnnealing,
    'CosineAnnealing': CosineAnnealing,
    'NoamAnnealing': NoamAnnealing,
    'NoamHoldAnnealing': NoamHoldAnnealing,
    'WarmupAnnealing': WarmupAnnealing,
    'InverseSquareRootAnnealing': InverseSquareRootAnnealing,
    'T5InverseSquareRootAnnealing': T5InverseSquareRootAnnealing,
    'SquareRootAnnealing': SquareRootAnnealing,
    'PolynomialDecayAnnealing': PolynomialDecayAnnealing,
    'PolynomialHoldDecayAnnealing': PolynomialHoldDecayAnnealing,
    'StepLR': pt_scheduler.StepLR,
    'ExponentialLR': pt_scheduler.ExponentialLR,
    'ReduceLROnPlateau': pt_scheduler.ReduceLROnPlateau,
    'CyclicLR': pt_scheduler.CyclicLR,
}

EPOCH_SCHEDULERS = {
    'ReduceLROnPlateau': pt_scheduler.ReduceLROnPlateau,
}