Skip to content

Commit ff7728e

Browse files
Support zero assistant number
1 parent 0b2c7ed commit ff7728e

File tree

1 file changed

+97
-48
lines changed

1 file changed

+97
-48
lines changed

src/cpp/src/speculative_decoding/speculative_decoding_stateful_eagle3.cpp

Lines changed: 97 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,17 @@ ov::Tensor extract_last_hidden_state(const ov::Tensor& hidden_features) {
6969
}
7070

7171
// Validate and set num_assistant_tokens
72-
void ensure_num_assistant_tokens_is_set(ov::genai::GenerationConfig& config) {
72+
void ensure_num_assistant_tokens_is_set(ov::genai::GenerationConfig& config, bool allow_zero = false) {
7373
OPENVINO_ASSERT(config.assistant_confidence_threshold == 0.f,
7474
"Stateful Speculative Decoding only supports num_assistant_tokens, "
7575
"not assistant_confidence_threshold. Set it to 0.f.");
7676

7777
constexpr std::size_t DEFAULT_NUM_ASSISTANT_TOKENS = 4;
78-
if (config.num_assistant_tokens == 0) {
78+
// If num_assistant_tokens is not explicitly set (is 0) and zero is not allowed, use default
79+
if (config.num_assistant_tokens == 0 && !allow_zero) {
7980
config.num_assistant_tokens = DEFAULT_NUM_ASSISTANT_TOKENS;
8081
}
82+
// If num_assistant_tokens is 0 and allow_zero is true, keep it as 0 (disabled speculative decoding)
8183
}
8284

8385
} // anonymous namespace
@@ -745,9 +747,21 @@ void StatefulEagle3LLMPipeline::set_verbose(bool verbose) {
745747
GenerationConfig StatefulEagle3LLMPipeline::resolve_generation_config(OptionalGenerationConfig generation_config) {
746748
GenerationConfig config = generation_config.value_or(m_generation_config);
747749

748-
ensure_num_assistant_tokens_is_set(config);
750+
std::size_t prev_draft_iterations = m_draft_iterations;
751+
// Allow num_assistant_tokens to be 0 (disables speculative decoding)
752+
ensure_num_assistant_tokens_is_set(config, /*allow_zero=*/true);
749753
m_draft_iterations = config.num_assistant_tokens;
750754

755+
// Log if draft_iterations changed from default
756+
if (m_draft_iterations != prev_draft_iterations) {
757+
if (m_draft_iterations == 0) {
758+
log_info("Speculative decoding DISABLED (num_assistant_tokens=0), using target model only");
759+
} else if (is_verbose()) {
760+
log_debug("Draft iterations updated: " + std::to_string(prev_draft_iterations) + " -> " +
761+
std::to_string(m_draft_iterations));
762+
}
763+
}
764+
751765
if (config.stop_token_ids.empty())
752766
config.stop_token_ids = m_generation_config.stop_token_ids;
753767
if (config.eos_token_id == -1)
@@ -935,68 +949,104 @@ EncodedResults StatefulEagle3LLMPipeline::generate(const EncodedInputs& inputs,
935949
log_debug("Initial token generated: " + std::to_string(initial_token));
936950
log_sequence_state("after initial token generation");
937951

938-
// Main generation loop using speculative decoding
952+
// Main generation loop
939953
std::size_t max_new_tokens = config.max_new_tokens;
940954
std::size_t generated_tokens = 1; // Count initial token
941-
std::size_t token_count = m_draft_model->get_sequence_length();
942-
auto target_hidden_states = main_hidden_features;
943955
bool eos_reached = false;
944956

945957
// Track metrics for speculative decoding
946958
std::size_t total_draft_accepted = 0; // Number of draft tokens accepted by main model
947959
std::size_t total_draft_generated = 0; // Total draft tokens generated (including rejected)
948960
std::size_t total_iterations = 0; // Number of speculative iterations
949961

950-
while (!eos_reached && generated_tokens < max_new_tokens &&
951-
m_main_model->get_sequence_length() < prompt_len + max_new_tokens &&
952-
(streaming_status == ov::genai::StreamingStatus::RUNNING)) {
953-
log_generation_step("Speculative Decoding Iteration", generated_tokens);
954-
log_sequence_state("iteration start");
955-
956-
auto result =
957-
run_speculative_iteration(target_hidden_states, token_count, static_cast<int64_t>(config.eos_token_id));
962+
// Check if speculative decoding is disabled
963+
if (m_draft_iterations == 0) {
964+
// Standard autoregressive generation with target model only
965+
log_info("Running standard autoregressive generation (no speculative decoding)");
966+
967+
int64_t current_token = initial_token;
968+
while (!eos_reached && generated_tokens < max_new_tokens &&
969+
m_main_model->get_sequence_length() < prompt_len + max_new_tokens &&
970+
(streaming_status == ov::genai::StreamingStatus::RUNNING)) {
971+
// Check for EOS
972+
if (current_token == static_cast<int64_t>(config.eos_token_id)) {
973+
eos_reached = true;
974+
log_debug("EOS reached - terminating generation");
975+
break;
976+
}
958977

959-
// Stream validated tokens
960-
streaming_status = stream_generated_tokens(streamer_ptr, result.validated_tokens);
978+
// Generate next token using target model
979+
ov::Tensor next_input_ids, next_attention_mask, next_position_ids;
980+
m_main_model->build_model_inputs(1, next_input_ids, next_attention_mask, next_position_ids);
961981

962-
// Update iteration counter
963-
total_iterations++;
982+
auto output = m_main_model->infer(next_input_ids, next_attention_mask, next_position_ids);
983+
current_token = std::get<int64_t>(m_main_model->sample_tokens(output.logits, 1));
964984

965-
// Update draft token statistics
966-
total_draft_generated += m_draft_iterations; // Each iteration generates m_draft_iterations draft tokens
967-
total_draft_accepted +=
968-
result.accepted_tokens_count; // Number of draft tokens accepted (not including main model's token)
985+
m_main_model->append_tokens({current_token});
969986

970-
// Update metrics
971-
if (result.new_token == static_cast<int64_t>(config.eos_token_id) || result.eos_reached) {
972-
eos_reached = true;
973-
log_debug("EOS reached - terminating generation");
974-
}
987+
// Stream the token
988+
streaming_status = stream_generated_tokens(streamer_ptr, std::vector<int64_t>{current_token});
975989

976-
if (result.new_token != -1) {
977990
generated_tokens++;
978-
log_debug("Generated token " + std::to_string(generated_tokens) + ": " + std::to_string(result.new_token) +
979-
", accepted " + std::to_string(result.accepted_tokens_count) + " draft tokens out of " +
980-
std::to_string(m_draft_iterations));
991+
log_debug("Generated token " + std::to_string(generated_tokens) + ": " + std::to_string(current_token));
981992
}
993+
} else {
994+
// Speculative decoding loop
995+
std::size_t token_count = m_draft_model->get_sequence_length();
996+
auto target_hidden_states = main_hidden_features;
997+
998+
while (!eos_reached && generated_tokens < max_new_tokens &&
999+
m_main_model->get_sequence_length() < prompt_len + max_new_tokens &&
1000+
(streaming_status == ov::genai::StreamingStatus::RUNNING)) {
1001+
log_generation_step("Speculative Decoding Iteration", generated_tokens);
1002+
log_sequence_state("iteration start");
1003+
1004+
auto result =
1005+
run_speculative_iteration(target_hidden_states, token_count, static_cast<int64_t>(config.eos_token_id));
1006+
1007+
// Stream validated tokens
1008+
streaming_status = stream_generated_tokens(streamer_ptr, result.validated_tokens);
1009+
1010+
// Update iteration counter
1011+
total_iterations++;
1012+
1013+
// Update draft token statistics
1014+
total_draft_generated += m_draft_iterations; // Each iteration generates m_draft_iterations draft tokens
1015+
total_draft_accepted +=
1016+
result.accepted_tokens_count; // Number of draft tokens accepted (not including main model's token)
1017+
1018+
// Update metrics
1019+
if (result.new_token == static_cast<int64_t>(config.eos_token_id) || result.eos_reached) {
1020+
eos_reached = true;
1021+
log_debug("EOS reached - terminating generation");
1022+
}
9821023

983-
// Prepare for next iteration
984-
token_count = result.next_window_size > 0 ? result.next_window_size
985-
: std::min<std::size_t>(1, m_main_model->get_sequence_length());
986-
target_hidden_states =
987-
result.next_hidden_window ? result.next_hidden_window : m_main_model->get_hidden_features();
1024+
if (result.new_token != -1) {
1025+
generated_tokens++;
1026+
log_debug("Generated token " + std::to_string(generated_tokens) + ": " +
1027+
std::to_string(result.new_token) + ", accepted " +
1028+
std::to_string(result.accepted_tokens_count) + " draft tokens out of " +
1029+
std::to_string(m_draft_iterations));
1030+
}
9881031

989-
log_debug("Next iteration: token_count=" + std::to_string(token_count) + ", hidden_states_size=" +
990-
(target_hidden_states ? std::to_string(target_hidden_states.get_size()) : "0"));
1032+
// Prepare for next iteration
1033+
token_count = result.next_window_size > 0 ? result.next_window_size
1034+
: std::min<std::size_t>(1, m_main_model->get_sequence_length());
1035+
target_hidden_states =
1036+
result.next_hidden_window ? result.next_hidden_window : m_main_model->get_hidden_features();
9911037

992-
// Safety check to prevent infinite loops
993-
if (result.next_window_size == 0 && result.new_token == -1) {
994-
log_debug("No progress made, terminating generation");
995-
break;
996-
}
1038+
log_debug("Next iteration: token_count=" + std::to_string(token_count) + ", hidden_states_size=" +
1039+
(target_hidden_states ? std::to_string(target_hidden_states.get_size()) : "0"));
9971040

998-
log_sequence_state("iteration end");
999-
}
1041+
// Safety check to prevent infinite loops
1042+
if (result.next_window_size == 0 && result.new_token == -1) {
1043+
log_debug("No progress made, terminating generation");
1044+
break;
1045+
}
1046+
1047+
log_sequence_state("iteration end");
1048+
}
1049+
} // End of speculative decoding / standard generation branch
10001050

10011051
m_streaming_was_cancelled = (streaming_status == ov::genai::StreamingStatus::CANCEL);
10021052
if (streamer_ptr) { // push streamer's cache
@@ -1220,7 +1270,7 @@ StatefulEagle3LLMPipeline::SpeculativeResult StatefulEagle3LLMPipeline::run_spec
12201270
result.next_window_size = total_accepted_tokens; // Total tokens for next iteration (draft + main)
12211271
result.new_token = main_predicted_token; // Main model's prediction
12221272
result.next_hidden_window = next_hidden;
1223-
result.validated_tokens = tokens_to_append; // Return validated tokens for streaming
1273+
result.validated_tokens = tokens_to_append; // Return validated tokens for streaming
12241274

12251275
log_debug("Speculative iteration completed - accepted " + std::to_string(accepted_count) +
12261276
" draft tokens + 1 main prediction = " + std::to_string(total_accepted_tokens) +
@@ -1306,8 +1356,7 @@ void StatefulEagle3LLMPipeline::log_sequence_state(const std::string& context) c
13061356
std::cout << "[EAGLE3-PIPELINE] Sequence state (" << context << "):" << std::endl;
13071357
std::cout << " Prompt length: " << m_prompt_length << " tokens" << std::endl;
13081358
std::cout << " Main model tokens: " << m_main_model->get_sequence_length() << " tokens" << std::endl;
1309-
std::cout << " Draft model tokens: " << m_draft_model->get_sequence_length() << " tokens"
1310-
<< std::endl;
1359+
std::cout << " Draft model tokens: " << m_draft_model->get_sequence_length() << " tokens" << std::endl;
13111360

13121361
// Show all tokens and positions from main model
13131362
const auto& main_tokens = m_main_model->get_tokens();

0 commit comments

Comments
 (0)