|
1 | 1 | import json |
2 | 2 | import random |
3 | 3 | from collections.abc import Iterable, Iterator |
| 4 | +from itertools import cycle |
4 | 5 | from pathlib import Path |
5 | 6 | from typing import Any, Literal, Optional, Union |
6 | 7 |
|
|
25 | 26 |
|
26 | 27 |
|
27 | 28 | class SyntheticDatasetConfig(BaseModel): |
| 29 | + prefix_tokens: int = Field( |
| 30 | + description="The number of shared prefix tokens to prepend to each prompt.", |
| 31 | + ge=0, |
| 32 | + default=0, |
| 33 | + ) |
28 | 34 | prompt_tokens: int = Field( |
29 | 35 | description="The average number of text tokens generated for prompts.", |
30 | 36 | gt=0, |
@@ -163,39 +169,54 @@ def __iter__( |
163 | 169 | ) |
164 | 170 | # ensure diff distribution from output tokens |
165 | 171 | rand = random.Random(self.random_seed + 2) # noqa: S311 |
| 172 | + unique_prefix_iter = cycle(self.processor.get_vocab().values()) |
| 173 | + |
| 174 | + prefix_index = rand.randint(0, len(self.text_creator.words)) |
| 175 | + prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index) |
166 | 176 |
|
167 | 177 | for _, prompt_tokens, output_tokens in zip( |
168 | 178 | range(self.config.samples), |
169 | 179 | prompt_tokens_sampler, |
170 | 180 | output_tokens_sampler, |
171 | 181 | ): |
172 | 182 | start_index = rand.randint(0, len(self.text_creator.words)) |
| 183 | + prompt_text = self.processor.decode( |
| 184 | + prefix_tokens |
| 185 | + + self._create_prompt( |
| 186 | + prompt_tokens, start_index, next(unique_prefix_iter) |
| 187 | + ), |
| 188 | + skip_special_tokens=True, |
| 189 | + ) |
173 | 190 | yield { |
174 | | - "prompt": self._create_prompt(prompt_tokens, start_index), |
175 | | - "prompt_tokens_count": prompt_tokens, |
| 191 | + "prompt": prompt_text, |
| 192 | + "prompt_tokens_count": self.config.prefix_tokens + prompt_tokens, |
176 | 193 | "output_tokens_count": output_tokens, |
177 | 194 | } |
178 | 195 |
|
179 | | - def _create_prompt(self, prompt_tokens: int, start_index: int) -> str: |
| 196 | + def _create_prompt( |
| 197 | + self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None |
| 198 | + ) -> list[int]: |
180 | 199 | if prompt_tokens <= 0: |
181 | | - return "" |
| 200 | + return [] |
182 | 201 |
|
183 | 202 | left = start_index |
184 | 203 | right = start_index + 4 * prompt_tokens |
| 204 | + start_tokens = [unique_prefix] if unique_prefix else [] |
185 | 205 |
|
186 | 206 | while left < right: |
187 | 207 | mid = (left + right) // 2 |
188 | 208 | test_prompt = self.text_creator.create_text(start_index, mid - start_index) |
189 | | - test_tokens = len(self.processor.tokenize(test_prompt)) |
| 209 | + test_tokens = start_tokens + self.processor.encode(test_prompt) |
190 | 210 |
|
191 | | - if test_tokens == prompt_tokens: |
192 | | - return test_prompt |
193 | | - elif test_tokens < prompt_tokens: |
| 211 | + if len(test_tokens) == prompt_tokens: |
| 212 | + return test_tokens |
| 213 | + elif len(test_tokens) < prompt_tokens: |
194 | 214 | left = mid + 1 |
195 | 215 | else: |
196 | 216 | right = mid |
197 | 217 |
|
198 | | - return self.text_creator.create_text(start_index, left - start_index) |
| 218 | + final_text = self.text_creator.create_text(start_index, left - start_index) |
| 219 | + return start_tokens + self.processor.encode(final_text) |
199 | 220 |
|
200 | 221 |
|
201 | 222 | class SyntheticDatasetCreator(DatasetCreator): |
|
0 commit comments