Skip to content

Commit b00c4ff

Browse files
authored
[Matrix][IR] Cap stride bitwidth at 64 (#163729)
a1ef81d added overloads for `llvm.matrix.column.major.store` and `llvm.matrix.column.major.load` that allow strides to occupy an arbitrary bitwidth. This change wasn't reflected in the verifier, causing an assertion to trip when given strides overflowing 64-bit. This patch explicitly caps the bitwidth at 64, repairing the crash and avoiding future complexity dealing with strides that overflow 64 bits. PR: llvm/llvm-project#163729
1 parent 8c72b2a commit b00c4ff

File tree

5 files changed

+43
-346
lines changed

5 files changed

+43
-346
lines changed

llvm/docs/LangRef.rst

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21062,12 +21062,15 @@ integer element type.
2106221062

2106321063
Syntax:
2106421064
"""""""
21065-
This is an overloaded intrinsic.
21065+
This is an overloaded intrinsic. You can use ``llvm.matrix.column.major.load``
21066+
to load any vector type with a stride of any bitwidth up to 64.
2106621067

2106721068
::
2106821069

21069-
declare vectorty @llvm.matrix.column.major.load.*(
21070+
declare <4 x i32> @llvm.matrix.column.major.load.v4i32.i64(
2107021071
ptrty %Ptr, i64 %Stride, i1 <IsVolatile>, i32 <Rows>, i32 <Cols>)
21072+
declare <9 x double> @llvm.matrix.column.major.load.v9f64.i32(
21073+
ptrty %Ptr, i32 %Stride, i1 <IsVolatile>, i32 <Rows>, i32 <Cols>)
2107121074

2107221075
Overview:
2107321076
"""""""""
@@ -21086,9 +21089,9 @@ Arguments:
2108621089

2108721090
The first argument ``%Ptr`` is a pointer type to the returned vector type, and
2108821091
corresponds to the start address to load from. The second argument ``%Stride``
21089-
is a positive, constant integer with ``%Stride >= <Rows>``. ``%Stride`` is used
21090-
to compute the column memory addresses. I.e., for a column ``C``, its start
21091-
memory addresses is calculated with ``%Ptr + C * %Stride``. The third Argument
21092+
is a positive integer for which ``%Stride >= <Rows>``. ``%Stride`` is used to
21093+
compute the column memory addresses. I.e., for a column ``C``, its start memory
21094+
addresses is calculated with ``%Ptr + C * %Stride``. The third Argument
2109221095
``<IsVolatile>`` is a boolean value. The fourth and fifth arguments,
2109321096
``<Rows>`` and ``<Cols>``, correspond to the number of rows and columns,
2109421097
respectively, and must be positive, constant integers. The returned vector must
@@ -21103,11 +21106,17 @@ The :ref:`align <attr_align>` parameter attribute can be provided for the
2110321106

2110421107
Syntax:
2110521108
"""""""
21109+
This is an overloaded intrinsic. ``llvm.matrix.column.major.store`` to store
21110+
any vector type with a stride of any bitwidth up to 64.
2110621111

2110721112
::
2110821113

21109-
declare void @llvm.matrix.column.major.store.*(
21110-
vectorty %In, ptrty %Ptr, i64 %Stride, i1 <IsVolatile>, i32 <Rows>, i32 <Cols>)
21114+
declare void @llvm.matrix.column.major.store.v4i32.i64(
21115+
<4 x i32> %In, ptrty %Ptr, i64 %Stride, i1 <IsVolatile>, i32 <Rows>,
21116+
i32 <Cols>)
21117+
declare void @llvm.matrix.column.major.store.v9f64.i32(
21118+
<9 x double> %In, ptrty %Ptr, i32 %Stride, i1 <IsVolatile>, i32
21119+
<Rows>, i32 <Cols>)
2111121120

2111221121
Overview:
2111321122
"""""""""
@@ -21127,7 +21136,7 @@ Arguments:
2112721136
The first argument ``%In`` is a vector that corresponds to a ``<Rows> x
2112821137
<Cols>`` matrix to be stored to memory. The second argument ``%Ptr`` is a
2112921138
pointer to the vector type of ``%In``, and is the start address of the matrix
21130-
in memory. The third argument ``%Stride`` is a positive, constant integer with
21139+
in memory. The third argument ``%Stride`` is a positive integer for which
2113121140
``%Stride >= <Rows>``. ``%Stride`` is used to compute the column memory
2113221141
addresses. I.e., for a column ``C``, its start memory addresses is calculated
2113321142
with ``%Ptr + C * %Stride``. The fourth argument ``<IsVolatile>`` is a boolean

llvm/lib/IR/Verifier.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6479,9 +6479,12 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
64796479
NumRows->getZExtValue() * NumColumns->getZExtValue(),
64806480
"Result of a matrix operation does not fit in the returned vector!");
64816481

6482-
if (Stride)
6482+
if (Stride) {
6483+
Check(Stride->getBitWidth() <= 64, "Stride bitwidth cannot exceed 64!",
6484+
IF);
64836485
Check(Stride->getZExtValue() >= NumRows->getZExtValue(),
64846486
"Stride must be greater or equal than the number of rows!", IF);
6487+
}
64856488

64866489
break;
64876490
}

0 commit comments

Comments
 (0)