From 661f6b53f6868675932224d2a3aef9233eee3b60 Mon Sep 17 00:00:00 2001 From: Sander Land Date: Tue, 22 Jul 2025 13:59:54 +0200 Subject: [PATCH 1/3] implement enforce utf8 boundaries --- bindings/python/src/trainers.rs | 1 + .../python/tests/bindings/test_trainers.py | 43 ++++ tokenizers/src/models/bpe/trainer.rs | 195 +++++++++++++++++- tokenizers/src/pre_tokenizers/byte_level.rs | 2 +- tokenizers/src/tokenizer/mod.rs | 31 +++ 5 files changed, 268 insertions(+), 4 deletions(-) diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index ef2c31e56..8eecf5c61 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -345,6 +345,7 @@ impl PyBpeTrainer { } "limit_alphabet" => builder = builder.limit_alphabet(val.extract()?), "max_token_length" => builder = builder.max_token_length(val.extract()?), + "enforce_utf8_boundaries" => builder = builder.enforce_utf8_boundaries(val.extract()?), "initial_alphabet" => { let alphabet: Vec = val.extract()?; builder = builder.initial_alphabet( diff --git a/bindings/python/tests/bindings/test_trainers.py b/bindings/python/tests/bindings/test_trainers.py index 38b599448..336a78db2 100644 --- a/bindings/python/tests/bindings/test_trainers.py +++ b/bindings/python/tests/bindings/test_trainers.py @@ -74,6 +74,49 @@ def test_can_pickle(self): ) + def test_enforce_utf8_boundaries(self): + # This input is designed to have a very frequent but invalid merge candidate: + # a space (0x20) followed by the first byte of different 4-byte encodings (0xF0). + # A less frequent but valid candidate is the first two bytes of an emoji (0xF0, 0x9F). + data = [" 🤗"] * 10 + [" 𝟑"] * 9 + + # Setup a tokenizer with a ByteLevel pre-tokenizer + tokenizer = Tokenizer(models.BPE()) + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) + + # 1. Train with `enforce_utf8_boundaries=False` (unconstrained) + unconstrained_trainer = trainers.BpeTrainer( + vocab_size=260, + special_tokens=[""], + enforce_utf8_boundaries=False, + show_progress=False, + ) + tokenizer.train_from_iterator(data, trainer=unconstrained_trainer) + vocab = tokenizer.get_vocab() + + # The pre-tokenizer maps byte 0x20 to `Ġ` and 0xF0 to `ð`. + # The invalid merge of these two should be present. + invalid_token = "Ġð" # Bytes: [20, F0] + assert invalid_token in vocab, "Unconstrained trainer should learn the invalid merge" + + # 2. Train with `enforce_utf8_boundaries=True` (constrained) + # We must re-initialize the tokenizer to start with a fresh model + tokenizer = Tokenizer(models.BPE()) + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) + + # Train with enforce_utf8_boundaries=True + constrained_trainer = trainers.BpeTrainer( + vocab_size=260, + special_tokens=[""], + enforce_utf8_boundaries=True, + show_progress=False, + ) + tokenizer.train_from_iterator(data, trainer=constrained_trainer) + vocab = tokenizer.get_vocab() + + # The invalid merge should not be present when enforcing UTF-8 boundaries + assert invalid_token not in vocab, "Constrained trainer should not learn invalid merges" + class TestWordPieceTrainer: def test_can_modify(self): trainer = trainers.WordPieceTrainer( diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index b3a6fd4b2..b37c32df7 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -4,12 +4,14 @@ use super::{Pair, WithFirstLastIterator, Word, BPE}; use crate::parallelism::*; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use crate::pre_tokenizers::byte_level::bytes_char; use ahash::{AHashMap, AHashSet}; use compact_str::CompactString; use dary_heap::OctonaryHeap; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::collections::HashSet; +use std::sync::LazyLock; #[derive(Debug, Eq)] struct Merge { @@ -48,6 +50,7 @@ struct Config { continuing_subword_prefix: Option, end_of_word_suffix: Option, max_token_length: Option, + enforce_utf8_boundaries: bool, } /// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom @@ -69,6 +72,7 @@ impl Default for BpeTrainerBuilder { continuing_subword_prefix: None, end_of_word_suffix: None, max_token_length: None, + enforce_utf8_boundaries: true, }, } } @@ -144,6 +148,13 @@ impl BpeTrainerBuilder { self } + /// Whether to enforce UTF-8 character boundaries during merges + #[must_use] + pub fn enforce_utf8_boundaries(mut self, enforce: bool) -> Self { + self.config.enforce_utf8_boundaries = enforce; + self + } + /// Constructs the final BpeTrainer pub fn build(self) -> BpeTrainer { BpeTrainer { @@ -156,6 +167,7 @@ impl BpeTrainerBuilder { continuing_subword_prefix: self.config.continuing_subword_prefix, end_of_word_suffix: self.config.end_of_word_suffix, max_token_length: self.config.max_token_length, + enforce_utf8_boundaries: self.config.enforce_utf8_boundaries, words: AHashMap::new(), } } @@ -199,6 +211,11 @@ pub struct BpeTrainer { pub end_of_word_suffix: Option, /// An optional parameter to limit the max length of any single token pub max_token_length: Option, + /// Whether to enforce UTF-8 character boundaries during merges. When true, only allows merging: + /// 1. Complete UTF-8 characters with each other + /// 2. Single bytes that are part of the same UTF-8 character, from left to right + /// This is useful to avoid creating tokens that are not valid UTF-8 sequences, at no cost to compression. + pub enforce_utf8_boundaries: bool, words: AHashMap, } @@ -209,7 +226,12 @@ impl Default for BpeTrainer { } } +/// for utf8 boundaries, we need to map gpt2 encoded bytes back +static CHAR_BYTES: LazyLock> = + LazyLock::new(|| bytes_char().into_iter().map(|(b, c)| (c, b)).collect()); + impl BpeTrainer { + pub fn new(min_frequency: u64, vocab_size: usize) -> Self { Self { min_frequency, @@ -270,6 +292,68 @@ impl BpeTrainer { } } + /// helper for is_merge_allowed, to get the original bytes of a part + fn get_original_bytes(&self, part: &str) -> Option> { + part.chars().map(|c| CHAR_BYTES.get(&c).copied()).collect() + } + /// Determines if a merge is allowed under UTF-8 boundary constraints. + /// + /// This check is only performed if `enforce_utf8_boundaries` is true. + /// A merge is allowed if it meets one of the following criteria: + /// 1. Both tokens consist of complete characters. + /// 2. Both tokens are part of the same single character, and the second is a single byte. + /// This allows building multi-byte characters from their individual bytes left-to-right. + /// All other combinations, such as merging a complete character with a partial byte, are disallowed. + /// This function is designed to work on the character-mapped output of a `ByteLevel` + /// pre-tokenizer by reversing the mapping to check the original bytes. + /// Determines if a merge is allowed under UTF-8 boundary constraints. + /// This function is designed to work on the character-mapped output of a `ByteLevel` + /// pre-tokenizer by reversing the mapping to check the original bytes. + fn is_merge_allowed(&self, pair: &Pair, id_to_word: &[CompactString]) -> bool { + if !self.enforce_utf8_boundaries { + return true; + } + + let part_a = &id_to_word[pair.0 as usize]; + let part_b = &id_to_word[pair.1 as usize]; + + // Get the original bytes by reversing the ByteLevel character mapping. + let bytes_a = self.get_original_bytes(part_a.as_ref()).unwrap_or_default(); + let bytes_b = self.get_original_bytes(part_b.as_ref()).unwrap_or_default(); + + // A "complete" token is one whose underlying bytes form a valid UTF-8 string. + // For ByteLevel, this means single-byte ASCII chars (like a space) are complete, + // but single bytes from a multi-byte sequence (like 0xF0) are not. + let is_a_complete = std::str::from_utf8(&bytes_a).is_ok(); + let is_b_complete = std::str::from_utf8(&bytes_b).is_ok(); + + // Rule 1: Allow merging two complete tokens. + if is_a_complete && is_b_complete { + return true; + } + + // Rule 3 (Implicit): Any mix of complete and incomplete is disallowed. + if is_a_complete || is_b_complete { + return false; + } + + // Rule 2: Both tokens are incomplete. Allow merge only if building a valid + // UTF-8 prefix by appending a single byte. + if bytes_b.len() == 1 { + let mut merged = bytes_a; + merged.extend_from_slice(&bytes_b); + match std::str::from_utf8(&merged) { + // The merged bytes form one or more complete characters. Valid. + Ok(_) => true, + // The merged bytes are an incomplete but valid prefix. Valid. + Err(e) => e.error_len().is_none(), + } + } else { + // If part_b is not a single byte, it's not a valid continuation merge. + false + } + } + /// Compute the initial alphabet and limit it if relevant fn compute_alphabet( &self, @@ -455,7 +539,7 @@ impl BpeTrainer { let mut queue = OctonaryHeap::with_capacity(pair_counts.len()); where_to_update.drain().for_each(|(pair, pos)| { let count = pair_counts[&pair]; - if count > 0 { + if count > 0 && self.is_merge_allowed(&pair, &id_to_word) { queue.push(Merge { pair, count: count as u64, @@ -550,13 +634,13 @@ impl BpeTrainer { for ((pair, change), iw) in changes { let count = change * counts[iw] as i32; *pair_counts.entry(pair).or_default() += count; - if change > 0 { + if change > 0 && self.is_merge_allowed(&pair, &id_to_word) { where_to_update.entry(pair).or_default().insert(iw); } } where_to_update.drain().for_each(|(pair, pos)| { let count = pair_counts[&pair]; - if count > 0 { + if count > 0 && self.is_merge_allowed(&pair, &id_to_word) { queue.push(Merge { pair, count: count as u64, @@ -644,8 +728,14 @@ impl Trainer for BpeTrainer { #[cfg(test)] mod tests { use super::{BpeTrainer, Pair, BPE}; + use crate::pre_tokenizers::byte_level::{bytes_char, ByteLevel}; + use crate::tokenizer::{ + OffsetReferential, OffsetType, PreTokenizedString, PreTokenizer, Result, Trainer, + }; use ahash::AHashMap; use compact_str::CompactString; + use std::collections::HashMap; + use std::sync::LazyLock; #[test] fn test_train() { @@ -762,6 +852,7 @@ mod tests { ) } } + #[test] fn bpe_test_max_token_length_direct_assert() { /* more direct version of bpe_test_max_token_length test @@ -831,4 +922,102 @@ mod tests { .collect(); assert_eq!(trained_vocab, expected_vocab) } + + // The CHAR_TO_BYTE mapping is kept here *only* for the debug printing helper, + // to make the test output readable. It is not used in the core test logic. + static BYTE_TO_CHAR: LazyLock> = LazyLock::new(bytes_char); + static CHAR_TO_BYTE: LazyLock> = + LazyLock::new(|| BYTE_TO_CHAR.iter().map(|(b, c)| (*c, *b)).collect()); + + #[test] + fn test_bpe_utf8_boundary_enforcement_with_byte_level_pretokenizer() { + /// A local helper to print the vocabulary with original hex byte representations for clarity. + fn print_vocab_with_hex(vocab: &HashMap, title: &str) { + println!("\n--- {} ---", title); + let mut vocab_items: Vec<_> = vocab.iter().collect(); + vocab_items.sort_by_key(|(_, id)| *id); + for (token, id) in vocab_items { + // De-mangle the token back to its original bytes for printing + let bytes: Vec = token + .chars() + .map(|c| format!("{:02X}", CHAR_TO_BYTE.get(&c).unwrap_or(&0))) + .collect(); + println!( + "ID {:<3} Token: {:<12} Bytes: [{}]", + id, + format!("{:?}", token), + bytes.join(" ") + ); + } + } + + // Use the actual ByteLevel pre-tokenizer to process the input string. + let byte_level_pretok = ByteLevel::new(false, false, false); + let process_fn = |s: &str| -> Result> { + let mut pretokenized = PreTokenizedString::from(s); + byte_level_pretok.pre_tokenize(&mut pretokenized)?; + Ok(pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(word, _, _)| word.to_string()) + .collect()) + }; + + let sequence = " 🤗 🦒 🐹 🦦 🤗 𝟑".to_string(); + let vocab_size = 25; + + // --- Part 1: Unconstrained BPE --- + let mut unconstrained_trainer = BpeTrainer::builder() + .vocab_size(vocab_size) + .show_progress(false) + .enforce_utf8_boundaries(false) + .build(); + unconstrained_trainer + .feed(std::iter::once(&sequence), &process_fn) + .unwrap(); + let mut unconstrained_model = BPE::default(); + unconstrained_trainer + .train(&mut unconstrained_model) + .unwrap(); + print_vocab_with_hex( + &unconstrained_model.get_vocab(), + "Unconstrained Vocabulary", + ); + let invalid_merge_token: String = + [BYTE_TO_CHAR[&b' '], BYTE_TO_CHAR[&0xF0]].iter().collect(); + assert!( + unconstrained_model + .get_vocab() + .contains_key(&invalid_merge_token), + "Unconstrained vocab SHOULD contain the top frequency merge (bytes [20 F0])" + ); + + // --- Part 2: Constrained BPE --- + let mut constrained_trainer = BpeTrainer::builder() + .vocab_size(vocab_size) + .show_progress(false) + .enforce_utf8_boundaries(true) + .build(); + constrained_trainer + .feed(std::iter::once(&sequence), &process_fn) + .unwrap(); + let mut constrained_model = BPE::default(); + constrained_trainer.train(&mut constrained_model).unwrap(); + print_vocab_with_hex(&constrained_model.get_vocab(), "Constrained Vocabulary"); + + let valid_merge_token: String = + [BYTE_TO_CHAR[&0xF0], BYTE_TO_CHAR[&0x9F]].iter().collect(); + assert!( + !constrained_model + .get_vocab() + .contains_key(&invalid_merge_token), + "Constrained vocab MUST NOT contain the invalid merge (bytes [20 F0])" + ); + assert!( + constrained_model + .get_vocab() + .contains_key(&valid_merge_token), + "Constrained vocab SHOULD contain the next valid merge (bytes [F0 9F])" + ); + } } diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 8bc0f30af..200967f78 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -12,7 +12,7 @@ use crate::utils::macro_rules_attribute; /// Converts bytes to unicode characters. /// See https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9 -pub(crate) fn bytes_char() -> AHashMap { +pub fn bytes_char() -> AHashMap { let mut bs: Vec = vec![]; bs.extend(b'!'..=b'~'); bs.extend(b'\xA1'..=b'\xAC'); diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 84f77a523..ebd225efc 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -15,6 +15,12 @@ use std::{ io::{prelude::*, BufReader}, ops::{Deref, DerefMut}, path::{Path, PathBuf}, + any::Any, +}; + +use crate::{ + models::bpe, + pre_tokenizers, }; use serde::de::DeserializeOwned; @@ -534,6 +540,29 @@ where PP: PostProcessor, D: Decoder, { + /// Validates compatibility between a trainer and the current tokenizer configuration. + /// Currently only checks: + // For BpeTrainer with `enforce_utf8_boundaries=True` => pretokenizer must be ByteLevel. + fn _check_trainer_compat + 'static>(&self, trainer: &T) -> Result<()> { + // Use `Any` to safely check for the BpeTrainer type at runtime + if let Some(bpe_trainer) = (trainer as &dyn Any).downcast_ref::() { + if bpe_trainer.enforce_utf8_boundaries { + // Now check if the pre_tokenizer is ByteLevel + let is_byte_level = self.pre_tokenizer.as_ref().map_or(false, |pretok| { + (pretok as &dyn Any).is::() + }); + + if !is_byte_level { + return Err( + "`enforce_utf8_boundaries=True` can only be used with a `ByteLevel` pre-tokenizer." + .into() + ); + } + } + } + Ok(()) + } + /// Instantiate a new Tokenizer, with the given Model pub fn new(model: M) -> Self { Self { @@ -1345,6 +1374,7 @@ where where T: Trainer + Sync, { + self._check_trainer_compat(trainer)?; // check that settings are compatible let mut len = 0; for file in files.iter() { len += File::open(file) @@ -1420,6 +1450,7 @@ where I: Iterator + Send, S: AsRef + Send, { + self._check_trainer_compat(trainer)?; // check that settings are compatible let (lower, upper) = sequences.size_hint(); let len = upper.unwrap_or(lower) as u64; let progress = if trainer.should_show_progress() { From 2ad4194b6fa97f9099ddf76a3398e1c21e4068a4 Mon Sep 17 00:00:00 2001 From: Sander Land Date: Tue, 22 Jul 2025 14:12:59 +0200 Subject: [PATCH 2/3] default false --- tokenizers/src/models/bpe/trainer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index b37c32df7..bacc27b30 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -72,7 +72,7 @@ impl Default for BpeTrainerBuilder { continuing_subword_prefix: None, end_of_word_suffix: None, max_token_length: None, - enforce_utf8_boundaries: true, + enforce_utf8_boundaries: false, }, } } From b18070530ab9a1f75ca7a4c7887d963db4175818 Mon Sep 17 00:00:00 2001 From: Sander Land Date: Sun, 3 Aug 2025 14:23:29 +0200 Subject: [PATCH 3/3] review comments --- tokenizers/src/models/bpe/trainer.rs | 48 +++------------------ tokenizers/src/pre_tokenizers/byte_level.rs | 4 +- tokenizers/src/tokenizer/mod.rs | 27 ++---------- 3 files changed, 13 insertions(+), 66 deletions(-) diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index bacc27b30..977d0e224 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -4,14 +4,13 @@ use super::{Pair, WithFirstLastIterator, Word, BPE}; use crate::parallelism::*; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::progress::{ProgressBar, ProgressStyle}; -use crate::pre_tokenizers::byte_level::bytes_char; +use crate::pre_tokenizers::byte_level::CHAR_BYTES; use ahash::{AHashMap, AHashSet}; use compact_str::CompactString; use dary_heap::OctonaryHeap; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::collections::HashSet; -use std::sync::LazyLock; #[derive(Debug, Eq)] struct Merge { @@ -226,10 +225,6 @@ impl Default for BpeTrainer { } } -/// for utf8 boundaries, we need to map gpt2 encoded bytes back -static CHAR_BYTES: LazyLock> = - LazyLock::new(|| bytes_char().into_iter().map(|(b, c)| (c, b)).collect()); - impl BpeTrainer { pub fn new(min_frequency: u64, vocab_size: usize) -> Self { @@ -327,18 +322,17 @@ impl BpeTrainer { let is_a_complete = std::str::from_utf8(&bytes_a).is_ok(); let is_b_complete = std::str::from_utf8(&bytes_b).is_ok(); - // Rule 1: Allow merging two complete tokens. + // - Allow merging two complete tokens. + // - Any mix of complete and incomplete is disallowed. if is_a_complete && is_b_complete { return true; } - - // Rule 3 (Implicit): Any mix of complete and incomplete is disallowed. - if is_a_complete || is_b_complete { + if is_a_complete ^ is_b_complete { return false; } - // Rule 2: Both tokens are incomplete. Allow merge only if building a valid - // UTF-8 prefix by appending a single byte. + // Here we know both tokens are incomplete. + // Allow merge only if building a valid UTF-8 prefix by appending a single byte. if bytes_b.len() == 1 { let mut merged = bytes_a; merged.extend_from_slice(&bytes_b); @@ -923,34 +917,10 @@ mod tests { assert_eq!(trained_vocab, expected_vocab) } - // The CHAR_TO_BYTE mapping is kept here *only* for the debug printing helper, - // to make the test output readable. It is not used in the core test logic. static BYTE_TO_CHAR: LazyLock> = LazyLock::new(bytes_char); - static CHAR_TO_BYTE: LazyLock> = - LazyLock::new(|| BYTE_TO_CHAR.iter().map(|(b, c)| (*c, *b)).collect()); #[test] fn test_bpe_utf8_boundary_enforcement_with_byte_level_pretokenizer() { - /// A local helper to print the vocabulary with original hex byte representations for clarity. - fn print_vocab_with_hex(vocab: &HashMap, title: &str) { - println!("\n--- {} ---", title); - let mut vocab_items: Vec<_> = vocab.iter().collect(); - vocab_items.sort_by_key(|(_, id)| *id); - for (token, id) in vocab_items { - // De-mangle the token back to its original bytes for printing - let bytes: Vec = token - .chars() - .map(|c| format!("{:02X}", CHAR_TO_BYTE.get(&c).unwrap_or(&0))) - .collect(); - println!( - "ID {:<3} Token: {:<12} Bytes: [{}]", - id, - format!("{:?}", token), - bytes.join(" ") - ); - } - } - // Use the actual ByteLevel pre-tokenizer to process the input string. let byte_level_pretok = ByteLevel::new(false, false, false); let process_fn = |s: &str| -> Result> { @@ -979,10 +949,7 @@ mod tests { unconstrained_trainer .train(&mut unconstrained_model) .unwrap(); - print_vocab_with_hex( - &unconstrained_model.get_vocab(), - "Unconstrained Vocabulary", - ); + let invalid_merge_token: String = [BYTE_TO_CHAR[&b' '], BYTE_TO_CHAR[&0xF0]].iter().collect(); assert!( @@ -1003,7 +970,6 @@ mod tests { .unwrap(); let mut constrained_model = BPE::default(); constrained_trainer.train(&mut constrained_model).unwrap(); - print_vocab_with_hex(&constrained_model.get_vocab(), "Constrained Vocabulary"); let valid_merge_token: String = [BYTE_TO_CHAR[&0xF0], BYTE_TO_CHAR[&0x9F]].iter().collect(); diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 200967f78..54a984282 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -12,7 +12,7 @@ use crate::utils::macro_rules_attribute; /// Converts bytes to unicode characters. /// See https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9 -pub fn bytes_char() -> AHashMap { +pub(crate) fn bytes_char() -> AHashMap { let mut bs: Vec = vec![]; bs.extend(b'!'..=b'~'); bs.extend(b'\xA1'..=b'\xAC'); @@ -45,7 +45,7 @@ static RE: LazyLock = LazyLock::new(|| { .unwrap() }); static BYTES_CHAR: LazyLock> = LazyLock::new(bytes_char); -static CHAR_BYTES: LazyLock> = +pub(crate) static CHAR_BYTES: LazyLock> = LazyLock::new(|| bytes_char().into_iter().map(|(c, b)| (b, c)).collect()); #[derive(Copy, Clone, Debug, PartialEq, Eq)] diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index ebd225efc..7a3086dc5 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -15,12 +15,6 @@ use std::{ io::{prelude::*, BufReader}, ops::{Deref, DerefMut}, path::{Path, PathBuf}, - any::Any, -}; - -use crate::{ - models::bpe, - pre_tokenizers, }; use serde::de::DeserializeOwned; @@ -543,23 +537,10 @@ where /// Validates compatibility between a trainer and the current tokenizer configuration. /// Currently only checks: // For BpeTrainer with `enforce_utf8_boundaries=True` => pretokenizer must be ByteLevel. - fn _check_trainer_compat + 'static>(&self, trainer: &T) -> Result<()> { - // Use `Any` to safely check for the BpeTrainer type at runtime - if let Some(bpe_trainer) = (trainer as &dyn Any).downcast_ref::() { - if bpe_trainer.enforce_utf8_boundaries { - // Now check if the pre_tokenizer is ByteLevel - let is_byte_level = self.pre_tokenizer.as_ref().map_or(false, |pretok| { - (pretok as &dyn Any).is::() - }); - - if !is_byte_level { - return Err( - "`enforce_utf8_boundaries=True` can only be used with a `ByteLevel` pre-tokenizer." - .into() - ); - } - } - } + fn _check_trainer_compat( + &self, + _trainer: &T, + ) -> Result<()> { Ok(()) }