11import asyncio
2+ import codecs
23import json
34from pathlib import Path
45from typing import get_args
89from guidellm .backend import BackendType
910from guidellm .benchmark import ProfileType , benchmark_generative_text
1011from guidellm .config import print_config
12+ from guidellm .preprocess .dataset import ShortPromptStrategy , process_dataset
1113from guidellm .scheduler import StrategyType
1214
1315STRATEGY_PROFILE_CHOICES = set (
@@ -280,6 +282,20 @@ def benchmark(
280282 )
281283
282284
285+ def decode_escaped_str (_ctx , _param , value ):
286+ """
287+ Click auto adds characters. For example, when using --pad-char "\n ",
288+ it parses it as "\\ n". This method decodes the string to handle escape
289+ sequences correctly.
290+ """
291+ if value is None :
292+ return None
293+ try :
294+ return codecs .decode (value , "unicode_escape" )
295+ except Exception as e :
296+ raise click .BadParameter (f"Could not decode escape sequences: { e } " ) from e
297+
298+
283299@cli .command (
284300 help = (
285301 "Print out the available configuration settings that can be set "
@@ -290,5 +306,139 @@ def config():
290306 print_config ()
291307
292308
309+ @cli .group (help = "General preprocessing tools and utilities." )
310+ def preprocess ():
311+ pass
312+
313+
314+ @preprocess .command (
315+ help = (
316+ "Convert a dataset to have specific prompt and output token sizes.\n "
317+ "DATA: Path to the input dataset or dataset ID.\n "
318+ "OUTPUT_PATH: Path to save the converted dataset, including file suffix."
319+ )
320+ )
321+ @click .argument (
322+ "data" ,
323+ type = str ,
324+ required = True ,
325+ )
326+ @click .argument (
327+ "output_path" ,
328+ type = click .Path (file_okay = True , dir_okay = False , writable = True , resolve_path = True ),
329+ required = True ,
330+ )
331+ @click .option (
332+ "--processor" ,
333+ type = str ,
334+ required = True ,
335+ help = (
336+ "The processor or tokenizer to use to calculate token counts for statistics "
337+ "and synthetic data generation."
338+ ),
339+ )
340+ @click .option (
341+ "--processor-args" ,
342+ default = None ,
343+ callback = parse_json ,
344+ help = (
345+ "A JSON string containing any arguments to pass to the processor constructor "
346+ "as a dict with **kwargs."
347+ ),
348+ )
349+ @click .option (
350+ "--data-args" ,
351+ callback = parse_json ,
352+ help = (
353+ "A JSON string containing any arguments to pass to the dataset creation "
354+ "as a dict with **kwargs."
355+ ),
356+ )
357+ @click .option (
358+ "--short-prompt-strategy" ,
359+ type = click .Choice ([s .value for s in ShortPromptStrategy ]),
360+ default = ShortPromptStrategy .IGNORE .value ,
361+ show_default = True ,
362+ help = "Strategy to handle prompts shorter than the target length. " ,
363+ )
364+ @click .option (
365+ "--pad-char" ,
366+ type = str ,
367+ default = "" ,
368+ callback = decode_escaped_str ,
369+ help = "The token to pad short prompts with when using the 'pad' strategy." ,
370+ )
371+ @click .option (
372+ "--concat-delimiter" ,
373+ type = str ,
374+ default = "" ,
375+ help = (
376+ "The delimiter to use when concatenating prompts that are too short."
377+ " Used when strategy is 'concatenate'."
378+ ),
379+ )
380+ @click .option (
381+ "--prompt-tokens" ,
382+ type = str ,
383+ default = None ,
384+ help = "Prompt tokens config (JSON, YAML file or key=value string)" ,
385+ )
386+ @click .option (
387+ "--output-tokens" ,
388+ type = str ,
389+ default = None ,
390+ help = "Output tokens config (JSON, YAML file or key=value string)" ,
391+ )
392+ @click .option (
393+ "--push-to-hub" ,
394+ is_flag = True ,
395+ help = "Set this flag to push the converted dataset to the Hugging Face Hub." ,
396+ )
397+ @click .option (
398+ "--hub-dataset-id" ,
399+ type = str ,
400+ default = None ,
401+ help = "The Hugging Face Hub dataset ID to push to. "
402+ "Required if --push-to-hub is used." ,
403+ )
404+ @click .option (
405+ "--random-seed" ,
406+ type = int ,
407+ default = 42 ,
408+ show_default = True ,
409+ help = "Random seed for prompt token sampling and output tokens sampling." ,
410+ )
411+ def dataset (
412+ data ,
413+ output_path ,
414+ processor ,
415+ processor_args ,
416+ data_args ,
417+ short_prompt_strategy ,
418+ pad_char ,
419+ concat_delimiter ,
420+ prompt_tokens ,
421+ output_tokens ,
422+ push_to_hub ,
423+ hub_dataset_id ,
424+ random_seed ,
425+ ):
426+ process_dataset (
427+ data = data ,
428+ output_path = output_path ,
429+ processor = processor ,
430+ prompt_tokens = prompt_tokens ,
431+ output_tokens = output_tokens ,
432+ processor_args = processor_args ,
433+ data_args = data_args ,
434+ short_prompt_strategy = short_prompt_strategy ,
435+ pad_char = pad_char ,
436+ concat_delimiter = concat_delimiter ,
437+ push_to_hub = push_to_hub ,
438+ hub_dataset_id = hub_dataset_id ,
439+ random_seed = random_seed ,
440+ )
441+
442+
293443if __name__ == "__main__" :
294444 cli ()
0 commit comments