-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[None] [feat] add eos_token_id in generation_config to sampling params #9514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
bd0da06
223b6a3
654428b
380a6b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -368,14 +368,6 @@ def _setup( | |||||||||||
| if self.end_id is None: | ||||||||||||
| self.end_id = tokenizer.eos_token_id | ||||||||||||
| self.pad_id = tokenizer.pad_token_id | ||||||||||||
| # kimi_k2 model uses the eos_token_id in generation config | ||||||||||||
| if ( | ||||||||||||
| hf_model_config is not None | ||||||||||||
| and hf_model_config.model_type == "kimi_k2" | ||||||||||||
| and generation_config is not None | ||||||||||||
| and isinstance(generation_config.eos_token_id, int) | ||||||||||||
| ): | ||||||||||||
| self.end_id = generation_config.eos_token_id | ||||||||||||
|
|
||||||||||||
| if self.pad_id is None: | ||||||||||||
| self.pad_id = self.end_id | ||||||||||||
|
|
@@ -395,24 +387,26 @@ def _encode(tokenizer, text, add_special_tokens): | |||||||||||
| strs = [self.stop] if isinstance(self.stop, str) else self.stop | ||||||||||||
| self._stop_word_ids = [_encode(tokenizer, s, add_special_tokens) for s in strs] | ||||||||||||
|
|
||||||||||||
| # add generation_config to stop word list, only in qwen3-next now | ||||||||||||
| if ( | ||||||||||||
| hf_model_config is not None | ||||||||||||
| and hf_model_config.model_type == "qwen3_next" | ||||||||||||
| and generation_config is not None | ||||||||||||
| and isinstance(generation_config.eos_token_id, List) | ||||||||||||
| and all(isinstance(i, int) for i in generation_config.eos_token_id) | ||||||||||||
| ): | ||||||||||||
| if self._stop_word_ids: | ||||||||||||
| # Add eos_token_id in generation_config to _stop_word_ids | ||||||||||||
| # Refer to https://huggingface.co/docs/hub/en/transformers#transformers-repository-files and | ||||||||||||
| # https://github.com/huggingface/transformers/blob/1ae4d917ed3badbdb1ffc167e0529f5a6d3c080d/src/transformers/generation/stopping_criteria.py#L451C1-L451C42 | ||||||||||||
| # The eos_token_id in generation_config are really mean to stop the text generation. | ||||||||||||
| if generation_config is not None and generation_config.eos_token_id is not None: | ||||||||||||
| if isinstance(generation_config.eos_token_id, int): | ||||||||||||
| generation_eos_token_ids = [generation_config.eos_token_id] | ||||||||||||
| else: # always List[int] | ||||||||||||
| generation_eos_token_ids = generation_config.eos_token_id | ||||||||||||
|
|
||||||||||||
| if self._stop_word_ids is None: | ||||||||||||
| self._stop_word_ids = [generation_eos_token_ids] | ||||||||||||
| else: | ||||||||||||
| all_stop_tokens_id = set(i for sublist in self._stop_word_ids for i in sublist) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use plural variable name for consistency. The variable Apply this diff: - all_stop_tokens_id = set(i for sublist in self._stop_word_ids for i in sublist)
+ all_stop_token_ids = set(i for sublist in self._stop_word_ids for i in sublist)
from_generation_stop_token_ids = [
- i for i in generation_eos_token_ids if i not in all_stop_tokens_id
+ i for i in generation_eos_token_ids if i not in all_stop_token_ids
]📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||
| from_generation_stop_tokens = [ | ||||||||||||
| i for i in generation_config.eos_token_id if i not in all_stop_tokens_id | ||||||||||||
| from_generation_stop_token_ids = [ | ||||||||||||
| i for i in generation_eos_token_ids if i not in all_stop_tokens_id | ||||||||||||
| ] | ||||||||||||
|
|
||||||||||||
| if from_generation_stop_tokens: | ||||||||||||
| self._stop_word_ids.append(from_generation_stop_tokens) | ||||||||||||
| else: | ||||||||||||
| self._stop_word_ids = [generation_config.eos_token_id] | ||||||||||||
| if from_generation_stop_token_ids: | ||||||||||||
| self._stop_word_ids.append(from_generation_stop_token_ids) | ||||||||||||
|
|
||||||||||||
| return self | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix grammatical error in comment.
Line 393 contains a grammatical error: "are really mean to stop" should be "really means to stop" or "are really meant to stop".
Apply this diff:
🤖 Prompt for AI Agents