Skip to content

Commit bf1150d

Browse files
authored
Fix version of Transformers fused_bert.py (#2157)
* Fix version of Transformers fused_bert.py * Update fused_bert.py * Update fused_bert.py
1 parent 5cd59d7 commit bf1150d

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

intel_extension_for_pytorch/cpu/tpp/fused_bert.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ def backward(ctx, *grad_outs):
703703
class BertEmbeddings(BlockedModule):
704704
"""Construct the embeddings from word, position and token_type embeddings."""
705705

706-
def __init__(self, config):
706+
def __init__(self, config, position_ids_persistent=False):
707707
super().__init__()
708708
self.word_embeddings = nn.Embedding(
709709
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
@@ -725,11 +725,14 @@ def __init__(self, config):
725725
self.pad_token_id = config.pad_token_id
726726

727727
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
728-
self.register_buffer(
729-
"position_ids",
730-
torch.arange(config.max_position_embeddings).expand((1, -1)),
731-
persistent=False,
732-
)
728+
if not position_ids_persistent:
729+
self.register_buffer(
730+
"position_ids",
731+
torch.arange(config.max_position_embeddings).expand((1, -1)),
732+
persistent=False,
733+
)
734+
else:
735+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
733736
self.position_embedding_type = getattr(
734737
config, "position_embedding_type", "absolute"
735738
)
@@ -1243,7 +1246,7 @@ def fast_bert(model, dtype=torch.float, optimizer=None, unpad=False):
12431246
# tpp bert optimization depends on the transformers repo to implementate the related module
12441247
installed_pkg = {pkg.key for pkg in pkg_resources.working_set}
12451248
min_version = "4.6.0"
1246-
max_version = "4.20.0"
1249+
max_version = "4.31.0"
12471250
if "transformers" not in installed_pkg:
12481251
raise RuntimeError(
12491252
"Please installed the transformers with version: between {} and {}".format(
@@ -1263,6 +1266,9 @@ def fast_bert(model, dtype=torch.float, optimizer=None, unpad=False):
12631266
min_version, max_version, trans_version
12641267
)
12651268
)
1269+
position_ids_persistent = False
1270+
if version.parse(trans_version) < version.parse("4.31.0"):
1271+
position_ids_persistent = True
12661272
PT_OPTIMIZER_TO_TPP_OPTIMIZER = {
12671273
torch.optim.AdamW: AdamW,
12681274
transformers.optimization.AdamW: AdamW,
@@ -1297,7 +1303,7 @@ def fast_bert(model, dtype=torch.float, optimizer=None, unpad=False):
12971303
assert isinstance(
12981304
new_model.embeddings, transformers.models.bert.modeling_bert.BertEmbeddings
12991305
)
1300-
new_model.embeddings = BertEmbeddings(model.config)
1306+
new_model.embeddings = BertEmbeddings(model.config, position_ids_persistent=position_ids_persistent)
13011307
assert isinstance(
13021308
new_model.encoder, transformers.models.bert.modeling_bert.BertEncoder
13031309
)
@@ -1309,7 +1315,7 @@ def fast_bert(model, dtype=torch.float, optimizer=None, unpad=False):
13091315
new_model.bert.embeddings,
13101316
transformers.models.bert.modeling_bert.BertEmbeddings,
13111317
)
1312-
new_model.bert.embeddings = BertEmbeddings(model.bert.config)
1318+
new_model.bert.embeddings = BertEmbeddings(model.bert.config, position_ids_persistent=position_ids_persistent)
13131319
assert isinstance(
13141320
new_model.bert.encoder, transformers.models.bert.modeling_bert.BertEncoder
13151321
)

0 commit comments

Comments
 (0)