diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 7ec5890f3c..3ec5ad056f 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -534,7 +534,7 @@ jobs: run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).tokenizers.test }} timeout: 60 - name: 'API tests' - cmd: 'python -m pytest -v ./tests/python_tests/test_continuous_batching.py ./tests/python_tests/test_generation_config.py ./tests/python_tests/test_sampling.py ./tests/python_tests/test_text_streamer.py' + cmd: 'python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "not eagle3" ./tests/python_tests/test_generation_config.py ./tests/python_tests/test_sampling.py ./tests/python_tests/test_text_streamer.py' run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test || fromJSON(needs.smart_ci.outputs.affected_components).sampling.test || fromJSON(needs.smart_ci.outputs.affected_components).text_streamer.test }} timeout: 60 - name: 'Rag tests' @@ -551,6 +551,12 @@ jobs: python -m pytest -v ./tools/who_what_benchmark/tests -m nanollava run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).WWB.test }} timeout: 90 + - name: 'EAGLE3 speculative decoding tests' + cmd: | + python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@ea9607daf32919024cdd4390deec9693a7b64d23 + python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" + run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).speculative_decoding.test }} + timeout: 90 defaults: run: shell: bash diff --git a/.github/workflows/manylinux_2_28.yml b/.github/workflows/manylinux_2_28.yml index 97b9e6291d..72dc3aa55b 100644 --- a/.github/workflows/manylinux_2_28.yml +++ b/.github/workflows/manylinux_2_28.yml @@ -472,7 +472,7 @@ jobs: run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).tokenizers.test }} timeout: 60 - name: 'API tests' - cmd: 'python -m pytest -v ./tests/python_tests/test_continuous_batching.py ./tests/python_tests/test_generation_config.py ./tests/python_tests/test_sampling.py ./tests/python_tests/test_text_streamer.py' + cmd: 'python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "not eagle3" ./tests/python_tests/test_generation_config.py ./tests/python_tests/test_sampling.py ./tests/python_tests/test_text_streamer.py' run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test || fromJSON(needs.smart_ci.outputs.affected_components).sampling.test || fromJSON(needs.smart_ci.outputs.affected_components).text_streamer.test }} timeout: 60 - name: 'Rag tests' @@ -489,6 +489,12 @@ jobs: python -m pytest -v ./tools/who_what_benchmark/tests -m nanollava run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).WWB.test }} timeout: 90 + - name: 'EAGLE3 speculative decoding tests' + cmd: | + python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@ea9607daf32919024cdd4390deec9693a7b64d23 + python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" + run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).speculative_decoding.test }} + timeout: 90 defaults: run: shell: bash diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index a85ac74351..64ac66ffd0 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -623,7 +623,7 @@ jobs: run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).tokenizers.test }} timeout: 60 - name: 'API tests' - cmd: 'python -m pytest -s -v tests/python_tests/test_continuous_batching.py tests/python_tests/test_generation_config.py tests/python_tests/test_sampling.py tests/python_tests/test_text_streamer.py' + cmd: 'python -m pytest -s -v tests/python_tests/test_continuous_batching.py -k "not eagle3" tests/python_tests/test_generation_config.py tests/python_tests/test_sampling.py tests/python_tests/test_text_streamer.py' run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test || fromJSON(needs.smart_ci.outputs.affected_components).sampling.test || fromJSON(needs.smart_ci.outputs.affected_components).text_streamer.test }} timeout: 60 - name: 'Rag tests' @@ -640,6 +640,12 @@ jobs: python -m pytest -v ./tools/who_what_benchmark/tests -m nanollava run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).WWB.test }} timeout: 90 + - name: 'EAGLE3 speculative decoding tests' + cmd: | + python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@ea9607daf32919024cdd4390deec9693a7b64d23 + python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" + run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).speculative_decoding.test }} + timeout: 90 defaults: run: shell: pwsh diff --git a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp index 6a5486e591..67a522d2fc 100644 --- a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp +++ b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp @@ -65,13 +65,18 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline { class ContinuousBatchingImpl; class ContinuousBatchingForSpeculativeDecodingImpl; + class ContinuousBatchingForEagle3DecodingImpl; class ContinuousBatchingForPromptLookupImpl; class SpeculativeDecodingImpl; + class Eagle3DecodingImpl; class PromptLookupImpl; friend class ContinuousBatchingForSpeculativeDecodingImpl; + friend class ContinuousBatchingForPromptLookupImpl; + friend class ContinuousBatchingForEagle3DecodingImpl; friend class SpeculativeDecodingImpl; + friend class Eagle3DecodingImpl; friend class PromptLookupImpl; std::shared_ptr m_impl; diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index 2d8e814922..fded073f97 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -24,6 +24,27 @@ inline std::string get_paged_attention_score_output_for_decoder_layer(size_t dec return ss.str(); } +enum HiddenStateFlags : uint8_t { + HS_NONE = 0, + HS_EXPORT = 1 << 0, + HS_IMPORT = 1 << 1, + HS_INTERNAL = 1 << 2 +}; + +struct SequenceKey { + size_t request_id{}; + size_t grouped_sequence_id{}; + bool operator<(const SequenceKey& other) const { + return std::tie(request_id, grouped_sequence_id) < + std::tie(other.request_id, other.grouped_sequence_id); + } +}; + +struct HiddenStateRange { + size_t start_token_idx{}; + size_t length{}; +}; + /** * @brief Runs the LLM infer request, parsing the continuous batching scheduler output into proper inputs in terms of OV API (e.g. token input IDs, * KV cache block indices etc.) and returning the logit scores for the next token to be generated for each of the currently scheduled sequences. @@ -48,6 +69,11 @@ class ModelRunner { // Input shape: [N, conversation length]. // Output shape: [1, conversation length, hidden_size]. EmbeddingsModel::Ptr m_embedding; + uint8_t m_hidden_state_flags = HS_NONE; + std::map m_sequence_hidden_state_mapping; + // a container which use sequence group id and request id as key to store hidden states + std::map m_initial_hidden_states; // shape: [N, seq_len, hidden_size] + size_t m_adjust_factor = 1; // to adjust the hidden size of draft model input std::shared_ptr m_inputs_embedder; @@ -107,11 +133,19 @@ class ModelRunner { return m_request; } + void enable_hidden_state_export(bool on) { on ? m_hidden_state_flags |= HS_EXPORT : m_hidden_state_flags &= ~HS_EXPORT; } + void enable_hidden_state_import(bool on) { on ? m_hidden_state_flags |= HS_IMPORT : m_hidden_state_flags &= ~HS_IMPORT; } + void enable_hidden_state_internal(bool on) { on ? m_hidden_state_flags |= HS_INTERNAL : m_hidden_state_flags &= ~HS_INTERNAL; } + void set_inputs_embedder(const std::shared_ptr& inputs_embedder) { m_inputs_embedder = inputs_embedder; m_embedding = inputs_embedder->get_embedding_model(); } + void set_adjust_factor(size_t adjust_factor) { + m_adjust_factor = adjust_factor; + } + /** * @return A map of sequence IDs to vectors of ov::Tensor per-token attention scores. Each vector element is associated with its own * decoder layer, in order of their execution in the model. Each ov::Tensor has a shape of {N_k}, where N_k is the length of @@ -134,6 +168,10 @@ class ModelRunner { m_cache_rotation_deltas_for_each_layer = std::move(rotation_deltas_for_each_layer); } + void set_initial_hidden_state(uint64_t request_id, const ov::Tensor& hidden_state) { + m_initial_hidden_states[request_id] = hidden_state; + } + /** * Runs the forward inference call on the underlying LLM's ov::InferRequest, scheduling for inferencing tokens for given sequences * taking into account the supplied scheduler output struct. @@ -142,6 +180,7 @@ class ModelRunner { * @return An ov::Tensor with next-token logit scores for each sequence processed during this `forward` call. */ ov::Tensor forward(const std::vector & sequence_groups, const Scheduler::Output& scheduler_output) { + m_sequence_hidden_state_mapping.clear(); size_t num_sequence_groups = scheduler_output.m_scheduled_sequence_groups_ids.size(); size_t batch_size_in_sequences = 0; @@ -185,6 +224,12 @@ class ModelRunner { ov::Tensor score_aggregation_window = _get_or_resize_tensor(m_cached_score_aggregation_window, "score_aggregation_window", {batch_size_in_sequences}, ov::element::i32); + ov::Tensor hidden_state_input = _prepare_hidden_state_input(total_num_tokens, hidden_size); + float* hidden_state_data = nullptr; + if (hidden_state_input) { + hidden_state_data = hidden_state_input.data(); + } + ov::Tensor generated_ids_embeds; float *generated_ids_embeds_data = nullptr; @@ -234,6 +279,7 @@ class ModelRunner { matmul_gathering_is_available = true; } catch (const ov::Exception&) {} + size_t current_token_idx = 0; std::map> seq_id_to_skipped_blocks_map; size_t position_ids_idx = 0; for (size_t i = 0; i < num_sequence_groups; ++i) { @@ -265,6 +311,64 @@ class ModelRunner { output_seq_len = 0; Sequence::CPtr sequence = running_sequences[seq_idx]; + if (_is_hs_export()) { + size_t start_token_idx = current_token_idx; + size_t sequence_length = num_scheduled_tokens; + + SequenceKey key{sequence_group->get_request_id(), sequence->get_grouped_id()}; + m_sequence_hidden_state_mapping[key] = HiddenStateRange{start_token_idx, sequence_length}; + } + if (_is_hs_import()) { + auto it = m_initial_hidden_states.find(sequence_group->get_request_id()); + + if (it != m_initial_hidden_states.end()) { + const auto& stored_hidden_state = it->second; + + if (stored_hidden_state.get_size() > 0) { + auto stored_shape = stored_hidden_state.get_shape(); + + if (stored_shape.size() >= 2) { + size_t stored_seq_len = stored_shape[0]; + size_t stored_hidden_size = stored_shape[stored_shape.size() - 1]; + + if (stored_hidden_size == hidden_size) { + if (stored_seq_len == total_num_tokens) { + hidden_state_input = stored_hidden_state; // all tokens from eagle are accepted + } else { + size_t copy_length = std::min(stored_seq_len, num_scheduled_tokens); + + size_t source_start_idx = + stored_seq_len >= copy_length ? stored_seq_len - copy_length : 0; + _copy_roi_between_tensors(stored_hidden_state, source_start_idx, copy_length, hidden_state_input, current_token_idx); + } + } + } + } else { + OPENVINO_ASSERT(false, "missing hidden state from target model to eagle draft model"); + } + } + } else if (_is_hs_internal()) { + // fill hidden_state_data with m_hidden_states + if (hidden_state_data) { + OPENVINO_ASSERT(num_scheduled_tokens == 1, "unexpected num_scheduled_tokens in speculative drafting stage in eagle3 mode"); + std::memset(hidden_state_data + current_token_idx * hidden_size, + 0, + num_scheduled_tokens * hidden_size * sizeof(float)); + auto hidden_state = running_sequences[seq_idx]->get_hidden_state(); + if (hidden_state.get_size() > 0) { + auto shape = hidden_state.get_shape(); + if (shape.size() >= 2 && shape[shape.size() - 1] == hidden_size) { + size_t seq_len = shape[0]; + size_t copy_length = std::min(seq_len, num_scheduled_tokens); + + size_t src_start_idx = seq_len >= copy_length ? seq_len - copy_length : 0; + auto target_shape = ov::Shape{num_scheduled_tokens, 1, hidden_size}; + ov::Tensor target_base(ov::element::f32, target_shape, hidden_state_data + current_token_idx * hidden_size); + _copy_roi_between_tensors(hidden_state, src_start_idx, copy_length, target_base, 0); + } + } + } + } for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) { // compute token for current sequence if (sequence_group_type == SequenceGroupType::TOKENS) { @@ -343,6 +447,7 @@ class ModelRunner { *score_aggregation_window_data = 1; } } + current_token_idx += num_scheduled_tokens; past_lens_data += 1; subsequence_begins_data += 1; block_indices_begins_data += 1; @@ -367,6 +472,31 @@ class ModelRunner { m_request.set_tensor("token_type_ids", token_type_ids); } } + if (hidden_state_input && hidden_state_input.get_size() > 0) { + if (_is_hs_import()) { + try { + m_request.set_tensor("hidden_states", hidden_state_input); + auto shape = hidden_state_input.get_shape(); + shape[shape.size() - 1] = shape[shape.size() - 1] / m_adjust_factor; + ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); + auto fake_data = fake_tensor.data(); + std::memset(fake_data, 0, fake_tensor.get_byte_size()); + m_request.set_tensor("internal_hidden_states", fake_tensor); + } catch (const ov::Exception& e) { + } + } else { + try { + m_request.set_tensor("internal_hidden_states", hidden_state_input); + auto shape = hidden_state_input.get_shape(); + shape[shape.size() - 1] = shape[shape.size() - 1] * m_adjust_factor; + ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); + auto fake_data = fake_tensor.data(); + std::memset(fake_data, 0, fake_tensor.get_byte_size()); + m_request.set_tensor("hidden_states", fake_tensor); + } catch (const ov::Exception& e) { + } + } + } if (position_ids.get_shape().size() == 3) { // flatten positions ids for 3D position ids case position_ids.set_shape({ov::shape_size(position_ids.get_shape())}); @@ -424,6 +554,23 @@ class ModelRunner { _reset_cache_rotation_coefficients(); + if (_is_hs_export()) { + try { + m_hidden_states = m_request.get_tensor("last_hidden_state"); + for (size_t i = 0; i < num_sequence_groups; ++i) { + size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i]; + SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id]; + std::vector running_sequences = sequence_group->get_running_sequences(); + for (size_t seq_idx = 0; seq_idx < running_sequences.size(); ++seq_idx) { + Sequence::Ptr sequence = running_sequences[seq_idx]; + sequence->update_hidden_state( + _get_hidden_state(sequence_group->get_request_id(), sequence->get_grouped_id())); + } + } + } catch (const ov::Exception&) { + m_hidden_states = ov::Tensor(); + } + } // return logits return m_request.get_tensor("logits"); } @@ -490,6 +637,116 @@ class ModelRunner { } private: + ov::Tensor m_hidden_states; + + // Hidden state flags and helpers + bool _is_hs_export() const { return m_hidden_state_flags & HS_EXPORT; } + bool _is_hs_import() const { return m_hidden_state_flags & HS_IMPORT; } + bool _is_hs_internal() const { return m_hidden_state_flags & HS_INTERNAL; } + + ov::Tensor _get_hidden_state(uint64_t request_id, uint64_t seq_grouped_id) const { + if (m_hidden_states.get_size() == 0) { + return ov::Tensor(); + } + + SequenceKey key{request_id, seq_grouped_id}; + const auto it = m_sequence_hidden_state_mapping.find(key); + if (it == m_sequence_hidden_state_mapping.end()) { + return ov::Tensor(); + } + + size_t start_idx = it->second.start_token_idx; + size_t length = it->second.length; + + auto shape = m_hidden_states.get_shape(); + if (shape.size() < 2) { + return ov::Tensor(); + } + + ov::Coordinate start_coord(shape.size(), 0); + ov::Coordinate end_coord(shape.size(), 0); + + start_coord[0] = start_idx; + end_coord[0] = start_idx + length; + + for (size_t i = 1; i < shape.size(); ++i) { + start_coord[i] = 0; + end_coord[i] = shape[i]; + } + + return ov::Tensor(m_hidden_states, start_coord, end_coord); + } + + ov::Tensor _prepare_hidden_state_input(size_t total_num_tokens, + size_t& hidden_size /*in/out*/) { + if (!(m_hidden_state_flags & (HS_IMPORT | HS_INTERNAL))) { + return {}; + } + + if (hidden_size == 0) { + for (const auto& kv : m_initial_hidden_states) { + const auto& initial_hidden_states = kv.second; + if (initial_hidden_states && initial_hidden_states.get_shape().size() >= 2) { + auto hidden_states_shape = initial_hidden_states.get_shape(); + hidden_size = hidden_states_shape.back(); + if (!(m_hidden_state_flags & HS_IMPORT)) { + hidden_size /= m_adjust_factor; + } + break; + } + } + } + if (hidden_size == 0) { + return {}; + } + + ov::Tensor hs(ov::element::f32, {total_num_tokens, 1, hidden_size}); + std::memset(hs.data(), 0, hs.get_byte_size()); + return hs; + } + + // Common helper to copy a contiguous slice (first-dim range) from src to dst using ROI tensors. + // src_start_idx: start index along src first dimension + // copy_length: number of elements along first dim to copy + // dst_base: destination base tensor (may be full buffer or a wrapper around a raw pointer) + // dst_first_dim_start: start index in first dimension of dst_base where copy should be placed + static void _copy_roi_between_tensors(const ov::Tensor& src, + size_t src_start_idx, + size_t copy_length, + const ov::Tensor& dst_base, + size_t dst_first_dim_start) { + if (copy_length == 0) { + return; + } + + // prepare source ROI coords + const auto src_shape = src.get_shape(); + OPENVINO_ASSERT(!src_shape.empty(), "source tensor rank is zero"); + ov::Coordinate src_start(src_shape.size(), 0), src_end(src_shape.size(), 0); + src_start[0] = src_start_idx; + src_end[0] = src_start_idx + copy_length; + for (size_t d = 1; d < src_shape.size(); ++d) { + src_start[d] = 0; + src_end[d] = src_shape[d]; + } + ov::Tensor src_roi(src, src_start, src_end); + + // prepare destination ROI coords + const auto dst_shape = dst_base.get_shape(); + OPENVINO_ASSERT(!dst_shape.empty(), "destination tensor rank is zero"); + ov::Coordinate tgt_start(dst_shape.size(), 0), tgt_end(dst_shape.size(), 0); + tgt_start[0] = dst_first_dim_start; + tgt_end[0] = dst_first_dim_start + copy_length; + for (size_t d = 1; d < dst_shape.size(); ++d) { + tgt_start[d] = 0; + tgt_end[d] = dst_shape[d]; + } + ov::Tensor tgt_roi(dst_base, tgt_start, tgt_end); + + // bulk copy + src_roi.copy_to(tgt_roi); + } + ov::Tensor _get_or_resize_tensor(ov::Tensor& cached_tensor, const std::string& tensor_name, const ov::Shape& required_shape, diff --git a/src/cpp/src/continuous_batching/pipeline.cpp b/src/cpp/src/continuous_batching/pipeline.cpp index 16eb169de7..92db6a7da7 100644 --- a/src/cpp/src/continuous_batching/pipeline.cpp +++ b/src/cpp/src/continuous_batching/pipeline.cpp @@ -11,14 +11,18 @@ #include "openvino/genai/tokenizer.hpp" #include "continuous_batching/pipeline_impl.hpp" #include "speculative_decoding/speculative_decoding_impl.hpp" +#include "speculative_decoding/speculative_decoding_eagle3_impl.hpp" +#include "speculative_decoding/speculative_decoding_utils.hpp" #include "prompt_lookup/prompt_lookup_impl.hpp" #include "continuous_batching/timer.hpp" #include "utils.hpp" #include "visual_language/inputs_embedder.hpp" +#include "json_utils.hpp" using namespace ov::genai; namespace { + bool extract_prompt_lookup_from_config(ov::AnyMap& config) { bool res = false; @@ -45,6 +49,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p auto properties_without_draft_model = properties; auto draft_model_desr = utils::extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); + auto eagle_rt_info = speculative_decoding::extract_eagle_mode_from_config(draft_model_desr.properties, models_path); auto model = utils::read_model(models_path, properties); auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model); @@ -63,6 +68,10 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually exclusive"); OPENVINO_ASSERT(embedder == nullptr, "Prompt lookup decoding is not supported for models with embeddings"); m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model_without_gguf, generation_config); + } else if (draft_model_desr.model != nullptr && eagle_rt_info.eagle3_mode) { + OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); + m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); @@ -87,13 +96,12 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto properties_without_draft_model = properties; auto draft_model_desr = utils::extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); - + auto eagle_rt_info = speculative_decoding::extract_eagle_mode_from_config(draft_model_desr.properties, models_path); auto model = utils::read_model(models_path, properties_without_draft_model); auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model); properties_without_draft_model_without_gguf[ov::cache_model_path.name()] = models_path; auto generation_config = utils::from_config_json_if_exists(models_path); - std::shared_ptr embedder; if (std::filesystem::exists(models_path / "openvino_text_embeddings_model.xml")) { embedder = std::make_shared(models_path, device, properties_without_draft_model_without_gguf); @@ -105,6 +113,13 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually exclusive"); OPENVINO_ASSERT(embedder == nullptr, "Prompt lookup decoding is not supported for models with embeddings"); m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model_without_gguf, generation_config); + } else if (draft_model_desr.model != nullptr && eagle_rt_info.eagle3_mode) { + OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); + // Eagle speculative decoding does not support dynamic_split_fuse mode + // because it requires hidden state interaction from main model to draft model + // to be implemented future + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); + m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); @@ -131,6 +146,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto properties_without_draft_model = properties; auto draft_model_desr = utils::extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); + auto eagle_rt_info = speculative_decoding::extract_eagle_mode_from_config(draft_model_desr.properties, std::filesystem::path(model_str)); auto model = utils::singleton_core().read_model(model_str, weights_tensor); auto rt_info = model->get_rt_info(); @@ -150,6 +166,10 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually exclusive"); OPENVINO_ASSERT(embedder == nullptr, "Prompt lookup decoding is not supported for models with embeddings"); m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model, generation_config); + } else if (draft_model_desr.model != nullptr && eagle_rt_info.eagle3_mode) { + OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config, generation_config); + m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config, generation_config); diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index e057d5da72..6d94353755 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -13,6 +13,8 @@ #include "llm/pipeline_continuous_batching_adapter.hpp" #include "speculative_decoding/speculative_decoding_impl.hpp" #include "speculative_decoding/speculative_decoding_stateful.hpp" +#include "speculative_decoding/speculative_decoding_stateful_eagle3.hpp" +#include "speculative_decoding/speculative_decoding_utils.hpp" #include "utils.hpp" namespace { @@ -88,6 +90,20 @@ std::pair generation_config(const GenerationConfig& config) { return {utils::CONFIG_ARG_NAME, Any::make(config)}; } +inline void apply_eagle_rt_info(std::shared_ptr& model, ov::AnyMap& properties, const std::filesystem::path& mapping_path) { + if (model->has_rt_info("eagle3_mode") && model->get_rt_info("eagle3_mode")) { + properties["eagle3_mode"] = true; + if (model->has_rt_info("hidden_layers_list")) + properties["hidden_layers_list"] = model->get_rt_info>("hidden_layers_list"); + } +} + +inline void apply_eagle_rt_info(std::shared_ptr& model, + ov::AnyMap& properties, + const std::string& mapping_path) { + apply_eagle_rt_info(model, properties, std::filesystem::path(mapping_path)); +} + std::pair draft_model( const std::filesystem::path& models_path, const std::string& device, @@ -96,6 +112,7 @@ std::pair draft_model( std::filesystem::path openvino_model_name = "openvino_model.xml"; auto model = utils::singleton_core().read_model(models_path / openvino_model_name, {}, plugin_config); + apply_eagle_rt_info(model, plugin_config, models_path); auto generation_config = utils::from_config_json_if_exists(models_path); auto tokenizer = ov::genai::Tokenizer(models_path); return { utils::DRAFT_MODEL_ARG_NAME, Any::make(model, tokenizer, device, plugin_config, scheduler_config, generation_config) }; @@ -111,6 +128,7 @@ std::pair draft_model( auto [plugin_config, scheduler_config] = utils::extract_scheduler_config(properties); auto model = utils::singleton_core().read_model(model_str, weights_tensor); + apply_eagle_rt_info(model, plugin_config, model_str); return { utils::DRAFT_MODEL_ARG_NAME, Any::make(model, tokenizer, device, plugin_config, scheduler_config, generation_config) }; } @@ -126,7 +144,8 @@ static std::unique_ptr create( tokenizer, device, properties, - utils::from_config_json_if_exists(models_path)); + utils::from_config_json_if_exists(models_path), + models_path); } static std::unique_ptr create( @@ -141,17 +160,43 @@ static std::unique_ptr create( const ov::genai::Tokenizer& tokenizer, const std::string& device, const ov::AnyMap& properties, - const ov::genai::GenerationConfig& generation_config) { + const ov::genai::GenerationConfig& generation_config, + const std::filesystem::path& models_path = {}) { auto properties_without_draft_model = properties; auto draft_model_descr = ov::genai::utils::extract_draft_model_from_config(properties_without_draft_model); + if (draft_model_descr.model != nullptr) { - // FIXME: Add support for StatefulSpeculativeLLMPipeline for non-NPU devices for both models. - OPENVINO_ASSERT(device == "NPU" || draft_model_descr.device == "NPU", - "Stateful Speculative Decoding is expected to be launched when NPU is requested as " - "execution device for one or both models."); - auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, {}, generation_config); - return std::make_unique(main_model_descr, draft_model_descr); + // Extract Eagle3 configuration from draft model properties + // Pass models_path for auto-deducing hidden_layers_list from config.json + auto eagle_rt_info = ov::genai::speculative_decoding::extract_eagle_mode_from_config( + draft_model_descr.properties, + models_path + ); + + if (eagle_rt_info.eagle3_mode) { + // Eagle3 Speculative Decoding mode + OPENVINO_ASSERT(device == "NPU" || draft_model_descr.device == "NPU", + "Stateful Eagle3 Speculative Decoding is expected to be launched when NPU is requested as " + "execution device for one or both models."); + + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, + properties_without_draft_model, {}, generation_config); + return std::make_unique( + main_model_descr, + draft_model_descr, + eagle_rt_info.hidden_layers_list + ); + } else { + // Standard Speculative Decoding mode + OPENVINO_ASSERT(device == "NPU" || draft_model_descr.device == "NPU", + "Stateful Speculative Decoding is expected to be launched when NPU is requested as " + "execution device for one or both models."); + + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, + properties_without_draft_model, {}, generation_config); + return std::make_unique(main_model_descr, draft_model_descr); + } } return std::make_unique(model, tokenizer, device, diff --git a/src/cpp/src/sampling/sampler.cpp b/src/cpp/src/sampling/sampler.cpp index f0320c7409..0ec3f0de60 100644 --- a/src/cpp/src/sampling/sampler.cpp +++ b/src/cpp/src/sampling/sampler.cpp @@ -852,6 +852,11 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr } } } + if (!is_validation_mode_enabled && m_draft2target_mapping) { // compute token offset for draft model in speculative sampling + ov::Tensor d2t_tensor = m_draft2target_mapping->get_tensor_view(); + auto d2t = d2t_tensor.data(); + sampled_token.m_index = sampled_token.m_index + (d2t? d2t[sampled_token.m_index] : 0); + } // flag to add sampled token to generated sequence or extend logit processors only bool is_extend_sequence = logit_token_offset == 0 || is_generate_n_tokens || !is_validation_passed; if (is_validation_mode_enabled && !is_extend_sequence) { diff --git a/src/cpp/src/sampling/sampler.hpp b/src/cpp/src/sampling/sampler.hpp index ffbbcac3e3..c4def2f871 100644 --- a/src/cpp/src/sampling/sampler.hpp +++ b/src/cpp/src/sampling/sampler.hpp @@ -99,6 +99,7 @@ class Sampler { Tokenizer m_tokenizer; ThreadPool m_thread_pool; + std::shared_ptr m_draft2target_mapping; // Tensor to store draft2target mapping for eagle model public: Sampler(const Sampler& rhs) = delete; Sampler(Sampler&& rhs) = delete; @@ -125,6 +126,10 @@ class Sampler { // pair with map with backend name and corresponding compiler init time, and vector of compile times for each concrete grammar std::pair, std::vector> get_structured_output_times(); void clear_structured_output_compile_times(); + + void set_d2t_for_decoding(std::shared_ptr& d2t) { + m_draft2target_mapping = d2t; + }; }; class Sampler::GroupBeamSearcher { diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index e1e2187498..54bd46c37a 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -44,6 +44,7 @@ class Sequence { LogProbs m_generated_log_probs; uint64_t m_grouped_id; uint64_t m_id = _get_next_global_sequence_id(); + ov::Tensor m_hidden_state = ov::Tensor(); SequenceStatus m_status = SequenceStatus::RUNNING; GenerationFinishReason m_finish_reason = GenerationFinishReason::NONE; float m_cumulative_log_prob = 0.0f; @@ -70,6 +71,7 @@ class Sequence { m_generated_ids(seq.m_generated_ids), m_generated_log_probs(seq.m_generated_log_probs), m_grouped_id(id), + m_hidden_state(seq.m_hidden_state), m_status(seq.m_status), m_cumulative_log_prob(seq.m_cumulative_log_prob), m_sequence_group(seq.m_sequence_group), @@ -142,6 +144,14 @@ class Sequence { m_generated_ids.push_back(token_id); } + void update_hidden_state(const ov::Tensor& tensor) { + m_hidden_state = tensor; + } + + ov::Tensor get_hidden_state() const { + return m_hidden_state; + } + // removes n last tokens and updates cumulative log prob // used to remove stop_string from the output void remove_last_tokens(int n) { @@ -643,7 +653,7 @@ class SequenceGroup : public std::enable_shared_from_this { m_num_validation_tokens = k; } - size_t get_num_tokens_to_validate() { + size_t get_num_tokens_to_validate() const { return m_num_validation_tokens; } diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index 4853b8bac6..bcc00f9a8a 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -63,12 +63,48 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::get_ge for (const auto& sequence : request->get_running_sequences()) { const auto& sequence_id = sequence->get_grouped_id(); OPENVINO_ASSERT(!generated_request.count(sequence_id)); - generated_request.insert({{sequence_id, { sequence->get_generated_ids(), sequence->get_generated_log_probs() } }}); + generated_request.insert({{sequence_id, { sequence->get_generated_ids(), sequence->get_generated_log_probs(), sequence->get_hidden_state() } }}); } } return result; } +ov::Tensor truncate_hidden_state_from_end(const ov::Tensor& hidden_state, size_t tokens_to_remove) { + if (hidden_state.get_size() == 0 || tokens_to_remove == 0) { + return hidden_state; + } + + auto shape = hidden_state.get_shape(); + if (shape.size() < 2) { + return hidden_state; + } + + size_t seq_len_dim = 0; + size_t current_seq_len = shape[seq_len_dim]; + + if (tokens_to_remove >= current_seq_len) { + ov::Shape new_shape = shape; + new_shape[seq_len_dim] = 0; + return ov::Tensor(hidden_state.get_element_type(), new_shape); + } + + size_t new_seq_len = current_seq_len - tokens_to_remove; + + ov::Coordinate start_coord(shape.size(), 0); + ov::Coordinate end_coord(shape.size(), 0); + + for (size_t i = 0; i < shape.size(); ++i) { + start_coord[i] = 0; + if (i == seq_len_dim) { + end_coord[i] = new_seq_len; + } else { + end_coord[i] = shape[i]; + } + } + + return ov::Tensor(hidden_state, start_coord, end_coord); +} + // { min_len_of_prefix, min_length_of_candidate } std::pair get_prefix_len( @@ -227,6 +263,7 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update std::vector running_sequences = request->get_running_sequences(); OPENVINO_ASSERT(running_sequences.size() > 0); size_t min_generated_tokens, min_candidate_len; + size_t validate_length = 0; if (running_sequences.front()->get_generated_len() == 0 && !request->get_num_tokens_to_validate()) { m_sampler->create_logit_processor(request_id, request->get_sampling_parameters(), request->get_prompt_ids()); auto& logit_processor = m_sampler->get_logit_processor(request_id); @@ -234,6 +271,9 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update min_generated_tokens = result.inserted_tokens_cnt; running_sequences = request->get_running_sequences(); min_candidate_len = result.inserted_tokens_cnt; + if (eagle_mode_enabled && !m_is_validation_mode_enabled) + m_model_runner->set_initial_hidden_state(request_id, + candidates.begin()->second.hidden_states); } else { // update existing sequences by the candidates auto& logit_processor = m_sampler->get_logit_processor(request_id); @@ -252,6 +292,16 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update candidate_token_ids.resize(min_candidate_len); candidate_token_log_probs.resize(min_candidate_len); result.inserted_tokens_cnt = insert_tokens_to_sequence(running_sequence, candidate_token_ids, candidate_token_log_probs, logit_processor, is_update_logit_processor); + // handle hidden states for eagle mode + if (eagle_mode_enabled && !m_is_validation_mode_enabled && result.inserted_tokens_cnt > 0) { // update hidden states for draft model + // at least there should be one bonus token from main + auto& hidden_state = candidate_sequence.hidden_states; + ov::Tensor pruned_hidden_state = truncate_hidden_state_from_end(hidden_state, result.removed_tokens_cnt); + m_model_runner->set_initial_hidden_state(request_id, + pruned_hidden_state); + const auto& shape = pruned_hidden_state.get_shape(); + validate_length = shape.size() > 0 ? shape[0] : 0; + } } // we should update a logit processor just for draft model to generate the same tokens // logit processors of main model will be updated in sampler while validation mode @@ -266,14 +316,22 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update updated_context_len = min_candidate_len + prompt_len, max_new_tokens = request->get_max_new_tokens(); size_t generated_len = request->get_context_len() >= request->get_prompt_len() ? request->get_context_len() - request->get_prompt_len() + 1 : 0; - if (generated_len > 0 && result.removed_tokens_cnt > 0) { - request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1); + if (validate_length > 0) { + if (generated_len > 0) { + request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1 - (validate_length - 1)); + } + } else { // fast draft or main model for eagle speculative + if (generated_len > 0 && result.removed_tokens_cnt > 0) { + request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1); + } } - if (result.inserted_tokens_cnt > 0 && result.removed_tokens_cnt == 0) { + if (validate_length == 0 && result.inserted_tokens_cnt > 0 && result.removed_tokens_cnt == 0) { request->set_num_validated_tokens(result.inserted_tokens_cnt); + } else if (validate_length > 0) { + request->set_num_validated_tokens(validate_length - 1); // in generation stage } // to pause `draft_model` generation in case of `generated_len >= max_new_tokens - 1` to generate last token by `main_model` - if (!m_is_validation_mode_enabled) { + if (!m_is_validation_mode_enabled && result.inserted_tokens_cnt != 0) { bool pause_gen_status = false; generated_len -= result.removed_tokens_cnt; generated_len += result.inserted_tokens_cnt; @@ -328,6 +386,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m raw_perf_metrics.m_batch_sizes.emplace_back(num_generated_tokens); } + if (eagle_mode_enabled) + m_model_runner->enable_hidden_state_import(false); to_generate = false; for (auto& request : m_requests) { const auto& sampling_params = request->get_sampling_parameters(); @@ -351,5 +411,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m to_generate |= request->can_generate_tokens(); } } + if (eagle_mode_enabled) + m_model_runner->enable_hidden_state_import(true); } } diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp index 40db6a2ddd..5d6d220028 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp @@ -40,5 +40,56 @@ class ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl : protected: void finish_request(SequenceGroup::Ptr request); void _pull_awaiting_requests() override {}; + bool eagle_mode_enabled = false; +}; + +class ContinuousBatchingPipeline::ContinuousBatchingForEagle3DecodingImpl + : public ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl { +public: + ContinuousBatchingForEagle3DecodingImpl() = default; + + ContinuousBatchingForEagle3DecodingImpl(const std::shared_ptr& model, + const Tokenizer& tokenizer, + const GenerationConfig& generation_config, + const SchedulerConfig& scheduler_config, + const std::string& device, + const ov::AnyMap& plugin_config, + bool is_validation_mode_enabled) : ContinuousBatchingForSpeculativeDecodingImpl( + model, tokenizer, generation_config, + scheduler_config, device, plugin_config, + is_validation_mode_enabled) { + eagle_mode_enabled = true; + }; + + bool is_requests_empty(); + + void set_d2t_for_draft_decoding(std::shared_ptr& d2t) { + if (m_sampler) { + m_sampler->set_d2t_for_decoding(d2t); + } + } + void set_hidden_state_export_needed(bool is_needed) { + if (m_model_runner) { + m_model_runner->enable_hidden_state_export(is_needed); + } + } + + void set_hidden_state_import_needed(bool is_needed) { + if (m_model_runner) { + m_model_runner->enable_hidden_state_import(is_needed); + } + } + + void set_hidden_state_internal_needed(bool is_needed) { + if (m_model_runner) { + m_model_runner->enable_hidden_state_internal(is_needed); + } + } + + void set_adjust_factor(size_t adjust_factor) { + if (m_model_runner) { + m_model_runner->set_adjust_factor(adjust_factor); + } + } }; } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp new file mode 100644 index 0000000000..daf3dfa8b7 --- /dev/null +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp @@ -0,0 +1,502 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +#include "speculative_decoding_eagle3_impl.hpp" +#include "logger.hpp" + +namespace ov::genai { +void share_embedding_weights(std::shared_ptr& main_model, std::shared_ptr& draft_model) { + // extract embedding weight from main model + auto find_embedding_gather = [](const std::shared_ptr& model) + -> std::shared_ptr { + constexpr size_t MIN_VOCAB_SIZE_THRESHOLD = 1000; + for (const auto& node : model->get_ordered_ops()) { + auto gather = std::dynamic_pointer_cast(node); + if (!gather) continue; + // [vocab, hidden_size] * [batch, seq_len] -> [batch, seq_len, hidden_size] + auto data_node = gather->input_value(0).get_node_shared_ptr(); + auto indices_node = gather->input_value(1).get_node_shared_ptr(); + if (!data_node || !indices_node) continue; + // indices_node should be on parameter path, maybe this is better rule + ov::PartialShape ps = data_node->get_output_partial_shape(0); + if (ps.rank().is_static() && ps.rank().get_length() >= 2) { + if (ps[0].is_static() && ps[0].get_length() > MIN_VOCAB_SIZE_THRESHOLD) { // Heuristic: vocab size > 1000 + return gather; + } + } + std::string fname = data_node->get_friendly_name(); + if (fname.find("embed_tokens") != std::string::npos || + fname.find("embedding") != std::string::npos) { + return gather; + } + } + return nullptr; + }; + auto main_gather = find_embedding_gather(main_model); + auto draft_gather = find_embedding_gather(draft_model); + if (!main_gather || !draft_gather) { + return; + } + auto main_weight_node = main_gather->input_value(0).get_node_shared_ptr(); + auto draft_weight_node = draft_gather->input_value(0).get_node_shared_ptr(); + + if (main_weight_node.get() == draft_weight_node.get()) { + return; + } + + try { + draft_weight_node->output(0).replace(main_weight_node->output(0)); + } catch (const std::exception& e) { + Logger::warn(std::string("Error: failed to import embedding weights from main model to draft model. Exception: ") + e.what()); + } catch (...) { + Logger::warn("Error: failed to import embedding weights from main model to draft model due to unknown exception."); + } +} + +std::shared_ptr extract_d2t_mapping_table(const std::shared_ptr& model) { + // extract result nodes from model + for (const auto& result : model->get_results()) { + auto input_node = result->input_value(0).get_node_shared_ptr(); + if (ov::is_type(input_node) && input_node->get_friendly_name().find("d2t") != std::string::npos) { + return ov::as_type_ptr(input_node); + } + } + return nullptr; +} + +void remove_d2t_result_node(std::shared_ptr& model) { + // Find and remove the d2t Result node + std::shared_ptr d2t_result_to_remove = nullptr; + + for (const auto& result : model->get_results()) { + auto input_node = result->input_value(0).get_node_shared_ptr(); + if (ov::is_type(input_node) && + input_node->get_friendly_name().find("d2t") != std::string::npos) { + d2t_result_to_remove = result; + break; + } + } + + if (d2t_result_to_remove) { + model->remove_result(d2t_result_to_remove); + model->validate_nodes_and_infer_types(); + } +} + +void extract_hidden_state_generic(std::shared_ptr& model, + const std::vector& hidden_layers_to_abstract, + const std::string& device) { + ov::pass::Manager pm; + pm.register_pass(hidden_layers_to_abstract, device); + pm.run_passes(model); +} + +EagleModelTransform::EagleModelTransform(const std::vector& layers, const std::string& device) + : m_layer_ids(layers), m_device(device) { +} + +bool EagleModelTransform::run_on_model(const std::shared_ptr& model) { + // share the embedding weights from main model to draft model + m_new_parameters.clear(); + m_new_results.clear(); + if (m_layer_ids.size() == 1 && m_layer_ids[0] == -1) { + ov::pass::Manager manager; + manager.set_per_pass_validation(false); + manager.register_pass(m_new_results); + // input transform for draft + // here we apply a trick for the fc layer in draft model + manager.register_pass(m_new_parameters, m_device); + manager.run_passes(model); + + model->add_parameters(m_new_parameters); + model->add_results(m_new_results); + return true; + } else { + ov::pass::Manager manager; + manager.set_per_pass_validation(false); + manager.register_pass(m_layer_ids, m_hidden_layer_outputs); + manager.run_passes(model); + if (!m_hidden_layer_outputs.empty()) { + auto concat = std::make_shared(m_hidden_layer_outputs, -1); + concat->set_friendly_name("eagle3_hidden_states_concat"); + + auto result = std::make_shared(concat); + std::string output_name = "last_hidden_state"; + result->output(0).set_names({output_name}); + result->set_friendly_name(output_name); + model->add_results({result}); + return true; + } + } + + return false; +} + +EagleInputTransform::EagleInputTransform(std::vector>& params, const std::string& device) + : m_device(device) { + register_matcher( + std::make_shared(ov::pass::pattern::wrap_type(), this->get_type_info().name), + ([¶ms, this](ov::pass::pattern::Matcher& m) { + auto node = m.get_match_root(); + try { + if (apply(node, params)) { + ++applied; + return true; + } + } catch (...) { + OPENVINO_ASSERT(false, "EagleTransform failed to apply"); + } + return false; + }) + ); +} + +bool EagleInputTransform::apply(NodePtr node, std::vector>& params) { + if (ov::is_type(node)) { + auto matmul_node = ov::as_type_ptr(node); + // check the input of matmul node, if it is a node with name "hidden_states", then it's the node we want + auto input_node = matmul_node->get_input_node_shared_ptr(0); + if (!ov::as_type_ptr(input_node)) { + return false; + } + + auto matmul_input0 = matmul_node->input_value(0); + auto matmul_input1 = matmul_node->input_value(1); + + std::shared_ptr matmul_output_node; + + // Apply scaling optimization for NPU devices to prevent FP16 overflow + if (m_device.find("NPU") != std::string::npos) { + // Scale input down by 100x before MatMul to avoid FP16 overflow, then scale result back up + // The factor 100 (0.01 and 100.0) is an empirical value + auto scale_down_const = std::make_shared(matmul_input0.get_element_type(), ov::Shape{}, 0.01f); + auto multiply_scale_down = std::make_shared(matmul_input0, scale_down_const); + multiply_scale_down->set_friendly_name(matmul_node->get_friendly_name() + "/multiply_scale_down"); + + // Create new MatMul with scaled input + auto new_matmul = std::make_shared(multiply_scale_down, matmul_input1, + matmul_node->get_transpose_a(), + matmul_node->get_transpose_b()); + new_matmul->set_friendly_name(matmul_node->get_friendly_name() + "/matmul"); + + // Scale result back up to maintain numerical equivalence + auto scale_up_const = std::make_shared(new_matmul->get_element_type(), ov::Shape{}, 100.0f); + auto multiply_scale_up = std::make_shared(new_matmul->output(0), scale_up_const); + multiply_scale_up->set_friendly_name(matmul_node->get_friendly_name() + "/multiply_scale_up"); + + matmul_output_node = multiply_scale_up; + } else { + // Default behavior: Use MatMul directly without scaling + auto new_matmul = std::make_shared(matmul_input0, matmul_input1, + matmul_node->get_transpose_a(), + matmul_node->get_transpose_b()); + new_matmul->set_friendly_name(matmul_node->get_friendly_name() + "/matmul"); + + matmul_output_node = new_matmul; + } + + auto shape = node->get_output_partial_shape(0); + auto internal_hidden_state = std::make_shared(node->get_element_type(), node->get_output_partial_shape(0)); + internal_hidden_state->output(0).set_names({"internal_hidden_states"}); + internal_hidden_state->set_friendly_name("internal_hidden_states"); + + // Create new Add node (MatMul output + internal_hidden_state) + auto new_eltwise = std::make_shared(internal_hidden_state, matmul_output_node->output(0)); + new_eltwise->set_friendly_name(matmul_node->get_friendly_name() + "/add"); + + // Replace the original MatMul node with the new Add + ov::replace_node(matmul_node, new_eltwise); + params.push_back(internal_hidden_state); + return true; + } + return false; +} + +EagleBaseTransform::EagleBaseTransform(std::vector>& results) { + register_matcher( + std::make_shared(ov::pass::pattern::wrap_type(), this->get_type_info().name), + ([&results, this](ov::pass::pattern::Matcher& m) { + auto node = m.get_match_root(); + try { + if (apply(node, results)) { + ++applied; + return true; + } + } catch (...) { + OPENVINO_ASSERT(false, "EagleTransform failed to apply"); + } + return false; + }) + ); +} + +std::shared_ptr EagleBaseTransform::find_last_residual_node(const std::shared_ptr& start_node, + std::set& visited_nodes) { + if (visited_nodes.count(start_node.get())) { + return nullptr; + } + + visited_nodes.insert(start_node.get()); + + if (ov::is_type(start_node)) { + // check the input nodes of MatMul, if found Gather node, return the gather node, otherwise ,retrun the matmul node + for (size_t i = 0; i < start_node->get_input_size(); ++i) { + auto input_node = start_node->get_input_node_shared_ptr(i); + if (!input_node) continue; + if (ov::as_type_ptr(input_node)) { + return start_node; // return the Add node itself + } + } + } + + for (size_t i = 0; i < start_node->get_input_size(); ++i) { + auto input_node = start_node->get_input_node_shared_ptr(i); + if (!input_node) continue; + + auto result = find_last_residual_node(input_node, visited_nodes); + if (result) { + return result; + } + } + return nullptr; +} + +std::shared_ptr EagleBaseTransform::find_last_residual_node(const std::shared_ptr& start_node) { + std::set visited_nodes; + return find_last_residual_node(start_node, visited_nodes); +} + +bool EagleBaseTransform::apply(NodePtr node, std::vector>& results) { + { + // 1. without normalization layer 2. add extra input + if (ov::is_type(node)) { + // we are applying transformation to the last hidden state, eagle2 mode + NodePtr input_node = node->get_input_node_shared_ptr(0); + if (!input_node) { + return false; + } + auto last_residual_node = find_last_residual_node(input_node); + if (!last_residual_node) { + return false; + } + auto result = std::make_shared(last_residual_node); + std::string output_name = "last_hidden_state"; + result->output(0).set_names({output_name}); + result->set_friendly_name(output_name); + results.push_back(result); + return true; + } + return false; + } +} + +Eagle3Transform::Eagle3Transform(const std::vector& layers, std::vector>& hidden_state_outputs) : m_layers(layers) { + auto is_target_pattern = [&](const Output& output) { + auto add_node = ov::as_type_ptr(output.get_node_shared_ptr()); + auto add_node_name = add_node->get_friendly_name(); + if (add_node_name.find("self_attn") != std::string::npos) + return false; // Skip self-attention layers + bool layer_matched = false; + for (auto layer_idx : m_layers) { + if (add_node_name.find("layers." + std::to_string(layer_idx) + "/") != std::string::npos) { + layer_matched = true; + break; + } + } + + if (!layer_matched) { + return false; // Skip layers that are not in the specified layers + } + auto input0 = add_node->get_input_node_shared_ptr(1); + if (!input0 || !ov::is_type(input0)) { + return false; + } + auto matmul_node = input0; + auto matmul_input = matmul_node->get_input_node_shared_ptr(0); + if (!matmul_input) { + return false; + } + + bool has_multiply = ov::is_type(matmul_input); // ACT(up) dot gate + return has_multiply; + }; + + auto hidden_layer = ov::pass::pattern::wrap_type(is_target_pattern); + register_matcher(std::make_shared(hidden_layer, "Eagle3Transform::hidden_extraction"), + [&hidden_state_outputs, this](ov::pass::pattern::Matcher& m) { + auto node = m.get_match_root(); + if (ov::is_type(node)) { + hidden_state_outputs.push_back(node->output(0)); + return true; + } + return false; + } + ); +} + +ContinuousBatchingPipeline::Eagle3DecodingImpl::Eagle3DecodingImpl(const ov::genai::ModelDesc& main_model_desc, + const ov::genai::ModelDesc& draft_model_desc, + const std::vector& hidden_layers) + : m_hidden_layers_to_abstract(hidden_layers) { + auto scheduler_configs = init_speculative_models(main_model_desc, draft_model_desc); + // Eagle speculative decoding does not support dynamic_split_fuse mode + // because it requires hidden state interaction from main model to draft model + // to be implemented future + if (scheduler_configs.first.dynamic_split_fuse) { + Logger::warn( + "Note: disable dynamic split fuse for eagle3 speculative decoding" + ); + scheduler_configs.first.dynamic_split_fuse = false; + scheduler_configs.second.dynamic_split_fuse = false; + } + auto main_model = main_model_desc.model; + auto draft_model = draft_model_desc.model; + + auto main_device = main_model_desc.device; + std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device; + + ov::AnyMap draft_properties = + draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties; + + // main and draft model can have different tokenizers + // to do: support retokenization: 154103 + Tokenizer main_model_tokenizer = main_model_desc.tokenizer; + Tokenizer draft_model_tokenizer = draft_model_desc.tokenizer; + m_tokenizer = main_model_tokenizer; + // for eagle model, we need to obtain hidden layer state as extra output + // apply transformations needed to run eagle model + // target model: hidden state extraction, draft model: hidden state import , hidden state extraction + // eagle3 specific : dt importing + share_embedding_weights(main_model, draft_model); + extract_hidden_state_generic(main_model, hidden_layers, main_device); + extract_hidden_state_generic(draft_model, { -1 }, draft_device); + + // to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode + m_main_pipeline = std::make_shared(main_model, + main_model_tokenizer, + main_model_desc.generation_config, + scheduler_configs.first, + main_device, + main_model_desc.properties, + true); + m_draft_pipeline = std::make_shared(draft_model, + draft_model_tokenizer, + draft_model_desc.generation_config, + scheduler_configs.second, + draft_device, + draft_properties, + false); + m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); + m_perf_metrics.raw_metrics.m_inference_durations = {{MicroSeconds(0.0f)}}; + m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; + + // specific params update for eagle pipeline + // check draft_model, retrieve d2t table if exists + auto d2t_tensor = extract_d2t_mapping_table(draft_model); + update_eagle_pipeline_params(d2t_tensor); +} + +ov::Tensor ContinuousBatchingPipeline::Eagle3DecodingImpl::create_draft_input_ids(const ov::Tensor& original_input_ids) { + auto shape = original_input_ids.get_shape(); + if (shape.size() == 0 || shape.back() <= 1) { + return ov::Tensor(original_input_ids); + } + + size_t original_length = shape.back(); + size_t new_length = original_length - 1; + + ov::Tensor draft_input_ids(ov::element::i64, {1, new_length}); + + const int64_t* src_data = original_input_ids.data(); + int64_t* dst_data = draft_input_ids.data(); + + std::copy(src_data + 1, src_data + original_length, dst_data); + + return draft_input_ids; +} + +void ContinuousBatchingPipeline::Eagle3DecodingImpl::update_eagle_pipeline_params(std::shared_ptr& d2t_tensor) { + auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); + auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); + m_main_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_hidden_state_import_needed(true); + m_draft_eagle_pipeline->set_hidden_state_internal_needed(true); + m_draft_eagle_pipeline->set_adjust_factor( + m_hidden_layers_to_abstract.size() > 0 ? m_hidden_layers_to_abstract.size() : 1); + m_draft_eagle_pipeline->set_d2t_for_draft_decoding(d2t_tensor); +} + +GenerationHandle +ContinuousBatchingPipeline::Eagle3DecodingImpl::add_request(uint64_t request_id, + const ov::Tensor& input_ids, + const ov::genai::GenerationConfig& sampling_params, + std::optional token_type_ids) { + std::lock_guard lock(m_draft_generations_mutex); + auto draft_sampling_params = sampling_params; + draft_sampling_params.ignore_eos = true; + draft_sampling_params.stop_strings = {}; + // remove first token from input_ids to create draft_input_ids + ov::Tensor draft_input_ids = create_draft_input_ids(input_ids); + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params, token_type_ids)}); + return m_main_pipeline->add_request(request_id, input_ids, sampling_params, token_type_ids); +} + +GenerationHandle +ContinuousBatchingPipeline::Eagle3DecodingImpl::add_request(uint64_t request_id, + const std::string& prompt, + const ov::genai::GenerationConfig& sampling_params) { + std::lock_guard lock(m_draft_generations_mutex); + auto draft_sampling_params = sampling_params; + draft_sampling_params.ignore_eos = true; + draft_sampling_params.stop_strings = {}; + // remove first token from input_ids to create draft_input_ids + // add_special_tokens is false for better compress rate + auto input_ids = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false)).input_ids; + ov::Tensor draft_input_ids = create_draft_input_ids(input_ids); + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params)}); + return m_main_pipeline->add_request(request_id, input_ids, sampling_params); +} + +std::vector ContinuousBatchingPipeline::Eagle3DecodingImpl::generate( + const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + const std::optional>& token_type_ids, + const std::optional>>>& position_ids) { + GenerateStrategy strategy; + strategy.prepare_request = [this](size_t, + const ov::Tensor& in_ids, + GenerationConfig& main_cfg, + GenerationConfig& draft_cfg, + ov::Tensor& main_in, + ov::Tensor& draft_in) { + OPENVINO_ASSERT(main_cfg.assistant_confidence_threshold == 0.f, + "Eagle3 only supports num_assistant_tokens (assistant_confidence_threshold must be 0.f)"); + if (main_cfg.num_assistant_tokens == 0) { + main_cfg.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; + draft_cfg.num_assistant_tokens = main_cfg.num_assistant_tokens; + } + draft_cfg.ignore_eos = true; + draft_cfg.stop_strings = {}; + main_in = in_ids; + draft_in = create_draft_input_ids(in_ids); + }; + + strategy.check_streaming = [](const std::shared_ptr& streamer_ptr, + const std::vector& input_ids, + const std::vector& sampling_params) { + OPENVINO_ASSERT(!streamer_ptr->has_callback() || + (input_ids.size() == 1 && + (sampling_params[0].is_greedy_decoding())), + "Eagle3 streaming only supports batch size=1 with greedy"); + }; + strategy.start_timer = [](){ + return std::chrono::steady_clock::now(); + }; + strategy.stop_timer = [](TimePoint start){ + return PerfMetrics::get_microsec(std::chrono::steady_clock::now() - start); + }; + + return generate_common(this, input_ids, sampling_params, streamer, token_type_ids, strategy); +} +} // namespace ov::genai \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp new file mode 100644 index 0000000000..ecd6b2131f --- /dev/null +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp @@ -0,0 +1,115 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "speculative_decoding_impl.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/result.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/manager.hpp" + +namespace ov::genai { + +void share_embedding_weights(std::shared_ptr& main_model, std::shared_ptr& draft_model); +void extract_hidden_state_generic(std::shared_ptr& model, const std::vector& hidden_layers_to_abstract, const std::string& device = ""); +std::shared_ptr extract_d2t_mapping_table(const std::shared_ptr& model); +void remove_d2t_result_node(std::shared_ptr& model); + +class ContinuousBatchingPipeline::Eagle3DecodingImpl : public ContinuousBatchingPipeline::SpeculativeDecodingImpl { +public: + template + friend std::vector generate_common( + Impl*, + const std::vector&, + const std::vector&, + const StreamerVariant&, + std::optional>, + GenerateStrategy&); + Eagle3DecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc, const std::vector& hidden_layers_to_abstract); + + std::vector + generate(const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + const std::optional>& token_type_ids = std::nullopt, + const std::optional>>>& position_ids = std::nullopt) override; + + GenerationHandle add_request(uint64_t request_id, + const ov::Tensor& input_ids, + const ov::genai::GenerationConfig& sampling_params, + std::optional token_type_ids = std::nullopt) override; + + GenerationHandle add_request(uint64_t request_id, + const std::string& prompt, + const ov::genai::GenerationConfig& sampling_params) override; +protected: + void update_eagle_pipeline_params(std::shared_ptr& d2t_tensor); + ov::Tensor create_draft_input_ids(const ov::Tensor& original_input_ids); + std::vector m_hidden_layers_to_abstract; +}; + +using NodePtr = std::shared_ptr; +using namespace ov::op; + +class EagleBaseTransform : public ov::pass::MatcherPass { +public: + using NodePtr = std::shared_ptr; + OPENVINO_MATCHER_PASS_RTTI("EagleBaseTransform"); + EagleBaseTransform(std::vector>& results); + + ~EagleBaseTransform() = default; + +private: + bool apply(NodePtr node, std::vector>& results); + size_t applied = 0; + std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node); + std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node, + std::set& visited_nodes); +}; +class EagleInputTransform : public ov::pass::MatcherPass { // eagle3 specific for draft model +public: + using NodePtr = std::shared_ptr; + OPENVINO_MATCHER_PASS_RTTI("EagleInputTransform"); + EagleInputTransform(std::vector>& params, const std::string& device = ""); + + ~EagleInputTransform() = default; + +private: + bool apply(NodePtr node, std::vector>& params); + size_t applied = 0; + std::string m_device; +}; +class Eagle3Transform : public ov::pass::MatcherPass { +public: + using NodePtr = std::shared_ptr; + OPENVINO_MATCHER_PASS_RTTI("Eagle3Transform"); + Eagle3Transform(const std::vector& layers, std::vector>& hidden_state_outputs); + + ~Eagle3Transform() = default; + +private: + std::vector m_layers; // layers to be abstracted +}; + +class EagleModelTransform : public ov::pass::ModelPass { +public: + EagleModelTransform(const std::vector& layer_ids, const std::string& device = ""); + bool run_on_model(const std::shared_ptr& model) override; + +private: + const std::vector m_layer_ids; + std::string m_device; + std::vector> m_new_results; + std::vector> m_new_parameters; + std::vector> m_hidden_layer_outputs; +}; +} diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 1ca63e9de7..4293e2e03c 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -25,26 +25,19 @@ bool are_tokenizers_equal(Tokenizer& lhs, Tokenizer& rhs) { lhs.get_bos_token_id() == rhs.get_bos_token_id() && lhs.get_pad_token_id() == rhs.get_pad_token_id(); } -ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(const ov::genai::ModelDesc& main_model_desc, - const ov::genai::ModelDesc& draft_model_desc) { - auto main_model = main_model_desc.model; - auto draft_model = draft_model_desc.model; - - OPENVINO_ASSERT(main_model != nullptr, "Main model cannot be null"); - OPENVINO_ASSERT(draft_model != nullptr, "Draft model cannot be null"); +std::pair +ContinuousBatchingPipeline::SpeculativeDecodingImpl::init_speculative_models(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc) { + OPENVINO_ASSERT(main_model_desc.model != nullptr, "Main model cannot be null"); + OPENVINO_ASSERT(draft_model_desc.model != nullptr, "Draft model cannot be null"); + utils::apply_paged_attention_transformations(main_model_desc.model, main_model_desc.scheduler_config.use_cache_eviction); + utils::apply_paged_attention_transformations(draft_model_desc.model, main_model_desc.scheduler_config.use_cache_eviction); - auto main_scheduler_config = main_model_desc.scheduler_config; - auto main_device = main_model_desc.device; + utils::apply_gather_before_matmul_transformation(main_model_desc.model); + utils::apply_gather_before_matmul_transformation(draft_model_desc.model); - utils::apply_paged_attention_transformations(main_model, main_model_desc.scheduler_config.use_cache_eviction); - utils::apply_paged_attention_transformations(draft_model, main_model_desc.scheduler_config.use_cache_eviction); - - utils::apply_gather_before_matmul_transformation(main_model); - utils::apply_gather_before_matmul_transformation(draft_model); - - std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device; bool is_draft_scheduler_undefined = draft_model_desc.scheduler_config == SchedulerConfig(); + auto main_scheduler_config = main_model_desc.scheduler_config; ov::genai::SchedulerConfig main_scheduler_config_updated = main_scheduler_config, draft_scheduler_config = is_draft_scheduler_undefined ? main_scheduler_config : draft_model_desc.scheduler_config; @@ -63,8 +56,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con } return total_hidden_size; }; - float main_model_hidden_size = compute_total_hidden_size(main_model), - draft_model_hidden_size = compute_total_hidden_size(draft_model); + float main_model_hidden_size = compute_total_hidden_size(main_model_desc.model), + draft_model_hidden_size = compute_total_hidden_size(draft_model_desc.model); auto k = draft_model_hidden_size / (main_model_hidden_size + draft_model_hidden_size); // TODO: work with KV blocks as it will be more precise instead of GBs @@ -82,8 +75,15 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con draft_scheduler_config.max_num_batched_tokens = main_scheduler_config_updated.max_num_batched_tokens; } - ov::AnyMap draft_properties = draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties; + return std::make_pair(main_scheduler_config_updated, draft_scheduler_config); +} +ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(const ov::genai::ModelDesc& main_model_desc, + const ov::genai::ModelDesc& draft_model_desc) { + auto scheduler_configs = init_speculative_models(main_model_desc, draft_model_desc); + + auto main_device = main_model_desc.device; + std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device; // main and draft model can have different tokenizers // to do: support retokenization: 154103 Tokenizer main_model_tokenizer = main_model_desc.tokenizer; @@ -91,16 +91,15 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con // todo: remove this condition after support of CVS-154103 OPENVINO_ASSERT(are_tokenizers_equal(main_model_tokenizer, draft_model_tokenizer), "Tokenizers for draft and main models are different!"); - m_tokenizer = main_model_tokenizer; - + ov::AnyMap draft_properties = draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties; // to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode m_main_pipeline = std::make_shared( - main_model, main_model_tokenizer, main_model_desc.generation_config, - main_scheduler_config_updated, main_device, main_model_desc.properties, true); + main_model_desc.model, main_model_tokenizer, main_model_desc.generation_config, + scheduler_configs.first, main_device, main_model_desc.properties, true); m_draft_pipeline = std::make_shared( - draft_model, draft_model_tokenizer, draft_model_desc.generation_config, - draft_scheduler_config, draft_device, draft_properties, false); + draft_model_desc.model, draft_model_tokenizer, draft_model_desc.generation_config, + scheduler_configs.second, draft_device, draft_properties, false); m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; @@ -241,116 +240,37 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< const StreamerVariant& streamer, const std::optional>& token_type_ids, const std::optional>>>& position_ids) { - OPENVINO_ASSERT(!token_type_ids.has_value()); - m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); - m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; - - OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request"); - OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); - - const auto generate_start = std::chrono::steady_clock::now(); - - // checks that all requests has the same LoRA adapters property value - for (size_t i = 1; i < sampling_params.size(); ++i) { - OPENVINO_ASSERT(sampling_params[i - 1].adapters == sampling_params[i].adapters, - "LoRA adapters value must be the same for all requests"); - } - m_main_pipeline->set_adapters(sampling_params[0].adapters); - m_draft_pipeline->set_adapters(sampling_params[0].adapters); - - const auto streamer_ptr = std::make_shared(streamer, m_tokenizer); - - OPENVINO_ASSERT(!streamer_ptr->has_callback() || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()), - "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); - - std::vector main_generations; - for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { - OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); - auto main_sampling_params = sampling_params[request_id]; - if (main_sampling_params.assistant_confidence_threshold == 0.f) { - if (main_sampling_params.num_assistant_tokens == 0) { - main_sampling_params.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; + GenerateStrategy strategy; + strategy.prepare_request = [this](size_t, + const ov::Tensor& in_ids, + GenerationConfig& main_cfg, + GenerationConfig& draft_cfg, + ov::Tensor& main_in, + ov::Tensor& draft_in) { + if (main_cfg.assistant_confidence_threshold == 0.f) { + if (main_cfg.num_assistant_tokens == 0) { + main_cfg.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; } } - main_generations.push_back(m_main_pipeline->add_request(request_id, input_ids[request_id], main_sampling_params)); - - auto draft_sampling_params = main_sampling_params; - // set the parameters do not stop draft generation without stopping of the same request for main pipeline - draft_sampling_params.ignore_eos = true; - draft_sampling_params.stop_strings = {}; - std::lock_guard lock(m_draft_generations_mutex); - m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, input_ids[request_id], draft_sampling_params)}); - } - auto all_requests = get_awaiting_requests(); - - GenerationHandle& generation = main_generations.at(0); - - streamer_ptr->start(); - - while (has_non_finished_requests()) { - try { - step(); - } catch (...) { - drop_requests(); // remove all requests from pipeline state in case of exception - streamer_ptr->end(); - std::rethrow_exception(std::current_exception()); - } - stream_tokens(streamer_ptr, generation); - } - - // waiting for completion of streaming - streamer_ptr->end(); - - OPENVINO_ASSERT(is_requests_empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); - - std::vector results; - results.reserve(all_requests.size()); - - m_perf_metrics.draft_model_metrics.raw_metrics = m_draft_pipeline->raw_perf_metrics; - - const auto generate_end = std::chrono::steady_clock::now(); - const auto generate_duration = PerfMetrics::get_microsec(generate_end - generate_start); - - for (size_t request_id = 0; request_id < all_requests.size(); ++request_id) { - const auto& request = all_requests[request_id]; - auto sampling_params = request->get_sampling_parameters(); - const auto& sequences = request->get_finished_sequences(); - size_t num_outputs = std::min(sampling_params.num_return_sequences, sequences.size()); - - EncodedGenerationResult result; - result.m_request_id = request_id; - result.m_generation_ids.resize(num_outputs); - result.m_scores.resize(num_outputs); - result.m_status = request->get_generation_stream()->get_status(); - - for (size_t i = 0; i < num_outputs; ++i) { - const auto & sequence = sequences[i]; - const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_prob(); - const auto & generated_ids = sequence->get_generated_ids(); - - if (sampling_params.echo) { - result.m_generation_ids[i] = request->get_prompt_ids(); - } - std::copy(generated_ids.begin(), generated_ids.end(), std::back_inserter(result.m_generation_ids[i])); - result.m_scores[i] = score; - } - - result.m_status = main_generations[request_id]->get_status(); - - // The same perf metrics for each sequence, only tokenization/detokenization will differ. - m_perf_metrics.raw_metrics.generate_durations.clear(); - m_perf_metrics.raw_metrics.generate_durations.emplace_back(generate_duration); - m_perf_metrics.num_input_tokens = request->get_prompt_len(); - m_perf_metrics.evaluate_statistics(generate_start); - - result.perf_metrics = m_perf_metrics; - result.extended_perf_metrics = std::make_shared(m_perf_metrics); - results.push_back(std::move(result)); - } - - OPENVINO_ASSERT(results.size() == input_ids.size()); - - return results; + draft_cfg.ignore_eos = true; + draft_cfg.stop_strings = {}; + main_in = in_ids; + draft_in = in_ids; + }; + strategy.check_streaming = [](const std::shared_ptr& streamer_ptr, + const std::vector& input_ids, + const std::vector& sampling_params) { + OPENVINO_ASSERT(!streamer_ptr->has_callback() || + (input_ids.size() == 1 && + (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial())), + "Streaming only supports batch size=1 with greedy/multinomial"); + }; + strategy.start_timer = [](){ return std::chrono::steady_clock::now(); }; + strategy.stop_timer = [](TimePoint start){ + return PerfMetrics::get_microsec(std::chrono::steady_clock::now() - start); + }; + + return generate_common(this, input_ids, sampling_params, streamer, token_type_ids, strategy); } SpeculativeDecodingMetrics @@ -375,4 +295,5 @@ std::vector ContinuousBatchingPipeline::SpeculativeDecodingI OPENVINO_ASSERT(main_awaiting_requests.size() == draft_awaiting_requests.size()); return main_awaiting_requests; } + } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index 87ae8ab60d..d913fb2979 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -11,6 +11,129 @@ #include "utils.hpp" namespace ov::genai { +struct GenerateStrategy { + std::function prepare_request; + std::function&, + const std::vector&, + const std::vector&)> check_streaming; + std::function start_timer; + std::function stop_timer; +}; + +template +std::vector generate_common( + Impl* self, + const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + std::optional> token_type_ids, + GenerateStrategy& strategy) { + + OPENVINO_ASSERT(!token_type_ids.has_value()); + self->perf_metrics() = ov::genai::SDPerModelsPerfMetrics(); + self->draft_pipeline()->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; + + OPENVINO_ASSERT(!self->has_non_finished_requests(), + "Generate cannot be called while ContinuousBatchingPipeline is already running"); + OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); + + auto t_start = strategy.start_timer(); + + for (size_t i = 1; i < sampling_params.size(); ++i) { + OPENVINO_ASSERT(sampling_params[i - 1].adapters == sampling_params[i].adapters, + "LoRA adapters must be same for all requests"); + } + self->main_pipeline()->set_adapters(sampling_params[0].adapters); + self->draft_pipeline()->set_adapters(sampling_params[0].adapters); + + auto streamer_ptr = std::make_shared(streamer, self->tokenizer()); + + strategy.check_streaming(streamer_ptr, input_ids, sampling_params); + + std::vector main_generations; + { + std::lock_guard lock(self->draft_generations_mutex()); + for (size_t rid = 0; rid < input_ids.size(); ++rid) { + GenerationConfig main_cfg = sampling_params[rid]; + GenerationConfig draft_cfg = main_cfg; + ov::Tensor main_in, draft_in; + strategy.prepare_request(rid, input_ids[rid], + main_cfg, draft_cfg, + main_in, draft_in); + main_generations.push_back(self->main_pipeline()->add_request(rid, main_in, main_cfg)); + self->draft_generations().insert({rid, + self->draft_pipeline()->add_request(rid, draft_in, draft_cfg)}); + } + } + + auto all_requests = self->get_awaiting_requests(); + GenerationHandle& generation = main_generations.at(0); + + streamer_ptr->start(); + while (self->has_non_finished_requests()) { + try { + self->step(); + } catch (...) { + self->drop_requests(); + streamer_ptr->end(); + std::rethrow_exception(std::current_exception()); + } + self->stream_tokens(streamer_ptr, generation); + } + streamer_ptr->end(); + + OPENVINO_ASSERT(self->is_requests_empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); + + self->perf_metrics().draft_model_metrics.raw_metrics = self->draft_pipeline()->raw_perf_metrics; + uint64_t generate_duration_us = strategy.stop_timer(t_start); + + std::vector results; + results.reserve(all_requests.size()); + + for (size_t rid = 0; rid < all_requests.size(); ++rid) { + const auto& request = all_requests[rid]; + auto cfg = request->get_sampling_parameters(); + const auto& seqs = request->get_finished_sequences(); + size_t num_out = std::min(cfg.num_return_sequences, seqs.size()); + + EncodedGenerationResult result; + result.m_request_id = rid; + result.m_generation_ids.resize(num_out); + result.m_scores.resize(num_out); + result.m_status = main_generations[rid]->get_status(); + + for (size_t i = 0; i < num_out; ++i) { + const auto& seq = seqs[i]; + float score = cfg.is_beam_search() ? + seq->get_beam_search_score(cfg) : + seq->get_cumulative_log_prob(); + const auto& gen_ids = seq->get_generated_ids(); + if (cfg.echo) { + result.m_generation_ids[i] = request->get_prompt_ids(); + } + std::copy(gen_ids.begin(), gen_ids.end(), + std::back_inserter(result.m_generation_ids[i])); + result.m_scores[i] = score; + } + + self->perf_metrics().raw_metrics.generate_durations.clear(); + self->perf_metrics().raw_metrics.generate_durations.emplace_back(generate_duration_us); + self->perf_metrics().num_input_tokens = request->get_prompt_len(); + self->perf_metrics().evaluate_statistics(t_start); + + result.perf_metrics = self->perf_metrics(); + result.extended_perf_metrics = std::make_shared(self->perf_metrics()); + results.push_back(std::move(result)); + } + + OPENVINO_ASSERT(results.size() == input_ids.size()); + return results; +} class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBatchingPipeline::IContinuousBatchingPipeline { protected: @@ -26,8 +149,18 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat void drop_requests(); bool is_requests_empty(); std::vector get_awaiting_requests(); - + std::pair init_speculative_models(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc); public: + template + friend std::vector generate_common( + Impl* self, + const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + std::optional> token_type_ids, + GenerateStrategy& strategy); + + SpeculativeDecodingImpl() = default; SpeculativeDecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc); GenerationHandle add_request(uint64_t request_id, @@ -50,6 +183,16 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat const std::optional>>>& position_ids = std::nullopt) override; SpeculativeDecodingMetrics get_speculative_decoding_metrics(); + SDPerModelsPerfMetrics& perf_metrics() { return m_perf_metrics; } + SDPerModelsPerfMetrics const& perf_metrics() const { return m_perf_metrics; } + std::shared_ptr& draft_pipeline() { return m_draft_pipeline; } + std::shared_ptr& main_pipeline() { return m_main_pipeline; } + + Tokenizer& tokenizer() { return m_tokenizer; } + const Tokenizer& tokenizer() const { return m_tokenizer; } + + std::mutex& draft_generations_mutex() { return m_draft_generations_mutex; } + std::map& draft_generations() { return m_draft_generations; } }; -} +} // namespace ov::genai \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp index ac1acb4652..cc1101dd3f 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "speculative_decoding_stateful.hpp" +#include "speculative_decoding_utils.hpp" #include "continuous_batching/timer.hpp" #include "openvino/runtime/core.hpp" #include "openvino/core/parallel.hpp" @@ -42,18 +43,6 @@ void update_perf_stat_by_infer_duration(ov::genai::RawPerfMetrics& raw_perf_coun raw_perf_counters.m_batch_sizes.emplace_back(num_generated_tokens); } -void ensure_num_assistant_tokens_is_set(ov::genai::GenerationConfig& generation_config) { - auto assistant_confidence_threshold = generation_config.assistant_confidence_threshold; - OPENVINO_ASSERT(assistant_confidence_threshold == 0.f, - "Stateful (non Continuous Batching) Speculative Decoding pipeline only supports `num_assistant_tokens` " - "as parameter in GenerationConfig and doesn't work with `assistant_confidence_threshold`.\nPlease " - "remove its specification or set it to 0.f."); - - constexpr std::size_t default_num_assistant_tokens = 5; - if (generation_config.num_assistant_tokens == 0) { - generation_config.num_assistant_tokens = default_num_assistant_tokens; - } -} }// anonymous namespace namespace ov { @@ -392,7 +381,7 @@ StatefulSpeculativeLLMPipeline::StatefulSpeculativeLLMPipeline( OPENVINO_ASSERT(m_draft_request != nullptr, "Failed to create draft model inference wrapper"); // Specifying number candidates to generate - ensure_num_assistant_tokens_is_set(m_generation_config); + ov::genai::speculative_decoding::ensure_num_assistant_tokens_is_set(m_generation_config); m_candidates_num = m_generation_config.num_assistant_tokens; // We set the upper limit for candidates number as two times the number requested // by user. @@ -412,7 +401,7 @@ StatefulSpeculativeLLMPipeline::StatefulSpeculativeLLMPipeline( GenerationConfig StatefulSpeculativeLLMPipeline::resolve_generation_config(OptionalGenerationConfig generation_config) { GenerationConfig config = generation_config.value_or(m_generation_config); - ensure_num_assistant_tokens_is_set(config); + ov::genai::speculative_decoding::ensure_num_assistant_tokens_is_set(config); m_candidates_num = config.num_assistant_tokens; // We set the upper limit for candidates number as two times the number // requested by user. diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful_eagle3.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful_eagle3.cpp new file mode 100644 index 0000000000..f17ec52aa0 --- /dev/null +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful_eagle3.cpp @@ -0,0 +1,1411 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "speculative_decoding_stateful_eagle3.hpp" +#include "speculative_decoding_utils.hpp" + +#include +#include +#include +#include +#include +#include + +#include "continuous_batching/timer.hpp" +#include "openvino/core/parallel.hpp" +#include "openvino/genai/text_streamer.hpp" +#include "openvino/runtime/core.hpp" +#include "openvino/runtime/properties.hpp" +#include "speculative_decoding_eagle3_impl.hpp" +#include "utils.hpp" + +namespace ov::genai { +template +struct overloaded : Ts... { + using Ts::operator()...; +}; +template +overloaded(Ts...) -> overloaded; +} // namespace ov::genai + +namespace { + +// Stream generated tokens to output +ov::genai::StreamingStatus stream_generated_tokens(std::shared_ptr streamer_ptr, + const std::vector& tokens) { + if (streamer_ptr) { + return streamer_ptr->write(tokens); + } + return ov::genai::StreamingStatus{}; +} + +// Extract last token's hidden state using zero-copy ROI +// Input: [batch=1, seq_len, hidden_size] -> Output: [1, 1, hidden_size] +ov::Tensor extract_last_hidden_state(const ov::Tensor& hidden_features) { + OPENVINO_ASSERT(hidden_features && hidden_features.get_size() > 0, "Hidden features tensor is empty"); + + auto shape = hidden_features.get_shape(); + OPENVINO_ASSERT(shape.size() == 3 && shape[0] == 1 && shape[1] > 0, + "Expected shape [1, seq_len, hidden_size], got [", + shape.size() == 3 + ? std::to_string(shape[0]) + ", " + std::to_string(shape[1]) + ", " + std::to_string(shape[2]) + : "invalid", + "]"); + + std::size_t seq_len = shape[1]; + std::size_t hidden_size = shape[2]; + + return ov::Tensor(hidden_features, {0, seq_len - 1, 0}, {1, seq_len, hidden_size}); +} + +} // anonymous namespace + +namespace ov { +namespace genai { + +//================================================================================================== +// Eagle3InferWrapperBase Implementation +//================================================================================================== + +Eagle3InferWrapperBase::Eagle3InferWrapperBase(const ov::genai::ModelDesc& model_desc) + : m_device(model_desc.device), + m_properties(model_desc.properties), + m_generation_config(model_desc.generation_config), + m_tokenizer(model_desc.tokenizer) { + log_debug("Initializing for device: " + m_device); + + m_kv_axes_pos = ov::genai::utils::get_kv_axes_pos(model_desc.model); + + if (m_device == "NPU") { + auto [compiled, kv_desc] = + ov::genai::utils::compile_decoder_for_npu(model_desc.model, m_properties, m_kv_axes_pos); + m_max_prompt_len = kv_desc.max_prompt_len; + m_kv_cache_capacity = kv_desc.max_prompt_len + kv_desc.min_response_len; + m_request = compiled.create_infer_request(); + + log_debug("NPU compiled: max_prompt=" + std::to_string(m_max_prompt_len) + + ", kv_capacity=" + std::to_string(m_kv_cache_capacity)); + } else { + m_request = ov::genai::utils::singleton_core() + .compile_model(model_desc.model, m_device, m_properties) + .create_infer_request(); + log_debug(m_device + " compiled successfully"); + } + + // Initialize metrics + m_raw_perf_metrics.m_inference_durations = {ov::genai::MicroSeconds(0.0f)}; + m_raw_perf_metrics.tokenization_durations = {ov::genai::MicroSeconds(0.0f)}; + m_raw_perf_metrics.detokenization_durations = {ov::genai::MicroSeconds(0.0f)}; + + log_debug("Initialization completed"); +} + +void Eagle3InferWrapperBase::append_tokens(const std::vector& tokens) { + if (tokens.empty()) + return; + + std::size_t old_size = m_tokens.size(); + m_tokens.insert(m_tokens.end(), tokens.begin(), tokens.end()); + + for (std::size_t i = 0; i < tokens.size(); ++i) { + m_positions.push_back(static_cast(old_size + i)); + } + + log_debug("Appended " + std::to_string(tokens.size()) + " tokens, total: " + std::to_string(m_tokens.size())); +} + +void Eagle3InferWrapperBase::truncate_sequence(std::size_t size) { + if (size < m_tokens.size()) { + m_tokens.resize(size); + m_positions.resize(size); + log_debug("Truncated to: " + std::to_string(size)); + } +} + +void Eagle3InferWrapperBase::trim_kv_cache(std::size_t tokens_to_remove) { + if (tokens_to_remove == 0 || m_processed_tokens == 0) { + return; + } + + OPENVINO_ASSERT(tokens_to_remove < m_processed_tokens, "Cannot trim more tokens than processed"); + + log_debug("Trimming KV cache: " + std::to_string(tokens_to_remove) + " tokens"); + + // NPU handles KV trimming via position IDs + if (m_device != "NPU") { + ov::genai::utils::KVCacheState state; + state.num_tokens_to_trim = tokens_to_remove; + state.seq_length_axis = m_kv_axes_pos.seq_len; + state.reset_mem_state = false; + ov::genai::utils::trim_kv_cache(m_request, state, {}); + } + + m_processed_tokens -= tokens_to_remove; + log_debug("KV trimmed, processed: " + std::to_string(m_processed_tokens)); +} + +void Eagle3InferWrapperBase::reset_state() { + m_tokens.clear(); + m_positions.clear(); + m_processed_tokens = 0; + m_last_sampled_token = -1; + + m_raw_perf_metrics.m_inference_durations = {ov::genai::MicroSeconds(0.0f)}; + m_raw_perf_metrics.m_durations.clear(); + m_raw_perf_metrics.m_batch_sizes.clear(); + + log_debug("State reset"); +} + +void Eagle3InferWrapperBase::release_memory() { + m_request.get_compiled_model().release_memory(); + log_debug("Memory released"); +} + +void Eagle3InferWrapperBase::build_model_inputs(std::size_t token_count, + ov::Tensor& input_ids, + ov::Tensor& attention_mask, + ov::Tensor& position_ids) { + OPENVINO_ASSERT(!m_tokens.empty() && token_count > 0, "Cannot build inputs: empty sequence or zero token count"); + OPENVINO_ASSERT(!m_positions.empty(), "Position IDs not initialized"); + + const std::size_t seq_len = m_tokens.size(); + OPENVINO_ASSERT(token_count <= seq_len, "Requested ", token_count, " tokens but only ", seq_len, " available"); + + const std::size_t start_pos = seq_len - token_count; + + input_ids = ov::Tensor(ov::element::i64, {1, token_count}); + position_ids = ov::Tensor(ov::element::i64, {1, token_count}); + + int64_t* input_ids_ptr = input_ids.data(); + int64_t* position_ids_ptr = position_ids.data(); + + std::memcpy(input_ids_ptr, m_tokens.data() + start_pos, token_count * sizeof(int64_t)); + std::memcpy(position_ids_ptr, m_positions.data() + start_pos, token_count * sizeof(int64_t)); + + // Attention mask length = last_position_id + 1 (total KV cache size) + const std::size_t attention_mask_len = static_cast(position_ids_ptr[token_count - 1] + 1); + + attention_mask = ov::Tensor(ov::element::i64, {1, attention_mask_len}); + std::fill_n(attention_mask.data(), attention_mask_len, 1); +} + +ov::Tensor Eagle3InferWrapperBase::create_hidden_state_placeholder(const ov::Shape& shape) const { + ov::Tensor tensor(ov::element::f32, shape); + std::fill_n(tensor.data(), tensor.get_size(), 0.0f); + return tensor; +} + +// TODO: Use already provided Sampler API, that will support both greedy and +// multinomial decoding. +std::variant> Eagle3InferWrapperBase::sample_tokens(const ov::Tensor& logits, + std::size_t count) { + auto shape = logits.get_shape(); + OPENVINO_ASSERT(shape.size() == 3 && shape[0] == 1, "Invalid logits shape for sampling"); + + std::size_t seq_len = shape[1]; + std::size_t vocab_size = shape[2]; + OPENVINO_ASSERT(count <= seq_len, "Requested count exceeds sequence length"); + + log_debug("Sampling " + std::to_string(count) + " tokens from logits [" + std::to_string(shape[0]) + ", " + + std::to_string(seq_len) + ", " + std::to_string(vocab_size) + "]"); + + auto sample_single = [&](std::size_t pos) -> int64_t { + const float* data = logits.data() + pos * vocab_size; + auto max_it = std::max_element(data, data + vocab_size); + int64_t token = static_cast(max_it - data); + + if (m_verbose) { + log_debug("Pos " + std::to_string(pos) + ": token " + std::to_string(token) + + " (logit: " + std::to_string(*max_it) + ")"); + } + return token; + }; + + if (count == 1) { + int64_t token = sample_single(seq_len - 1); + m_last_sampled_token = token; + log_debug("Sampled: " + std::to_string(token)); + return token; + } + + std::vector tokens; + tokens.reserve(count); + for (std::size_t i = 0; i < count; ++i) { + tokens.push_back(sample_single(seq_len - count + i)); + } + if (!tokens.empty()) { + m_last_sampled_token = tokens.back(); + } + + if (m_verbose) { + std::cout << "[EAGLE3-WRAPPER] Sampled " << count << " tokens: "; + for (std::size_t i = 0; i < tokens.size(); ++i) { + std::cout << tokens[i]; + if (i + 1 < tokens.size()) + std::cout << ", "; + } + std::cout << std::endl; + } + + return tokens; +} + +ov::Tensor Eagle3InferWrapperBase::get_logits() const { + return m_request.get_tensor("logits"); +} + +ov::Tensor Eagle3InferWrapperBase::get_hidden_features() const { + auto hidden_state = m_request.get_tensor("last_hidden_state"); + auto shape = hidden_state.get_shape(); + + OPENVINO_ASSERT(shape.size() == 3 && shape[0] == 1, + "Expected [1, seq_len, hidden_size], got [", + shape.size() == 3 + ? std::to_string(shape[0]) + ", " + std::to_string(shape[1]) + ", " + std::to_string(shape[2]) + : "invalid", + "]"); + + std::size_t output_seq_len = shape[1]; + std::size_t hidden_size = shape[2]; + + auto input_ids = m_request.get_tensor("input_ids"); + std::size_t actual_seq_len = input_ids.get_shape()[1]; + + if (output_seq_len == actual_seq_len) { + return hidden_state; + } + + OPENVINO_ASSERT(actual_seq_len <= output_seq_len, + "Actual length (", + actual_seq_len, + ") exceeds output (", + output_seq_len, + ")"); + + log_debug("Trimming hidden: " + std::to_string(output_seq_len) + " -> " + std::to_string(actual_seq_len)); + + // if NPU device is used, the output may be padded, trim it via ROI + std::size_t start_offset = output_seq_len - actual_seq_len; + ov::Tensor trimmed(hidden_state, {0, start_offset, 0}, {1, output_seq_len, hidden_size}); + + log_debug("Trimmed via ROI: [1, " + std::to_string(actual_seq_len) + ", " + std::to_string(hidden_size) + "]"); + + return trimmed; +} + +uint64_t Eagle3InferWrapperBase::execute_inference() { + auto start = std::chrono::steady_clock::now(); + m_request.infer(); + auto end = std::chrono::steady_clock::now(); + + // Update processed tokens to current sequence length + m_processed_tokens = m_tokens.size(); + + return std::chrono::duration_cast(end - start).count(); +} + +void Eagle3InferWrapperBase::update_performance_metrics(uint64_t inference_time_us, std::size_t tokens_count) { + m_raw_perf_metrics.m_durations.emplace_back(static_cast(inference_time_us)); + m_raw_perf_metrics.m_inference_durations[0] += ov::genai::MicroSeconds(static_cast(inference_time_us)); + m_raw_perf_metrics.m_batch_sizes.emplace_back(tokens_count); +} + +void Eagle3InferWrapperBase::log_debug(const std::string& message) const { + if (m_verbose) { + std::cout << "[EAGLE3-WRAPPER] " << message << std::endl; + } +} + +void Eagle3InferWrapperBase::log_tensor_info(const std::string& name, const ov::Tensor& tensor) const { + if (!m_verbose) + return; + + auto shape = tensor.get_shape(); + std::cout << "[EAGLE3-WRAPPER] " << name << " shape: ["; + for (std::size_t i = 0; i < shape.size(); ++i) { + std::cout << shape[i]; + if (i + 1 < shape.size()) + std::cout << ", "; + } + std::cout << "], type: " << tensor.get_element_type() << std::endl; +} + +void Eagle3InferWrapperBase::log_tensor_content(const std::string& name, + const ov::Tensor& tensor, + std::size_t max_elements) const { + if (!m_verbose || !tensor) + return; + + auto shape = tensor.get_shape(); + std::size_t total_elements = tensor.get_size(); + + // For input_ids, position_ids, attention_mask, always show all elements, ignore max_elements + bool show_all = (name == "input_ids" || name == "position_ids" || name == "attention_mask"); + std::size_t elements_to_show = show_all ? total_elements : std::min(max_elements, total_elements); + + std::cout << "[EAGLE3-WRAPPER] " << name << " content (" << elements_to_show << " elements): "; + + if (tensor.get_element_type() == ov::element::i64) { + const int64_t* data = tensor.data(); + for (std::size_t i = 0; i < elements_to_show; ++i) { + std::cout << data[i]; + if (i + 1 < elements_to_show) + std::cout << ", "; + } + } else if (tensor.get_element_type() == ov::element::f32) { + const float* data = tensor.data(); + for (std::size_t i = 0; i < elements_to_show; ++i) { + std::cout << std::fixed << std::setprecision(4) << data[i]; + if (i + 1 < elements_to_show) + std::cout << ", "; + } + } else if (tensor.get_element_type() == ov::element::i32) { + const int32_t* data = tensor.data(); + for (std::size_t i = 0; i < elements_to_show; ++i) { + std::cout << data[i]; + if (i + 1 < elements_to_show) + std::cout << ", "; + } + } + + if (!show_all && elements_to_show < total_elements) { + std::cout << " ... (+" << (total_elements - elements_to_show) << " more)"; + } + std::cout << std::endl; +} + +void Eagle3InferWrapperBase::log_model_inputs(const ov::Tensor& input_ids, + const ov::Tensor& attention_mask, + const ov::Tensor& position_ids) const { + if (!m_verbose) + return; + + std::cout << "[EAGLE3-WRAPPER] ========== MODEL INPUTS ==========" << std::endl; + log_tensor_info("input_ids", input_ids); + log_tensor_content("input_ids", input_ids, 0); // 0 means show all for input_ids + + log_tensor_info("attention_mask", attention_mask); + log_tensor_content("attention_mask", attention_mask, 0); + + log_tensor_info("position_ids", position_ids); + log_tensor_content("position_ids", position_ids, 0); // 0 means show all for position_ids + std::cout << "[EAGLE3-WRAPPER] =================================" << std::endl; +} + +void Eagle3InferWrapperBase::log_model_outputs(const ov::Tensor& logits, const ov::Tensor& hidden_features) const { + if (!m_verbose) + return; + + std::cout << "[EAGLE3-WRAPPER] ========== MODEL OUTPUTS =========" << std::endl; + log_tensor_info("logits", logits); + if (logits && logits.get_size() > 0) { + // For logits, show top 10 values for each position + auto logits_shape = logits.get_shape(); + if (logits_shape.size() == 3) { + std::size_t seq_len = logits_shape[1]; + std::size_t vocab_size = logits_shape[2]; + + // If seq_len > 5, only show last 5 positions + std::size_t start_pos = (seq_len > 5) ? (seq_len - 5) : 0; + std::size_t positions_to_show = seq_len - start_pos; + + if (start_pos > 0) { + std::cout << "[EAGLE3-WRAPPER] Showing only last " << positions_to_show + << " positions (total seq_len: " << seq_len << ")" << std::endl; + } + + // Show top 10 logit values for each position (only last 5 if seq_len > 5) + for (std::size_t pos = start_pos; pos < seq_len; ++pos) { + const float* logits_data = logits.data() + pos * vocab_size; + std::vector> top_logits; + + for (std::size_t i = 0; i < vocab_size; ++i) { + top_logits.emplace_back(logits_data[i], static_cast(i)); + } + + std::sort(top_logits.begin(), top_logits.end(), std::greater>()); + + std::cout << "[EAGLE3-WRAPPER] Position " << pos << " - Top 10 logits: "; + for (std::size_t i = 0; i < std::min(10, top_logits.size()); ++i) { + std::cout << "token_" << top_logits[i].second << ":" << std::fixed << std::setprecision(3) + << top_logits[i].first; + if (i + 1 < std::min(10, top_logits.size())) + std::cout << ", "; + } + std::cout << std::endl; + + // Show first 20 raw logit values for each position + std::cout << "[EAGLE3-WRAPPER] Position " << pos << " - First 20 raw logits: "; + for (std::size_t i = 0; i < std::min(20, vocab_size); ++i) { + std::cout << std::fixed << std::setprecision(4) << logits_data[i]; + if (i + 1 < std::min(20, vocab_size)) + std::cout << ", "; + } + std::cout << std::endl; + } + } + } + + if (hidden_features && hidden_features.get_size() > 0) { + log_tensor_info("hidden_features", hidden_features); + // Show first few elements of the last hidden state + auto hidden_shape = hidden_features.get_shape(); + if (hidden_shape.size() == 3 && hidden_shape[1] > 0) { + std::size_t seq_len = hidden_shape[1]; + std::size_t hidden_dim = hidden_shape[2]; + const float* hidden_data = hidden_features.data() + (seq_len - 1) * hidden_dim; + + std::cout << "[EAGLE3-WRAPPER] Last hidden state (first 10 dims): "; + for (std::size_t i = 0; i < std::min(10, hidden_dim); ++i) { + std::cout << std::fixed << std::setprecision(4) << hidden_data[i]; + if (i + 1 < std::min(10, hidden_dim)) + std::cout << ", "; + } + if (hidden_dim > 10) + std::cout << " ... (+" << (hidden_dim - 10) << " more)"; + std::cout << std::endl; + } + } + std::cout << "[EAGLE3-WRAPPER] =================================" << std::endl; +} + +//================================================================================================== +// Eagle3TargetModelWrapper Implementation +//================================================================================================== + +Eagle3TargetModelWrapper::Eagle3TargetModelWrapper(const ov::genai::ModelDesc& model_desc) + : Eagle3InferWrapperBase(model_desc) { + log_debug("Target model initialized"); +} + +void Eagle3TargetModelWrapper::initialize_sequence(const ov::Tensor& input_ids, const ov::Tensor& position_ids) { + const int64_t* ids_data = input_ids.data(); + std::size_t seq_len = input_ids.get_size(); + m_tokens.assign(ids_data, ids_data + seq_len); + + if (position_ids) { + const int64_t* pos_data = position_ids.data(); + m_positions.assign(pos_data, pos_data + position_ids.get_size()); + } + + log_debug("Sequence initialized: " + std::to_string(m_tokens.size()) + " tokens"); +} + +InferenceOutput Eagle3TargetModelWrapper::infer(const ov::Tensor& input_ids, + const ov::Tensor& attention_mask, + const ov::Tensor& position_ids) { + log_debug("Target inference start"); + log_model_inputs(input_ids, attention_mask, position_ids); + + if (m_device == "NPU") { + auto prompt_len = input_ids.get_shape()[1]; + OPENVINO_ASSERT(prompt_len <= m_max_prompt_len, + "NPU prompt length ", + prompt_len, + " exceeds max ", + m_max_prompt_len); + } + + m_request.set_tensor("input_ids", input_ids); + m_request.set_tensor("attention_mask", attention_mask); + m_request.set_tensor("position_ids", position_ids); + + if (m_device != "NPU") { + m_request.get_tensor("beam_idx").set_shape({BATCH_SIZE}); + m_request.get_tensor("beam_idx").data()[0] = 0; + } + + uint64_t time_us = execute_inference(); + update_performance_metrics(time_us, input_ids.get_shape()[1]); + + InferenceOutput output; + output.logits = get_logits(); + output.hidden_features = get_hidden_features(); + log_model_outputs(output.logits, output.hidden_features); + + log_debug("Target inference done: " + std::to_string(time_us / 1000.0) + "ms"); + return output; +} + +//================================================================================================== +// Eagle3DraftModelWrapper Implementation +//================================================================================================== + +Eagle3DraftModelWrapper::Eagle3DraftModelWrapper(const ov::genai::ModelDesc& model_desc) + : Eagle3InferWrapperBase(model_desc) { + log_debug("Draft model initialized"); +} + +void Eagle3DraftModelWrapper::initialize_sequence(const ov::Tensor& input_ids, const ov::Tensor& position_ids) { + // Eagle3: draft uses tokens[1:] with positions [0, 1, ..., n-2] + const int64_t* ids_data = input_ids.data(); + std::size_t total_len = input_ids.get_size(); + + OPENVINO_ASSERT(total_len >= 2, "Draft model requires at least 2 tokens, got ", total_len); + + std::size_t actual_len = total_len - 1; + m_tokens.assign(ids_data + 1, ids_data + total_len); + + if (position_ids) { + m_positions.resize(actual_len); + std::iota(m_positions.begin(), m_positions.end(), 0); + } + + log_debug("Sequence initialized: " + std::to_string(m_tokens.size()) + " tokens (skipped first, positions 0 to " + + std::to_string(actual_len - 1) + ")"); +} + +InferenceOutput Eagle3DraftModelWrapper::infer(const ov::Tensor& input_ids, + const ov::Tensor& attention_mask, + const ov::Tensor& position_ids, + const ov::Tensor& target_hidden_features, + const ov::Tensor& internal_hidden_features) { + log_debug("Draft inference start"); + log_model_inputs(input_ids, attention_mask, position_ids); + + m_request.set_tensor("input_ids", input_ids); + m_request.set_tensor("attention_mask", attention_mask); + m_request.set_tensor("position_ids", position_ids); + + // Eagle3 requires exactly one hidden state input + bool has_target = target_hidden_features && target_hidden_features.get_size() > 0; + bool has_internal = internal_hidden_features && internal_hidden_features.get_size() > 0; + + OPENVINO_ASSERT(has_target ^ has_internal, "Draft model requires exactly one of target/internal hidden features"); + + ov::Tensor target_tensor, internal_tensor; + + if (has_target) { + auto t_shape = target_hidden_features.get_shape(); + OPENVINO_ASSERT(t_shape.size() == 3 && t_shape.back() % 3 == 0, "Invalid target hidden features shape"); + + target_tensor = target_hidden_features; + auto internal_shape = t_shape; + internal_shape.back() = t_shape.back() / 3; + internal_tensor = create_hidden_state_placeholder(internal_shape); + + log_tensor_info("target_tensor", target_tensor); + log_tensor_info("internal_placeholder", internal_tensor); + } else { + auto i_shape = internal_hidden_features.get_shape(); + OPENVINO_ASSERT(i_shape.size() == 3, "Invalid internal hidden features shape"); + + internal_tensor = internal_hidden_features; + auto target_shape = i_shape; + target_shape.back() = i_shape.back() * 3; + target_tensor = create_hidden_state_placeholder(target_shape); + + log_tensor_info("internal_tensor", internal_tensor); + log_tensor_info("target_placeholder", target_tensor); + } + + m_request.set_tensor("hidden_states", target_tensor); + m_request.set_tensor("internal_hidden_states", internal_tensor); + + if (m_verbose) { + std::cout << "[EAGLE3-WRAPPER] Hidden State Tensors:" << std::endl; + log_tensor_info("hidden_states", target_tensor); + log_tensor_content("hidden_states", target_tensor, 10); + log_tensor_info("internal_hidden_states", internal_tensor); + log_tensor_content("internal_hidden_states", internal_tensor, 10); + } + + if (m_device != "NPU") { + m_request.get_tensor("beam_idx").set_shape({BATCH_SIZE}); + m_request.get_tensor("beam_idx").data()[0] = 0; + } + + uint64_t time_us = execute_inference(); + update_performance_metrics(time_us, input_ids.get_shape()[1]); + + InferenceOutput output; + output.logits = get_logits(); + output.hidden_features = get_hidden_features(); + log_model_outputs(output.logits, output.hidden_features); + + log_debug("Draft inference done: " + std::to_string(time_us / 1000.0) + "ms"); + return output; +} + +//================================================================================================== +// StatefulEagle3LLMPipeline Implementation +//================================================================================================== + +StatefulEagle3LLMPipeline::StatefulEagle3LLMPipeline(const ov::genai::ModelDesc& main_model_desc, + const ov::genai::ModelDesc& draft_model_desc, + const std::vector& hidden_layers_to_abstract) + : LLMPipelineImplBase(main_model_desc.tokenizer, main_model_desc.generation_config), + m_hidden_layers_to_abstract(hidden_layers_to_abstract) { + ov::genai::speculative_decoding::ensure_num_assistant_tokens_is_set(m_generation_config); + m_draft_iterations = m_generation_config.num_assistant_tokens; + + log_info("Initializing Eagle3: main=" + main_model_desc.device + ", draft=" + draft_model_desc.device + + ", iterations=" + std::to_string(m_draft_iterations)); + + auto main_model = main_model_desc.model; + auto draft_model = draft_model_desc.model; + m_tokenizer = main_model_desc.tokenizer; + + OPENVINO_ASSERT(!m_hidden_layers_to_abstract.empty(), + "Eagle3 requires hidden_layers_list configuration. " + "Provide it in properties or config.json. " + "Example: config[\"hidden_layers_list\"] = std::vector{2, 16, 29}"); + + // Model transformations + ov::genai::share_embedding_weights(main_model, draft_model); + log_debug("Shared embedding weights"); + + set_draft_target_mapping(draft_model); + + // Currently, the d2t node is stored in the draft model + // If it is not removed, it will affect the splitting and compilation of NPUW + // TODO: Root cause and better to remove this logic in model conversion step + ov::genai::remove_d2t_result_node(draft_model); + log_debug("Removed d2t node"); + + ov::genai::extract_hidden_state_generic(main_model, m_hidden_layers_to_abstract, main_model_desc.device); + log_debug("Extracted main model hidden states"); + + ov::genai::extract_hidden_state_generic(draft_model, {-1}, draft_model_desc.device); + log_debug("Model transformations completed"); + + std::size_t validation_window = m_draft_iterations + 1; + + auto draft_desc = draft_model_desc; + if (draft_desc.device == "NPU") { + draft_desc.properties["NPUW_LLM_MAX_GENERATION_TOKEN_LEN"] = validation_window; + draft_desc.properties["NPUW_DEVICES"] = "CPU"; + // TODO: Partition issue for draft model, low priority since it has only one repeat block + draft_desc.properties["NPUW_ONLINE_PIPELINE"] = "NONE"; + } + m_draft_model = std::make_unique(draft_desc); + + auto main_desc = main_model_desc; + if (main_desc.device == "NPU") { + main_desc.properties["NPUW_LLM_MAX_GENERATION_TOKEN_LEN"] = validation_window; + main_desc.properties["NPUW_DEVICES"] = "CPU"; + // Set rt_info to identify Eagle3 mode in NPUW + main_model->set_rt_info("true", "eagle3_mode"); + } + m_main_model = std::make_unique(main_desc); + + m_sd_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); + m_sd_perf_metrics.raw_metrics.m_inference_durations = {{ov::genai::MicroSeconds(0.0f)}}; + + log_info("Eagle3 initialization completed"); +} + +StatefulEagle3LLMPipeline::~StatefulEagle3LLMPipeline() { + m_main_model->release_memory(); + m_draft_model->release_memory(); + log_debug("Pipeline destroyed"); +} + +void StatefulEagle3LLMPipeline::set_draft_target_mapping(const std::shared_ptr& draft_model) { + OPENVINO_ASSERT(draft_model, "Draft model is null"); + + auto d2t_tensor = ov::genai::extract_d2t_mapping_table(draft_model); + OPENVINO_ASSERT(d2t_tensor, "Draft-to-target mapping not found. Eagle3 requires d2t mapping."); + + OPENVINO_ASSERT(d2t_tensor->get_element_type() == ov::element::i64, "Draft-to-target mapping must be int64"); + + ov::Tensor d2t_mapping(ov::element::i64, d2t_tensor->get_shape()); + std::memcpy(d2t_mapping.data(), d2t_tensor->get_data_ptr(), d2t_tensor->get_byte_size()); + + m_draft_target_mapping = std::move(d2t_mapping); + log_info("D2T mapping: " + std::to_string(m_draft_target_mapping.get_size()) + " entries"); +} + +void StatefulEagle3LLMPipeline::set_verbose(bool verbose) { + if (m_main_model) + m_main_model->set_verbose(verbose); + if (m_draft_model) + m_draft_model->set_verbose(verbose); + log_debug("Verbose: " + std::string(verbose ? "on" : "off")); +} + +GenerationConfig StatefulEagle3LLMPipeline::resolve_generation_config(OptionalGenerationConfig generation_config) { + GenerationConfig config = generation_config.value_or(m_generation_config); + + std::size_t prev_draft_iterations = m_draft_iterations; + ov::genai::speculative_decoding::ensure_num_assistant_tokens_is_set(config); + m_draft_iterations = config.num_assistant_tokens; + + // Log if draft_iterations changed from default + if (m_draft_iterations != prev_draft_iterations) { + if (m_draft_iterations == 0) { + log_info("Speculative decoding DISABLED (num_assistant_tokens=0), using target model only"); + } else if (is_verbose()) { + log_debug("Draft iterations updated: " + std::to_string(prev_draft_iterations) + " -> " + + std::to_string(m_draft_iterations)); + } + } + + if (config.stop_token_ids.empty()) + config.stop_token_ids = m_generation_config.stop_token_ids; + if (config.eos_token_id == -1) + config.set_eos_token_id(m_generation_config.eos_token_id); + config.validate(); + return config; +} + +DecodedResults StatefulEagle3LLMPipeline::generate(StringInputs inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer) { + ManualTimer generate_timer("StatefulEagle3LLMPipeline::generate()"); + generate_timer.start(); + ManualTimer encode_timer("Encode"); + encode_timer.start(); + + std::string prompt = + std::visit(overloaded{[](const std::string& prompt_str) { + return prompt_str; + }, + [](std::vector& prompts) { + OPENVINO_ASSERT(prompts.size() == 1u, "Currently only batch size=1 is supported"); + return prompts.front(); + }}, + inputs); + + GenerationConfig config = resolve_generation_config(generation_config); + + ov::genai::TokenizedInputs tokenized_input; + if (m_is_chat_active) { + m_chat_history.push_back({{"role", "user"}, {"content", prompt}}); + constexpr bool add_generation_prompt = true; + prompt = m_tokenizer.apply_chat_template(m_chat_history, add_generation_prompt); + // for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF + tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false)); + } else { + if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) { + ChatHistory history({{{"role", "user"}, {"content", prompt}}}); + constexpr bool add_generation_prompt = true; + auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); + tokenized_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)); + } else { + // in case when chat_template was not found in tokenizer_config.json or set + tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(true)); + } + } + + encode_timer.end(); + auto encoded_results = generate(tokenized_input, config, streamer); + + ManualTimer decode_timer("Decode"); + decode_timer.start(); + DecodedResults decoded_results = {m_tokenizer.decode(encoded_results.tokens), encoded_results.scores}; + decode_timer.end(); + + if (m_is_chat_active) { + auto answer = decoded_results.texts[0]; + if (m_streaming_was_cancelled) + // If generation process was cancelled by user, let's rollback to previous state of history + m_chat_history.pop_back(); + else + m_chat_history.push_back({{"role", "assistant"}, {"content", answer}}); + } + + // Update perf metrics + decoded_results.perf_metrics = encoded_results.perf_metrics; + decoded_results.extended_perf_metrics = encoded_results.extended_perf_metrics; + generate_timer.end(); + auto& raw_counters = decoded_results.perf_metrics.raw_metrics; + raw_counters.generate_durations.clear(); + raw_counters.generate_durations.emplace_back(generate_timer.get_duration_microsec()); + raw_counters.tokenization_durations.emplace_back(encode_timer.get_duration_microsec()); + raw_counters.detokenization_durations.emplace_back(decode_timer.get_duration_microsec()); + decoded_results.perf_metrics.m_evaluated = false; + decoded_results.perf_metrics.evaluate_statistics(generate_timer.get_start_time()); + + return decoded_results; +} + +DecodedResults StatefulEagle3LLMPipeline::generate(const ChatHistory& history, + OptionalGenerationConfig generation_config, + StreamerVariant streamer) { + ManualTimer generate_timer("StatefulEagle3LLMPipeline::generate()"); + generate_timer.start(); + ManualTimer encode_timer("Encode"); + encode_timer.start(); + + GenerationConfig config = resolve_generation_config(generation_config); + + OPENVINO_ASSERT(config.apply_chat_template, + "Chat template must be applied when using ChatHistory in generate method."); + OPENVINO_ASSERT(!m_tokenizer.get_chat_template().empty(), + "Chat template must not be empty when using ChatHistory in generate method."); + OPENVINO_ASSERT(!history.empty(), "Chat history must not be empty when using ChatHistory in generate method."); + + constexpr bool add_generation_prompt = true; + auto templated_chat_history = m_tokenizer.apply_chat_template(history, add_generation_prompt); + // for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF + auto tokenized_inputs = m_tokenizer.encode(templated_chat_history, ov::genai::add_special_tokens(false)); + encode_timer.end(); + auto encoded_results = generate(tokenized_inputs, config, streamer); + + ManualTimer decode_timer("Decode"); + decode_timer.start(); + DecodedResults decoded_results = {m_tokenizer.decode(encoded_results.tokens), encoded_results.scores}; + decode_timer.end(); + + // Update perf metrics + decoded_results.perf_metrics = encoded_results.perf_metrics; + decoded_results.extended_perf_metrics = encoded_results.extended_perf_metrics; + auto& raw_counters = decoded_results.perf_metrics.raw_metrics; + generate_timer.end(); + raw_counters.generate_durations.clear(); + raw_counters.generate_durations.emplace_back(generate_timer.get_duration_microsec()); + raw_counters.tokenization_durations.emplace_back(encode_timer.get_duration_microsec()); + raw_counters.detokenization_durations.emplace_back(decode_timer.get_duration_microsec()); + decoded_results.perf_metrics.m_evaluated = false; + decoded_results.perf_metrics.evaluate_statistics(generate_timer.get_start_time()); + + return decoded_results; +} + +EncodedResults StatefulEagle3LLMPipeline::generate(const EncodedInputs& inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer) { + ManualTimer generate_timer("StatefulEagle3LLMPipeline::generate"); + generate_timer.start(); + + auto config = resolve_generation_config(generation_config); + + // Create streamer for streaming output + std::shared_ptr streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer); + + log_info("Starting Eagle3 generation with max_new_tokens=" + std::to_string(config.max_new_tokens) + + ", draft_iterations=" + std::to_string(m_draft_iterations)); + + // Extract input tensors + ov::Tensor input_ids, attention_mask; + if (auto* tensor_input = std::get_if(&inputs)) { + input_ids = *tensor_input; + attention_mask = ov::genai::utils::init_attention_mask(input_ids); + } else if (auto* tokenized_input = std::get_if(&inputs)) { + input_ids = tokenized_input->input_ids; + attention_mask = tokenized_input->attention_mask; + } + + auto prompt_shape = input_ids.get_shape(); + if (prompt_shape[0] != 1) { + throw std::runtime_error("Only batch size 1 is supported"); + } + + std::size_t prompt_len = prompt_shape[1]; + + m_prompt_length = prompt_len; + + log_debug("Prompt length: " + std::to_string(prompt_len) + " tokens"); + + // Initialize position IDs + ov::Tensor position_ids{ov::element::i64, input_ids.get_shape()}; + utils::initialize_position_ids(position_ids, attention_mask); + + // Reset model states and initialize sequences + m_main_model->reset_state(); + m_draft_model->reset_state(); + + // Initialize main model with full sequence + m_main_model->initialize_sequence(input_ids, position_ids); + + // Initialize draft model with sequence starting from second token + m_draft_model->initialize_sequence(input_ids, position_ids); + + // Initial main model inference + log_generation_step("Initial Main Model Inference", 0); + auto main_output = m_main_model->infer(input_ids, attention_mask, position_ids); + auto initial_token = std::get(m_main_model->sample_tokens(main_output.logits, 1)); + + // Get initial hidden features and append first generated token + auto main_hidden_features = main_output.hidden_features; + m_main_model->append_tokens({initial_token}); + m_draft_model->append_tokens({initial_token}); + + // Stream the initial token + auto streaming_status = stream_generated_tokens(streamer_ptr, std::vector{initial_token}); + + log_debug("Initial token generated: " + std::to_string(initial_token)); + log_sequence_state("after initial token generation"); + + // Main generation loop + std::size_t max_new_tokens = config.max_new_tokens; + std::size_t generated_tokens = 1; // Count initial token + bool eos_reached = false; + + // Track metrics for speculative decoding + std::size_t total_draft_accepted = 0; // Number of draft tokens accepted by main model + std::size_t total_draft_generated = 0; // Total draft tokens generated (including rejected) + std::size_t total_iterations = 0; // Number of speculative iterations + + // Speculative decoding loop + std::size_t token_count = m_draft_model->get_sequence_length(); + auto target_hidden_states = main_hidden_features; + + while (!eos_reached && generated_tokens < max_new_tokens && + m_main_model->get_sequence_length() < prompt_len + max_new_tokens && + (streaming_status == ov::genai::StreamingStatus::RUNNING)) { + log_generation_step("Speculative Decoding Iteration", generated_tokens); + log_sequence_state("iteration start"); + + auto result = + run_speculative_iteration(target_hidden_states, token_count, static_cast(config.eos_token_id)); + + // Stream validated tokens + streaming_status = stream_generated_tokens(streamer_ptr, result.validated_tokens); + + // Update iteration counter + total_iterations++; + + // Update draft token statistics + total_draft_generated += m_draft_iterations; // Each iteration generates m_draft_iterations draft tokens + total_draft_accepted += + result.accepted_tokens_count; // Number of draft tokens accepted (not including main model's token) + + if (result.new_token == static_cast(config.eos_token_id) || result.eos_reached) { + eos_reached = true; + log_debug("EOS reached - terminating generation"); + } + + // Validate that speculative iteration produced valid results + OPENVINO_ASSERT(result.new_token != -1, "Speculative iteration must produce a valid token"); + OPENVINO_ASSERT(result.next_window_size > 0, "Speculative iteration must produce valid next_window_size"); + OPENVINO_ASSERT(result.next_hidden_window && result.next_hidden_window.get_size() > 0, + "Speculative iteration must produce valid next_hidden_window"); + + generated_tokens++; + log_debug("Generated token " + std::to_string(generated_tokens) + ": " + + std::to_string(result.new_token) + ", accepted " + + std::to_string(result.accepted_tokens_count) + " draft tokens out of " + + std::to_string(m_draft_iterations)); + + // Prepare for next iteration + token_count = result.next_window_size; + target_hidden_states = result.next_hidden_window; + + log_debug("Next iteration: token_count=" + std::to_string(token_count) + + ", hidden_states_size=" + std::to_string(target_hidden_states.get_size())); + + log_sequence_state("iteration end"); + } + + m_streaming_was_cancelled = (streaming_status == ov::genai::StreamingStatus::CANCEL); + if (streamer_ptr) { // push streamer's cache + streamer_ptr->end(); + } + + // Prepare results using main model's tokens as source of truth + EncodedResults results; + results.tokens = {m_main_model->get_tokens()}; + results.scores.resize(1); + results.scores[0] = 0.0f; // Greedy decoding, no scores + + // Display final tokens if verbose + if (is_verbose() && !results.tokens[0].empty()) { + try { + std::string decoded_text = m_tokenizer.decode(results.tokens[0]); + std::cout << "[EAGLE3-FINAL] All tokens decoded (" << results.tokens[0].size() << " tokens): \"" + << decoded_text << "\"" << std::endl; + } catch (const std::exception& e) { + std::cout << "[EAGLE3-FINAL] Failed to decode tokens: " << e.what() << std::endl; + } + } + + // Update performance metrics following the standard stateful speculative decoding pattern + generate_timer.end(); + + m_sd_perf_metrics.num_input_tokens = prompt_len; + m_sd_perf_metrics.load_time = this->m_load_time_ms; + m_sd_perf_metrics.raw_metrics.generate_durations.clear(); + m_sd_perf_metrics.raw_metrics.generate_durations.emplace_back(generate_timer.get_duration_microsec()); + + // Update main and draft model metrics from their RawPerfMetrics + m_sd_perf_metrics.main_model_metrics.raw_metrics = m_main_model->get_raw_perf_metrics(); + m_sd_perf_metrics.draft_model_metrics.raw_metrics = m_draft_model->get_raw_perf_metrics(); + + // Set num_accepted_tokens - this represents draft tokens accepted by main model + m_sd_perf_metrics.num_accepted_tokens = total_draft_accepted; + + // Update speculative decoding metrics based on collected data + if (generated_tokens > 0) { + // Calculate acceptance rate: accepted draft tokens / total draft tokens generated + float acceptance_rate = total_draft_generated > 0 + ? (static_cast(total_draft_accepted) / total_draft_generated * 100.0f) + : 0.0f; + + m_sd_metrics.update_acceptance_rate(0, acceptance_rate); + m_sd_metrics.update_draft_accepted_tokens(0, total_draft_accepted); + m_sd_metrics.update_draft_generated_len(0, total_draft_generated); + m_sd_metrics.update_generated_len(generated_tokens); + } + + // Evaluate statistics using standard interface + m_sd_perf_metrics.evaluate_statistics(generate_timer.get_start_time()); + + results.perf_metrics = m_sd_perf_metrics; + results.extended_perf_metrics = std::make_shared(m_sd_perf_metrics); + + return results; +} + +StatefulEagle3LLMPipeline::SpeculativeResult StatefulEagle3LLMPipeline::run_speculative_iteration( + const ov::Tensor& target_hidden_states, + std::size_t token_count, + int64_t eos_token_id) { + SpeculativeResult result; + + log_debug("Starting speculative iteration with token_count=" + std::to_string(token_count)); + + if (!target_hidden_states || target_hidden_states.get_size() == 0) { + log_debug("Invalid target hidden states provided"); + return result; + } + + // Step 1: Initial draft inference using target model's hidden states + ov::Tensor draft_input_ids, draft_attention_mask, draft_position_ids; + m_draft_model->build_model_inputs(token_count, draft_input_ids, draft_attention_mask, draft_position_ids); + + auto draft_output = m_draft_model->infer(draft_input_ids, + draft_attention_mask, + draft_position_ids, + target_hidden_states, + /*internal_hidden_states=*/ov::Tensor{}); + + int64_t first_draft_token = std::get(m_draft_model->sample_tokens(draft_output.logits, 1)); + first_draft_token = map_draft_token(first_draft_token); + + // Record the sequence length before draft generation for rollback + std::size_t pre_draft_main_len = m_main_model->get_sequence_length(); + std::size_t pre_draft_draft_len = m_draft_model->get_sequence_length(); + + // Store all draft tokens temporarily (including first one) + std::vector draft_candidates; + draft_candidates.push_back(first_draft_token); + + // Append first draft token and get its hidden state for subsequent iterations + m_main_model->append_tokens({first_draft_token}); + m_draft_model->append_tokens({first_draft_token}); + auto internal_hidden_states = extract_last_hidden_state(draft_output.hidden_features); + + // Step 2: Additional draft iterations + for (std::size_t i = 0; i < m_draft_iterations - 1; ++i) { + m_draft_model->build_model_inputs(1, draft_input_ids, draft_attention_mask, draft_position_ids); + + auto more_output = m_draft_model->infer(draft_input_ids, + draft_attention_mask, + draft_position_ids, + /*target_hidden_states*/ ov::Tensor{}, + internal_hidden_states); + + int64_t draft_token = std::get(m_draft_model->sample_tokens(more_output.logits, 1)); + draft_token = map_draft_token(draft_token); + + draft_candidates.push_back(draft_token); + m_main_model->append_tokens({draft_token}); + m_draft_model->append_tokens({draft_token}); + internal_hidden_states = extract_last_hidden_state(more_output.hidden_features); + } + + // Step 3: Validation - main model validates draft tokens with shift comparison + log_debug("Starting validation phase with " + std::to_string(m_draft_iterations) + " draft tokens"); + + std::size_t validation_window_size = m_draft_iterations + 1; + + ov::Tensor val_input_ids, val_attention_mask, val_position_ids; + m_main_model->build_model_inputs(validation_window_size, val_input_ids, val_attention_mask, val_position_ids); + + // Run main model validation inference + auto val_output = m_main_model->infer(val_input_ids, val_attention_mask, val_position_ids); + auto sampled_tokens = + std::get>(m_main_model->sample_tokens(val_output.logits, validation_window_size)); + + // Compare main model predictions with draft tokens (shift comparison) + const int64_t* existing_tokens = val_input_ids.data(); + std::size_t accepted_count = 0; + + for (std::size_t i = 0; i < m_draft_iterations; ++i) { + if (sampled_tokens[i] == existing_tokens[i + 1]) { + accepted_count++; + } else { + log_debug("Validation mismatch at position " + std::to_string(i) + ": expected " + + std::to_string(existing_tokens[i + 1]) + ", got " + std::to_string(sampled_tokens[i])); + break; + } + } + + // The main model's prediction at accepted_count position becomes the new token + int64_t main_predicted_token = sampled_tokens[accepted_count]; + + // Calculate tokens to accept and reject + std::size_t tokens_to_remove_from_draft = m_draft_iterations - accepted_count; + std::size_t total_accepted_tokens = accepted_count + 1; // accepted drafts + main prediction + + log_debug("Validation result: accepted " + std::to_string(accepted_count) + "/" + + std::to_string(m_draft_iterations) + + " draft tokens, main_predicted_token=" + std::to_string(main_predicted_token)); + + // Rollback both models to pre-draft state, then append accepted tokens + m_main_model->truncate_sequence(pre_draft_main_len); + m_draft_model->truncate_sequence(pre_draft_draft_len); + + // Append accepted draft tokens + main predicted token to both models + std::vector tokens_to_append; + tokens_to_append.reserve(total_accepted_tokens); + for (std::size_t i = 0; i < accepted_count; ++i) { + tokens_to_append.push_back(draft_candidates[i]); + } + tokens_to_append.push_back(main_predicted_token); + + m_main_model->append_tokens(tokens_to_append); + m_draft_model->append_tokens(tokens_to_append); + + // Trim KV cache for rejected draft tokens + if (tokens_to_remove_from_draft > 0) { + m_main_model->trim_kv_cache(tokens_to_remove_from_draft); + m_draft_model->trim_kv_cache(tokens_to_remove_from_draft); + } + + log_debug("Accepted total " + std::to_string(total_accepted_tokens) + " tokens (" + std::to_string(accepted_count) + + " draft + 1 main prediction), rejected " + std::to_string(tokens_to_remove_from_draft) + + " draft tokens."); + + // Build next hidden window for next iteration + auto current_hidden = val_output.hidden_features; + OPENVINO_ASSERT(current_hidden && current_hidden.get_size() > 0, + "Hidden features from validation output must exist"); + + auto h_shape = current_hidden.get_shape(); + OPENVINO_ASSERT(h_shape.size() == 3 && h_shape[0] == 1, + "Invalid hidden state shape for next window construction"); + + std::size_t seq_len = h_shape[1]; + std::size_t hidden_dim = h_shape[2]; + std::size_t next_window_len = total_accepted_tokens; + + OPENVINO_ASSERT(seq_len >= next_window_len, + "Hidden state seq_len (", + seq_len, + ") < next_window_len (", + next_window_len, + ")"); + + // Extract hidden states for accepted tokens + // Input: [1, seq_len, hidden_dim] -> Output: [1, next_window_len, hidden_dim] + ov::Tensor next_hidden = ov::Tensor(current_hidden, {0, 0, 0}, {1, next_window_len, hidden_dim}); + + log_debug("Built next hidden window with " + std::to_string(next_window_len) + " positions (ROI, zero-copy)"); + + // Check for EOS token + if (main_predicted_token == eos_token_id) { + result.eos_reached = true; + log_debug("EOS token detected: " + std::to_string(main_predicted_token)); + } + + // Set result fields + // IMPORTANT: accepted_tokens_count is ONLY the number of DRAFT tokens accepted by main model + // It does NOT include the main model's own prediction token + result.accepted_tokens_count = accepted_count; // Only draft tokens accepted + result.next_window_size = total_accepted_tokens; // Total tokens for next iteration (draft + main) + result.new_token = main_predicted_token; // Main model's prediction + result.next_hidden_window = next_hidden; + result.validated_tokens = tokens_to_append; // Return validated tokens for streaming + + log_debug("Speculative iteration completed - accepted " + std::to_string(accepted_count) + + " draft tokens + 1 main prediction = " + std::to_string(total_accepted_tokens) + + " tokens for next iteration"); + + return result; +} + +int64_t StatefulEagle3LLMPipeline::map_draft_token(int64_t draft_token) const { + if (!m_draft_target_mapping || m_draft_target_mapping.get_size() == 0) { + return draft_token; + } + + std::size_t mapping_size = m_draft_target_mapping.get_size(); + if (draft_token < 0 || static_cast(draft_token) >= mapping_size) { + log_debug("Token " + std::to_string(draft_token) + " out of range, identity mapping"); + return draft_token; + } + + const int64_t* data = m_draft_target_mapping.data(); + int64_t target_token = draft_token + data[draft_token]; + + if (is_verbose()) { + log_debug("Mapped: " + std::to_string(draft_token) + " -> " + std::to_string(target_token)); + } + + return target_token; +} + +std::vector StatefulEagle3LLMPipeline::map_draft_tokens(const std::vector& draft_tokens) const { + if (!m_draft_target_mapping || m_draft_target_mapping.get_size() == 0) { + return draft_tokens; + } + + std::vector mapped; + mapped.reserve(draft_tokens.size()); + for (int64_t token : draft_tokens) { + mapped.push_back(map_draft_token(token)); + } + return mapped; +} + +void StatefulEagle3LLMPipeline::start_chat(const std::string& system_message) { + m_is_chat_active = true; + m_chat_history.clear(); + + if (!system_message.empty()) { + m_chat_history.push_back({{"role", "system"}, {"content", system_message}}); + } + log_info("Chat started"); +} + +void StatefulEagle3LLMPipeline::finish_chat() { + m_is_chat_active = false; + m_chat_history.clear(); + log_info("Chat ended"); +} + +ov::genai::SpeculativeDecodingMetrics StatefulEagle3LLMPipeline::get_speculative_decoding_metrics() const { + return m_sd_metrics; +} + +void StatefulEagle3LLMPipeline::log_info(const std::string& message) const { + std::cout << "[EAGLE3-PIPELINE] " << message << std::endl; +} + +void StatefulEagle3LLMPipeline::log_debug(const std::string& message) const { + if (is_verbose()) { + std::cout << "[EAGLE3-DEBUG] " << message << std::endl; + } +} + +void StatefulEagle3LLMPipeline::log_generation_step(const std::string& step_name, std::size_t step_number) const { + if (is_verbose()) { + std::cout << "\n[EAGLE3] ===== STEP " << step_number << ": " << step_name << " =====" << std::endl; + } +} + +void StatefulEagle3LLMPipeline::log_sequence_state(const std::string& context) const { + if (!is_verbose()) + return; + + std::cout << "[EAGLE3-PIPELINE] Sequence state (" << context << "):" << std::endl; + std::cout << " Prompt length: " << m_prompt_length << " tokens" << std::endl; + std::cout << " Main model tokens: " << m_main_model->get_sequence_length() << " tokens" << std::endl; + std::cout << " Draft model tokens: " << m_draft_model->get_sequence_length() << " tokens" << std::endl; + + // Show all tokens and positions from main model + const auto& main_tokens = m_main_model->get_tokens(); + const auto& main_positions = m_main_model->get_positions(); + if (!main_tokens.empty()) { + std::cout << " Main model tokens: "; + for (std::size_t i = 0; i < main_tokens.size(); ++i) { + std::cout << main_tokens[i]; + if (i + 1 < main_tokens.size()) + std::cout << ", "; + } + std::cout << std::endl; + + if (!main_positions.empty()) { + std::cout << " Main model positions: "; + for (std::size_t i = 0; i < main_positions.size(); ++i) { + std::cout << main_positions[i]; + if (i + 1 < main_positions.size()) + std::cout << ", "; + } + std::cout << std::endl; + } + } + + // Show all tokens and positions from draft model + const auto& draft_tokens = m_draft_model->get_tokens(); + const auto& draft_positions = m_draft_model->get_positions(); + if (!draft_tokens.empty()) { + std::cout << " Draft model tokens: "; + for (std::size_t i = 0; i < draft_tokens.size(); ++i) { + std::cout << draft_tokens[i]; + if (i + 1 < draft_tokens.size()) + std::cout << ", "; + } + std::cout << std::endl; + + if (!draft_positions.empty()) { + std::cout << " Draft model positions: "; + for (std::size_t i = 0; i < draft_positions.size(); ++i) { + std::cout << draft_positions[i]; + if (i + 1 < draft_positions.size()) + std::cout << ", "; + } + std::cout << std::endl; + } + } +} + +ov::Tensor StatefulEagle3LLMPipeline::slice_hidden_features(const ov::Tensor& hidden_features, + std::size_t start_pos, + std::size_t length) const { + if (!hidden_features || hidden_features.get_size() == 0) { + return ov::Tensor{}; + } + + auto shape = hidden_features.get_shape(); + if (shape.size() != 3 || shape[0] != 1) { + return hidden_features; + } + + std::size_t seq_len = shape[1]; + std::size_t hidden_dim = shape[2]; + + if (start_pos >= seq_len || length == 0) { + return ov::Tensor{}; + } + + std::size_t actual_len = std::min(length, seq_len - start_pos); + ov::Tensor sliced(ov::element::f32, {1, actual_len, hidden_dim}); + + if (hidden_features.get_element_type() == ov::element::f32) { + const float* src = hidden_features.data() + start_pos * hidden_dim; + std::copy_n(src, actual_len * hidden_dim, sliced.data()); + } + + return sliced; +} + +ov::Tensor StatefulEagle3LLMPipeline::combine_hidden_windows(const ov::Tensor& confirmed_hidden, + const ov::Tensor& new_hidden) const { + if (!confirmed_hidden || confirmed_hidden.get_size() == 0) { + return new_hidden; + } + if (!new_hidden || new_hidden.get_size() == 0) { + return confirmed_hidden; + } + + auto conf_shape = confirmed_hidden.get_shape(); + auto new_shape = new_hidden.get_shape(); + + if (conf_shape.size() != 3 || new_shape.size() != 3 || conf_shape[0] != 1 || new_shape[0] != 1 || + conf_shape[2] != new_shape[2]) { + return confirmed_hidden; + } + + std::size_t conf_len = conf_shape[1]; + std::size_t new_len = new_shape[1]; + std::size_t hidden_dim = conf_shape[2]; + std::size_t total_len = conf_len + new_len; + + ov::Tensor combined(ov::element::f32, {1, total_len, hidden_dim}); + + if (confirmed_hidden.get_element_type() == ov::element::f32 && new_hidden.get_element_type() == ov::element::f32) { + float* dst = combined.data(); + std::copy_n(confirmed_hidden.data(), conf_len * hidden_dim, dst); + std::copy_n(new_hidden.data(), new_len * hidden_dim, dst + conf_len * hidden_dim); + } + + return combined; +} + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_stateful_eagle3.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_stateful_eagle3.hpp new file mode 100644 index 0000000000..eb979a9cd8 --- /dev/null +++ b/src/cpp/src/speculative_decoding/speculative_decoding_stateful_eagle3.hpp @@ -0,0 +1,266 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include + +#include "llm/pipeline_base.hpp" +#include "sampling/sampler.hpp" +#include "speculative_decoding_metrics.hpp" +#include "utils.hpp" + +namespace ov { +namespace genai { + +/** + * @brief Eagle3 model inference output + */ +struct InferenceOutput { + ov::Tensor logits; + ov::Tensor hidden_features; +}; + +/** + * @brief Base class for Eagle3 model inference + * + * Provides shared functionality for target and draft model wrappers: + * - Sequence and KV cache management + * - Tensor building and sampling + * - Performance tracking + */ +class Eagle3InferWrapperBase { +public: + explicit Eagle3InferWrapperBase(const ov::genai::ModelDesc& model_desc); + virtual ~Eagle3InferWrapperBase() = default; + + // Configuration + std::string device() const { + return m_device; + } + ov::genai::GenerationConfig get_generation_config() const { + return m_generation_config; + } + void set_generation_config(ov::genai::GenerationConfig cfg) { + m_generation_config = std::move(cfg); + } + void set_verbose(bool verbose) { + m_verbose = verbose; + } + bool is_verbose() const { + return m_verbose; + } + + // Sequence management + void append_tokens(const std::vector& tokens); + void truncate_sequence(std::size_t size); + void trim_kv_cache(std::size_t tokens_to_remove); + void reset_state(); + void release_memory(); + + // Sequence access + std::size_t get_sequence_length() const { + return m_tokens.size(); + } + const std::vector& get_tokens() const { + return m_tokens; + } + const std::vector& get_positions() const { + return m_positions; + } + int64_t get_last_sampled_token() const { + return m_last_sampled_token; + } + + // Model outputs + ov::Tensor get_logits() const; + ov::Tensor get_hidden_features() const; + + // Tensor operations + void build_model_inputs(std::size_t token_count, + ov::Tensor& input_ids, + ov::Tensor& attention_mask, + ov::Tensor& position_ids); + ov::Tensor create_hidden_state_placeholder(const ov::Shape& shape) const; + + // Sampling + std::variant> sample_tokens(const ov::Tensor& logits, std::size_t count); + + // Performance + ov::genai::RawPerfMetrics& get_raw_perf_metrics() { + return m_raw_perf_metrics; + } + +protected: + static constexpr std::size_t BATCH_SIZE = 1; + + // Inference and metrics + uint64_t execute_inference(); + void update_performance_metrics(uint64_t inference_time_us, std::size_t tokens_count); + + // Debug logging + void log_debug(const std::string& message) const; + void log_tensor_info(const std::string& name, const ov::Tensor& tensor) const; + void log_tensor_content(const std::string& name, const ov::Tensor& tensor, std::size_t max_elements = 10) const; + void log_model_inputs(const ov::Tensor& input_ids, + const ov::Tensor& attention_mask, + const ov::Tensor& position_ids) const; + void log_model_outputs(const ov::Tensor& logits, const ov::Tensor& hidden_features) const; + + // Model and configuration + std::string m_device; + ov::AnyMap m_properties; + ov::genai::GenerationConfig m_generation_config; + ov::genai::Tokenizer m_tokenizer; + mutable ov::InferRequest m_request; + ov::genai::utils::KVAxesPosition m_kv_axes_pos; + + // Device limits (NPU-specific) + std::size_t m_max_prompt_len = 0; + std::size_t m_kv_cache_capacity = 0; + + // Speculative token sequences (may be rolled back) + std::vector m_tokens; + std::vector m_positions; + + // State + std::size_t m_processed_tokens = 0; + int64_t m_last_sampled_token = -1; + + // Metrics + ov::genai::RawPerfMetrics m_raw_perf_metrics; + + // Configuration + bool m_verbose = false; +}; + +/** + * @brief Target model wrapper for Eagle3 + * + * Main model that validates draft predictions and generates final output. + */ +class Eagle3TargetModelWrapper : public Eagle3InferWrapperBase { +public: + explicit Eagle3TargetModelWrapper(const ov::genai::ModelDesc& model_desc); + ~Eagle3TargetModelWrapper() = default; + + void initialize_sequence(const ov::Tensor& input_ids, const ov::Tensor& position_ids); + InferenceOutput infer(const ov::Tensor& input_ids, + const ov::Tensor& attention_mask, + const ov::Tensor& position_ids); +}; + +/** + * @brief Draft model wrapper for Eagle3 + * + * Generates candidate tokens using target hidden states or internal features. + * Uses tokens[1:] with position_ids [0, 1, ..., seq_len-2] (Eagle3 specific). + */ +class Eagle3DraftModelWrapper : public Eagle3InferWrapperBase { +public: + explicit Eagle3DraftModelWrapper(const ov::genai::ModelDesc& model_desc); + ~Eagle3DraftModelWrapper() = default; + + void initialize_sequence(const ov::Tensor& input_ids, const ov::Tensor& position_ids); + InferenceOutput infer(const ov::Tensor& input_ids, + const ov::Tensor& attention_mask, + const ov::Tensor& position_ids, + const ov::Tensor& target_hidden_features, + const ov::Tensor& internal_hidden_features); +}; + +/** + * @brief Stateful Eagle3 LLM Pipeline + * + * Eagle3 speculative decoding: draft model generates candidates, main model validates. + */ +class StatefulEagle3LLMPipeline : public ov::genai::LLMPipelineImplBase { +public: + StatefulEagle3LLMPipeline(const ov::genai::ModelDesc& main_model_desc, + const ov::genai::ModelDesc& draft_model_desc, + const std::vector& hidden_layers_to_abstract = {}); + ~StatefulEagle3LLMPipeline(); + + // LLMPipelineImplBase interface + DecodedResults generate(StringInputs inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer) override; + DecodedResults generate(const ChatHistory& history, + OptionalGenerationConfig generation_config, + StreamerVariant streamer) override; + EncodedResults generate(const EncodedInputs& inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer) override; + void start_chat(const std::string& system_message) override; + void finish_chat() override; + + // Configuration + void set_draft_target_mapping(const std::shared_ptr& draft_model); + void set_verbose(bool verbose); + bool is_verbose() const { + return m_main_model ? m_main_model->is_verbose() : false; + } + GenerationConfig resolve_generation_config(OptionalGenerationConfig generation_config); + + // Metrics + ov::genai::SpeculativeDecodingMetrics get_speculative_decoding_metrics() const; + + // Logging + void log_generation_step(const std::string& step_name, std::size_t step_number) const; + void log_sequence_state(const std::string& context) const; + +private: + struct SpeculativeResult { + ov::Tensor next_hidden_window; + std::size_t accepted_tokens_count = 0; + std::size_t next_window_size = 0; + int64_t new_token = -1; + bool eos_reached = false; + std::vector validated_tokens; // Tokens accepted by main model (draft + main prediction) + }; + + // Core algorithm + SpeculativeResult run_speculative_iteration(const ov::Tensor& target_hidden_states, + std::size_t token_count, + int64_t eos_token_id); + + // Token mapping + int64_t map_draft_token(int64_t draft_token) const; + std::vector map_draft_tokens(const std::vector& draft_tokens) const; + + // Logging + void log_info(const std::string& message) const; + void log_debug(const std::string& message) const; + + // Tensor utilities + ov::Tensor slice_hidden_features(const ov::Tensor& hidden_features, + std::size_t start_pos, + std::size_t length) const; + ov::Tensor combine_hidden_windows(const ov::Tensor& confirmed_hidden, const ov::Tensor& new_hidden) const; + + // Models + std::unique_ptr m_draft_model; + std::unique_ptr m_main_model; + + // Algorithm configuration + std::size_t m_draft_iterations = 5; + ov::Tensor m_draft_target_mapping; + std::vector m_hidden_layers_to_abstract; + + std::size_t m_prompt_length = 0; + + // Metrics + ov::genai::SpeculativeDecodingMetrics m_sd_metrics; + ov::genai::SDPerModelsPerfMetrics m_sd_perf_metrics; + + // Chat state + bool m_is_chat_active = false; + ChatHistory m_chat_history; + bool m_streaming_was_cancelled = false; +}; + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_utils.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_utils.cpp new file mode 100644 index 0000000000..868b48b836 --- /dev/null +++ b/src/cpp/src/speculative_decoding/speculative_decoding_utils.cpp @@ -0,0 +1,59 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "speculative_decoding_utils.hpp" + +#include +#include + +#include "json_utils.hpp" + +namespace ov { +namespace genai { +namespace speculative_decoding { + +void ensure_num_assistant_tokens_is_set(ov::genai::GenerationConfig& config) { + // Only num_assistant_tokens is supported, not assistant_confidence_threshold + OPENVINO_ASSERT( + config.assistant_confidence_threshold == 0.f, + "Speculative Decoding only supports num_assistant_tokens, not assistant_confidence_threshold. Set it to 0.f."); + if (config.num_assistant_tokens == 0) { + config.num_assistant_tokens = DEFAULT_NUM_ASSISTANT_TOKENS; + } +} + +Eagle3RTInfo +extract_eagle_mode_from_config(ov::AnyMap& config, const std::filesystem::path& models_path) { + Eagle3RTInfo eagle_rt_info; + if (config.find("eagle3_mode") != config.end()) { + eagle_rt_info.eagle3_mode = config.at("eagle3_mode").as(); + config.erase("eagle3_mode"); + if (config.find("hidden_layers_list") != config.end()) { + eagle_rt_info.hidden_layers_list = config.at("hidden_layers_list").as>(); + config.erase("hidden_layers_list"); + } else { + // compute the layers from number of hidden layers + auto config_file_path = models_path / "config.json"; + if (!std::filesystem::exists(config_file_path)) + OPENVINO_THROW("cannot deduce layers for hidden layer extraction"); + std::ifstream file(config_file_path); + + nlohmann::json data = nlohmann::json::parse(file); + using ov::genai::utils::read_json_param; + int num_decoder_layers = 0; + read_json_param(data, "num_hidden_layers", num_decoder_layers); + OPENVINO_ASSERT(num_decoder_layers > 3, "num_decoder_layers is too small to deduce hidden layers for extraction"); + // The following default hidden layer selection corresponds to the EAGLE reference implementation: + // https://github.com/SafeAILab/EAGLE/blob/0ea94696/eagle/model/modeling_llama_kv.py#L1138 + // These layers (2, num_decoder_layers / 2, num_decoder_layers - 3) are chosen to capture features from + // early, middle, and late stages of the decoder, as recommended by the EAGLE authors. + // If you wish to use different layers, provide the "hidden_layers_list" parameter in the config. + eagle_rt_info.hidden_layers_list = { 2, num_decoder_layers / 2, num_decoder_layers - 3 }; + } + } + return eagle_rt_info; +} + +} // namespace speculative_decoding +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_utils.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_utils.hpp new file mode 100644 index 0000000000..70f7d3fc23 --- /dev/null +++ b/src/cpp/src/speculative_decoding/speculative_decoding_utils.hpp @@ -0,0 +1,49 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "openvino/genai/generation_config.hpp" + +#include "openvino/runtime/core.hpp" + +namespace ov { +namespace genai { +namespace speculative_decoding { + +// Set num_assistant_tokens to default if not specified and check config validity +constexpr std::size_t DEFAULT_NUM_ASSISTANT_TOKENS = 4; +void ensure_num_assistant_tokens_is_set(ov::genai::GenerationConfig& config); + +/** + * @brief Eagle3 runtime configuration information + */ +struct Eagle3RTInfo { + bool eagle3_mode = false; + std::vector hidden_layers_list; + std::filesystem::path dt_mapping_table; +}; + +/** + * @brief Extract Eagle3 configuration from draft model properties + * + * This function extracts Eagle3-specific configuration from the draft model's + * property map. It looks for: + * - eagle3_mode: boolean flag to enable Eagle3 speculative decoding + * - hidden_layers_list: explicit list of layer indices to extract hidden states from + * + * If hidden_layers_list is not provided and models_path is given, the function + * will attempt to auto-deduce the layers from the model's config.json file. + * + * @param config Draft model configuration map (will be modified - eagle3 params will be erased) + * @param models_path Optional path to model directory for auto-deducing hidden layers from config.json + * @return Eagle3RTInfo structure with extracted configuration + */ +Eagle3RTInfo extract_eagle_mode_from_config(ov::AnyMap& config, const std::filesystem::path& models_path = {}); + +} // namespace speculative_decoding +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/speculative_decoding/update_request_structs.hpp b/src/cpp/src/speculative_decoding/update_request_structs.hpp index 68f79268f5..4426372507 100644 --- a/src/cpp/src/speculative_decoding/update_request_structs.hpp +++ b/src/cpp/src/speculative_decoding/update_request_structs.hpp @@ -10,11 +10,17 @@ namespace ov::genai { struct GeneratedSequence { std::vector token_ids; std::vector log_probs; - + // Stores the hidden states tensor associated with the generated sequence. + // This field is used for the "eagle speculative" decoding algorithm, + // where hidden states are required to efficiently validate and extend speculative tokens. + // If not using eagle speculative decoding, this field may remain empty. + ov::Tensor hidden_states; GeneratedSequence(const std::vector& generated_token_ids, - const std::vector& generated_log_probs) : + const std::vector& generated_log_probs, + const ov::Tensor generated_hidden_states = {}) : token_ids(generated_token_ids), - log_probs(generated_log_probs) {}; + log_probs(generated_log_probs), + hidden_states(generated_hidden_states) {}; }; struct UpdateRequestResult { diff --git a/tests/python_tests/samples/conftest.py b/tests/python_tests/samples/conftest.py index 8a011ccfe2..f6e0b0eeeb 100644 --- a/tests/python_tests/samples/conftest.py +++ b/tests/python_tests/samples/conftest.py @@ -143,6 +143,14 @@ "tiny-random-SpeechT5ForTextToSpeech": { "name": "hf-internal-testing/tiny-random-SpeechT5ForTextToSpeech", "convert_args": ["--model-kwargs", json.dumps({"vocoder": "fxmarty/speecht5-hifigan-tiny"})] + }, + "Qwen3-1.7B": { + "name": "Qwen/Qwen3-1.7B", + "convert_args": ["--task", "text-generation-with-past", '--trust-remote-code'] + }, + "qwen3_1.7b_eagle3": { + "name": "AngelSlim/Qwen3-1.7B_eagle3", + "convert_args": ["--task", "text-generation-with-past", "--trust-remote-code", "--eagle3"] } } diff --git a/tests/python_tests/samples/test_speculative_decoding_lm.py b/tests/python_tests/samples/test_speculative_decoding_lm.py index 35a1bb285a..0b4ef570a2 100644 --- a/tests/python_tests/samples/test_speculative_decoding_lm.py +++ b/tests/python_tests/samples/test_speculative_decoding_lm.py @@ -10,6 +10,23 @@ convert_draft_model = convert_model +def _run_spec_case(convert_model, convert_draft_model, sample_args, env): + cpp_sample = os.path.join(SAMPLES_CPP_DIR, 'speculative_decoding_lm') + cpp_command =[cpp_sample, convert_model, convert_draft_model, sample_args] + cpp_result = run_sample(cpp_command, env=env) + + py_script = os.path.join(SAMPLES_PY_DIR, "text_generation/speculative_decoding_lm.py") + py_command = [sys.executable, py_script, convert_model, convert_draft_model, sample_args] + py_result = run_sample(py_command, env=env) + + cpp_sample_ref = os.path.join(SAMPLES_CPP_DIR, 'greedy_causal_lm') + cpp_command_ref = [cpp_sample_ref, convert_model, sample_args] + cpp_result_ref = run_sample(cpp_command_ref, env=env) + + assert cpp_result_ref.stdout.strip() in py_result.stdout.strip(), "Python and CPP results should match" + assert cpp_result_ref.stdout.strip() in cpp_result.stdout.strip(), "Greedy and speculative decoding results should match" + return cpp_result, py_result, cpp_result_ref + class TestSpeculativeDecodingLM: @pytest.mark.llm @pytest.mark.samples @@ -26,22 +43,27 @@ def test_sample_speculative_decoding_lm(self, convert_model, convert_draft_model pytest.xfail("Ticket 173586") env = os.environ.copy() env["OPENVINO_LOG_LEVEL"] = "0" - # Test CPP sample - cpp_sample = os.path.join(SAMPLES_CPP_DIR, 'speculative_decoding_lm') - cpp_command =[cpp_sample, convert_model, convert_draft_model, sample_args] - cpp_result = run_sample(cpp_command, env=env) - - # Test Python sample - py_script = os.path.join(SAMPLES_PY_DIR, "text_generation/speculative_decoding_lm.py") - py_command = [sys.executable, py_script, convert_model, convert_draft_model, sample_args] - py_result = run_sample(py_command, env=env) - - # Greedy decoding - cpp_sample_ref = os.path.join(SAMPLES_CPP_DIR, 'greedy_causal_lm') - cpp_command_ref = [cpp_sample_ref, convert_model, sample_args] - cpp_result_ref = run_sample(cpp_command_ref, env=env) - - # Compare results - assert cpp_result_ref.stdout.strip() in py_result.stdout.strip(), "Python and CPP results should match" - assert cpp_result_ref.stdout.strip() in cpp_result.stdout.strip(), "Greedy and speculative decoding results should match" + _run_spec_case(convert_model, convert_draft_model, sample_args, env) + +test_prompt = """Code: +def add(a, b): + return a + b +Question: Can you please add 2 and 3 +A:""" +class TestEagle3SpeculativeDecodingLM: + @pytest.mark.llm + @pytest.mark.samples + @pytest.mark.parametrize( + "convert_model, convert_draft_model, sample_args", + [ + pytest.param("Qwen3-1.7B", "qwen3_1.7b_eagle3", test_prompt, marks=pytest.mark.skip(reason = 'CVS-171947, CVS-171943, CVS-174959')), + ], + indirect=["convert_model", "convert_draft_model"], + ) + def test_sample_speculative_decoding_lm(self, convert_model, convert_draft_model, sample_args): + if sys.platform == 'darwin': + pytest.xfail("Ticket 173586") + env = os.environ.copy() + env["OPENVINO_LOG_LEVEL"] = "0" + _run_spec_case(convert_model, convert_draft_model, sample_args, env) \ No newline at end of file diff --git a/tests/python_tests/test_continuous_batching.py b/tests/python_tests/test_continuous_batching.py index b32ba5e8e7..975952406c 100644 --- a/tests/python_tests/test_continuous_batching.py +++ b/tests/python_tests/test_continuous_batching.py @@ -16,8 +16,9 @@ from utils.generation_config import get_greedy, get_beam_search, \ get_multinomial_all_parameters, get_multinomial_temperature_and_num_return_sequence, \ get_multinomial_temperature_and_top_k, get_multinomial_temperature, get_multinomial_temperature_and_top_p -from utils.hugging_face import download_and_convert_model -from utils.ov_genai_pipelines import create_ov_pipeline, create_ov_cb_pipeline, PipelineType, dict_to_scheduler_config, generate_and_compare, prepare_generation_config_by_pipe_type, GenerationChatInputsType +from utils.hugging_face import download_and_convert_model, run_hugging_face +from utils.ov_genai_pipelines import create_ov_pipeline, create_ov_cb_pipeline, PipelineType, dict_to_scheduler_config, generate_and_compare, prepare_generation_config_by_pipe_type, convert_decoded_results_to_generation_result, GenerationChatInputsType +from utils.comparation import compare_generation_results from data.models import get_chat_models_list from data.test_dataset import get_test_dataset @@ -478,21 +479,44 @@ def get_data_by_pipeline_type(model_path: Path, pipeline_type: str, generation_c return pipe, prompt, generation_config -def run_extended_perf_metrics_collection(model_id, generation_config: GenerationConfig, prompt: str, pipeline_type: PipelineType): +def run_extended_perf_metrics_collection(model_id, generation_config: GenerationConfig, prompt: str, pipeline_type: PipelineType, draft_model_id: str): _, _, model_path = download_and_convert_model(model_id) - ov_pipe = create_ov_pipeline(model_path, pipeline_type=pipeline_type) + draft_model_path = None + if draft_model_id is not None: + _,_, draft_model_path = download_and_convert_model(draft_model_id) + ov_pipe = create_ov_pipeline(model_path, pipeline_type=pipeline_type, draft_model_path = draft_model_path) return ov_pipe.generate([prompt], generation_config).extended_perf_metrics +eagle_models_and_input = [ + ("Qwen/Qwen3-1.7B", "AngelSlim/Qwen3-1.7B_eagle3", """Code: +def add(a, b): + return a + b +Question: Can you please add 2 and 3 +A:""")] + +speculative_cases = [ + ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None, "Why is the Sun yellow?"), + eagle_models_and_input[0], +] @pytest.mark.parametrize("pipeline_type", [PipelineType.PAGED_ATTENTION, PipelineType.SPECULATIVE_DECODING]) -def test_speculative_decoding_extended_perf_metrics(pipeline_type): +@pytest.mark.parametrize("main_model_id,draft_model_id, prompt", speculative_cases) +@pytest.mark.precommit +def test_speculative_decoding_extended_perf_metrics(pipeline_type, main_model_id, draft_model_id, prompt): import time start_time = time.perf_counter() - model_id : str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" - generation_config = GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) - extended_perf_metrics = run_extended_perf_metrics_collection(model_id, generation_config, "Why is the Sun yellow?", pipeline_type) - total_time = (time.perf_counter() - start_time) * 1000 + extended_perf_metrics = None + if draft_model_id is None: + generation_config = GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) + extended_perf_metrics = run_extended_perf_metrics_collection(main_model_id, generation_config, prompt, pipeline_type, draft_model_id) + total_time = (time.perf_counter() - start_time) * 1000 + else: + if (pipeline_type == PipelineType.SPECULATIVE_DECODING): + generation_config = GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) + extended_perf_metrics = run_extended_perf_metrics_collection(main_model_id, generation_config, prompt, pipeline_type, draft_model_id) + total_time = (time.perf_counter() - start_time) * 1000 + if (pipeline_type == PipelineType.SPECULATIVE_DECODING): assert not extended_perf_metrics is None assert not extended_perf_metrics.main_model_metrics is None @@ -530,3 +554,31 @@ def test_speculative_decoding_extended_perf_metrics(pipeline_type): assert std_gen_duration == 0 else: assert extended_perf_metrics is None + +devices = [ + ('CPU', 'CPU') +] +@pytest.mark.parametrize("main_model,draft_model,prompt", eagle_models_and_input) +@pytest.mark.parametrize("main_device,draft_device", devices) +@pytest.mark.precommit +def test_eagle3_sd_string_inputs(main_model, main_device, draft_model, draft_device, prompt): + # Download and convert model: + main_opt_model, main_hf_tokenizer, main_model_path = download_and_convert_model(main_model) + __, __, draft_model_path = download_and_convert_model(draft_model) + + # Create OpenVINO GenAI pipeline: + + ov_pipe = create_ov_pipeline(main_model_path, pipeline_type = PipelineType.SPECULATIVE_DECODING, draft_model_path = draft_model_path) + + # Run reference HF model: + ov_generation_config = GenerationConfig(max_new_tokens=20) + ref_gen_results = run_hugging_face(main_opt_model, main_hf_tokenizer, [prompt], ov_generation_config) + + # Run OpenVINO GenAI pipeline: + ov_decoded_results = ov_pipe.generate([prompt], ov_generation_config) + ov_gen_results = convert_decoded_results_to_generation_result(ov_decoded_results, 1, 1, False) + + del ov_pipe + + # Compare results: + compare_generation_results([prompt], ref_gen_results, ov_gen_results, ov_generation_config) \ No newline at end of file diff --git a/tests/python_tests/utils/hugging_face.py b/tests/python_tests/utils/hugging_face.py index ec2535dcbe..dcfe3dc7cc 100644 --- a/tests/python_tests/utils/hugging_face.py +++ b/tests/python_tests/utils/hugging_face.py @@ -166,9 +166,14 @@ def run_hugging_face( # download HF model or read converted model def get_huggingface_models(model_id: str | Path, model_class: Type[OVModel], local_files_only=False): - hf_tokenizer = retry_request(lambda: AutoTokenizer.from_pretrained(model_id, local_files_only=local_files_only)) - opt_model = retry_request(lambda: model_class.from_pretrained(model_id, export=isinstance(model_id, str), compile=False, load_in_8bit=False, ov_config=get_default_llm_properties(), local_files_only=local_files_only)) - return opt_model, hf_tokenizer + if "eagle3" not in str(model_id).lower(): + hf_tokenizer = retry_request(lambda: AutoTokenizer.from_pretrained(model_id, local_files_only=local_files_only)) + opt_model = retry_request(lambda: model_class.from_pretrained(model_id, export=isinstance(model_id, str), compile=False, load_in_8bit=False, ov_config=get_default_llm_properties(), local_files_only=local_files_only)) + return opt_model, hf_tokenizer + else: + hf_tokenizer = None + opt_model = retry_request(lambda: model_class.from_pretrained(model_id, eagle3=True, export=isinstance(model_id, str), compile=False, load_in_8bit=False, ov_config=get_default_llm_properties(), local_files_only=local_files_only)) + return opt_model, hf_tokenizer def convert_and_save_tokenizer(hf_tokenizer : AutoTokenizer, @@ -192,9 +197,10 @@ def convert_models(opt_model : OVModelForCausalLM, opt_model.config.save_pretrained(models_path) # to store tokenizer config jsons with special tokens - hf_tokenizer.save_pretrained(models_path) - # convert tokenizers as well - convert_and_save_tokenizer(hf_tokenizer, models_path, **tokenizer_kwargs) + if hf_tokenizer: + hf_tokenizer.save_pretrained(models_path) + # convert tokenizers as well + convert_and_save_tokenizer(hf_tokenizer, models_path, **tokenizer_kwargs) def download_and_convert_model(model_id: str, **tokenizer_kwargs):