Skip to content
11 changes: 8 additions & 3 deletions keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,14 @@ def __init__(
# is `None`.
self.text_only_model = self.image_converter is None

self.image_placeholder = self.tokenizer.image_placeholder
self.start_of_image_token = self.tokenizer.start_of_image_token
self.end_of_image_token = self.tokenizer.end_of_image_token
if self.text_only_model:
self.image_placeholder = None
self.start_of_image_token = None
self.end_of_image_token = None
else:
self.image_placeholder = self.tokenizer.image_placeholder
self.start_of_image_token = self.tokenizer.start_of_image_token
self.end_of_image_token = self.tokenizer.end_of_image_token

def build(self, input_shape):
# Defer packer creation to `build()` so that we can be sure tokenizer
Expand Down
28 changes: 20 additions & 8 deletions keras_hub/src/models/gemma3/gemma3_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,32 @@ class Gemma3Tokenizer(SentencePieceTokenizer):

backbone_cls = Gemma3Backbone

def __init__(self, proto, **kwargs):
def __init__(self, proto, is_vision_model=True, **kwargs):
# Add special tokens.

self.is_vision_model = is_vision_model
# The usual tokens.
self._add_special_token("<bos>", "start_token")
self._add_special_token("<eos>", "end_token")
self._add_special_token("<pad>", "pad_token")

# Image placeholder token.
self._add_special_token("<img>", "image_placeholder")

# Some tokens which are used in the preprocessor. We need to keep them
# here so that the preprocessor works with `tf.data`.
self._add_special_token("<start_of_image>", "start_of_image_token")
self._add_special_token("<end_of_image>", "end_of_image_token")
if is_vision_model:
# Image placeholder token.
self._add_special_token("<img>", "image_placeholder")
# Some tokens which are used in the preprocessor.
# We need to keep them
# here so that the preprocessor works with tf.data.
self._add_special_token("<start_of_image>", "start_of_image_token")
self._add_special_token("<end_of_image>", "end_of_image_token")
else:
# For text-only, skip assigning token IDs or set to -1
self.start_of_image_token_id = -1
self.image_placeholder_token_id = -1
self.end_of_image_token_id = -1

super().__init__(proto=proto, **kwargs)

def get_config(self):
config = super().get_config()
config.update({"is_vision_model": self.is_vision_model})
return config
Loading
Loading