Skip to content

Commit be25814

Browse files
authored
Consolidated optimization ahash dary compact str (#1799)
* free speed/mem optimizations with ahash, dary_heap, and compact_str (bindings broken without library refactor) * Rebased ahash. * Removing data files. * Fixing (dummily) the python pyobject conversion. * Fixing the surface by not providing ahash. * Fixing the python side with the ahash public api removal. * Cleanup. * Removing test file. * Remove dead file. * Bad conflict resolution. * Bad merge 2.
1 parent f81b262 commit be25814

File tree

34 files changed

+318
-252
lines changed

34 files changed

+318
-252
lines changed

bindings/node/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ napi = "2"
1414
napi-derive = "2"
1515
serde = { version = "1.0.163", features = ["derive"] }
1616
tokenizers = { path = "../../tokenizers/" }
17+
ahash = { version = "0.8.11", features = ["serde"] }
1718

1819
[build-dependencies]
1920
napi-build = "2"

bindings/node/src/models.rs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
use crate::arc_rwlock_serde;
22
use crate::tasks::models::{BPEFromFilesTask, WordLevelFromFilesTask, WordPieceFromFilesTask};
33
use crate::trainers::Trainer;
4+
use ahash::AHashMap;
45
use napi::bindgen_prelude::*;
56
use napi_derive::napi;
67
use serde::{Deserialize, Serialize};
78
use std::collections::HashMap;
89
use std::path::{Path, PathBuf};
910
use std::sync::{Arc, RwLock};
1011
use tokenizers as tk;
11-
use tokenizers::models::bpe::{BpeBuilder, Merges, Vocab};
12+
use tokenizers::models::bpe::{BpeBuilder, Merges};
1213
use tokenizers::models::wordlevel::WordLevelBuilder;
1314
use tokenizers::models::wordpiece::WordPieceBuilder;
1415

@@ -44,8 +45,13 @@ impl Bpe {
4445
}
4546

4647
#[napi(factory, ts_return_type = "Model")]
47-
pub fn init(vocab: Vocab, merges: Merges, options: Option<BpeOptions>) -> Result<Model> {
48+
pub fn init(
49+
vocab: HashMap<String, u32>,
50+
merges: Merges,
51+
options: Option<BpeOptions>,
52+
) -> Result<Model> {
4853
let options = options.unwrap_or_default();
54+
let vocab: AHashMap<_, _> = vocab.into_iter().collect();
4955
let mut builder = tk::models::bpe::BPE::builder().vocab_and_merges(vocab, merges);
5056
builder = options.apply_to_bpe_builder(builder);
5157
let model = builder
@@ -206,10 +212,11 @@ pub struct WordPiece {}
206212
#[napi]
207213
impl WordPiece {
208214
#[napi(factory, ts_return_type = "Model")]
209-
pub fn init(vocab: Vocab, options: Option<WordPieceOptions>) -> Result<Model> {
215+
pub fn init(vocab: HashMap<String, u32>, options: Option<WordPieceOptions>) -> Result<Model> {
210216
let options = options.unwrap_or_default();
211217

212-
let mut builder = tk::models::wordpiece::WordPiece::builder().vocab(vocab);
218+
let mut builder = tk::models::wordpiece::WordPiece::builder()
219+
.vocab(vocab.into_iter().collect::<AHashMap<_, _>>());
213220
builder = options.apply_to_wordpiece_builder(builder);
214221
let model = builder
215222
.build()
@@ -263,9 +270,10 @@ pub struct WordLevel {}
263270
#[napi]
264271
impl WordLevel {
265272
#[napi(factory, ts_return_type = "Model")]
266-
pub fn init(vocab: Vocab, options: Option<WordLevelOptions>) -> Result<Model> {
273+
pub fn init(vocab: HashMap<String, u32>, options: Option<WordLevelOptions>) -> Result<Model> {
267274
let options = options.unwrap_or_default();
268-
let mut builder = tk::models::wordlevel::WordLevel::builder().vocab(vocab);
275+
let mut builder =
276+
tk::models::wordlevel::WordLevel::builder().vocab(vocab.into_iter().collect());
269277
builder = options.apply_to_wordlevel_builder(builder);
270278
let model = builder
271279
.build()

bindings/python/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pyo3 = { version = "0.25", features = ["abi3", "abi3-py39", "py-clone"] }
1818
numpy = "0.25"
1919
ndarray = "0.16"
2020
itertools = "0.14"
21+
ahash = { version = "0.8.11", features = ["serde"] }
2122

2223
[dependencies.tokenizers]
2324
path = "../../tokenizers"

bindings/python/src/models.rs

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ use std::sync::{Arc, RwLock};
44

55
use crate::token::PyToken;
66
use crate::trainers::PyTrainer;
7+
use ahash::AHashMap;
78
use pyo3::exceptions;
89
use pyo3::prelude::*;
910
use pyo3::types::*;
1011
use serde::{Deserialize, Serialize};
11-
use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
12+
use tk::models::bpe::{BpeBuilder, Merges, BPE};
1213
use tk::models::unigram::Unigram;
1314
use tk::models::wordlevel::WordLevel;
1415
use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
@@ -347,9 +348,10 @@ macro_rules! setter {
347348

348349
#[derive(FromPyObject)]
349350
enum PyVocab {
350-
Vocab(Vocab),
351+
Vocab(HashMap<String, u32>),
351352
Filename(String),
352353
}
354+
353355
#[derive(FromPyObject)]
354356
enum PyMerges {
355357
Merges(Merges),
@@ -454,6 +456,7 @@ impl PyBPE {
454456
if let (Some(vocab), Some(merges)) = (vocab, merges) {
455457
match (vocab, merges) {
456458
(PyVocab::Vocab(vocab), PyMerges::Merges(merges)) => {
459+
let vocab: AHashMap<_, _> = vocab.into_iter().collect();
457460
builder = builder.vocab_and_merges(vocab, merges);
458461
}
459462
(PyVocab::Filename(vocab_filename), PyMerges::Filename(merges_filename)) => {
@@ -494,13 +497,15 @@ impl PyBPE {
494497
/// The vocabulary and merges loaded into memory
495498
#[staticmethod]
496499
#[pyo3(text_signature = "(self, vocab, merges)")]
497-
fn read_file(vocab: &str, merges: &str) -> PyResult<(Vocab, Merges)> {
498-
BPE::read_file(vocab, merges).map_err(|e| {
500+
fn read_file(vocab: &str, merges: &str) -> PyResult<(HashMap<String, u32>, Merges)> {
501+
let (vocab, merges) = BPE::read_file(vocab, merges).map_err(|e| {
499502
exceptions::PyException::new_err(format!(
500503
"Error while reading vocab & merges files: {}",
501504
e
502505
))
503-
})
506+
})?;
507+
let vocab = vocab.into_iter().collect();
508+
Ok((vocab, merges))
504509
}
505510

506511
/// Instantiate a BPE model from the given files.
@@ -536,6 +541,7 @@ impl PyBPE {
536541
let (vocab, merges) = BPE::read_file(vocab, merges).map_err(|e| {
537542
exceptions::PyException::new_err(format!("Error while reading BPE files: {}", e))
538543
})?;
544+
let vocab = vocab.into_iter().collect();
539545
Py::new(
540546
py,
541547
PyBPE::new(
@@ -668,6 +674,7 @@ impl PyWordPiece {
668674
if let Some(vocab) = vocab {
669675
match vocab {
670676
PyVocab::Vocab(vocab) => {
677+
let vocab: AHashMap<_, _> = vocab.into_iter().collect();
671678
builder = builder.vocab(vocab);
672679
}
673680
PyVocab::Filename(vocab_filename) => {
@@ -699,10 +706,11 @@ impl PyWordPiece {
699706
/// :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict`
700707
#[staticmethod]
701708
#[pyo3(text_signature = "(vocab)")]
702-
fn read_file(vocab: &str) -> PyResult<Vocab> {
703-
WordPiece::read_file(vocab).map_err(|e| {
709+
fn read_file(vocab: &str) -> PyResult<HashMap<String, u32>> {
710+
let vocab = WordPiece::read_file(vocab).map_err(|e| {
704711
exceptions::PyException::new_err(format!("Error while reading WordPiece file: {}", e))
705-
})
712+
})?;
713+
Ok(vocab.into_iter().collect())
706714
}
707715

708716
/// Instantiate a WordPiece model from the given file
@@ -734,6 +742,7 @@ impl PyWordPiece {
734742
let vocab = WordPiece::read_file(vocab).map_err(|e| {
735743
exceptions::PyException::new_err(format!("Error while reading WordPiece file: {}", e))
736744
})?;
745+
let vocab = vocab.into_iter().collect();
737746
Py::new(
738747
py,
739748
PyWordPiece::new(py, Some(PyVocab::Vocab(vocab)), kwargs)?,
@@ -778,6 +787,7 @@ impl PyWordLevel {
778787
if let Some(vocab) = vocab {
779788
match vocab {
780789
PyVocab::Vocab(vocab) => {
790+
let vocab = vocab.into_iter().collect();
781791
builder = builder.vocab(vocab);
782792
}
783793
PyVocab::Filename(vocab_filename) => {
@@ -818,10 +828,12 @@ impl PyWordLevel {
818828
/// :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict`
819829
#[staticmethod]
820830
#[pyo3(text_signature = "(vocab)")]
821-
fn read_file(vocab: &str) -> PyResult<Vocab> {
822-
WordLevel::read_file(vocab).map_err(|e| {
831+
fn read_file(vocab: &str) -> PyResult<HashMap<String, u32>> {
832+
let vocab = WordLevel::read_file(vocab).map_err(|e| {
823833
exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e))
824-
})
834+
})?;
835+
let vocab: HashMap<_, _> = vocab.into_iter().collect();
836+
Ok(vocab)
825837
}
826838

827839
/// Instantiate a WordLevel model from the given file
@@ -853,6 +865,7 @@ impl PyWordLevel {
853865
let vocab = WordLevel::read_file(vocab).map_err(|e| {
854866
exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e))
855867
})?;
868+
let vocab = vocab.into_iter().collect();
856869
Py::new(
857870
py,
858871
PyWordLevel::new(py, Some(PyVocab::Vocab(vocab)), unk_token)?,

tokenizers/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ fancy-regex = { version = "0.14", optional = true}
6666
getrandom = { version = "0.3" }
6767
esaxx-rs = { version = "0.1.10", default-features = false, features=[]}
6868
monostate = "0.1.12"
69+
ahash = { version = "0.8.11", features = ["serde"] }
70+
dary_heap = { version = "0.3.6", features = ["serde"] }
71+
compact_str = { version = "0.9", features = ["serde"] }
6972

7073
[features]
7174
default = ["progressbar", "onig", "esaxx_fast"]

tokenizers/src/models/bpe/model.rs

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@ use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word};
22
use crate::tokenizer::{Model, Result, Token};
33
use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH};
44
use crate::utils::iter::ResultShunt;
5+
use ahash::AHashMap;
56
use serde_json::Value;
67
use std::borrow::Cow;
8+
9+
use std::collections::HashMap;
710
use std::{
8-
collections::HashMap,
911
fs::File,
1012
io::prelude::*,
1113
io::{BufRead, BufReader},
1214
path::{Path, PathBuf},
1315
};
1416

15-
pub type Vocab = HashMap<String, u32>;
16-
type VocabR = HashMap<u32, String>;
17-
pub type MergeMap = HashMap<Pair, (u32, u32)>;
17+
pub type Vocab = AHashMap<String, u32>;
18+
type VocabR = AHashMap<u32, String>;
19+
pub type MergeMap = AHashMap<Pair, (u32, u32)>;
1820
pub type Merges = Vec<(String, String)>;
1921

2022
struct Config {
@@ -41,7 +43,7 @@ impl Default for BpeBuilder {
4143
Self {
4244
config: Config {
4345
files: None,
44-
vocab: HashMap::new(),
46+
vocab: AHashMap::new(),
4547
merges: vec![],
4648
cache_capacity: DEFAULT_CACHE_CAPACITY,
4749
dropout: None,
@@ -71,8 +73,12 @@ impl BpeBuilder {
7173

7274
/// Set the vocab (token -> ID) and merges mappings.
7375
#[must_use]
74-
pub fn vocab_and_merges(mut self, vocab: Vocab, merges: Merges) -> Self {
75-
self.config.vocab = vocab;
76+
pub fn vocab_and_merges<V: Into<AHashMap<String, u32>>>(
77+
mut self,
78+
vocab: V,
79+
merges: Merges,
80+
) -> Self {
81+
self.config.vocab = vocab.into();
7682
self.config.merges = merges;
7783
self
7884
}
@@ -324,7 +330,7 @@ impl BPE {
324330
let mut buffer = String::new();
325331
vocab_file.read_to_string(&mut buffer)?;
326332
let json: Value = serde_json::from_str(&buffer)?;
327-
let mut vocab = HashMap::new();
333+
let mut vocab = AHashMap::new();
328334
match json {
329335
Value::Object(m) => {
330336
for (token, id) in m {
@@ -361,8 +367,8 @@ impl BPE {
361367
}
362368
}
363369

364-
pub fn get_vocab(&self) -> Vocab {
365-
self.vocab.clone()
370+
pub fn get_vocab(&self) -> HashMap<String, u32> {
371+
self.vocab.clone().into_iter().collect()
366372
}
367373

368374
pub fn get_unk_token(&self) -> &Option<String> {
@@ -494,7 +500,7 @@ impl Model for BPE {
494500
type Trainer = BpeTrainer;
495501

496502
fn get_vocab(&self) -> HashMap<String, u32> {
497-
self.vocab.clone()
503+
self.vocab.clone().into_iter().collect()
498504
}
499505

500506
fn get_vocab_size(&self) -> usize {

tokenizers/src/models/bpe/serialization.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE};
2+
use ahash::AHashMap;
23
use serde::{
34
de::{Error, MapAccess, Visitor},
45
ser::SerializeStruct,
56
Deserialize, Deserializer, Serialize, Serializer,
67
};
7-
use std::collections::HashMap;
88

99
impl Serialize for BPE {
1010
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
@@ -80,7 +80,7 @@ impl<'de> Visitor<'de> for BPEVisitor {
8080
V: MapAccess<'de>,
8181
{
8282
let mut builder = BpeBuilder::new();
83-
let mut vocab: Option<HashMap<String, u32>> = None;
83+
let mut vocab: Option<AHashMap<String, u32>> = None;
8484

8585
#[derive(Debug, Deserialize)]
8686
#[serde(untagged)]

0 commit comments

Comments
 (0)