Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,11 @@ build/
# Pycharm files
.idea/

# Virtual environments
venv/
env/
ENV/
gpt2/
*.venv/

.DS_Store
26 changes: 26 additions & 0 deletions keras_hub/src/models/gpt2/gpt2_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import json
import os
import shutil

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
Expand Down Expand Up @@ -71,3 +75,25 @@ def __init__(
merges=merges,
**kwargs,
)

def save_assets(self, dir_path):
# Save vocabulary.
if isinstance(self.vocabulary, str):
# If `vocabulary` is a file path, copy it.
shutil.copy(
self.vocabulary, os.path.join(dir_path, "vocabulary.json")
)
else:
# Otherwise, `vocabulary` is a dict. Save it to a JSON file.
with open(os.path.join(dir_path, "vocabulary.json"), "w") as f:
json.dump(self.vocabulary, f)

# Save merges.
if isinstance(self.merges, str):
# If `merges` is a file path, copy it.
shutil.copy(self.merges, os.path.join(dir_path, "merges.txt"))
else:
# Otherwise, `merges` is a list. Save it to a text file.
with open(os.path.join(dir_path, "merges.txt"), "w") as f:
for merge in self.merges:
f.write(f"{merge}\n")
146 changes: 146 additions & 0 deletions keras_hub/src/utils/transformers/export/gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import keras.ops as ops
import transformers


def get_gpt2_config(keras_model):
"""Convert Keras GPT-2 config to Hugging Face GPT2Config."""
return transformers.GPT2Config(
vocab_size=keras_model.vocabulary_size,
n_positions=keras_model.max_sequence_length,
n_embd=keras_model.hidden_dim,
n_layer=keras_model.num_layers,
n_head=keras_model.num_heads,
n_inner=keras_model.intermediate_dim,
activation_function="gelu_new",
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
scale_attn_weights=True,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
)


def get_gpt2_weights_map(keras_model, include_lm_head=False):
"""Create a weights map for a given GPT-2 model."""
weights_map = {}

# Token and position embeddings
weights_map["transformer.wte.weight"] = keras_model.get_layer(
"token_embedding"
).embeddings
weights_map["transformer.wpe.weight"] = keras_model.get_layer(
"position_embedding"
).position_embeddings

for i in range(keras_model.num_layers):
# Attention weights
q_w = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._query_dense.kernel
k_w = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._key_dense.kernel
v_w = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._value_dense.kernel
Comment on lines +46 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing private layer attributes like _self_attention_layer and its sub-layers makes this code brittle. If the internal structure of TransformerDecoder changes, this export script will break. It would be more robust to expose these weights via a public API on the layer to create a more stable interface.

q_b = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._query_dense.bias
k_b = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._key_dense.bias
v_b = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._value_dense.bias

q_w = ops.reshape(q_w, (keras_model.hidden_dim, keras_model.hidden_dim))
k_w = ops.reshape(k_w, (keras_model.hidden_dim, keras_model.hidden_dim))
v_w = ops.reshape(v_w, (keras_model.hidden_dim, keras_model.hidden_dim))

c_attn_w = ops.concatenate([q_w, k_w, v_w], axis=-1)
weights_map[f"transformer.h.{i}.attn.c_attn.weight"] = c_attn_w

q_b = ops.reshape(q_b, [-1])
k_b = ops.reshape(k_b, [-1])
v_b = ops.reshape(v_b, [-1])

c_attn_b = ops.concatenate([q_b, k_b, v_b], axis=-1)
weights_map[f"transformer.h.{i}.attn.c_attn.bias"] = c_attn_b

# Attention projection
c_proj_w = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._output_dense.kernel
c_proj_w = ops.reshape(
c_proj_w, (keras_model.hidden_dim, keras_model.hidden_dim)
)
weights_map[f"transformer.h.{i}.attn.c_proj.weight"] = c_proj_w
weights_map[f"transformer.h.{i}.attn.c_proj.bias"] = (
keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._output_dense.bias
)

# Layer norms
weights_map[f"transformer.h.{i}.ln_1.weight"] = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer_norm.gamma
weights_map[f"transformer.h.{i}.ln_1.bias"] = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer_norm.beta
weights_map[f"transformer.h.{i}.ln_2.weight"] = keras_model.get_layer(
f"transformer_layer_{i}"
)._feedforward_layer_norm.gamma
weights_map[f"transformer.h.{i}.ln_2.bias"] = keras_model.get_layer(
f"transformer_layer_{i}"
)._feedforward_layer_norm.beta

# MLP
c_fc_w = keras_model.get_layer(
f"transformer_layer_{i}"
)._feedforward_intermediate_dense.kernel
weights_map[f"transformer.h.{i}.mlp.c_fc.weight"] = c_fc_w
weights_map[f"transformer.h.{i}.mlp.c_fc.bias"] = keras_model.get_layer(
f"transformer_layer_{i}"
)._feedforward_intermediate_dense.bias
c_proj_w_mlp = keras_model.get_layer(
f"transformer_layer_{i}"
)._feedforward_output_dense.kernel
weights_map[f"transformer.h.{i}.mlp.c_proj.weight"] = c_proj_w_mlp
weights_map[f"transformer.h.{i}.mlp.c_proj.bias"] = (
keras_model.get_layer(
f"transformer_layer_{i}"
)._feedforward_output_dense.bias
)

# Final layer norm
weights_map["transformer.ln_f.weight"] = keras_model.get_layer(
"layer_norm"
).gamma
weights_map["transformer.ln_f.bias"] = keras_model.get_layer(
"layer_norm"
).beta

if include_lm_head:
# lm_head is tied to token embeddings
weights_map["lm_head.weight"] = weights_map["transformer.wte.weight"]

return weights_map


def get_gpt2_tokenizer_config(tokenizer):
return {
"model_type": "gpt2",
"bos_token": "<|endoftext|>",
"eos_token": "<|endoftext|>",
"unk_token": "<|endoftext|>",
}
100 changes: 100 additions & 0 deletions keras_hub/src/utils/transformers/export/gpt2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import shutil
import sys
import tempfile
from os.path import abspath
from os.path import dirname

import keras.ops as ops
import numpy as np
from absl.testing import parameterized
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

# Add the project root to the Python path.
sys.path.insert(
0, dirname(dirname(dirname(dirname(dirname(abspath(__file__))))))
)

from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_hub.src.tests.test_case import TestCase
from keras_hub.src.utils.transformers.export.hf_exporter import (
export_to_safetensors,
)


def to_numpy(x):
# Torch tensor
if hasattr(x, "detach") and hasattr(x, "cpu"):
return x.detach().cpu().numpy()

# TF tensor
if hasattr(x, "numpy"):
return x.numpy()

# Numpy
if isinstance(x, np.ndarray):
return x

raise TypeError(f"Cannot convert value of type {type(x)} to numpy")


class GPT2ExportTest(TestCase):
@parameterized.named_parameters(
("gpt2_base_en", "gpt2_base_en"),
)
def test_gpt2_export(self, preset):
# Create a temporary directory to save the converted model.
temp_dir = tempfile.mkdtemp()
output_path = os.path.join(temp_dir, preset)

# Load Keras model.
keras_model = GPT2CausalLM.from_preset(preset)

# Export to Hugging Face format.
export_to_safetensors(keras_model, output_path)

# Load the converted model with Hugging Face Transformers.
hf_model = AutoModelForCausalLM.from_pretrained(output_path)
hf_tokenizer = AutoTokenizer.from_pretrained(output_path)

# Assertions for config parameters.
self.assertEqual(
keras_model.backbone.hidden_dim, hf_model.config.hidden_size
)
self.assertEqual(
keras_model.backbone.num_layers, hf_model.config.n_layer
)
self.assertEqual(keras_model.backbone.num_heads, hf_model.config.n_head)
self.assertEqual(
keras_model.backbone.intermediate_dim, hf_model.config.n_inner
)
self.assertEqual(
keras_model.backbone.vocabulary_size, hf_model.config.vocab_size
)
self.assertEqual(
keras_model.backbone.max_sequence_length,
hf_model.config.n_positions,
)

# Test logits.
prompt = "Hello, my name is"
token_ids = ops.array(keras_model.preprocessor.tokenizer([prompt]))
padding_mask = ops.ones_like(token_ids, dtype="int32")
keras_inputs = {"token_ids": token_ids, "padding_mask": padding_mask}
keras_logits = keras_model(keras_inputs)

hf_inputs = hf_tokenizer(prompt, return_tensors="pt")
hf_logits = hf_model(**hf_inputs).logits

# Compare logits.
# Convert Keras logits (TF/Torch/JAX) -> numpy
keras_logits_np = to_numpy(keras_logits)

# Convert HF logits (Torch) -> numpy
hf_logits_np = to_numpy(hf_logits)

self.assertAllClose(keras_logits_np, hf_logits_np, atol=1e-3, rtol=1e-3)

# Clean up the temporary directory.
shutil.rmtree(temp_dir)
Loading
Loading