|
10 | 10 | import yaml |
11 | 11 | from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict |
12 | 12 | from loguru import logger |
13 | | -from pydantic import BeforeValidator, Field, PositiveFloat, PositiveInt |
| 13 | +from pydantic import BeforeValidator, Field, PositiveFloat, PositiveInt, SkipValidation |
14 | 14 | from transformers.tokenization_utils_base import ( # type: ignore[import] |
15 | 15 | PreTrainedTokenizerBase, |
16 | 16 | ) |
@@ -115,16 +115,18 @@ class Config: |
115 | 115 | # types like PreTrainedTokenizerBase |
116 | 116 | arbitrary_types_allowed = True |
117 | 117 |
|
118 | | - data: ( |
| 118 | + data: Annotated[ |
119 | 119 | Iterable[str] |
120 | 120 | | Iterable[dict[str, Any]] |
121 | 121 | | Dataset |
122 | 122 | | DatasetDict |
123 | 123 | | IterableDataset |
124 | 124 | | IterableDatasetDict |
125 | 125 | | str |
126 | | - | Path |
127 | | - ) |
| 126 | + | Path, |
| 127 | + # BUG: See https://github.com/pydantic/pydantic/issues/9541 |
| 128 | + SkipValidation, |
| 129 | + ] |
128 | 130 | profile: StrategyType | ProfileType | Profile |
129 | 131 | rate: Annotated[list[PositiveFloat] | None, BeforeValidator(parse_float_list)] = ( |
130 | 132 | None |
@@ -159,7 +161,7 @@ def enable_scenarios(func: Callable) -> Any: |
159 | 161 | @wraps(func) |
160 | 162 | async def decorator(*args, scenario: Scenario | None = None, **kwargs) -> Any: |
161 | 163 | if scenario is not None: |
162 | | - kwargs.update(**vars(scenario)) |
| 164 | + kwargs.update(**scenario.model_dump()) |
163 | 165 | return await func(*args, **kwargs) |
164 | 166 |
|
165 | 167 | # Modify the signature of the decorator to include the `scenario` argument |
|
0 commit comments