Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions src/frontends/onnx/frontend/src/op/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
#include "openvino/op/gather.hpp"
#include "openvino/op/lstm_sequence.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/tile.hpp"
#include "openvino/util/common_util.hpp"
#include "utils/reshape.hpp"
#include "utils/split.hpp"
Expand All @@ -37,6 +41,49 @@ enum class LSTMInput {
LSTM_INPUT_P
};

// Helper function to reduce tensor rank to target_rank by squeezing or reshaping
std::shared_ptr<ov::Node> reduce_tensor_rank(const ov::Output<ov::Node>& input,
int64_t target_rank,
const std::string& debug_name) {
const auto& input_shape = input.get_partial_shape();

if (!input_shape.rank().is_static()) {
return input.get_node_shared_ptr();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!input_shape.rank().is_static()) {
return input.get_node_shared_ptr();
}
if (input_shape.rank().is_dynamic()) {
return input;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed


const auto input_rank = input_shape.rank().get_length();

if (input_rank <= target_rank) {
return input.get_node_shared_ptr();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return input.get_node_shared_ptr();
return input;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed

}

// Strategy: Try to squeeze all leading dimensions that are equal to 1
std::vector<int64_t> axes_to_squeeze;
for (int64_t i = 0; i < input_rank - target_rank; ++i) {
if (input_shape[i].is_static() && input_shape[i].get_length() == 1) {
axes_to_squeeze.push_back(i);
}
}

if (axes_to_squeeze.size() == static_cast<size_t>(input_rank - target_rank)) {
// All extra dimensions are 1, we can squeeze to get target rank
auto axes_const = v0::Constant::create(ov::element::i64, Shape{axes_to_squeeze.size()}, axes_to_squeeze);
return std::make_shared<v0::Squeeze>(input, axes_const);
} else {
// Some dimensions are not 1 or dynamic, need to reshape
auto shape_of_input = std::make_shared<v3::ShapeOf>(input);
auto start_idx = v0::Constant::create(ov::element::i64, Shape{1}, {input_rank - target_rank});
auto stop_idx = v0::Constant::create(ov::element::i64, Shape{1}, {input_rank});
auto step = v0::Constant::create(ov::element::i64, Shape{1}, {1});

// Get last target_rank dimensions: shape[-target_rank:]
auto last_dims = std::make_shared<v8::Slice>(shape_of_input, start_idx, stop_idx, step);

// Reshape to extract last target_rank dimensions
return std::make_shared<v1::Reshape>(input, last_dims, false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can extra input dimensions be not 1? In such case the reshape will fail

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Investigation shows how ONNX Runtime handles this case: ONNX Runtime strictly validates rank=3: https://github.com/microsoft/onnxruntime/blob/423a03f1fc80d3cbed4f973574ee96f31521a3d3/onnxruntime/core/providers/cpu/rnn/lstm_base.cc#L191-L192

if (X_shape.NumDimensions() != 3)
    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
        "Input X must have 3 dimensions only. Actual:", X_shape);

ONNX specification also requires rank=3: https://github.com/onnx/onnx/blob/main/onnx/defs/rnn/defs.cc#L26-L28

if (first_input_shape.dim_size() != 3) {
    fail_shape_inference("First input tensor must have rank 3");
}

ONNX Runtime does not support inputs with rank > 3 for LSTM. It simply fails with an error. Our approach in OpenVINO is an extension that handles models where Unsqueeze operations precede LSTM nodes.
I updated solution:
I removed the Reshape fallback since it's incorrect for dimensions != 1
Now we always use Squeeze for all leading dimensions
Per ONNX spec, LSTM requires rank=3, so extra dimensions MUST be == 1 (from Unsqueeze operations)
If somehow extra dimensions are != 1 at runtime, Squeeze will fail with a clear error message, which is the correct behavior since such input violates ONNX LSTM specification

}
}

struct LSTMNgInputMap {
explicit LSTMNgInputMap(const Node& node) {
const auto& ng_inputs = node.get_ov_inputs();
Expand All @@ -48,7 +95,14 @@ struct LSTMNgInputMap {
// Packed input sequences.
// ONNX Shape: [seq_length, batch_size, input_size]
// OpenVino Shape: [batch_size, seq_length, input_size]
m_input_map[LSTMInput::LSTM_INPUT_X] = ov::op::util::reorder_axes(ng_inputs.at(0), {1, 0, 2});

// First reduce rank if needed, THEN reorder axes
// This is important because Squeeze changes dimension indices
auto input_x = ng_inputs.at(0);
input_x = reduce_tensor_rank(input_x, 3, "X_before_reorder");
input_x = ov::op::util::reorder_axes(input_x, {1, 0, 2});

m_input_map[LSTMInput::LSTM_INPUT_X] = input_x;

// Weight tensor for the gates.
// Shape: [num_directions, 4*hidden_size, input_size]
Expand Down Expand Up @@ -124,7 +178,12 @@ struct LSTMNgInputMap {
// ONNX Shape: [num_directions, batch_size, hidden_size]
// OpenVino Shape: [batch_size, num_directions, hidden_size]
if (ng_inputs.size() > 5 && !ov::op::util::is_null(ng_inputs.at(5))) {
m_input_map[LSTMInput::LSTM_INPUT_INIT_H] = ov::op::util::reorder_axes(ng_inputs.at(5), {1, 0, 2});
auto init_h = ng_inputs.at(5);
// First reduce rank, THEN reorder axes
init_h = reduce_tensor_rank(init_h, 3, "initial_h_before_reorder");
init_h = ov::op::util::reorder_axes(init_h, {1, 0, 2});

m_input_map[LSTMInput::LSTM_INPUT_INIT_H] = init_h;
} else {
auto init_h_shape =
std::make_shared<v0::Concat>(ov::OutputVector{batch_size_node, num_directions_node, hidden_size_node},
Expand All @@ -137,7 +196,12 @@ struct LSTMNgInputMap {
// ONNX Shape: [num_directions, batch_size, hidden_size]
// OpenVino Shape: [batch_size, num_directions, hidden_size]
if (ng_inputs.size() > 6 && !ov::op::util::is_null(ng_inputs.at(6))) {
m_input_map[LSTMInput::LSTM_INPUT_INIT_C] = ov::op::util::reorder_axes(ng_inputs.at(6), {1, 0, 2});
auto init_c = ng_inputs.at(6);
// First reduce rank, THEN reorder axes
init_c = reduce_tensor_rank(init_c, 3, "initial_c_before_reorder");
init_c = ov::op::util::reorder_axes(init_c, {1, 0, 2});

m_input_map[LSTMInput::LSTM_INPUT_INIT_C] = init_c;
} else {
auto init_c_shape =
std::make_shared<v0::Concat>(ov::OutputVector{batch_size_node, num_directions_node, hidden_size_node},
Expand Down
167 changes: 167 additions & 0 deletions src/frontends/onnx/tests/models/lstm_high_rank_input.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
ir_version: 4
producer_name: "OpenVINO ONNX Frontend"
graph {
node {
input: "X"
input: "W"
input: "R"
input: "B"
output: "Y"
output: "Y_h"
output: "Y_c"
op_type: "LSTM"
attribute {
name: "direction"
s: "forward"
type: STRING
}
attribute {
name: "hidden_size"
i: 2
type: INT
}
}
name: "compute_graph"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 2
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 8
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 8
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 16
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y_c"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 7
}
Loading
Loading