@@ -52,7 +52,7 @@ void share_embedding_weights(std::shared_ptr<ov::Model>& main_model, std::shared
5252 }
5353}
5454
55- std::shared_ptr<ov::op::v0::Constant> extract_d2t_mapping_table (std::shared_ptr<ov::Model>& model) {
55+ std::shared_ptr<ov::op::v0::Constant> extract_d2t_mapping_table (const std::shared_ptr<ov::Model>& model) {
5656 // extract result nodes from model
5757 for (const auto & result : model->get_results ()) {
5858 auto input_node = result->input_value (0 ).get_node_shared_ptr ();
@@ -62,14 +62,36 @@ std::shared_ptr<ov::op::v0::Constant> extract_d2t_mapping_table(std::shared_ptr<
6262 }
6363 return nullptr ;
6464}
65+
66+ void remove_d2t_result_node (std::shared_ptr<ov::Model>& model) {
67+ // Find and remove the d2t Result node
68+ std::shared_ptr<ov::op::v0::Result> d2t_result_to_remove = nullptr ;
69+
70+ for (const auto & result : model->get_results ()) {
71+ auto input_node = result->input_value (0 ).get_node_shared_ptr ();
72+ if (ov::is_type<ov::op::v0::Constant>(input_node) &&
73+ input_node->get_friendly_name ().find (" d2t" ) != std::string::npos) {
74+ d2t_result_to_remove = result;
75+ break ;
76+ }
77+ }
78+
79+ if (d2t_result_to_remove) {
80+ model->remove_result (d2t_result_to_remove);
81+ model->validate_nodes_and_infer_types ();
82+ }
83+ }
84+
6585void extract_hidden_state_generic (std::shared_ptr<ov::Model>& model,
66- const std::vector<int >& hidden_layers_to_abstract) {
86+ const std::vector<int >& hidden_layers_to_abstract,
87+ const std::string& device) {
6788 ov::pass::Manager pm;
68- pm.register_pass <EagleModelTransform>(hidden_layers_to_abstract);
89+ pm.register_pass <EagleModelTransform>(hidden_layers_to_abstract, device );
6990 pm.run_passes (model);
7091}
7192
72- EagleModelTransform::EagleModelTransform (const std::vector<int >& layers) : m_layer_ids(layers) {
93+ EagleModelTransform::EagleModelTransform (const std::vector<int >& layers, const std::string& device)
94+ : m_layer_ids(layers), m_device(device) {
7395}
7496
7597bool EagleModelTransform::run_on_model (const std::shared_ptr<ov::Model>& model) {
@@ -82,7 +104,7 @@ bool EagleModelTransform::run_on_model(const std::shared_ptr<ov::Model>& model)
82104 manager.register_pass <EagleBaseTransform>(m_new_results);
83105 // input transform for draft
84106 // here we apply a trick for the fc layer in draft model
85- manager.register_pass <EagleInputTransform>(m_new_parameters);
107+ manager.register_pass <EagleInputTransform>(m_new_parameters, m_device );
86108 manager.run_passes (model);
87109
88110 model->add_parameters (m_new_parameters);
@@ -109,7 +131,8 @@ bool EagleModelTransform::run_on_model(const std::shared_ptr<ov::Model>& model)
109131 return false ;
110132}
111133
112- EagleInputTransform::EagleInputTransform (std::vector<std::shared_ptr<v0::Parameter>>& params) {
134+ EagleInputTransform::EagleInputTransform (std::vector<std::shared_ptr<v0::Parameter>>& params, const std::string& device)
135+ : m_device(device) {
113136 register_matcher (
114137 std::make_shared<ov::pass::pattern::Matcher>(ov::pass::pattern::wrap_type<v0::MatMul>(), this ->get_type_info ().name ),
115138 ([¶ms, this ](ov::pass::pattern::Matcher& m) {
@@ -126,6 +149,7 @@ EagleInputTransform::EagleInputTransform(std::vector<std::shared_ptr<v0::Paramet
126149 })
127150 );
128151}
152+
129153bool EagleInputTransform::apply (NodePtr node, std::vector<std::shared_ptr<v0::Parameter>>& params) {
130154 if (ov::is_type<v0::MatMul>(node)) {
131155 auto matmul_node = ov::as_type_ptr<v0::MatMul>(node);
@@ -135,16 +159,56 @@ bool EagleInputTransform::apply(NodePtr node, std::vector<std::shared_ptr<v0::Pa
135159 return false ;
136160 }
137161
162+ auto matmul_input0 = matmul_node->input_value (0 );
163+ auto matmul_input1 = matmul_node->input_value (1 );
164+
165+ std::shared_ptr<ov::Node> matmul_output_node;
166+
167+ // Apply scaling optimization for NPU devices to prevent FP16 overflow
168+ if (m_device.find (" NPU" ) != std::string::npos) {
169+ // Scale input down by 100x before MatMul to avoid FP16 overflow, then scale result back up
170+ // The factor 100 (0.01 and 100.0) is an empirical value
171+ auto scale_down_const = std::make_shared<v0::Constant>(matmul_input0.get_element_type (), ov::Shape{}, 0 .01f );
172+ auto multiply_scale_down = std::make_shared<v1::Multiply>(matmul_input0, scale_down_const);
173+ multiply_scale_down->set_friendly_name (matmul_node->get_friendly_name () + " /multiply_scale_down" );
174+
175+ // Create new MatMul with scaled input
176+ auto new_matmul = std::make_shared<v0::MatMul>(multiply_scale_down, matmul_input1,
177+ matmul_node->get_transpose_a (),
178+ matmul_node->get_transpose_b ());
179+ new_matmul->set_friendly_name (matmul_node->get_friendly_name () + " /matmul" );
180+
181+ // Scale result back up to maintain numerical equivalence
182+ auto scale_up_const = std::make_shared<v0::Constant>(new_matmul->get_element_type (), ov::Shape{}, 100 .0f );
183+ auto multiply_scale_up = std::make_shared<v1::Multiply>(new_matmul->output (0 ), scale_up_const);
184+ multiply_scale_up->set_friendly_name (matmul_node->get_friendly_name () + " /multiply_scale_up" );
185+
186+ matmul_output_node = multiply_scale_up;
187+ } else {
188+ // Default behavior: Use MatMul directly without scaling
189+ auto new_matmul = std::make_shared<v0::MatMul>(matmul_input0, matmul_input1,
190+ matmul_node->get_transpose_a (),
191+ matmul_node->get_transpose_b ());
192+ new_matmul->set_friendly_name (matmul_node->get_friendly_name () + " /matmul" );
193+
194+ matmul_output_node = new_matmul;
195+ }
196+
138197 auto shape = node->get_output_partial_shape (0 );
139198 auto internal_hidden_state = std::make_shared<v0::Parameter>(node->get_element_type (), node->get_output_partial_shape (0 ));
140199 internal_hidden_state->output (0 ).set_names ({" internal_hidden_states" });
141200 internal_hidden_state->set_friendly_name (" internal_hidden_states" );
142- // create new eltwise node to add output of MatMul node and internal hidden state input from last cycle of itself
143- auto new_eltwise = std::make_shared<v1::Add>(internal_hidden_state, matmul_node->output (0 ));
201+
202+ // Create new Add node (MatMul output + internal_hidden_state)
203+ auto new_eltwise = std::make_shared<v1::Add>(internal_hidden_state, matmul_output_node->output (0 ));
204+ new_eltwise->set_friendly_name (matmul_node->get_friendly_name () + " /add" );
205+
206+ // Replace the original MatMul node with the new Add
144207 ov::replace_node (matmul_node, new_eltwise);
145208 params.push_back (internal_hidden_state);
146209 return true ;
147210 }
211+ return false ;
148212}
149213
150214EagleBaseTransform::EagleBaseTransform (std::vector<std::shared_ptr<v0::Result>>& results) {
@@ -303,8 +367,8 @@ ContinuousBatchingPipeline::Eagle3DecodingImpl::Eagle3DecodingImpl(const ov::gen
303367 // target model: hidden state extraction, draft model: hidden state import , hidden state extraction
304368 // eagle3 specific : dt importing
305369 share_embedding_weights (main_model, draft_model);
306- extract_hidden_state_generic (main_model, hidden_layers);
307- extract_hidden_state_generic (draft_model, { -1 });
370+ extract_hidden_state_generic (main_model, hidden_layers, main_device );
371+ extract_hidden_state_generic (draft_model, { -1 }, draft_device );
308372
309373 // to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode
310374 m_main_pipeline = std::make_shared<ContinuousBatchingForEagle3DecodingImpl>(main_model,
0 commit comments