@@ -4,11 +4,12 @@ use std::sync::{Arc, RwLock};
44
55use crate :: token:: PyToken ;
66use crate :: trainers:: PyTrainer ;
7+ use ahash:: AHashMap ;
78use pyo3:: exceptions;
89use pyo3:: prelude:: * ;
910use pyo3:: types:: * ;
1011use serde:: { Deserialize , Serialize } ;
11- use tk:: models:: bpe:: { BpeBuilder , Merges , Vocab , BPE } ;
12+ use tk:: models:: bpe:: { BpeBuilder , Merges , BPE } ;
1213use tk:: models:: unigram:: Unigram ;
1314use tk:: models:: wordlevel:: WordLevel ;
1415use tk:: models:: wordpiece:: { WordPiece , WordPieceBuilder } ;
@@ -347,9 +348,10 @@ macro_rules! setter {
347348
348349#[ derive( FromPyObject ) ]
349350enum PyVocab {
350- Vocab ( Vocab ) ,
351+ Vocab ( HashMap < String , u32 > ) ,
351352 Filename ( String ) ,
352353}
354+
353355#[ derive( FromPyObject ) ]
354356enum PyMerges {
355357 Merges ( Merges ) ,
@@ -454,6 +456,7 @@ impl PyBPE {
454456 if let ( Some ( vocab) , Some ( merges) ) = ( vocab, merges) {
455457 match ( vocab, merges) {
456458 ( PyVocab :: Vocab ( vocab) , PyMerges :: Merges ( merges) ) => {
459+ let vocab: AHashMap < _ , _ > = vocab. into_iter ( ) . collect ( ) ;
457460 builder = builder. vocab_and_merges ( vocab, merges) ;
458461 }
459462 ( PyVocab :: Filename ( vocab_filename) , PyMerges :: Filename ( merges_filename) ) => {
@@ -494,13 +497,15 @@ impl PyBPE {
494497 /// The vocabulary and merges loaded into memory
495498 #[ staticmethod]
496499 #[ pyo3( text_signature = "(self, vocab, merges)" ) ]
497- fn read_file ( vocab : & str , merges : & str ) -> PyResult < ( Vocab , Merges ) > {
498- BPE :: read_file ( vocab, merges) . map_err ( |e| {
500+ fn read_file ( vocab : & str , merges : & str ) -> PyResult < ( HashMap < String , u32 > , Merges ) > {
501+ let ( vocab , merges ) = BPE :: read_file ( vocab, merges) . map_err ( |e| {
499502 exceptions:: PyException :: new_err ( format ! (
500503 "Error while reading vocab & merges files: {}" ,
501504 e
502505 ) )
503- } )
506+ } ) ?;
507+ let vocab = vocab. into_iter ( ) . collect ( ) ;
508+ Ok ( ( vocab, merges) )
504509 }
505510
506511 /// Instantiate a BPE model from the given files.
@@ -536,6 +541,7 @@ impl PyBPE {
536541 let ( vocab, merges) = BPE :: read_file ( vocab, merges) . map_err ( |e| {
537542 exceptions:: PyException :: new_err ( format ! ( "Error while reading BPE files: {}" , e) )
538543 } ) ?;
544+ let vocab = vocab. into_iter ( ) . collect ( ) ;
539545 Py :: new (
540546 py,
541547 PyBPE :: new (
@@ -668,6 +674,7 @@ impl PyWordPiece {
668674 if let Some ( vocab) = vocab {
669675 match vocab {
670676 PyVocab :: Vocab ( vocab) => {
677+ let vocab: AHashMap < _ , _ > = vocab. into_iter ( ) . collect ( ) ;
671678 builder = builder. vocab ( vocab) ;
672679 }
673680 PyVocab :: Filename ( vocab_filename) => {
@@ -699,10 +706,11 @@ impl PyWordPiece {
699706 /// :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict`
700707 #[ staticmethod]
701708 #[ pyo3( text_signature = "(vocab)" ) ]
702- fn read_file ( vocab : & str ) -> PyResult < Vocab > {
703- WordPiece :: read_file ( vocab) . map_err ( |e| {
709+ fn read_file ( vocab : & str ) -> PyResult < HashMap < String , u32 > > {
710+ let vocab = WordPiece :: read_file ( vocab) . map_err ( |e| {
704711 exceptions:: PyException :: new_err ( format ! ( "Error while reading WordPiece file: {}" , e) )
705- } )
712+ } ) ?;
713+ Ok ( vocab. into_iter ( ) . collect ( ) )
706714 }
707715
708716 /// Instantiate a WordPiece model from the given file
@@ -734,6 +742,7 @@ impl PyWordPiece {
734742 let vocab = WordPiece :: read_file ( vocab) . map_err ( |e| {
735743 exceptions:: PyException :: new_err ( format ! ( "Error while reading WordPiece file: {}" , e) )
736744 } ) ?;
745+ let vocab = vocab. into_iter ( ) . collect ( ) ;
737746 Py :: new (
738747 py,
739748 PyWordPiece :: new ( py, Some ( PyVocab :: Vocab ( vocab) ) , kwargs) ?,
@@ -778,6 +787,7 @@ impl PyWordLevel {
778787 if let Some ( vocab) = vocab {
779788 match vocab {
780789 PyVocab :: Vocab ( vocab) => {
790+ let vocab = vocab. into_iter ( ) . collect ( ) ;
781791 builder = builder. vocab ( vocab) ;
782792 }
783793 PyVocab :: Filename ( vocab_filename) => {
@@ -818,10 +828,12 @@ impl PyWordLevel {
818828 /// :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict`
819829 #[ staticmethod]
820830 #[ pyo3( text_signature = "(vocab)" ) ]
821- fn read_file ( vocab : & str ) -> PyResult < Vocab > {
822- WordLevel :: read_file ( vocab) . map_err ( |e| {
831+ fn read_file ( vocab : & str ) -> PyResult < HashMap < String , u32 > > {
832+ let vocab = WordLevel :: read_file ( vocab) . map_err ( |e| {
823833 exceptions:: PyException :: new_err ( format ! ( "Error while reading WordLevel file: {}" , e) )
824- } )
834+ } ) ?;
835+ let vocab: HashMap < _ , _ > = vocab. into_iter ( ) . collect ( ) ;
836+ Ok ( vocab)
825837 }
826838
827839 /// Instantiate a WordLevel model from the given file
@@ -853,6 +865,7 @@ impl PyWordLevel {
853865 let vocab = WordLevel :: read_file ( vocab) . map_err ( |e| {
854866 exceptions:: PyException :: new_err ( format ! ( "Error while reading WordLevel file: {}" , e) )
855867 } ) ?;
868+ let vocab = vocab. into_iter ( ) . collect ( ) ;
856869 Py :: new (
857870 py,
858871 PyWordLevel :: new ( py, Some ( PyVocab :: Vocab ( vocab) ) , unk_token) ?,
0 commit comments