@@ -703,7 +703,7 @@ def backward(ctx, *grad_outs):
703703class BertEmbeddings (BlockedModule ):
704704 """Construct the embeddings from word, position and token_type embeddings."""
705705
706- def __init__ (self , config ):
706+ def __init__ (self , config , position_ids_persistent = False ):
707707 super ().__init__ ()
708708 self .word_embeddings = nn .Embedding (
709709 config .vocab_size , config .hidden_size , padding_idx = config .pad_token_id
@@ -725,11 +725,14 @@ def __init__(self, config):
725725 self .pad_token_id = config .pad_token_id
726726
727727 # position_ids (1, len position emb) is contiguous in memory and exported when serialized
728- self .register_buffer (
729- "position_ids" ,
730- torch .arange (config .max_position_embeddings ).expand ((1 , - 1 )),
731- persistent = False ,
732- )
728+ if not position_ids_persistent :
729+ self .register_buffer (
730+ "position_ids" ,
731+ torch .arange (config .max_position_embeddings ).expand ((1 , - 1 )),
732+ persistent = False ,
733+ )
734+ else :
735+ self .register_buffer ("position_ids" , torch .arange (config .max_position_embeddings ).expand ((1 , - 1 )))
733736 self .position_embedding_type = getattr (
734737 config , "position_embedding_type" , "absolute"
735738 )
@@ -1243,7 +1246,7 @@ def fast_bert(model, dtype=torch.float, optimizer=None, unpad=False):
12431246 # tpp bert optimization depends on the transformers repo to implementate the related module
12441247 installed_pkg = {pkg .key for pkg in pkg_resources .working_set }
12451248 min_version = "4.6.0"
1246- max_version = "4.20 .0"
1249+ max_version = "4.31 .0"
12471250 if "transformers" not in installed_pkg :
12481251 raise RuntimeError (
12491252 "Please installed the transformers with version: between {} and {}" .format (
@@ -1263,6 +1266,9 @@ def fast_bert(model, dtype=torch.float, optimizer=None, unpad=False):
12631266 min_version , max_version , trans_version
12641267 )
12651268 )
1269+ position_ids_persistent = False
1270+ if version .parse (trans_version ) < version .parse ("4.31.0" ):
1271+ position_ids_persistent = True
12661272 PT_OPTIMIZER_TO_TPP_OPTIMIZER = {
12671273 torch .optim .AdamW : AdamW ,
12681274 transformers .optimization .AdamW : AdamW ,
@@ -1297,7 +1303,7 @@ def fast_bert(model, dtype=torch.float, optimizer=None, unpad=False):
12971303 assert isinstance (
12981304 new_model .embeddings , transformers .models .bert .modeling_bert .BertEmbeddings
12991305 )
1300- new_model .embeddings = BertEmbeddings (model .config )
1306+ new_model .embeddings = BertEmbeddings (model .config , position_ids_persistent = position_ids_persistent )
13011307 assert isinstance (
13021308 new_model .encoder , transformers .models .bert .modeling_bert .BertEncoder
13031309 )
@@ -1309,7 +1315,7 @@ def fast_bert(model, dtype=torch.float, optimizer=None, unpad=False):
13091315 new_model .bert .embeddings ,
13101316 transformers .models .bert .modeling_bert .BertEmbeddings ,
13111317 )
1312- new_model .bert .embeddings = BertEmbeddings (model .bert .config )
1318+ new_model .bert .embeddings = BertEmbeddings (model .bert .config , position_ids_persistent = position_ids_persistent )
13131319 assert isinstance (
13141320 new_model .bert .encoder , transformers .models .bert .modeling_bert .BertEncoder
13151321 )
0 commit comments