33import os
44import re
55from pathlib import Path
6- from typing import Callable , Optional
6+ from typing import Optional
77
88import numpy as np
99from datasets import load_dataset
1313MAX_CHAR = 1000
1414
1515
16- def create_token_estimator (
17- model_name : str = "mistralai/Mistral-7B-Instruct-v0.2" ,
18- ) -> Callable [[ str ], int ]:
19- _tokenizer : Optional [AutoTokenizer ] = None
16+ class TokenCounter :
17+ def __init__ ( self , model_name : str = "mistralai/Mistral-7B-Instruct-v0.2" ):
18+ self . model_name = model_name
19+ self . _tokenizer : Optional [AutoTokenizer ] = None
2020
21- def initialize () -> None :
22- nonlocal _tokenizer
23- if _tokenizer is None :
21+ def _initialize_tokenizer (self ) -> None :
22+ if self ._tokenizer is None :
2423 os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
2524 try :
26- _tokenizer = AutoTokenizer .from_pretrained (model_name )
25+ self . _tokenizer = AutoTokenizer .from_pretrained (self . model_name )
2726 except (OSError , ImportError , ValueError ) as e :
2827 raise RuntimeError (f"Failed to initialize tokenizer: { e } " ) from e
2928
30- def estimate_num_tokens (text : str ) -> int :
31- initialize ()
29+ def estimate_num_tokens (self , text : str ) -> int :
30+ self . _initialize_tokenizer ()
3231
33- if _tokenizer is None :
32+ if self . _tokenizer is None :
3433 return 0
3534
3635 try :
37- encoding = _tokenizer (text , return_tensors = None )
36+ encoding = self . _tokenizer (text , return_tensors = None )
3837 return len (encoding ["input_ids" ])
3938 except (AttributeError , TypeError , RuntimeError ) as e :
4039 raise ValueError (f"Error processing text: { e } " ) from e
4140
42- return estimate_num_tokens
43-
4441
4542def extract_and_save_with_filtering (file ):
4643 """substract human prompts and apply filtering conditions"""
@@ -93,7 +90,7 @@ def extract_and_save_with_filtering(file):
9390 with Path (sharegpt_file ).open ("r" , encoding = "utf-8" ) as file :
9491 data = json .load (file )
9592
96- estimate_tokens = create_token_estimator ()
93+ counter = TokenCounter ()
9794 num_of_ids = len (data )
9895 data = data [: int (num_of_ids * args .parse )]
9996 for d in data :
@@ -102,9 +99,9 @@ def extract_and_save_with_filtering(file):
10299 gpt_tokens = []
103100 for conv in d ["conversations" ]:
104101 if conv ["from" ] == "human" :
105- human_tokens .append (estimate_tokens (conv ["value" ]))
102+ human_tokens .append (counter . estimate_num_tokens (conv ["value" ]))
106103 if conv ["from" ] == "gpt" :
107- token_number = estimate_tokens (conv ["value" ])
104+ token_number = counter . estimate_num_tokens (conv ["value" ])
108105 conv ["num_tokens" ] = token_number
109106 gpt_tokens .append (token_number )
110107 if len (human_tokens ) == 0 :
0 commit comments