33from collections .abc import Iterable , Iterator
44from itertools import cycle
55from pathlib import Path
6- from typing import Any , Literal , Optional , Union
6+ from typing import Any , Optional , TypedDict , Union
77
88import yaml
99from datasets import (
@@ -69,6 +69,26 @@ class SyntheticDatasetConfig(BaseModel):
6969 gt = 0 ,
7070 default = None ,
7171 )
72+ turns : int = Field (
73+ description = "The number of turns in the conversation." ,
74+ gt = 0 ,
75+ default = 1 ,
76+ )
77+ turns_stdev : Optional [int ] = Field (
78+ description = "The standard deviation of the number of turns." ,
79+ gt = 0 ,
80+ default = None ,
81+ )
82+ turns_min : Optional [int ] = Field (
83+ description = "The minimum number of turns in the conversation." ,
84+ gt = 0 ,
85+ default = None ,
86+ )
87+ turns_max : Optional [int ] = Field (
88+ description = "The maximum number of turns in the conversation." ,
89+ gt = 0 ,
90+ default = None ,
91+ )
7292 samples : int = Field (
7393 description = "The number of samples to generate for the dataset." ,
7494 gt = 0 ,
@@ -124,14 +144,13 @@ def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig":
124144 return SyntheticDatasetConfig (** config_dict )
125145
126146
127- class SyntheticTextItemsGenerator (
128- Iterable [
129- dict [
130- Literal ["prompt" , "prompt_tokens_count" , "output_tokens_count" ],
131- Union [str , int ],
132- ]
133- ]
134- ):
147+ class SyntheticDatasetRow (TypedDict ):
148+ prompt : list [str ]
149+ prompt_tokens_count : list [int ]
150+ output_tokens_count : list [int ]
151+
152+
153+ class SyntheticTextItemsGenerator (Iterable [SyntheticDatasetRow ]):
135154 def __init__ (
136155 self ,
137156 config : SyntheticDatasetConfig ,
@@ -147,12 +166,7 @@ def __init__(
147166
148167 def __iter__ (
149168 self ,
150- ) -> Iterator [
151- dict [
152- Literal ["prompt" , "prompt_tokens_count" , "output_tokens_count" ],
153- Union [str , int ],
154- ]
155- ]:
169+ ) -> Iterator [SyntheticDatasetRow ]:
156170 prompt_tokens_sampler = IntegerRangeSampler (
157171 average = self .config .prompt_tokens ,
158172 variance = self .config .prompt_tokens_stdev ,
@@ -167,31 +181,56 @@ def __iter__(
167181 max_value = self .config .output_tokens_max ,
168182 random_seed = self .random_seed + 1 , # ensure diff dist from prompts
169183 )
184+ turns_sampler = IntegerRangeSampler (
185+ average = self .config .turns ,
186+ variance = self .config .turns_stdev ,
187+ min_value = self .config .turns_min ,
188+ max_value = self .config .turns_max ,
189+ random_seed = self .random_seed + 7 , # ensure diff dist
190+ )
170191 # ensure diff distribution from output tokens
171192 rand = random .Random (self .random_seed + 2 ) # noqa: S311
172193 unique_prefix_iter = cycle (self .processor .get_vocab ().values ())
173194
174195 prefix_index = rand .randint (0 , len (self .text_creator .words ))
175196 prefix_tokens = self ._create_prompt (self .config .prefix_tokens , prefix_index )
176197
177- for _ , prompt_tokens , output_tokens in zip (
178- range (self .config .samples ),
179- prompt_tokens_sampler ,
180- output_tokens_sampler ,
181- ):
182- start_index = rand .randint (0 , len (self .text_creator .words ))
183- prompt_text = self .processor .decode (
184- prefix_tokens
185- + self ._create_prompt (
186- prompt_tokens , start_index , next (unique_prefix_iter )
187- ),
188- skip_special_tokens = True ,
189- )
190- yield {
191- "prompt" : prompt_text ,
192- "prompt_tokens_count" : self .config .prefix_tokens + prompt_tokens ,
193- "output_tokens_count" : output_tokens ,
198+ for _ , turns in zip (range (self .config .samples ), turns_sampler ):
199+ row : SyntheticDatasetRow = {
200+ "prompt" : [],
201+ "prompt_tokens_count" : [],
202+ "output_tokens_count" : [],
194203 }
204+ for i , prompt_tokens , output_tokens in zip (
205+ range (turns ),
206+ prompt_tokens_sampler ,
207+ output_tokens_sampler ,
208+ ):
209+ start_index = rand .randint (0 , len (self .text_creator .words ))
210+ # Append the prefix tokens only for the first turn
211+ if i == 0 :
212+ prompt_text = self .processor .decode (
213+ prefix_tokens
214+ + self ._create_prompt (
215+ prompt_tokens , start_index , next (unique_prefix_iter )
216+ ),
217+ skip_special_tokens = True ,
218+ )
219+ row ["prompt" ].append (prompt_text )
220+ row ["prompt_tokens_count" ].append (self .config .prefix_tokens + prompt_tokens )
221+ row ["output_tokens_count" ].append (output_tokens )
222+ else :
223+ prompt_text = self .processor .decode (
224+ self ._create_prompt (
225+ prompt_tokens , start_index , next (unique_prefix_iter )
226+ ),
227+ skip_special_tokens = True ,
228+ )
229+ row ["prompt" ].append (prompt_text )
230+ row ["prompt_tokens_count" ].append (prompt_tokens )
231+ row ["output_tokens_count" ].append (output_tokens )
232+
233+ yield row
195234
196235 def _create_prompt (
197236 self , prompt_tokens : int , start_index : int , unique_prefix : Optional [int ] = None
0 commit comments