33
44# Copyright Lightning AI. Licensed under the Apache License 2.0,
55# see LICENSE file at https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
6-
6+ import warnings
77from dataclasses import dataclass
88from typing import Any , Literal , Optional , Type
99
1010import torch
1111from typing_extensions import Self
1212
13- import samba_pytorch .samba
1413from samba_pytorch .utils import find_multiple
1514
1615
@@ -101,8 +100,9 @@ def from_name(cls, name: str, **kwargs: Any) -> Self:
101100
102101 @property
103102 def mlp_class (self ) -> Type :
103+ from samba_pytorch import samba
104104 # `self._mlp_class` cannot be the type to keep the config json serializable
105- return getattr (samba_pytorch . samba , self ._mlp_class )
105+ return getattr (samba , self ._mlp_class )
106106
107107 @property
108108 def norm_class (self ) -> Type :
@@ -112,9 +112,12 @@ def norm_class(self) -> Type:
112112
113113 return RMSNorm
114114 elif self ._norm_class == "FusedRMSNorm" :
115- from samba_pytorch .modules .rmsnorm import FusedRMSNorm
115+ warnings .warn (
116+ "FusedRMSNorm has been removed, using standard torch RMSNorm instead"
117+ )
118+ from samba_pytorch .modules .rmsnorm import RMSNorm
116119
117- return FusedRMSNorm
120+ return RMSNorm
118121 return getattr (torch .nn , self ._norm_class )
119122
120123
@@ -133,7 +136,7 @@ def norm_class(self) -> Type:
133136 rotary_percentage = 1.0 ,
134137 parallel_residual = False ,
135138 bias = False ,
136- _norm_class = "FusedRMSNorm " ,
139+ _norm_class = "RMSNorm " ,
137140 norm_eps = 1e-5 ,
138141 _mlp_class = "LLaMAMLP" ,
139142 intermediate_size = 4096 ,
@@ -150,7 +153,7 @@ def norm_class(self) -> Type:
150153 rotary_percentage = 1.0 ,
151154 parallel_residual = False ,
152155 bias = False ,
153- _norm_class = "FusedRMSNorm " ,
156+ _norm_class = "RMSNorm " ,
154157 norm_eps = 1e-5 ,
155158 _mlp_class = "LLaMAMLP" ,
156159 intermediate_size = 4096 ,
@@ -168,7 +171,7 @@ def norm_class(self) -> Type:
168171 rotary_percentage = 1.0 ,
169172 parallel_residual = False ,
170173 bias = False ,
171- _norm_class = "FusedRMSNorm " ,
174+ _norm_class = "RMSNorm " ,
172175 norm_eps = 1e-5 ,
173176 full_per_layer = 2 ,
174177 _mlp_class = "LLaMAMLP" ,
@@ -187,7 +190,7 @@ def norm_class(self) -> Type:
187190 rotary_percentage = 1.0 ,
188191 parallel_residual = False ,
189192 bias = False ,
190- _norm_class = "FusedRMSNorm " ,
193+ _norm_class = "RMSNorm " ,
191194 norm_eps = 1e-5 ,
192195 _mlp_class = "LLaMAMLP" ,
193196 intermediate_size = 4608 ,
@@ -206,7 +209,7 @@ def norm_class(self) -> Type:
206209 rotary_percentage = 1.0 ,
207210 parallel_residual = False ,
208211 bias = False ,
209- _norm_class = "FusedRMSNorm " ,
212+ _norm_class = "RMSNorm " ,
210213 norm_eps = 1e-5 ,
211214 _mlp_class = "LLaMAMLP" ,
212215 intermediate_size = 4608 ,
@@ -225,7 +228,7 @@ def norm_class(self) -> Type:
225228 rotary_percentage = 1.0 ,
226229 parallel_residual = False ,
227230 bias = False ,
228- _norm_class = "FusedRMSNorm " ,
231+ _norm_class = "RMSNorm " ,
229232 norm_eps = 1e-5 ,
230233 _mlp_class = "LLaMAMLP" ,
231234 intermediate_size = 4608 ,
@@ -244,7 +247,7 @@ def norm_class(self) -> Type:
244247 rotary_percentage = 1.0 ,
245248 parallel_residual = False ,
246249 bias = False ,
247- _norm_class = "FusedRMSNorm " ,
250+ _norm_class = "RMSNorm " ,
248251 norm_eps = 1e-5 ,
249252 _mlp_class = "LLaMAMLP" ,
250253 intermediate_size = 4608 ,
@@ -263,7 +266,7 @@ def norm_class(self) -> Type:
263266 rotary_percentage = 1.0 ,
264267 parallel_residual = False ,
265268 bias = False ,
266- _norm_class = "FusedRMSNorm " ,
269+ _norm_class = "RMSNorm " ,
267270 norm_eps = 1e-5 ,
268271 _mlp_class = "LLaMAMLP" ,
269272 intermediate_size = 4096 ,
@@ -280,7 +283,7 @@ def norm_class(self) -> Type:
280283 rotary_percentage = 1.0 ,
281284 parallel_residual = False ,
282285 bias = False ,
283- _norm_class = "FusedRMSNorm " ,
286+ _norm_class = "RMSNorm " ,
284287 norm_eps = 1e-5 ,
285288 _mlp_class = "LLaMAMLP" ,
286289 intermediate_size = 4096 ,
@@ -298,7 +301,7 @@ def norm_class(self) -> Type:
298301 rotary_percentage = 1.0 ,
299302 parallel_residual = False ,
300303 bias = False ,
301- _norm_class = "FusedRMSNorm " ,
304+ _norm_class = "RMSNorm " ,
302305 norm_eps = 1e-5 ,
303306 _mlp_class = "LLaMAMLP" ,
304307 intermediate_size = 4096 ,
@@ -316,7 +319,7 @@ def norm_class(self) -> Type:
316319 rotary_percentage = 1.0 ,
317320 parallel_residual = False ,
318321 bias = False ,
319- _norm_class = "FusedRMSNorm " ,
322+ _norm_class = "RMSNorm " ,
320323 norm_eps = 1e-5 ,
321324 _mlp_class = "LLaMAMLP" ,
322325 intermediate_size = 4096 ,
@@ -335,7 +338,7 @@ def norm_class(self) -> Type:
335338 parallel_residual = True ,
336339 shared_attention_norm = True ,
337340 bias = False ,
338- _norm_class = "FusedRMSNorm " ,
341+ _norm_class = "RMSNorm " ,
339342 norm_eps = 1e-5 ,
340343 _mlp_class = "LLaMAMLP" ,
341344 intermediate_size = 4096 ,
@@ -354,7 +357,7 @@ def norm_class(self) -> Type:
354357 rotary_percentage = 1.0 ,
355358 parallel_residual = False ,
356359 bias = False ,
357- _norm_class = "FusedRMSNorm " ,
360+ _norm_class = "RMSNorm " ,
358361 norm_eps = 1e-5 ,
359362 _mlp_class = "LLaMAMLP" ,
360363 intermediate_size = 4096 ,
@@ -373,7 +376,7 @@ def norm_class(self) -> Type:
373376 rotary_percentage = 1.0 ,
374377 parallel_residual = False ,
375378 bias = False ,
376- _norm_class = "FusedRMSNorm " ,
379+ _norm_class = "RMSNorm " ,
377380 norm_eps = 1e-5 ,
378381 _mlp_class = "LLaMAMLP" ,
379382 intermediate_size = 4096 ,
@@ -393,7 +396,7 @@ def norm_class(self) -> Type:
393396 rotary_percentage = 1.0 ,
394397 parallel_residual = False ,
395398 bias = False ,
396- _norm_class = "FusedRMSNorm " ,
399+ _norm_class = "RMSNorm " ,
397400 norm_eps = 1e-5 ,
398401 _mlp_class = "LLaMAMLP" ,
399402 intermediate_size = 4096 ,
@@ -412,7 +415,7 @@ def norm_class(self) -> Type:
412415 rotary_percentage = 1.0 ,
413416 parallel_residual = False ,
414417 bias = False ,
415- _norm_class = "FusedRMSNorm " ,
418+ _norm_class = "RMSNorm " ,
416419 norm_eps = 1e-5 ,
417420 _mlp_class = "LLaMAMLP" ,
418421 intermediate_size = 4096 ,
@@ -431,7 +434,7 @@ def norm_class(self) -> Type:
431434 rotary_percentage = 1.0 ,
432435 parallel_residual = False ,
433436 bias = False ,
434- _norm_class = "FusedRMSNorm " ,
437+ _norm_class = "RMSNorm " ,
435438 norm_eps = 1e-5 ,
436439 _mlp_class = "LLaMAMLP" ,
437440 intermediate_size = 4096 ,
@@ -450,7 +453,7 @@ def norm_class(self) -> Type:
450453 rotary_percentage = 1.0 ,
451454 parallel_residual = False ,
452455 bias = False ,
453- _norm_class = "FusedRMSNorm " ,
456+ _norm_class = "RMSNorm " ,
454457 norm_eps = 1e-5 ,
455458 _mlp_class = "LLaMAMLP" ,
456459 intermediate_size = 4096 ,
@@ -469,7 +472,7 @@ def norm_class(self) -> Type:
469472 rotary_percentage = 1.0 ,
470473 parallel_residual = False ,
471474 bias = False ,
472- _norm_class = "FusedRMSNorm " ,
475+ _norm_class = "RMSNorm " ,
473476 norm_eps = 1e-5 ,
474477 _mlp_class = "LLaMAMLP" ,
475478 intermediate_size = 4096 ,
@@ -489,7 +492,7 @@ def norm_class(self) -> Type:
489492 rotary_percentage = 1.0 ,
490493 parallel_residual = False ,
491494 bias = False ,
492- _norm_class = "FusedRMSNorm " ,
495+ _norm_class = "RMSNorm " ,
493496 norm_eps = 1e-5 ,
494497 _mlp_class = "LLaMAMLP" ,
495498 intermediate_size = 4608 ,
@@ -510,7 +513,7 @@ def norm_class(self) -> Type:
510513 rotary_percentage = 1.0 ,
511514 parallel_residual = False ,
512515 bias = False ,
513- _norm_class = "FusedRMSNorm " ,
516+ _norm_class = "RMSNorm " ,
514517 norm_eps = 1e-5 ,
515518 _mlp_class = "LLaMAMLP" ,
516519 intermediate_size = 4608 ,
@@ -531,7 +534,7 @@ def norm_class(self) -> Type:
531534 rotary_percentage = 1.0 ,
532535 parallel_residual = False ,
533536 bias = False ,
534- _norm_class = "FusedRMSNorm " ,
537+ _norm_class = "RMSNorm " ,
535538 norm_eps = 1e-5 ,
536539 _mlp_class = "LLaMAMLP" ,
537540 intermediate_size = 4608 ,
@@ -552,7 +555,7 @@ def norm_class(self) -> Type:
552555 rotary_percentage = 1.0 ,
553556 parallel_residual = False ,
554557 bias = False ,
555- _norm_class = "FusedRMSNorm " ,
558+ _norm_class = "RMSNorm " ,
556559 norm_eps = 1e-5 ,
557560 _mlp_class = "LLaMAMLP" ,
558561 intermediate_size = 4608 ,
@@ -573,7 +576,7 @@ def norm_class(self) -> Type:
573576 rotary_percentage = 1.0 ,
574577 parallel_residual = False ,
575578 bias = False ,
576- _norm_class = "FusedRMSNorm " ,
579+ _norm_class = "RMSNorm " ,
577580 norm_eps = 1e-5 ,
578581 _mlp_class = "LLaMAMLP" ,
579582 intermediate_size = 4096 ,
@@ -592,7 +595,7 @@ def norm_class(self) -> Type:
592595 rotary_percentage = 1.0 ,
593596 parallel_residual = False ,
594597 bias = False ,
595- _norm_class = "FusedRMSNorm " ,
598+ _norm_class = "RMSNorm " ,
596599 norm_eps = 1e-5 ,
597600 _mlp_class = "LLaMAMLP" ,
598601 intermediate_size = 4096 ,
@@ -612,7 +615,7 @@ def norm_class(self) -> Type:
612615 rotary_percentage = 1.0 ,
613616 parallel_residual = False ,
614617 bias = False ,
615- _norm_class = "FusedRMSNorm " ,
618+ _norm_class = "RMSNorm " ,
616619 norm_eps = 1e-5 ,
617620 _mlp_class = "LLaMAMLP" ,
618621 intermediate_size = 4096 ,
@@ -632,7 +635,7 @@ def norm_class(self) -> Type:
632635 rotary_percentage = 1.0 ,
633636 parallel_residual = False ,
634637 bias = False ,
635- _norm_class = "FusedRMSNorm " ,
638+ _norm_class = "RMSNorm " ,
636639 norm_eps = 1e-5 ,
637640 _mlp_class = "LLaMAMLP" ,
638641 intermediate_size = 4096 ,
@@ -653,7 +656,7 @@ def norm_class(self) -> Type:
653656 rotary_percentage = 1.0 ,
654657 parallel_residual = False ,
655658 bias = False ,
656- _norm_class = "FusedRMSNorm " ,
659+ _norm_class = "RMSNorm " ,
657660 norm_eps = 1e-5 ,
658661 _mlp_class = "LLaMAMLP" ,
659662 intermediate_size = 6144 ,
@@ -673,7 +676,7 @@ def norm_class(self) -> Type:
673676 rotary_percentage = 1.0 ,
674677 parallel_residual = False ,
675678 bias = False ,
676- _norm_class = "FusedRMSNorm " ,
679+ _norm_class = "RMSNorm " ,
677680 norm_eps = 1e-5 ,
678681 _mlp_class = "LLaMAMLP" ,
679682 intermediate_size = 6144 ,
@@ -693,7 +696,7 @@ def norm_class(self) -> Type:
693696 rotary_percentage = 1.0 ,
694697 parallel_residual = False ,
695698 bias = False ,
696- _norm_class = "FusedRMSNorm " ,
699+ _norm_class = "RMSNorm " ,
697700 norm_eps = 1e-5 ,
698701 _mlp_class = "LLaMAMLP" ,
699702 intermediate_size = 6144 ,
@@ -712,7 +715,7 @@ def norm_class(self) -> Type:
712715 rotary_percentage = 1.0 ,
713716 parallel_residual = False ,
714717 bias = False ,
715- _norm_class = "FusedRMSNorm " ,
718+ _norm_class = "RMSNorm " ,
716719 norm_eps = 1e-5 ,
717720 _mlp_class = "LLaMAMLP" ,
718721 intermediate_size = 6144 ,
@@ -731,7 +734,7 @@ def norm_class(self) -> Type:
731734 rotary_percentage = 1.0 ,
732735 parallel_residual = False ,
733736 bias = False ,
734- _norm_class = "FusedRMSNorm " ,
737+ _norm_class = "RMSNorm " ,
735738 norm_eps = 1e-5 ,
736739 _mlp_class = "LLaMAMLP" ,
737740 intermediate_size = 6144 ,
@@ -750,7 +753,7 @@ def norm_class(self) -> Type:
750753 rotary_percentage = 1.0 ,
751754 parallel_residual = False ,
752755 bias = False ,
753- _norm_class = "FusedRMSNorm " ,
756+ _norm_class = "RMSNorm " ,
754757 norm_eps = 1e-5 ,
755758 _mlp_class = "LLaMAMLP" ,
756759 intermediate_size = 6144 ,
@@ -769,7 +772,7 @@ def norm_class(self) -> Type:
769772 rotary_percentage = 1.0 ,
770773 parallel_residual = False ,
771774 bias = False ,
772- _norm_class = "FusedRMSNorm " ,
775+ _norm_class = "RMSNorm " ,
773776 norm_eps = 1e-5 ,
774777 _mlp_class = "LLaMAMLP" ,
775778 intermediate_size = 6144 ,
@@ -787,7 +790,7 @@ def norm_class(self) -> Type:
787790 rotary_percentage = 1.0 ,
788791 parallel_residual = False ,
789792 bias = False ,
790- _norm_class = "FusedRMSNorm " ,
793+ _norm_class = "RMSNorm " ,
791794 norm_eps = 1e-5 ,
792795 _mlp_class = "LLaMAMLP" ,
793796 intermediate_size = 8192 ,
0 commit comments