Skip to content

Commit a47dbaf

Browse files
Support eagle 3 top 1 for NPU
1 parent 257978f commit a47dbaf

File tree

9 files changed

+1966
-26
lines changed

9 files changed

+1966
-26
lines changed

samples/python/text_generation/speculative_decoding_lm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ def main():
2121
# User can run main and draft model on different devices.
2222
# Please, set device for main model in `openvino_genai.LLMPipeline` constructor and in `openvino_genai.draft_model` for draft.
2323
# CPU, GPU and NPU can be used. For NPU, the preferred configuration is when both the main and draft models use NPU.
24-
main_device = 'CPU'
25-
draft_device = 'CPU'
24+
main_device = 'NPU'
25+
draft_device = 'NPU'
2626

2727
draft_model = openvino_genai.draft_model(args.draft_model_dir, draft_device)
2828

2929
pipe = openvino_genai.LLMPipeline(args.model_dir, main_device, draft_model=draft_model)
3030

3131
config = openvino_genai.GenerationConfig()
32-
config.max_new_tokens = 100
32+
config.max_new_tokens = 20
3333
# Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded.
3434
# Add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration.
3535
# NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial

src/cpp/src/continuous_batching/pipeline.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "continuous_batching/pipeline_impl.hpp"
1313
#include "speculative_decoding/speculative_decoding_impl.hpp"
1414
#include "speculative_decoding/speculative_decoding_eagle3_impl.hpp"
15+
#include "speculative_decoding/speculative_decoding_utils.hpp"
1516
#include "prompt_lookup/prompt_lookup_impl.hpp"
1617
#include "continuous_batching/timer.hpp"
1718
#include "utils.hpp"
@@ -85,7 +86,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p
8586
auto properties_without_draft_model = properties;
8687
auto draft_model_desr = utils::extract_draft_model_from_config(properties_without_draft_model);
8788
auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model);
88-
auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties, models_path);
89+
auto eagle_rt_info = speculative_decoding::extract_eagle_mode_from_config(draft_model_desr.properties, models_path);
8990

9091
auto model = utils::read_model(models_path, properties);
9192
auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model);
@@ -132,7 +133,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline(
132133
auto properties_without_draft_model = properties;
133134
auto draft_model_desr = utils::extract_draft_model_from_config(properties_without_draft_model);
134135
auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model);
135-
auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties, models_path);
136+
auto eagle_rt_info = speculative_decoding::extract_eagle_mode_from_config(draft_model_desr.properties, models_path);
136137
auto model = utils::read_model(models_path, properties_without_draft_model);
137138
auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model);
138139
properties_without_draft_model_without_gguf[ov::cache_model_path.name()] = models_path;
@@ -182,7 +183,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline(
182183
auto properties_without_draft_model = properties;
183184
auto draft_model_desr = utils::extract_draft_model_from_config(properties_without_draft_model);
184185
auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model);
185-
auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties, std::filesystem::path(model_str));
186+
auto eagle_rt_info = speculative_decoding::extract_eagle_mode_from_config(draft_model_desr.properties, std::filesystem::path(model_str));
186187
auto model = utils::singleton_core().read_model(model_str, weights_tensor);
187188

188189
auto rt_info = model->get_rt_info();

src/cpp/src/llm/pipeline.cpp

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "llm/pipeline_continuous_batching_adapter.hpp"
1414
#include "speculative_decoding/speculative_decoding_impl.hpp"
1515
#include "speculative_decoding/speculative_decoding_stateful.hpp"
16+
#include "speculative_decoding/speculative_decoding_stateful_eagle3.hpp"
17+
#include "speculative_decoding/speculative_decoding_utils.hpp"
1618
#include "utils.hpp"
1719

1820
namespace {
@@ -142,7 +144,8 @@ static std::unique_ptr<LLMPipelineImplBase> create(
142144
tokenizer,
143145
device,
144146
properties,
145-
utils::from_config_json_if_exists(models_path));
147+
utils::from_config_json_if_exists(models_path),
148+
models_path);
146149
}
147150

148151
static std::unique_ptr<LLMPipelineImplBase> create(
@@ -157,17 +160,43 @@ static std::unique_ptr<LLMPipelineImplBase> create(
157160
const ov::genai::Tokenizer& tokenizer,
158161
const std::string& device,
159162
const ov::AnyMap& properties,
160-
const ov::genai::GenerationConfig& generation_config) {
163+
const ov::genai::GenerationConfig& generation_config,
164+
const std::filesystem::path& models_path = {}) {
161165

162166
auto properties_without_draft_model = properties;
163167
auto draft_model_descr = ov::genai::utils::extract_draft_model_from_config(properties_without_draft_model);
168+
164169
if (draft_model_descr.model != nullptr) {
165-
// FIXME: Add support for StatefulSpeculativeLLMPipeline for non-NPU devices for both models.
166-
OPENVINO_ASSERT(device == "NPU" || draft_model_descr.device == "NPU",
167-
"Stateful Speculative Decoding is expected to be launched when NPU is requested as "
168-
"execution device for one or both models.");
169-
auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, {}, generation_config);
170-
return std::make_unique<StatefulSpeculativeLLMPipeline>(main_model_descr, draft_model_descr);
170+
// Extract Eagle3 configuration from draft model properties
171+
// Pass models_path for auto-deducing hidden_layers_list from config.json
172+
auto eagle_rt_info = ov::genai::speculative_decoding::extract_eagle_mode_from_config(
173+
draft_model_descr.properties,
174+
models_path
175+
);
176+
177+
if (eagle_rt_info.eagle3_mode) {
178+
// Eagle3 Speculative Decoding mode
179+
OPENVINO_ASSERT(device == "NPU" || draft_model_descr.device == "NPU",
180+
"Stateful Eagle3 Speculative Decoding is expected to be launched when NPU is requested as "
181+
"execution device for one or both models.");
182+
183+
auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device,
184+
properties_without_draft_model, {}, generation_config);
185+
return std::make_unique<StatefulEagle3LLMPipeline>(
186+
main_model_descr,
187+
draft_model_descr,
188+
eagle_rt_info.hidden_layers_list
189+
);
190+
} else {
191+
// Standard Speculative Decoding mode
192+
OPENVINO_ASSERT(device == "NPU" || draft_model_descr.device == "NPU",
193+
"Stateful Speculative Decoding is expected to be launched when NPU is requested as "
194+
"execution device for one or both models.");
195+
196+
auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device,
197+
properties_without_draft_model, {}, generation_config);
198+
return std::make_unique<StatefulSpeculativeLLMPipeline>(main_model_descr, draft_model_descr);
199+
}
171200
}
172201

173202
return std::make_unique<StatefulLLMPipeline>(model, tokenizer, device,

src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void share_embedding_weights(std::shared_ptr<ov::Model>& main_model, std::shared
5252
}
5353
}
5454

55-
std::shared_ptr<ov::op::v0::Constant> extract_d2t_mapping_table(std::shared_ptr<ov::Model>& model) {
55+
std::shared_ptr<ov::op::v0::Constant> extract_d2t_mapping_table(const std::shared_ptr<ov::Model>& model) {
5656
// extract result nodes from model
5757
for (const auto& result : model->get_results()) {
5858
auto input_node = result->input_value(0).get_node_shared_ptr();
@@ -62,14 +62,36 @@ std::shared_ptr<ov::op::v0::Constant> extract_d2t_mapping_table(std::shared_ptr<
6262
}
6363
return nullptr;
6464
}
65+
66+
void remove_d2t_result_node(std::shared_ptr<ov::Model>& model) {
67+
// Find and remove the d2t Result node
68+
std::shared_ptr<ov::op::v0::Result> d2t_result_to_remove = nullptr;
69+
70+
for (const auto& result : model->get_results()) {
71+
auto input_node = result->input_value(0).get_node_shared_ptr();
72+
if (ov::is_type<ov::op::v0::Constant>(input_node) &&
73+
input_node->get_friendly_name().find("d2t") != std::string::npos) {
74+
d2t_result_to_remove = result;
75+
break;
76+
}
77+
}
78+
79+
if (d2t_result_to_remove) {
80+
model->remove_result(d2t_result_to_remove);
81+
model->validate_nodes_and_infer_types();
82+
}
83+
}
84+
6585
void extract_hidden_state_generic(std::shared_ptr<ov::Model>& model,
66-
const std::vector<int>& hidden_layers_to_abstract) {
86+
const std::vector<int>& hidden_layers_to_abstract,
87+
const std::string& device) {
6788
ov::pass::Manager pm;
68-
pm.register_pass<EagleModelTransform>(hidden_layers_to_abstract);
89+
pm.register_pass<EagleModelTransform>(hidden_layers_to_abstract, device);
6990
pm.run_passes(model);
7091
}
7192

72-
EagleModelTransform::EagleModelTransform(const std::vector<int>& layers) : m_layer_ids(layers) {
93+
EagleModelTransform::EagleModelTransform(const std::vector<int>& layers, const std::string& device)
94+
: m_layer_ids(layers), m_device(device) {
7395
}
7496

7597
bool EagleModelTransform::run_on_model(const std::shared_ptr<ov::Model>& model) {
@@ -82,7 +104,7 @@ bool EagleModelTransform::run_on_model(const std::shared_ptr<ov::Model>& model)
82104
manager.register_pass<EagleBaseTransform>(m_new_results);
83105
// input transform for draft
84106
// here we apply a trick for the fc layer in draft model
85-
manager.register_pass<EagleInputTransform>(m_new_parameters);
107+
manager.register_pass<EagleInputTransform>(m_new_parameters, m_device);
86108
manager.run_passes(model);
87109

88110
model->add_parameters(m_new_parameters);
@@ -109,7 +131,8 @@ bool EagleModelTransform::run_on_model(const std::shared_ptr<ov::Model>& model)
109131
return false;
110132
}
111133

112-
EagleInputTransform::EagleInputTransform(std::vector<std::shared_ptr<v0::Parameter>>& params) {
134+
EagleInputTransform::EagleInputTransform(std::vector<std::shared_ptr<v0::Parameter>>& params, const std::string& device)
135+
: m_device(device) {
113136
register_matcher(
114137
std::make_shared<ov::pass::pattern::Matcher>(ov::pass::pattern::wrap_type<v0::MatMul>(), this->get_type_info().name),
115138
([&params, this](ov::pass::pattern::Matcher& m) {
@@ -126,6 +149,7 @@ EagleInputTransform::EagleInputTransform(std::vector<std::shared_ptr<v0::Paramet
126149
})
127150
);
128151
}
152+
129153
bool EagleInputTransform::apply(NodePtr node, std::vector<std::shared_ptr<v0::Parameter>>& params) {
130154
if (ov::is_type<v0::MatMul>(node)) {
131155
auto matmul_node = ov::as_type_ptr<v0::MatMul>(node);
@@ -135,16 +159,56 @@ bool EagleInputTransform::apply(NodePtr node, std::vector<std::shared_ptr<v0::Pa
135159
return false;
136160
}
137161

162+
auto matmul_input0 = matmul_node->input_value(0);
163+
auto matmul_input1 = matmul_node->input_value(1);
164+
165+
std::shared_ptr<ov::Node> matmul_output_node;
166+
167+
// Apply scaling optimization for NPU devices to prevent FP16 overflow
168+
if (m_device.find("NPU") != std::string::npos) {
169+
// Scale input down by 100x before MatMul to avoid FP16 overflow, then scale result back up
170+
// The factor 100 (0.01 and 100.0) is an empirical value
171+
auto scale_down_const = std::make_shared<v0::Constant>(matmul_input0.get_element_type(), ov::Shape{}, 0.01f);
172+
auto multiply_scale_down = std::make_shared<v1::Multiply>(matmul_input0, scale_down_const);
173+
multiply_scale_down->set_friendly_name(matmul_node->get_friendly_name() + "/multiply_scale_down");
174+
175+
// Create new MatMul with scaled input
176+
auto new_matmul = std::make_shared<v0::MatMul>(multiply_scale_down, matmul_input1,
177+
matmul_node->get_transpose_a(),
178+
matmul_node->get_transpose_b());
179+
new_matmul->set_friendly_name(matmul_node->get_friendly_name() + "/matmul");
180+
181+
// Scale result back up to maintain numerical equivalence
182+
auto scale_up_const = std::make_shared<v0::Constant>(new_matmul->get_element_type(), ov::Shape{}, 100.0f);
183+
auto multiply_scale_up = std::make_shared<v1::Multiply>(new_matmul->output(0), scale_up_const);
184+
multiply_scale_up->set_friendly_name(matmul_node->get_friendly_name() + "/multiply_scale_up");
185+
186+
matmul_output_node = multiply_scale_up;
187+
} else {
188+
// Default behavior: Use MatMul directly without scaling
189+
auto new_matmul = std::make_shared<v0::MatMul>(matmul_input0, matmul_input1,
190+
matmul_node->get_transpose_a(),
191+
matmul_node->get_transpose_b());
192+
new_matmul->set_friendly_name(matmul_node->get_friendly_name() + "/matmul");
193+
194+
matmul_output_node = new_matmul;
195+
}
196+
138197
auto shape = node->get_output_partial_shape(0);
139198
auto internal_hidden_state = std::make_shared<v0::Parameter>(node->get_element_type(), node->get_output_partial_shape(0));
140199
internal_hidden_state->output(0).set_names({"internal_hidden_states"});
141200
internal_hidden_state->set_friendly_name("internal_hidden_states");
142-
// create new eltwise node to add output of MatMul node and internal hidden state input from last cycle of itself
143-
auto new_eltwise = std::make_shared<v1::Add>(internal_hidden_state, matmul_node->output(0));
201+
202+
// Create new Add node (MatMul output + internal_hidden_state)
203+
auto new_eltwise = std::make_shared<v1::Add>(internal_hidden_state, matmul_output_node->output(0));
204+
new_eltwise->set_friendly_name(matmul_node->get_friendly_name() + "/add");
205+
206+
// Replace the original MatMul node with the new Add
144207
ov::replace_node(matmul_node, new_eltwise);
145208
params.push_back(internal_hidden_state);
146209
return true;
147210
}
211+
return false;
148212
}
149213

150214
EagleBaseTransform::EagleBaseTransform(std::vector<std::shared_ptr<v0::Result>>& results) {
@@ -303,8 +367,8 @@ ContinuousBatchingPipeline::Eagle3DecodingImpl::Eagle3DecodingImpl(const ov::gen
303367
// target model: hidden state extraction, draft model: hidden state import , hidden state extraction
304368
// eagle3 specific : dt importing
305369
share_embedding_weights(main_model, draft_model);
306-
extract_hidden_state_generic(main_model, hidden_layers);
307-
extract_hidden_state_generic(draft_model, { -1 });
370+
extract_hidden_state_generic(main_model, hidden_layers, main_device);
371+
extract_hidden_state_generic(draft_model, { -1 }, draft_device);
308372

309373
// to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode
310374
m_main_pipeline = std::make_shared<ContinuousBatchingForEagle3DecodingImpl>(main_model,

src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
#include "openvino/pass/manager.hpp"
1919

2020
namespace ov::genai {
21+
22+
// Forward declarations for Eagle3 transformation functions
23+
void share_embedding_weights(std::shared_ptr<ov::Model>& main_model, std::shared_ptr<ov::Model>& draft_model);
24+
void extract_hidden_state_generic(std::shared_ptr<ov::Model>& model, const std::vector<int>& hidden_layers_to_abstract, const std::string& device = "");
25+
std::shared_ptr<ov::op::v0::Constant> extract_d2t_mapping_table(const std::shared_ptr<ov::Model>& model);
26+
void remove_d2t_result_node(std::shared_ptr<ov::Model>& model);
27+
2128
class ContinuousBatchingPipeline::Eagle3DecodingImpl : public ContinuousBatchingPipeline::SpeculativeDecodingImpl {
2229
public:
2330
template<class Impl>
@@ -73,13 +80,14 @@ class EagleInputTransform : public ov::pass::MatcherPass { // eagle3 specific fo
7380
public:
7481
using NodePtr = std::shared_ptr<ov::Node>;
7582
OPENVINO_MATCHER_PASS_RTTI("EagleInputTransform");
76-
EagleInputTransform(std::vector<std::shared_ptr<ov::op::v0::Parameter>>& params);
83+
EagleInputTransform(std::vector<std::shared_ptr<ov::op::v0::Parameter>>& params, const std::string& device = "");
7784

7885
~EagleInputTransform() = default;
7986

8087
private:
8188
bool apply(NodePtr node, std::vector<std::shared_ptr<ov::op::v0::Parameter>>& params);
8289
size_t applied = 0;
90+
std::string m_device;
8391
};
8492
class Eagle3Transform : public ov::pass::MatcherPass {
8593
public:
@@ -95,11 +103,12 @@ class Eagle3Transform : public ov::pass::MatcherPass {
95103

96104
class EagleModelTransform : public ov::pass::ModelPass {
97105
public:
98-
EagleModelTransform(const std::vector<int>& layer_ids);
106+
EagleModelTransform(const std::vector<int>& layer_ids, const std::string& device = "");
99107
bool run_on_model(const std::shared_ptr<Model>& model) override;
100108

101109
private:
102110
const std::vector<int> m_layer_ids;
111+
std::string m_device;
103112
std::vector<std::shared_ptr<ov::op::v0::Result>> m_new_results;
104113
std::vector<std::shared_ptr<ov::op::v0::Parameter>> m_new_parameters;
105114
std::vector<Output<Node>> m_hidden_layer_outputs;

0 commit comments

Comments
 (0)