@@ -74,15 +74,24 @@ def pl_load(
7474 """
7575 if not isinstance (path_or_url , (str , Path )):
7676 # any sort of BytesIO or similar
77- return torch .load (path_or_url , map_location = map_location , weights_only = False )
77+ # get the torch version
78+ torch_version = torch .__version__
79+ if torch_version < "2.6" :
80+ return torch .load (path_or_url , map_location = map_location ) # for torch version < 2.6
81+ elif torch_version >= "2.6" :
82+ return torch .load (path_or_url , map_location = map_location , weights_only = False )
7883 if str (path_or_url ).startswith ("http" ):
7984 return torch .hub .load_state_dict_from_url (
8085 str (path_or_url ),
8186 map_location = map_location , # type: ignore[arg-type] # upstream annotation is not correct
8287 )
8388 fs = get_filesystem (path_or_url )
8489 with fs .open (path_or_url , "rb" ) as f :
85- return torch .load (f , map_location = map_location , weights_only = False )
90+ torch_version = torch .__version__
91+ if torch_version < "2.6" :
92+ return torch .load (f , map_location = map_location ) # for torch version < 2.6
93+ elif torch_version >= "2.6" :
94+ return torch .load (f , map_location = map_location , weights_only = False )
8695
8796
8897def check_numpy (x ):
0 commit comments