2626 year={2025}
2727}
2828
29+ @inproceedings{heo2024rotary,
30+ title={Rotary position embedding for vision transformer},
31+ author={Heo, Byeongho and Park, Song and Han, Dongyoon and Yun, Sangdoo},
32+ booktitle={European Conference on Computer Vision},
33+ pages={289--305},
34+ year={2024},
35+ organization={Springer}
36+ }
37+
2938This file contains a number of ViT variants the utilise ROPE position embeddings, SwiGLU and other additions:
3039 * EVA & EVA02 model implementations that evolved from BEiT, additional models in vision_transformer.py.
3140 * `timm` original SBB ViT w/ ROPE position embeddings
3241 * Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)
42+ * ROPE-ViT from Naver AI (https://arxiv.org/abs/2403.13298)
3343
3444Modifications by / Copyright 2023 Ross Wightman, original copyrights below
3545"""
@@ -773,7 +783,7 @@ def forward_intermediates(
773783 else :
774784 blocks = self .blocks [:max_index + 1 ]
775785 # Handle depth-dependent embeddings for mixed mode
776- if self . rope_mixed and rot_pos_embed is not None :
786+ if getattr ( self , ' rope_mixed' , False ) and rot_pos_embed is not None :
777787 for i , blk in enumerate (blocks ):
778788 if self .grad_checkpointing and not torch .jit .is_scripting ():
779789 x = checkpoint (blk , x , rope = rot_pos_embed [i ])
@@ -850,7 +860,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
850860 x = self .norm_pre (x )
851861
852862 # Handle depth-dependent embeddings for mixed mode
853- if self . rope_mixed and rot_pos_embed is not None :
863+ if getattr ( self , ' rope_mixed' , False ) and rot_pos_embed is not None :
854864 # rot_pos_embed has shape (depth, H*W, dim) for mixed mode
855865 for i , blk in enumerate (self .blocks ):
856866 if self .grad_checkpointing and not torch .jit .is_scripting ():
@@ -991,23 +1001,6 @@ def checkpoint_filter_fn(
9911001 state_dict = state_dict .get ('module' , state_dict )
9921002 state_dict = state_dict .get ('state_dict' , state_dict )
9931003
994- # FIXME remove after conversion, check if this is a rope-vit checkpoint
995- if 'freqs' in state_dict :
996- # Handle rope-vit specific conversions
997- for k , v in state_dict .items ():
998- # Skip rope-vit specific buffers
999- if any ([kk in k for kk in ('freqs_t_x' , 'freqs_t_y' )]):
1000- continue
1001- # Handle mixed mode frequency parameters
1002- if k == 'freqs' :
1003- # Check if model uses mixed mode by looking at other keys or freqs shape
1004- # Mixed mode has learnable freqs, axial mode doesn't use them
1005- k = 'rope.freqs'
1006- model_shape = model .state_dict ().get (k ).shape
1007- v = v .reshape (model_shape )
1008- out_dict [k ] = v
1009- return out_dict
1010-
10111004 # Loading Meta PE (Perception Encoder) weights
10121005 if 'visual.conv1.weight' in state_dict :
10131006 return _convert_pe (state_dict , model )
@@ -1031,7 +1024,7 @@ def checkpoint_filter_fn(
10311024 continue
10321025 k = k [len_prefix :]
10331026
1034- if 'rope' in k :
1027+ if 'rope' in k and not k == 'rope.freqs' :
10351028 # fixed embedding no need to load buffer from checkpoint
10361029 continue
10371030
@@ -1375,76 +1368,76 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]:
13751368
13761369 # RoPE-ViT models from Naver
13771370 'vit_small_patch16_rope_224.naver_in1k' : _cfg (
1378- hf_hub_id = 'naver-ai/rope_axial_deit_small_patch16_LS' ,
1379- hf_hub_filename = 'pytorch_model.bin' ,
1371+ hf_hub_id = 'timm/' ,
13801372 mean = IMAGENET_DEFAULT_MEAN ,
13811373 std = IMAGENET_DEFAULT_STD ,
1374+ license = 'apache-2.0' ,
13821375 ),
13831376 'vit_base_patch16_rope_224.naver_in1k' : _cfg (
1384- hf_hub_id = 'naver-ai/rope_axial_deit_base_patch16_LS' ,
1385- hf_hub_filename = 'pytorch_model.bin' ,
1377+ hf_hub_id = 'timm/' ,
13861378 mean = IMAGENET_DEFAULT_MEAN ,
13871379 std = IMAGENET_DEFAULT_STD ,
1380+ license = 'apache-2.0' ,
13881381 ),
13891382 'vit_large_patch16_rope_224.naver_in1k' : _cfg (
1390- hf_hub_id = 'naver-ai/rope_axial_deit_large_patch16_LS' ,
1391- hf_hub_filename = 'pytorch_model.bin' ,
1383+ hf_hub_id = 'timm/' ,
13921384 mean = IMAGENET_DEFAULT_MEAN ,
13931385 std = IMAGENET_DEFAULT_STD ,
1386+ license = 'apache-2.0' ,
13941387 ),
1395- 'vit_small_patch16_mrope_224.naver_in1k' : _cfg (
1396- hf_hub_id = 'naver-ai/rope_mixed_deit_small_patch16_LS' ,
1397- hf_hub_filename = 'pytorch_model.bin' ,
1388+ 'vit_small_patch16_rope_mixed_224.naver_in1k' : _cfg (
1389+ hf_hub_id = 'timm/' ,
13981390 mean = IMAGENET_DEFAULT_MEAN ,
13991391 std = IMAGENET_DEFAULT_STD ,
1392+ license = 'apache-2.0' ,
14001393 ),
1401- 'vit_base_patch16_mrope_224.naver_in1k' : _cfg (
1402- hf_hub_id = 'naver-ai/rope_mixed_deit_base_patch16_LS' ,
1403- hf_hub_filename = 'pytorch_model.bin' ,
1394+ 'vit_base_patch16_rope_mixed_224.naver_in1k' : _cfg (
1395+ hf_hub_id = 'timm/' ,
14041396 mean = IMAGENET_DEFAULT_MEAN ,
14051397 std = IMAGENET_DEFAULT_STD ,
1398+ license = 'apache-2.0' ,
14061399 ),
1407- 'vit_large_patch16_mrope_224.naver_in1k' : _cfg (
1408- hf_hub_id = 'naver-ai/rope_mixed_deit_large_patch16_LS' ,
1409- hf_hub_filename = 'pytorch_model.bin' ,
1400+ 'vit_large_patch16_rope_mixed_224.naver_in1k' : _cfg (
1401+ hf_hub_id = 'timm/' ,
14101402 mean = IMAGENET_DEFAULT_MEAN ,
14111403 std = IMAGENET_DEFAULT_STD ,
1404+ license = 'apache-2.0' ,
14121405 ),
14131406 'vit_small_patch16_rope_ape_224.naver_in1k' : _cfg (
1414- hf_hub_id = 'naver-ai/rope_axial_ape_deit_small_patch16_LS' ,
1415- hf_hub_filename = 'pytorch_model.bin' ,
1407+ hf_hub_id = 'timm/' ,
14161408 mean = IMAGENET_DEFAULT_MEAN ,
14171409 std = IMAGENET_DEFAULT_STD ,
1410+ license = 'apache-2.0' ,
14181411 ),
14191412 'vit_base_patch16_rope_ape_224.naver_in1k' : _cfg (
1420- hf_hub_id = 'naver-ai/rope_axial_ape_deit_base_patch16_LS' ,
1421- hf_hub_filename = 'pytorch_model.bin' ,
1413+ hf_hub_id = 'timm/' ,
14221414 mean = IMAGENET_DEFAULT_MEAN ,
14231415 std = IMAGENET_DEFAULT_STD ,
1416+ license = 'apache-2.0' ,
14241417 ),
14251418 'vit_large_patch16_rope_ape_224.naver_in1k' : _cfg (
1426- hf_hub_id = 'naver-ai/rope_axial_ape_deit_large_patch16_LS' ,
1427- hf_hub_filename = 'pytorch_model.bin' ,
1419+ hf_hub_id = 'timm/' ,
14281420 mean = IMAGENET_DEFAULT_MEAN ,
14291421 std = IMAGENET_DEFAULT_STD ,
1422+ license = 'apache-2.0' ,
14301423 ),
1431- 'vit_small_patch16_mrope_ape_224.naver_in1k' : _cfg (
1432- hf_hub_id = 'naver-ai/rope_mixed_ape_deit_small_patch16_LS' ,
1433- hf_hub_filename = 'pytorch_model.bin' ,
1424+ 'vit_small_patch16_rope_mixed_ape_224.naver_in1k' : _cfg (
1425+ hf_hub_id = 'timm/' ,
14341426 mean = IMAGENET_DEFAULT_MEAN ,
14351427 std = IMAGENET_DEFAULT_STD ,
1428+ license = 'apache-2.0' ,
14361429 ),
1437- 'vit_base_patch16_mrope_ape_224.naver_in1k' : _cfg (
1438- hf_hub_id = 'naver-ai/rope_mixed_ape_deit_base_patch16_LS' ,
1439- hf_hub_filename = 'pytorch_model.bin' ,
1430+ 'vit_base_patch16_rope_mixed_ape_224.naver_in1k' : _cfg (
1431+ hf_hub_id = 'timm/' ,
14401432 mean = IMAGENET_DEFAULT_MEAN ,
14411433 std = IMAGENET_DEFAULT_STD ,
1434+ license = 'apache-2.0' ,
14421435 ),
1443- 'vit_large_patch16_mrope_ape_224.naver_in1k' : _cfg (
1444- hf_hub_id = 'naver-ai/rope_mixed_ape_deit_large_patch16_LS' ,
1445- hf_hub_filename = 'pytorch_model.bin' ,
1436+ 'vit_large_patch16_rope_mixed_ape_224.naver_in1k' : _cfg (
1437+ hf_hub_id = 'timm/' ,
14461438 mean = IMAGENET_DEFAULT_MEAN ,
14471439 std = IMAGENET_DEFAULT_STD ,
1440+ license = 'apache-2.0' ,
14481441 ),
14491442})
14501443
@@ -2023,7 +2016,7 @@ def vit_large_patch16_rope_224(pretrained: bool = False, **kwargs) -> Eva:
20232016
20242017
20252018@register_model
2026- def vit_small_patch16_mrope_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2019+ def vit_small_patch16_rope_mixed_224 (pretrained : bool = False , ** kwargs ) -> Eva :
20272020 """RoPE-Mixed ViT-S/16 from https://github.com/naver-ai/rope-vit"""
20282021 model_args = dict (
20292022 patch_size = 16 ,
@@ -2042,12 +2035,12 @@ def vit_small_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
20422035 rope_temperature = 10.0 ,
20432036 rope_mixed_mode = True ,
20442037 )
2045- model = _create_eva ('vit_small_patch16_mrope_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2038+ model = _create_eva ('vit_small_patch16_rope_mixed_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
20462039 return model
20472040
20482041
20492042@register_model
2050- def vit_base_patch16_mrope_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2043+ def vit_base_patch16_rope_mixed_224 (pretrained : bool = False , ** kwargs ) -> Eva :
20512044 """RoPE-Mixed ViT-B/16 from https://github.com/naver-ai/rope-vit"""
20522045 model_args = dict (
20532046 patch_size = 16 ,
@@ -2066,12 +2059,12 @@ def vit_base_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
20662059 rope_temperature = 10.0 ,
20672060 rope_mixed_mode = True ,
20682061 )
2069- model = _create_eva ('vit_base_patch16_mrope_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2062+ model = _create_eva ('vit_base_patch16_rope_mixed_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
20702063 return model
20712064
20722065
20732066@register_model
2074- def vit_large_patch16_mrope_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2067+ def vit_large_patch16_rope_mixed_224 (pretrained : bool = False , ** kwargs ) -> Eva :
20752068 """RoPE-Mixed ViT-L/16 from https://github.com/naver-ai/rope-vit"""
20762069 model_args = dict (
20772070 patch_size = 16 ,
@@ -2090,7 +2083,7 @@ def vit_large_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
20902083 rope_temperature = 10.0 ,
20912084 rope_mixed_mode = True ,
20922085 )
2093- model = _create_eva ('vit_large_patch16_mrope_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2086+ model = _create_eva ('vit_large_patch16_rope_mixed_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
20942087 return model
20952088
20962089
@@ -2170,7 +2163,7 @@ def vit_large_patch16_rope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
21702163
21712164
21722165@register_model
2173- def vit_small_patch16_mrope_ape_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2166+ def vit_small_patch16_rope_mixed_ape_224 (pretrained : bool = False , ** kwargs ) -> Eva :
21742167 """RoPE-Mixed + APE ViT-S/16 from https://github.com/naver-ai/rope-vit"""
21752168 model_args = dict (
21762169 patch_size = 16 ,
@@ -2191,12 +2184,12 @@ def vit_small_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
21912184 rope_mixed_mode = True ,
21922185 )
21932186
2194- model = _create_eva ('vit_small_patch16_mrope_ape_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2187+ model = _create_eva ('vit_small_patch16_rope_mixed_ape_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
21952188 return model
21962189
21972190
21982191@register_model
2199- def vit_base_patch16_mrope_ape_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2192+ def vit_base_patch16_rope_mixed_ape_224 (pretrained : bool = False , ** kwargs ) -> Eva :
22002193 """RoPE-Mixed + APE ViT-B/16 from https://github.com/naver-ai/rope-vit"""
22012194 model_args = dict (
22022195 patch_size = 16 ,
@@ -2216,12 +2209,12 @@ def vit_base_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
22162209 rope_temperature = 10.0 ,
22172210 rope_mixed_mode = True ,
22182211 )
2219- model = _create_eva ('vit_base_patch16_mrope_ape_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2212+ model = _create_eva ('vit_base_patch16_rope_mixed_ape_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
22202213 return model
22212214
22222215
22232216@register_model
2224- def vit_large_patch16_mrope_ape_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2217+ def vit_large_patch16_rope_mixed_ape_224 (pretrained : bool = False , ** kwargs ) -> Eva :
22252218 """RoPE-Mixed + APE ViT-L/16 from https://github.com/naver-ai/rope-vit"""
22262219 model_args = dict (
22272220 patch_size = 16 ,
@@ -2241,6 +2234,6 @@ def vit_large_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
22412234 rope_temperature = 10.0 ,
22422235 rope_mixed_mode = True ,
22432236 )
2244- model = _create_eva ('vit_large_patch16_mrope_ape_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2237+ model = _create_eva ('vit_large_patch16_rope_mixed_ape_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
22452238 return model
22462239
0 commit comments