@@ -69,15 +69,17 @@ ov::Tensor extract_last_hidden_state(const ov::Tensor& hidden_features) {
6969}
7070
7171// Validate and set num_assistant_tokens
72- void ensure_num_assistant_tokens_is_set (ov::genai::GenerationConfig& config) {
72+ void ensure_num_assistant_tokens_is_set (ov::genai::GenerationConfig& config, bool allow_zero = false ) {
7373 OPENVINO_ASSERT (config.assistant_confidence_threshold == 0 .f ,
7474 " Stateful Speculative Decoding only supports num_assistant_tokens, "
7575 " not assistant_confidence_threshold. Set it to 0.f." );
7676
7777 constexpr std::size_t DEFAULT_NUM_ASSISTANT_TOKENS = 4 ;
78- if (config.num_assistant_tokens == 0 ) {
78+ // If num_assistant_tokens is not explicitly set (is 0) and zero is not allowed, use default
79+ if (config.num_assistant_tokens == 0 && !allow_zero) {
7980 config.num_assistant_tokens = DEFAULT_NUM_ASSISTANT_TOKENS;
8081 }
82+ // If num_assistant_tokens is 0 and allow_zero is true, keep it as 0 (disabled speculative decoding)
8183}
8284
8385} // anonymous namespace
@@ -745,9 +747,21 @@ void StatefulEagle3LLMPipeline::set_verbose(bool verbose) {
745747GenerationConfig StatefulEagle3LLMPipeline::resolve_generation_config (OptionalGenerationConfig generation_config) {
746748 GenerationConfig config = generation_config.value_or (m_generation_config);
747749
748- ensure_num_assistant_tokens_is_set (config);
750+ std::size_t prev_draft_iterations = m_draft_iterations;
751+ // Allow num_assistant_tokens to be 0 (disables speculative decoding)
752+ ensure_num_assistant_tokens_is_set (config, /* allow_zero=*/ true );
749753 m_draft_iterations = config.num_assistant_tokens ;
750754
755+ // Log if draft_iterations changed from default
756+ if (m_draft_iterations != prev_draft_iterations) {
757+ if (m_draft_iterations == 0 ) {
758+ log_info (" Speculative decoding DISABLED (num_assistant_tokens=0), using target model only" );
759+ } else if (is_verbose ()) {
760+ log_debug (" Draft iterations updated: " + std::to_string (prev_draft_iterations) + " -> " +
761+ std::to_string (m_draft_iterations));
762+ }
763+ }
764+
751765 if (config.stop_token_ids .empty ())
752766 config.stop_token_ids = m_generation_config.stop_token_ids ;
753767 if (config.eos_token_id == -1 )
@@ -935,68 +949,104 @@ EncodedResults StatefulEagle3LLMPipeline::generate(const EncodedInputs& inputs,
935949 log_debug (" Initial token generated: " + std::to_string (initial_token));
936950 log_sequence_state (" after initial token generation" );
937951
938- // Main generation loop using speculative decoding
952+ // Main generation loop
939953 std::size_t max_new_tokens = config.max_new_tokens ;
940954 std::size_t generated_tokens = 1 ; // Count initial token
941- std::size_t token_count = m_draft_model->get_sequence_length ();
942- auto target_hidden_states = main_hidden_features;
943955 bool eos_reached = false ;
944956
945957 // Track metrics for speculative decoding
946958 std::size_t total_draft_accepted = 0 ; // Number of draft tokens accepted by main model
947959 std::size_t total_draft_generated = 0 ; // Total draft tokens generated (including rejected)
948960 std::size_t total_iterations = 0 ; // Number of speculative iterations
949961
950- while (!eos_reached && generated_tokens < max_new_tokens &&
951- m_main_model->get_sequence_length () < prompt_len + max_new_tokens &&
952- (streaming_status == ov::genai::StreamingStatus::RUNNING)) {
953- log_generation_step (" Speculative Decoding Iteration" , generated_tokens);
954- log_sequence_state (" iteration start" );
955-
956- auto result =
957- run_speculative_iteration (target_hidden_states, token_count, static_cast <int64_t >(config.eos_token_id ));
962+ // Check if speculative decoding is disabled
963+ if (m_draft_iterations == 0 ) {
964+ // Standard autoregressive generation with target model only
965+ log_info (" Running standard autoregressive generation (no speculative decoding)" );
966+
967+ int64_t current_token = initial_token;
968+ while (!eos_reached && generated_tokens < max_new_tokens &&
969+ m_main_model->get_sequence_length () < prompt_len + max_new_tokens &&
970+ (streaming_status == ov::genai::StreamingStatus::RUNNING)) {
971+ // Check for EOS
972+ if (current_token == static_cast <int64_t >(config.eos_token_id )) {
973+ eos_reached = true ;
974+ log_debug (" EOS reached - terminating generation" );
975+ break ;
976+ }
958977
959- // Stream validated tokens
960- streaming_status = stream_generated_tokens (streamer_ptr, result.validated_tokens );
978+ // Generate next token using target model
979+ ov::Tensor next_input_ids, next_attention_mask, next_position_ids;
980+ m_main_model->build_model_inputs (1 , next_input_ids, next_attention_mask, next_position_ids);
961981
962- // Update iteration counter
963- total_iterations++ ;
982+ auto output = m_main_model-> infer (next_input_ids, next_attention_mask, next_position_ids);
983+ current_token = std::get< int64_t >(m_main_model-> sample_tokens (output. logits , 1 )) ;
964984
965- // Update draft token statistics
966- total_draft_generated += m_draft_iterations; // Each iteration generates m_draft_iterations draft tokens
967- total_draft_accepted +=
968- result.accepted_tokens_count ; // Number of draft tokens accepted (not including main model's token)
985+ m_main_model->append_tokens ({current_token});
969986
970- // Update metrics
971- if (result.new_token == static_cast <int64_t >(config.eos_token_id ) || result.eos_reached ) {
972- eos_reached = true ;
973- log_debug (" EOS reached - terminating generation" );
974- }
987+ // Stream the token
988+ streaming_status = stream_generated_tokens (streamer_ptr, std::vector<int64_t >{current_token});
975989
976- if (result.new_token != -1 ) {
977990 generated_tokens++;
978- log_debug (" Generated token " + std::to_string (generated_tokens) + " : " + std::to_string (result.new_token ) +
979- " , accepted " + std::to_string (result.accepted_tokens_count ) + " draft tokens out of " +
980- std::to_string (m_draft_iterations));
991+ log_debug (" Generated token " + std::to_string (generated_tokens) + " : " + std::to_string (current_token));
981992 }
993+ } else {
994+ // Speculative decoding loop
995+ std::size_t token_count = m_draft_model->get_sequence_length ();
996+ auto target_hidden_states = main_hidden_features;
997+
998+ while (!eos_reached && generated_tokens < max_new_tokens &&
999+ m_main_model->get_sequence_length () < prompt_len + max_new_tokens &&
1000+ (streaming_status == ov::genai::StreamingStatus::RUNNING)) {
1001+ log_generation_step (" Speculative Decoding Iteration" , generated_tokens);
1002+ log_sequence_state (" iteration start" );
1003+
1004+ auto result =
1005+ run_speculative_iteration (target_hidden_states, token_count, static_cast <int64_t >(config.eos_token_id ));
1006+
1007+ // Stream validated tokens
1008+ streaming_status = stream_generated_tokens (streamer_ptr, result.validated_tokens );
1009+
1010+ // Update iteration counter
1011+ total_iterations++;
1012+
1013+ // Update draft token statistics
1014+ total_draft_generated += m_draft_iterations; // Each iteration generates m_draft_iterations draft tokens
1015+ total_draft_accepted +=
1016+ result.accepted_tokens_count ; // Number of draft tokens accepted (not including main model's token)
1017+
1018+ // Update metrics
1019+ if (result.new_token == static_cast <int64_t >(config.eos_token_id ) || result.eos_reached ) {
1020+ eos_reached = true ;
1021+ log_debug (" EOS reached - terminating generation" );
1022+ }
9821023
983- // Prepare for next iteration
984- token_count = result.next_window_size > 0 ? result.next_window_size
985- : std::min<std::size_t >(1 , m_main_model->get_sequence_length ());
986- target_hidden_states =
987- result.next_hidden_window ? result.next_hidden_window : m_main_model->get_hidden_features ();
1024+ if (result.new_token != -1 ) {
1025+ generated_tokens++;
1026+ log_debug (" Generated token " + std::to_string (generated_tokens) + " : " +
1027+ std::to_string (result.new_token ) + " , accepted " +
1028+ std::to_string (result.accepted_tokens_count ) + " draft tokens out of " +
1029+ std::to_string (m_draft_iterations));
1030+ }
9881031
989- log_debug (" Next iteration: token_count=" + std::to_string (token_count) + " , hidden_states_size=" +
990- (target_hidden_states ? std::to_string (target_hidden_states.get_size ()) : " 0" ));
1032+ // Prepare for next iteration
1033+ token_count = result.next_window_size > 0 ? result.next_window_size
1034+ : std::min<std::size_t >(1 , m_main_model->get_sequence_length ());
1035+ target_hidden_states =
1036+ result.next_hidden_window ? result.next_hidden_window : m_main_model->get_hidden_features ();
9911037
992- // Safety check to prevent infinite loops
993- if (result.next_window_size == 0 && result.new_token == -1 ) {
994- log_debug (" No progress made, terminating generation" );
995- break ;
996- }
1038+ log_debug (" Next iteration: token_count=" + std::to_string (token_count) + " , hidden_states_size=" +
1039+ (target_hidden_states ? std::to_string (target_hidden_states.get_size ()) : " 0" ));
9971040
998- log_sequence_state (" iteration end" );
999- }
1041+ // Safety check to prevent infinite loops
1042+ if (result.next_window_size == 0 && result.new_token == -1 ) {
1043+ log_debug (" No progress made, terminating generation" );
1044+ break ;
1045+ }
1046+
1047+ log_sequence_state (" iteration end" );
1048+ }
1049+ } // End of speculative decoding / standard generation branch
10001050
10011051 m_streaming_was_cancelled = (streaming_status == ov::genai::StreamingStatus::CANCEL);
10021052 if (streamer_ptr) { // push streamer's cache
@@ -1220,7 +1270,7 @@ StatefulEagle3LLMPipeline::SpeculativeResult StatefulEagle3LLMPipeline::run_spec
12201270 result.next_window_size = total_accepted_tokens; // Total tokens for next iteration (draft + main)
12211271 result.new_token = main_predicted_token; // Main model's prediction
12221272 result.next_hidden_window = next_hidden;
1223- result.validated_tokens = tokens_to_append; // Return validated tokens for streaming
1273+ result.validated_tokens = tokens_to_append; // Return validated tokens for streaming
12241274
12251275 log_debug (" Speculative iteration completed - accepted " + std::to_string (accepted_count) +
12261276 " draft tokens + 1 main prediction = " + std::to_string (total_accepted_tokens) +
@@ -1306,8 +1356,7 @@ void StatefulEagle3LLMPipeline::log_sequence_state(const std::string& context) c
13061356 std::cout << " [EAGLE3-PIPELINE] Sequence state (" << context << " ):" << std::endl;
13071357 std::cout << " Prompt length: " << m_prompt_length << " tokens" << std::endl;
13081358 std::cout << " Main model tokens: " << m_main_model->get_sequence_length () << " tokens" << std::endl;
1309- std::cout << " Draft model tokens: " << m_draft_model->get_sequence_length () << " tokens"
1310- << std::endl;
1359+ std::cout << " Draft model tokens: " << m_draft_model->get_sequence_length () << " tokens" << std::endl;
13111360
13121361 // Show all tokens and positions from main model
13131362 const auto & main_tokens = m_main_model->get_tokens ();
0 commit comments