-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Fix LSTM conversion for models with rank > 3 inputs from Unsqueeze operations #33023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 5 commits
d1689ce
de3353d
b06e7f2
c58c5bc
5ba2f09
08c7dda
5bc7dca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,11 @@ | |||||||||||||
| #include "openvino/op/gather.hpp" | ||||||||||||||
| #include "openvino/op/lstm_sequence.hpp" | ||||||||||||||
| #include "openvino/op/multiply.hpp" | ||||||||||||||
| #include "openvino/op/reshape.hpp" | ||||||||||||||
| #include "openvino/op/shape_of.hpp" | ||||||||||||||
| #include "openvino/op/slice.hpp" | ||||||||||||||
| #include "openvino/op/squeeze.hpp" | ||||||||||||||
| #include "openvino/op/tile.hpp" | ||||||||||||||
| #include "openvino/util/common_util.hpp" | ||||||||||||||
| #include "utils/reshape.hpp" | ||||||||||||||
| #include "utils/split.hpp" | ||||||||||||||
|
|
@@ -37,6 +41,49 @@ enum class LSTMInput { | |||||||||||||
| LSTM_INPUT_P | ||||||||||||||
| }; | ||||||||||||||
|
|
||||||||||||||
| // Helper function to reduce tensor rank to target_rank by squeezing or reshaping | ||||||||||||||
| std::shared_ptr<ov::Node> reduce_tensor_rank(const ov::Output<ov::Node>& input, | ||||||||||||||
| int64_t target_rank, | ||||||||||||||
| const std::string& debug_name) { | ||||||||||||||
| const auto& input_shape = input.get_partial_shape(); | ||||||||||||||
|
|
||||||||||||||
| if (!input_shape.rank().is_static()) { | ||||||||||||||
| return input.get_node_shared_ptr(); | ||||||||||||||
| } | ||||||||||||||
|
||||||||||||||
| if (!input_shape.rank().is_static()) { | |
| return input.get_node_shared_ptr(); | |
| } | |
| if (input_shape.rank().is_dynamic()) { | |
| return input; | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return input.get_node_shared_ptr(); | |
| return input; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can extra input dimensions be not 1? In such case the reshape will fail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Investigation shows how ONNX Runtime handles this case: ONNX Runtime strictly validates rank=3: https://github.com/microsoft/onnxruntime/blob/423a03f1fc80d3cbed4f973574ee96f31521a3d3/onnxruntime/core/providers/cpu/rnn/lstm_base.cc#L191-L192
if (X_shape.NumDimensions() != 3)
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input X must have 3 dimensions only. Actual:", X_shape);
ONNX specification also requires rank=3: https://github.com/onnx/onnx/blob/main/onnx/defs/rnn/defs.cc#L26-L28
if (first_input_shape.dim_size() != 3) {
fail_shape_inference("First input tensor must have rank 3");
}
ONNX Runtime does not support inputs with rank > 3 for LSTM. It simply fails with an error. Our approach in OpenVINO is an extension that handles models where Unsqueeze operations precede LSTM nodes.
I updated solution:
I removed the Reshape fallback since it's incorrect for dimensions != 1
Now we always use Squeeze for all leading dimensions
Per ONNX spec, LSTM requires rank=3, so extra dimensions MUST be == 1 (from Unsqueeze operations)
If somehow extra dimensions are != 1 at runtime, Squeeze will fail with a clear error message, which is the correct behavior since such input violates ONNX LSTM specification
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,167 @@ | ||
| ir_version: 4 | ||
| producer_name: "OpenVINO ONNX Frontend" | ||
| graph { | ||
| node { | ||
| input: "X" | ||
| input: "W" | ||
| input: "R" | ||
| input: "B" | ||
| output: "Y" | ||
| output: "Y_h" | ||
| output: "Y_c" | ||
| op_type: "LSTM" | ||
| attribute { | ||
| name: "direction" | ||
| s: "forward" | ||
| type: STRING | ||
| } | ||
| attribute { | ||
| name: "hidden_size" | ||
| i: 2 | ||
| type: INT | ||
| } | ||
| } | ||
| name: "compute_graph" | ||
| input { | ||
| name: "X" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 3 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| dim { | ||
| dim_value: 4 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| input { | ||
| name: "W" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 8 | ||
| } | ||
| dim { | ||
| dim_value: 4 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| input { | ||
| name: "R" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 8 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| input { | ||
| name: "B" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 16 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| output { | ||
| name: "Y" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 3 | ||
| } | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| output { | ||
| name: "Y_h" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| output { | ||
| name: "Y_c" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 1 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| dim { | ||
| dim_value: 2 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| opset_import { | ||
| version: 7 | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.