11from __future__ import annotations
22
3- from collections .abc import Iterable
43from pathlib import Path
54from typing import Any , Literal
65
7- from datasets import Dataset , DatasetDict , IterableDataset , IterableDatasetDict
8- from transformers import ( # type: ignore[import]
9- PreTrainedTokenizerBase ,
10- )
11-
126from guidellm .backends import (
137 Backend ,
148 BackendType ,
159 GenerationRequest ,
1610 GenerationResponse ,
1711)
1812from guidellm .benchmark .aggregator import (
19- Aggregator ,
20- CompilableAggregator ,
2113 GenerativeRequestsAggregator ,
2214 GenerativeStatsProgressAggregator ,
2315 SchedulerStatsAggregator ,
2921 GenerativeBenchmarkerOutput ,
3022)
3123from guidellm .benchmark .profile import Profile , ProfileType
32- from guidellm .benchmark .progress import (
33- BenchmarkerProgress ,
34- BenchmarkerProgressGroup ,
35- )
24+ from guidellm .benchmark .progress import BenchmarkerProgressGroup
3625from guidellm .benchmark .scenario import enable_scenarios
26+ from guidellm .benchmark .type import OutputFormatType , DataInputType , ProcessorInputType , ProgressInputType , \
27+ AggregatorInputType
3728from guidellm .request import GenerativeRequestLoader
3829from guidellm .scheduler import (
3930 ConstraintInitializer ,
5142_CURRENT_WORKING_DIR = Path .cwd ()
5243
5344
54- # Data types
55-
56- DataType = (
57- Iterable [str ]
58- | Iterable [dict [str , Any ]]
59- | Dataset
60- | DatasetDict
61- | IterableDataset
62- | IterableDatasetDict
63- | str
64- | Path
65- )
66-
67- OutputFormatType = (
68- tuple [str , ...]
69- | list [str ]
70- | dict [str , str | dict [str , Any ] | GenerativeBenchmarkerOutput ]
71- | None
72- )
73-
74-
7545# Helper functions
7646
7747async def initialize_backend (
@@ -147,7 +117,7 @@ async def finalize_outputs(
147117@enable_scenarios
148118async def benchmark_generative_text ( # noqa: C901
149119 target : str ,
150- data : DataType ,
120+ data : DataInputType ,
151121 profile : StrategyType | ProfileType | Profile ,
152122 rate : list [float ] | None = None ,
153123 random_seed : int = 42 ,
@@ -156,20 +126,18 @@ async def benchmark_generative_text( # noqa: C901
156126 backend_kwargs : dict [str , Any ] | None = None ,
157127 model : str | None = None ,
158128 # Data configuration
159- processor : str | Path | PreTrainedTokenizerBase | None = None ,
129+ processor : ProcessorInputType | None = None ,
160130 processor_args : dict [str , Any ] | None = None ,
161131 data_args : dict [str , Any ] | None = None ,
162132 data_sampler : Literal ["random" ] | None = None ,
163133 # Output configuration
164134 output_path : str | Path | None = _CURRENT_WORKING_DIR ,
165135 output_formats : OutputFormatType = ("console" , "json" , "html" , "csv" ),
166136 # Updates configuration
167- progress : tuple [ str , ...] | list [ str ] | list [ BenchmarkerProgress ] | None = None ,
137+ progress : ProgressInputType | None = None ,
168138 print_updates : bool = False ,
169139 # Aggregators configuration
170- add_aggregators : (
171- dict [str , str | dict [str , Any ] | Aggregator | CompilableAggregator ] | None
172- ) = None ,
140+ add_aggregators : AggregatorInputType | None = None ,
173141 warmup : float | None = None ,
174142 cooldown : float | None = None ,
175143 request_samples : int | None = 20 ,
0 commit comments