diff --git a/src/frontends/onnx/frontend/src/op/lstm.cpp b/src/frontends/onnx/frontend/src/op/lstm.cpp index e004076c0e9ce6..e7bb52ea46e107 100644 --- a/src/frontends/onnx/frontend/src/op/lstm.cpp +++ b/src/frontends/onnx/frontend/src/op/lstm.cpp @@ -13,6 +13,8 @@ #include "openvino/op/lstm_sequence.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/shape_of.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/tile.hpp" #include "openvino/util/common_util.hpp" #include "utils/reshape.hpp" #include "utils/split.hpp" @@ -37,6 +39,32 @@ enum class LSTMInput { LSTM_INPUT_P }; +// Helper function to reduce tensor rank to target_rank by squeezing leading dimensions. +// Per ONNX specification, LSTM requires rank-3 inputs, so extra dimensions must be == 1. +// If extra dimensions are != 1 at runtime, Squeeze will fail with a clear error. +ov::Output reduce_tensor_rank(const ov::Output& input, int64_t target_rank) { + const auto& input_shape = input.get_partial_shape(); + + if (input_shape.rank().is_dynamic()) { + return input; + } + + const auto input_rank = input_shape.rank().get_length(); + + if (input_rank <= target_rank) { + return input; + } + + // Squeeze all leading dimensions to reduce rank to target_rank + std::vector axes_to_squeeze; + for (int64_t i = 0; i < input_rank - target_rank; ++i) { + axes_to_squeeze.push_back(i); + } + + auto axes_const = v0::Constant::create(ov::element::i64, Shape{axes_to_squeeze.size()}, axes_to_squeeze); + return std::make_shared(input, axes_const); +} + struct LSTMNgInputMap { explicit LSTMNgInputMap(const Node& node) { const auto& ng_inputs = node.get_ov_inputs(); @@ -48,7 +76,14 @@ struct LSTMNgInputMap { // Packed input sequences. // ONNX Shape: [seq_length, batch_size, input_size] // OpenVino Shape: [batch_size, seq_length, input_size] - m_input_map[LSTMInput::LSTM_INPUT_X] = ov::op::util::reorder_axes(ng_inputs.at(0), {1, 0, 2}); + + // First reduce rank if needed, THEN reorder axes + // This is important because Squeeze changes dimension indices + auto input_x = ng_inputs.at(0); + input_x = reduce_tensor_rank(input_x, 3); + input_x = ov::op::util::reorder_axes(input_x, {1, 0, 2}); + + m_input_map[LSTMInput::LSTM_INPUT_X] = input_x; // Weight tensor for the gates. // Shape: [num_directions, 4*hidden_size, input_size] @@ -124,7 +159,12 @@ struct LSTMNgInputMap { // ONNX Shape: [num_directions, batch_size, hidden_size] // OpenVino Shape: [batch_size, num_directions, hidden_size] if (ng_inputs.size() > 5 && !ov::op::util::is_null(ng_inputs.at(5))) { - m_input_map[LSTMInput::LSTM_INPUT_INIT_H] = ov::op::util::reorder_axes(ng_inputs.at(5), {1, 0, 2}); + auto init_h = ng_inputs.at(5); + // First reduce rank, THEN reorder axes + init_h = reduce_tensor_rank(init_h, 3); + init_h = ov::op::util::reorder_axes(init_h, {1, 0, 2}); + + m_input_map[LSTMInput::LSTM_INPUT_INIT_H] = init_h; } else { auto init_h_shape = std::make_shared(ov::OutputVector{batch_size_node, num_directions_node, hidden_size_node}, @@ -137,7 +177,12 @@ struct LSTMNgInputMap { // ONNX Shape: [num_directions, batch_size, hidden_size] // OpenVino Shape: [batch_size, num_directions, hidden_size] if (ng_inputs.size() > 6 && !ov::op::util::is_null(ng_inputs.at(6))) { - m_input_map[LSTMInput::LSTM_INPUT_INIT_C] = ov::op::util::reorder_axes(ng_inputs.at(6), {1, 0, 2}); + auto init_c = ng_inputs.at(6); + // First reduce rank, THEN reorder axes + init_c = reduce_tensor_rank(init_c, 3); + init_c = ov::op::util::reorder_axes(init_c, {1, 0, 2}); + + m_input_map[LSTMInput::LSTM_INPUT_INIT_C] = init_c; } else { auto init_c_shape = std::make_shared(ov::OutputVector{batch_size_node, num_directions_node, hidden_size_node}, diff --git a/src/frontends/onnx/tests/models/lstm_high_rank_input.prototxt b/src/frontends/onnx/tests/models/lstm_high_rank_input.prototxt new file mode 100644 index 00000000000000..7ced96e260b29a --- /dev/null +++ b/src/frontends/onnx/tests/models/lstm_high_rank_input.prototxt @@ -0,0 +1,167 @@ +ir_version: 4 +producer_name: "OpenVINO ONNX Frontend" +graph { + node { + input: "X" + input: "W" + input: "R" + input: "B" + output: "Y" + output: "Y_h" + output: "Y_c" + op_type: "LSTM" + attribute { + name: "direction" + s: "forward" + type: STRING + } + attribute { + name: "hidden_size" + i: 2 + type: INT + } + } + name: "compute_graph" + input { + name: "X" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 3 + } + dim { + dim_value: 2 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "W" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "R" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "Y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "Y_h" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "Y_c" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 7 +} diff --git a/src/frontends/onnx/tests/models/lstm_rank4_with_unsqueeze.prototxt b/src/frontends/onnx/tests/models/lstm_rank4_with_unsqueeze.prototxt new file mode 100644 index 00000000000000..313dc2892b405b --- /dev/null +++ b/src/frontends/onnx/tests/models/lstm_rank4_with_unsqueeze.prototxt @@ -0,0 +1,587 @@ +ir_version: 9 +producer_name: "OpenVINO Test" +graph { + node { + input: "X_input" + input: "unsqueeze_axes" + output: "X_unsqueezed" + op_type: "Unsqueeze" + } + node { + input: "X_unsqueezed" + input: "W" + input: "R" + input: "B" + output: "Y" + output: "Y_h" + output: "Y_c" + op_type: "LSTM" + attribute { + name: "direction" + s: "forward" + type: STRING + } + attribute { + name: "hidden_size" + i: 8 + type: INT + } + } + name: "lstm_rank4_with_unsqueeze" + initializer { + dims: 1 + dims: 32 + dims: 4 + data_type: 1 + float_data: 0.496714145 + float_data: -0.138264298 + float_data: 0.647688568 + float_data: 1.5230298 + float_data: -0.234153375 + float_data: -0.234136954 + float_data: 1.57921278 + float_data: 0.767434716 + float_data: -0.469474375 + float_data: 0.542560041 + float_data: -0.463417679 + float_data: -0.465729743 + float_data: 0.241962269 + float_data: -1.91328025 + float_data: -1.72491789 + float_data: -0.562287509 + float_data: -1.01283109 + float_data: 0.31424734 + float_data: -0.908024073 + float_data: -1.41230369 + float_data: 1.46564877 + float_data: -0.2257763 + float_data: 0.0675282 + float_data: -1.42474818 + float_data: -0.544382751 + float_data: 0.11092259 + float_data: -1.15099359 + float_data: 0.37569803 + float_data: -0.600638688 + float_data: -0.291693747 + float_data: -0.601706624 + float_data: 1.85227823 + float_data: -0.013497225 + float_data: -1.05771089 + float_data: 0.822544932 + float_data: -1.22084367 + float_data: 0.208863601 + float_data: -1.95967007 + float_data: -1.32818604 + float_data: 0.196861237 + float_data: 0.738466561 + float_data: 0.171368286 + float_data: -0.115648285 + float_data: -0.301103681 + float_data: -1.47852194 + float_data: -0.719844222 + float_data: -0.460638762 + float_data: 1.05712223 + float_data: 0.343618304 + float_data: -1.76304018 + float_data: 0.324083984 + float_data: -0.385082275 + float_data: -0.676922 + float_data: 0.611676276 + float_data: 1.03099954 + float_data: 0.931280136 + float_data: -0.839217544 + float_data: -0.309212387 + float_data: 0.331263423 + float_data: 0.975545108 + float_data: -0.479174227 + float_data: -0.185658976 + float_data: -1.10633492 + float_data: -1.19620657 + float_data: 0.812525809 + float_data: 1.35624 + float_data: -0.0720101222 + float_data: 1.00353289 + float_data: 0.361636 + float_data: -0.645119727 + float_data: 0.361395597 + float_data: 1.53803658 + float_data: -0.0358260386 + float_data: 1.56464362 + float_data: -2.61974502 + float_data: 0.821902514 + float_data: 0.0870470703 + float_data: -0.299007356 + float_data: 0.0917607769 + float_data: -1.98756886 + float_data: -0.21967189 + float_data: 0.357112557 + float_data: 1.47789407 + float_data: -0.518270195 + float_data: -0.808493614 + float_data: -0.501757 + float_data: 0.915402114 + float_data: 0.328751117 + float_data: -0.529760182 + float_data: 0.513267457 + float_data: 0.0970775485 + float_data: 0.968645 + float_data: -0.70205307 + float_data: -0.32766214 + float_data: -0.392108142 + float_data: -1.46351492 + float_data: 0.296120286 + float_data: 0.261055261 + float_data: 0.00511345686 + float_data: -0.234587133 + float_data: -1.4153707 + float_data: -0.420645326 + float_data: -0.342714518 + float_data: -0.802277267 + float_data: -0.161285713 + float_data: 0.404050857 + float_data: 1.88618588 + float_data: 0.174577817 + float_data: 0.257550389 + float_data: -0.0744459182 + float_data: -1.91877127 + float_data: -0.0265138745 + float_data: 0.0602302104 + float_data: 2.46324205 + float_data: -0.192360967 + float_data: 0.301547348 + float_data: -0.0347117707 + float_data: -1.16867805 + float_data: 1.14282286 + float_data: 0.751933038 + float_data: 0.791031957 + float_data: -0.909387469 + float_data: 1.40279436 + float_data: -1.40185106 + float_data: 0.58685708 + float_data: 2.19045568 + float_data: -0.990536332 + float_data: -0.56629771 + name: "W" + } + initializer { + dims: 1 + dims: 32 + dims: 8 + data_type: 1 + float_data: 0.0996513665 + float_data: -0.503475666 + float_data: -1.55066347 + float_data: 0.068562977 + float_data: -1.06230366 + float_data: 0.47359243 + float_data: -0.919424236 + float_data: 1.54993439 + float_data: -0.783253312 + float_data: -0.322061509 + float_data: 0.813517213 + float_data: -1.23086429 + float_data: 0.227459937 + float_data: 1.30714273 + float_data: -1.60748327 + float_data: 0.184633866 + float_data: 0.259882808 + float_data: 0.78182286 + float_data: -1.23695076 + float_data: -1.32045662 + float_data: 0.521941543 + float_data: 0.296984673 + float_data: 0.250492841 + float_data: 0.346448213 + float_data: -0.680024743 + float_data: 0.2322537 + float_data: 0.293072462 + float_data: -0.714351416 + float_data: 1.86577451 + float_data: 0.473832935 + float_data: -1.19130349 + float_data: 0.656553626 + float_data: -0.974681675 + float_data: 0.787084579 + float_data: 1.15859556 + float_data: -0.820682347 + float_data: 0.963376105 + float_data: 0.412780941 + float_data: 0.822060168 + float_data: 1.89679301 + float_data: -0.24538812 + float_data: -0.753736138 + float_data: -0.889514446 + float_data: -0.815810263 + float_data: -0.0771017075 + float_data: 0.341151983 + float_data: 0.276690811 + float_data: 0.827183247 + float_data: 0.0130018918 + float_data: 1.45353413 + float_data: -0.264656842 + float_data: 2.72016907 + float_data: 0.625667334 + float_data: -0.857157528 + float_data: -1.07089245 + float_data: 0.48247242 + float_data: -0.22346279 + float_data: 0.714000523 + float_data: 0.473237634 + float_data: -0.0728289112 + float_data: -0.846793711 + float_data: -1.51484728 + float_data: -0.446514964 + float_data: 0.856398821 + float_data: 0.214093745 + float_data: -1.24573874 + float_data: 0.173180923 + float_data: 0.385317385 + float_data: -0.883857429 + float_data: 0.153725103 + float_data: 0.0582087189 + float_data: -1.14297032 + float_data: 0.357787371 + float_data: 0.560784519 + float_data: 1.0830512 + float_data: 1.05380201 + float_data: -1.37766933 + float_data: -0.937825 + float_data: 0.515035272 + float_data: 0.513785958 + float_data: 0.515047669 + float_data: 3.85273147 + float_data: 0.570890486 + float_data: 1.13556564 + float_data: 0.954001784 + float_data: 0.651391268 + float_data: -0.315269232 + float_data: 0.758969247 + float_data: -0.772825241 + float_data: -0.236818612 + float_data: -0.485363543 + float_data: 0.0818741396 + float_data: 2.31465864 + float_data: -1.86726522 + float_data: 0.686260164 + float_data: -1.61271584 + float_data: -0.471931875 + float_data: 1.08895063 + float_data: 0.0642800182 + float_data: -1.07774472 + float_data: -0.715303719 + float_data: 0.679597735 + float_data: -0.730366647 + float_data: 0.216458589 + float_data: 0.0455718413 + float_data: -0.651600361 + float_data: 2.14394403 + float_data: 0.633919 + float_data: -2.02514267 + float_data: 0.186454311 + float_data: -0.661786437 + float_data: 0.852433324 + float_data: -0.792520761 + float_data: -0.114736438 + float_data: 0.504987299 + float_data: 0.8657552 + float_data: -1.2002964 + float_data: -0.334501237 + float_data: -0.474945307 + float_data: -0.653329253 + float_data: 1.76545429 + float_data: 0.404981703 + float_data: -1.26088393 + float_data: 0.917861938 + float_data: 2.12215614 + float_data: 1.03246522 + float_data: -1.51937 + float_data: -0.484234065 + float_data: 1.26691115 + float_data: -0.707669437 + float_data: 0.443819433 + float_data: 0.774634063 + float_data: -0.926930487 + float_data: -0.0595253557 + float_data: -3.24126744 + float_data: -1.0243876 + float_data: -0.252568156 + float_data: -1.24778318 + float_data: 1.63241136 + float_data: -1.43014133 + float_data: -0.440044492 + float_data: 0.130740583 + float_data: 1.44127333 + float_data: -1.43586218 + float_data: 1.16316378 + float_data: 0.0102330614 + float_data: -0.981508672 + float_data: 0.462103486 + float_data: 0.199059695 + float_data: -0.600216866 + float_data: 0.0698020831 + float_data: -0.3853136 + float_data: 0.113517344 + float_data: 0.662130654 + float_data: 1.58601677 + float_data: -1.2378155 + float_data: 2.13303328 + float_data: -1.95208776 + float_data: -0.151785091 + float_data: 0.588317215 + float_data: 0.280991882 + float_data: -0.622699499 + float_data: -0.208122253 + float_data: -0.493000925 + float_data: -0.589364767 + float_data: 0.849602103 + float_data: 0.357015491 + float_data: -0.692909598 + float_data: 0.89959985 + float_data: 0.307299525 + float_data: 0.812862098 + float_data: 0.629628837 + float_data: -0.828995 + float_data: -0.560181 + float_data: 0.747293591 + float_data: 0.610370278 + float_data: -0.0209015943 + float_data: 0.117327385 + float_data: 1.2776649 + float_data: -0.591571391 + float_data: 0.547097385 + float_data: -0.202192649 + float_data: -0.217681199 + float_data: 1.09877682 + float_data: 0.825416327 + float_data: 0.813509643 + float_data: 1.30547881 + float_data: 0.0210038424 + float_data: 0.681952953 + float_data: -0.310266763 + float_data: 0.324166358 + float_data: -0.130143061 + float_data: 0.0969959646 + float_data: 0.595157 + float_data: -0.818220675 + float_data: 2.0923872 + float_data: -1.00601733 + float_data: -1.21418858 + float_data: 1.15811086 + float_data: 0.791662693 + float_data: 0.624119818 + float_data: 0.62834549 + float_data: -0.0122467726 + float_data: -0.897254348 + float_data: 0.0758045614 + float_data: -0.677161694 + float_data: 0.97511971 + float_data: -0.147057384 + float_data: -0.82549721 + float_data: -0.321385831 + float_data: 0.412931442 + float_data: -0.563724577 + float_data: -0.822220385 + float_data: 0.243687212 + float_data: 0.244966567 + float_data: -0.506943166 + float_data: -0.471038312 + float_data: 0.232049942 + float_data: -1.44808435 + float_data: -1.40746379 + float_data: -0.718444228 + float_data: -0.213447154 + float_data: 0.310907573 + float_data: 1.47535622 + float_data: 0.857659638 + float_data: -0.159938529 + float_data: -0.0190162081 + float_data: -1.00252938 + float_data: -0.0185131356 + float_data: -0.288658649 + float_data: 0.322718561 + float_data: -0.82723093 + float_data: 0.519346535 + float_data: 1.53273892 + float_data: -0.108760148 + float_data: 0.401711732 + float_data: 0.690144 + float_data: -0.401220471 + float_data: 0.224092484 + float_data: 0.0125924 + float_data: 0.0976761 + float_data: -0.773009777 + float_data: 0.024510175 + float_data: 0.497998297 + float_data: 1.45114362 + float_data: 0.959270835 + float_data: 2.15318251 + float_data: -0.767347574 + float_data: 0.872320652 + float_data: 0.18334201 + float_data: 2.18980289 + float_data: -0.80829829 + float_data: -0.839721859 + float_data: -0.599392653 + float_data: -2.12389565 + float_data: -0.525755048 + name: "R" + } + initializer { + dims: 1 + dims: 64 + data_type: 1 + float_data: -0.759132683 + float_data: 0.150393784 + float_data: 0.341756 + float_data: 1.87617087 + float_data: 0.950423837 + float_data: -0.576903641 + float_data: -0.898414671 + float_data: 0.49191916 + float_data: -1.32023323 + float_data: 1.83145881 + float_data: 1.17944014 + float_data: -0.469175667 + float_data: -1.71313453 + float_data: 1.35387242 + float_data: -0.114539847 + float_data: 1.23781633 + float_data: -1.5944277 + float_data: -0.599375 + float_data: 0.00524369953 + float_data: 0.0469805934 + float_data: -0.450065464 + float_data: 0.622849941 + float_data: -1.0676204 + float_data: -0.142379478 + float_data: 0.120295629 + float_data: 0.514438808 + float_data: 0.711614907 + float_data: -1.12464213 + float_data: -1.53411412 + float_data: 1.27767682 + float_data: 0.332314 + float_data: -0.748486519 + float_data: 1.55115199 + float_data: 0.115674637 + float_data: 1.17929721 + float_data: 0.0675184801 + float_data: 2.06074786 + float_data: 1.75534081 + float_data: -0.248964146 + float_data: 0.971570969 + float_data: 0.645375967 + float_data: 1.3686316 + float_data: -0.964923441 + float_data: 0.686051488 + float_data: 1.05842447 + float_data: -1.75873947 + float_data: -1.18325853 + float_data: -2.03923225 + float_data: -0.269406825 + float_data: 0.717542231 + float_data: 1.50235701 + float_data: 0.0740947798 + float_data: 1.6286155 + float_data: -1.38010144 + float_data: -1.70338249 + float_data: -0.0555477 + float_data: 0.384065449 + float_data: -0.0326947495 + float_data: -2.06744218 + float_data: -0.0891200379 + float_data: -1.30446947 + float_data: 0.669672549 + float_data: 0.366598248 + float_data: -0.939879775 + name: "B" + } + initializer { + dims: 1 + data_type: 7 + int64_data: 0 + name: "unsqueeze_axes" + } + input { + name: "X_input" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 2 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "Y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "Y_h" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "Y_c" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/src/frontends/onnx/tests/models/lstm_rank5_squeeze.prototxt b/src/frontends/onnx/tests/models/lstm_rank5_squeeze.prototxt new file mode 100644 index 00000000000000..b45a56b0b67999 --- /dev/null +++ b/src/frontends/onnx/tests/models/lstm_rank5_squeeze.prototxt @@ -0,0 +1,122 @@ +ir_version: 7 +producer_name: "OpenVINO LSTM Test" +graph { + node { + input: "X" + input: "W" + input: "R" + input: "B" + output: "Y" + output: "Y_h" + output: "Y_c" + op_type: "LSTM" + attribute { + name: "direction" + s: "forward" + type: STRING + } + attribute { + name: "hidden_size" + i: 2 + type: INT + } + } + name: "lstm_rank5_squeeze" + input { + name: "X" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_value: 1 } + dim { dim_value: 1 } + dim { dim_value: 3 } + dim { dim_value: 2 } + dim { dim_value: 4 } + } + } + } + } + input { + name: "W" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_value: 1 } + dim { dim_value: 8 } + dim { dim_value: 4 } + } + } + } + } + input { + name: "R" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_value: 1 } + dim { dim_value: 8 } + dim { dim_value: 2 } + } + } + } + } + input { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_value: 1 } + dim { dim_value: 16 } + } + } + } + } + output { + name: "Y" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_value: 3 } + dim { dim_value: 1 } + dim { dim_value: 2 } + dim { dim_value: 2 } + } + } + } + } + output { + name: "Y_h" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_value: 1 } + dim { dim_value: 2 } + dim { dim_value: 2 } + } + } + } + } + output { + name: "Y_c" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_value: 1 } + dim { dim_value: 2 } + dim { dim_value: 2 } + } + } + } + } +} +opset_import { + domain: "" + version: 14 +} diff --git a/src/frontends/onnx/tests/onnx_import_rnn.in.cpp b/src/frontends/onnx/tests/onnx_import_rnn.in.cpp index db26d0b0245c3b..3c333f0b443cc1 100644 --- a/src/frontends/onnx/tests/onnx_import_rnn.in.cpp +++ b/src/frontends/onnx/tests/onnx_import_rnn.in.cpp @@ -425,6 +425,190 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_lstm_dynamic_batch_size_and_seq_len) { test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1); } +// Test for LSTM with high-rank input (rank reduction) +// Tests the reduce_tensor_rank() function in lstm.cpp +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_lstm_rank5_squeeze) { + if (std::string("${BACKEND_NAME}") == std::string("IE_GPU")) { + GTEST_SKIP() << "GPU backend has known accuracy issues with LSTMSequence Y_h/Y_c outputs. " + << "Root cause: fp16/fp32 buffer mismatch in OneDNN integration."; + } + + auto model = convert_model("lstm_rank5_squeeze.onnx"); + auto test_case = ov::test::TestCase(model, s_device); + + // Input X shape: [1, 1, 3, 2, 4] + std::vector X = {0.8125258088111877f, 1.3562400341033936f, -0.07201012223958969f, 1.003532886505127f, + 0.3616360127925873f, -0.6451197266578674f, 0.36139559745788574f, 1.538036584854126f, + -0.03582603856921196f, 1.5646436214447021f, -2.6197450160980225f, 0.8219025135040283f, + 0.08704707026481628f, -0.2990073561668396f, 0.0917607769370079f, -1.9875688552856445f, + -0.21967189013957977f, 0.3571125566959381f, 1.4778940677642822f, -0.5182701945304871f, + -0.8084936141967773f, -0.501757025718689f, 0.9154021143913269f, 0.3287511169910431f}; + + // Weight W shape: [1, 8, 4] + std::vector W = {0.49671414494514465f, -0.13826429843902588f, 0.6476885676383972f, 1.5230298042297363f, + -0.2341533750295639f, -0.23413695394992828f, 1.5792127847671509f, 0.7674347162246704f, + -0.4694743752479553f, 0.5425600409507751f, -0.4634176790714264f, -0.4657297432422638f, + 0.241962268948555f, -1.9132802486419678f, -1.7249178886413574f, -0.5622875094413757f, + -1.0128310918807983f, 0.31424733996391296f, -0.9080240726470947f, -1.4123036861419678f, + 1.4656487703323364f, -0.2257762998342514f, 0.06752820312976837f, -1.424748182296753f, + -0.5443827509880066f, 0.11092258989810944f, -1.1509935855865479f, 0.3756980299949646f, + -0.6006386876106262f, -0.2916937470436096f, -0.6017066240310669f, 1.852278232574463f}; + + // Recurrent weight R shape: [1, 8, 2] + std::vector R = {-0.013497225008904934f, + -1.057710886001587f, + 0.8225449323654175f, + -1.2208436727523804f, + 0.20886360108852386f, + -1.959670066833496f, + -1.32818603515625f, + 0.19686123728752136f, + 0.7384665608406067f, + 0.1713682860136032f, + -0.1156482845544815f, + -0.3011036813259125f, + -1.4785219430923462f, + -0.7198442220687866f, + -0.46063876152038574f, + 1.0571222305297852f}; + + // Bias B shape: [1, 16] + std::vector B = {0.3436183035373688f, + -1.7630401849746704f, + 0.32408398389816284f, + -0.38508227467536926f, + -0.6769220232963562f, + 0.6116762757301331f, + 1.0309995412826538f, + 0.9312801361083984f, + -0.8392175436019897f, + -0.3092123866081238f, + 0.3312634229660034f, + 0.9755451083183289f, + -0.4791742265224457f, + -0.18565897643566132f, + -1.106334924697876f, + -1.1962065696716309f}; + + // Expected outputs from ONNX Runtime + // Y shape: [3, 1, 2, 2] + std::vector expected_Y = {0.04541102051734924f, + 0.007742831949144602f, + -0.05228624865412712f, + 0.24427476525306702f, + 0.24919648468494415f, + 0.016252165660262108f, + -0.10558786988258362f, + 0.337483674287796f, + -0.1856241226196289f, + -0.027182353660464287f, + -0.1143757626414299f, + 0.14773108065128326f}; + + // Y_h shape: [1, 2, 2] + std::vector expected_Y_h = {-0.1856241226196289f, + -0.027182353660464287f, + -0.1143757626414299f, + 0.14773108065128326f}; + + // Y_c shape: [1, 2, 2] + std::vector expected_Y_c = {-0.3040353059768677f, + -0.4779679775238037f, + -0.31235989928245544f, + 0.3412477970123291f}; + + // Add inputs + test_case.add_input(Shape{1, 1, 3, 2, 4}, X); + test_case.add_input(Shape{1, 8, 4}, W); + test_case.add_input(Shape{1, 8, 2}, R); + test_case.add_input(Shape{1, 16}, B); + + // Add expected outputs + test_case.add_expected_output(Shape{3, 1, 2, 2}, expected_Y); + test_case.add_expected_output(Shape{1, 2, 2}, expected_Y_h); + test_case.add_expected_output(Shape{1, 2, 2}, expected_Y_c); + + test_case.run_with_tolerance_as_fp(1.0e-4f); +} + +// Test for LSTM with Unsqueeze preprocessing (reproduces silero_vad structure) +// Input goes through Unsqueeze(axes=[0]) to get rank-4: [1, seq, batch, input] +// Then LSTM applies rank reduction (Squeeze path) internally +// W, R, B are embedded as initializers, not passed as test inputs +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_lstm_rank4_with_unsqueeze) { + auto model = convert_model("lstm_rank4_with_unsqueeze.onnx"); + auto test_case = ov::test::TestCase(model, s_device); + + // Input X_input shape: [3, 2, 4] + // Inside model: Unsqueeze(axes=[0]) -> [1, 3, 2, 4] + // Then LSTM reduces rank-4 back to rank-3 + std::vector X_input = { + -0.5138669013977051f, -1.0592135190963745f, -0.06267909705638885f, 0.9551423192024231f, -0.9857260584831238f, + 0.5040464997291565f, -0.5302576422691345f, -0.7928728461265564f, -0.10703036189079285f, -1.0352423191070557f, + -0.5536493062973022f, -1.1978778839111328f, 1.964725136756897f, 0.0352635532617569f, -0.6997255086898804f, + 0.21397991478443146f, -0.11232805252075195f, -0.22096960246562958f, 0.6141666769981384f, 0.7575076818466187f, + -0.530501127243042f, -0.5758182406425476f, -0.27505168318748474f, -2.3019211292266846f}; + + // Expected outputs from ONNX Runtime + std::vector expected_Y = { + -0.041506070643663406f, 0.5078101754188538f, -0.1956426054239273f, -0.6245316863059998f, + -0.46069785952568054f, 0.16945451498031616f, -0.036582063883543015f, -0.14519499242305756f, + 0.03458118066191673f, 0.22937199473381042f, -0.33998003602027893f, -0.24845707416534424f, + -0.10919304192066193f, 0.041263557970523834f, -0.01278171967715025f, -0.09499483555555344f, + -0.05660216882824898f, 0.4603622555732727f, -0.6351584196090698f, 0.03479098901152611f, + -0.88917076587677f, 0.037890415638685226f, 0.006493804045021534f, -0.5771874189376831f, + 0.08443683385848999f, -0.06232306361198425f, -0.8248671293258667f, 0.05733451619744301f, + -0.454610675573349f, 0.02996906079351902f, 0.0028792552184313536f, -0.045823417603969574f, + 0.3284722566604614f, 0.4379877746105194f, -0.22600656747817993f, -0.107466921210289f, + -0.6092900633811951f, 0.4817439317703247f, -0.06534680724143982f, -0.039009325206279755f, + 0.191391259431839f, 0.07052464038133621f, -0.33953383564949036f, 0.14264167845249176f, + -0.7233807444572449f, 0.01202191412448883f, 0.009454132989048958f, -0.23135091364383698f}; + + std::vector expected_Y_h = {0.3284722566604614f, + 0.4379877746105194f, + -0.22600656747817993f, + -0.107466921210289f, + -0.6092900633811951f, + 0.4817439317703247f, + -0.06534680724143982f, + -0.039009325206279755f, + 0.191391259431839f, + 0.07052464038133621f, + -0.33953383564949036f, + 0.14264167845249176f, + -0.7233807444572449f, + 0.01202191412448883f, + 0.009454132989048958f, + -0.23135091364383698f}; + + std::vector expected_Y_c = {0.843663215637207f, + 0.4875003397464752f, + -0.501198947429657f, + -0.2064143717288971f, + -1.8175270557403564f, + 0.9995222091674805f, + -0.0911954939365387f, + -1.0586004257202148f, + 0.2041035294532776f, + 0.0721229761838913f, + -1.948965311050415f, + 1.1799299716949463f, + -2.050872564315796f, + 0.5028446316719055f, + 0.1787012219429016f, + -0.24886609613895416f}; + + // Add input (only X, weights are embedded) + test_case.add_input(Shape{3, 2, 4}, X_input); + + // Add expected outputs + test_case.add_expected_output(Shape{3, 1, 2, 8}, expected_Y); + test_case.add_expected_output(Shape{1, 2, 8}, expected_Y_h); + test_case.add_expected_output(Shape{1, 2, 8}, expected_Y_c); + + test_case.run_with_tolerance_as_fp(1.0e-4f); +} + // RNNLikeSequenceOp test fixture for test setup reuse class GRUSequenceOp : public testing::Test { public: