Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
40d147a
Basic framework for WaveActive ops.
alsepkow Nov 4, 2025
a0fb36d
Cleanup. Switch to default validation so we get 1 ULP of tolerance fo…
alsepkow Nov 5, 2025
344cafd
check device in test method setup. default to min wave size instead
alsepkow Nov 6, 2025
46011c3
WIP on some other WaveActive ops
alsepkow Nov 6, 2025
0d866c8
Fix XML
alsepkow Nov 6, 2025
b1bbfb9
WIP
alsepkow Nov 6, 2025
9d06bdd
Remove WaveActiveBitAnd
alsepkow Nov 6, 2025
93f43c1
Naming
alsepkow Nov 7, 2025
1148fdd
Remove todo in xml
alsepkow Nov 7, 2025
0e19c03
WaveActiveBit Ops
alsepkow Nov 8, 2025
8779d3f
Unreferenced
alsepkow Nov 8, 2025
c32ba7d
Add AllEqual
alsepkow Nov 8, 2025
70d00d8
All on Wave active macro
alsepkow Nov 12, 2025
3ee0d08
Fix namig for read
alsepkow Nov 12, 2025
916f878
Cleanup. Remove WaveCountBits
alsepkow Nov 13, 2025
470e9af
Fix the prefix ops
alsepkow Nov 13, 2025
842e639
Comments
alsepkow Nov 13, 2025
eefa04e
Clang format
alsepkow Nov 13, 2025
0b1a8d7
merge conflict
alsepkow Nov 13, 2025
1939982
Actually fix merge conflict
alsepkow Nov 13, 2025
4e6f6eb
Fix input set name
alsepkow Nov 13, 2025
87bc68f
MultiWavePrefixBitAnd
alsepkow Nov 14, 2025
c6b5ecf
Move comment for uint16_t WaveActiveBit ops
alsepkow Nov 17, 2025
5edfa7d
Xor and some cleanp
alsepkow Nov 18, 2025
f5e38f9
Clangity clang format
alsepkow Nov 18, 2025
ed88e0e
Finish multi ops. Needs a little tidy
alsepkow Nov 18, 2025
ae9940b
Merge main
alsepkow Nov 18, 2025
e80ce91
Comment cleanup. Fix typo
alsepkow Nov 18, 2025
cd9f2e7
Update to use numeric limits and -1 in bitwise sets
alsepkow Nov 21, 2025
d5f450a
Use keys and WaveMatch for the WaveMulti tests
alsepkow Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tools/clang/unittests/HLSLExec/LongVectorOps.def
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ INPUT_SET(Bitwise)
INPUT_SET(SelectCond)
INPUT_SET(FloatSpecial)
INPUT_SET(AllOnes)
INPUT_SET(WaveMultiPrefixBitwise)

#undef INPUT_SET

Expand Down Expand Up @@ -207,5 +208,10 @@ OP_DEFAULT_DEFINES(Wave, WaveReadLaneAt, 1, "TestWaveReadLaneAt", "", " -DFUNC_W
OP_DEFAULT_DEFINES(Wave, WaveReadLaneFirst, 1, "TestWaveReadLaneFirst", "", " -DFUNC_WAVE_READ_LANE_FIRST=1")
OP_DEFAULT_DEFINES(Wave, WavePrefixSum, 1, "TestWavePrefixSum", "", " -DFUNC_WAVE_PREFIX_SUM=1 -DIS_WAVE_PREFIX_OP=1")
OP_DEFAULT_DEFINES(Wave, WavePrefixProduct, 1, "TestWavePrefixProduct", "", " -DFUNC_WAVE_PREFIX_PRODUCT=1 -DIS_WAVE_PREFIX_OP=1")
OP(Wave, WaveMultiPrefixSum, 1, "TestWaveMultiPrefixSum", "", " -DFUNC_WAVE_MULTI_PREFIX_SUM=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", Default1, Default2, Default3)
OP(Wave, WaveMultiPrefixProduct, 1, "TestWaveMultiPrefixProduct", "", " -DFUNC_WAVE_MULTI_PREFIX_PRODUCT=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", Default1, Default2, Default3)
OP(Wave, WaveMultiPrefixBitAnd, 1, "TestWaveMultiPrefixBitAnd", "", " -DFUNC_WAVE_MULTI_PREFIX_BIT_AND=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", WaveMultiPrefixBitwise, Default2, Default3)
OP(Wave, WaveMultiPrefixBitOr, 1, "TestWaveMultiPrefixBitOr", "", " -DFUNC_WAVE_MULTI_PREFIX_BIT_OR=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", WaveMultiPrefixBitwise, Default2, Default3)
OP(Wave, WaveMultiPrefixBitXor, 1, "TestWaveMultiPrefixBitXor", "", " -DFUNC_WAVE_MULTI_PREFIX_BIT_XOR=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", WaveMultiPrefixBitwise, Default2, Default3)

#undef OP
12 changes: 12 additions & 0 deletions tools/clang/unittests/HLSLExec/LongVectorTestData.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ INPUT_SET(InputSet::Bitwise, std::numeric_limits<int16_t>::min(), -1, 0, 1, 3,
std::numeric_limits<int16_t>::max());
INPUT_SET(InputSet::SelectCond, 0, 1);
INPUT_SET(InputSet::AllOnes, 1);
INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0x10, 0x12, 0xF,
-1);
END_INPUT_SETS()

BEGIN_INPUT_SETS(int32_t)
Expand All @@ -304,6 +306,8 @@ INPUT_SET(InputSet::Bitwise, std::numeric_limits<int32_t>::min(), -1, 0, 1, 3,
std::numeric_limits<int32_t>::max());
INPUT_SET(InputSet::SelectCond, 0, 1);
INPUT_SET(InputSet::AllOnes, 1);
INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0x10, 0x12, 0xF,
-1);
END_INPUT_SETS()

BEGIN_INPUT_SETS(int64_t)
Expand All @@ -318,6 +322,8 @@ INPUT_SET(InputSet::Bitwise, std::numeric_limits<int64_t>::min(), -1, 0, 1, 3,
std::numeric_limits<int64_t>::max());
INPUT_SET(InputSet::SelectCond, 0, 1);
INPUT_SET(InputSet::AllOnes, 1);
INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0x10, 0x12, 0xF,
-1ll);
END_INPUT_SETS()

BEGIN_INPUT_SETS(uint16_t)
Expand All @@ -329,6 +335,8 @@ INPUT_SET(InputSet::Bitwise, 0, 1, 3, 6, 9, 0x5555, 0xAAAA, 0x8000, 127,
std::numeric_limits<uint16_t>::max());
INPUT_SET(InputSet::SelectCond, 0, 1);
INPUT_SET(InputSet::AllOnes, 1);
INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0x10, 0x12, 0xF,
std::numeric_limits<uint16_t>::max());
END_INPUT_SETS()

BEGIN_INPUT_SETS(uint32_t)
Expand All @@ -340,6 +348,8 @@ INPUT_SET(InputSet::Bitwise, 0, 1, 3, 6, 9, 0x55555555, 0xAAAAAAAA, 0x80000000,
127, std::numeric_limits<uint32_t>::max());
INPUT_SET(InputSet::SelectCond, 0, 1);
INPUT_SET(InputSet::AllOnes, 1);
INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0xA, 0xC, 0xF,
std::numeric_limits<uint32_t>::max());
END_INPUT_SETS()

BEGIN_INPUT_SETS(uint64_t)
Expand All @@ -352,6 +362,8 @@ INPUT_SET(InputSet::Bitwise, 0, 1, 3, 6, 9, 0x5555555555555555,
std::numeric_limits<uint64_t>::max());
INPUT_SET(InputSet::SelectCond, 0, 1);
INPUT_SET(InputSet::AllOnes, 1);
INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0xA, 0xC, 0xF,
std::numeric_limits<uint64_t>::max());
END_INPUT_SETS()

BEGIN_INPUT_SETS(HLSLHalf_t)
Expand Down
114 changes: 109 additions & 5 deletions tools/clang/unittests/HLSLExec/LongVectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,7 @@ template <typename T> T waveActiveBitAnd(T A, UINT) {
WAVE_OP(OpType::WaveActiveBitAnd, (waveActiveBitAnd(A, WaveSize)));

template <typename T> T waveActiveBitOr(T A, UINT) {
// We set the LSB to 0 in one of the lanes.
// We set the LSB to 1 in one of the lanes.
return static_cast<T>(A | static_cast<T>(1));
}

Expand All @@ -1362,6 +1362,60 @@ template <typename T> T waveActiveBitXor(T A, UINT) {

WAVE_OP(OpType::WaveActiveBitXor, (waveActiveBitXor(A, WaveSize)));

WAVE_OP(OpType::WaveMultiPrefixBitAnd, waveMultiPrefixBitAnd(A, WaveSize));

template <typename T> T waveMultiPrefixBitAnd(T A, UINT) {
// All lanes in the group mask use a mask to filter for only the second and
// third LSBs.
return static_cast<T>(A & static_cast<T>(0x6));
}

WAVE_OP(OpType::WaveMultiPrefixBitOr, waveMultiPrefixBitOr(A, WaveSize));

template <typename T> T waveMultiPrefixBitOr(T A, UINT) {
// All lanes in the group mask clear the second LSB.
return static_cast<T>(A & ~static_cast<T>(0x2));
}

template <typename T>
struct Op<OpType::WaveMultiPrefixBitXor, T, 1> : StrictValidation {};

template <typename T> struct ExpectedBuilder<OpType::WaveMultiPrefixBitXor, T> {
static std::vector<T> buildExpected(Op<OpType::WaveMultiPrefixBitXor, T, 1> &,
const InputSets<T> &Inputs, UINT) {
DXASSERT_NOMSG(Inputs.size() == 1);

std::vector<T> Expected;
const size_t VectorSize = Inputs[0].size();

// We get a little creative for MultiPrefixBitXor. The mask we use for the
// group in the shader is 0xE (0b1110), which includes lanes 1, 2, and 3.
// Prefix ops don't include the value of the current lane in their result.
// So, for this test we store the result of WaveMultiPrefixBitXor from lane
// 3. This means only the values from lanes 1 and 2 contribute to the result
// at lane 3.
//
// In the shader:
// - Lane 0: Set to 0 (not in mask, shouldn't affect result)
// - Lane 1: Keeps original input values
// - Lane 2: Lower half + last element set to 0, upper half keeps input
// - Lane 3: Stores the prefix XOR result (lanes 1 XOR lanes 2)
//
// Expected result: Lower half matches input (lane 1 XOR 0), upper half is
// 0s, except last element matches input.
for (size_t I = 0; I < VectorSize / 2; ++I)
Expected.push_back(Inputs[0][I]);
for (size_t I = VectorSize / 2; I < VectorSize - 1; ++I)
Expected.push_back(0);

// We also set the last element to 0 on lane 2 so the last element in the
// output vector matches the last element in the input vector.
Expected.push_back(Inputs[0][VectorSize - 1]);

return Expected;
}
};

template <typename T>
struct Op<OpType::WaveActiveAllEqual, T, 1> : StrictValidation {};

Expand Down Expand Up @@ -1420,16 +1474,29 @@ template <typename T> struct ExpectedBuilder<OpType::WaveReadLaneFirst, T> {
WAVE_OP(OpType::WavePrefixSum, (wavePrefixSum(A, WaveSize)));

template <typename T> T wavePrefixSum(T A, UINT WaveSize) {
// We test the prefix sume in the 'middle' lane. This choice is arbitrary.
return static_cast<T>(A * static_cast<T>(WaveSize / 2));
// We test the prefix sum in the 'middle' lane. This choice is arbitrary.
return A * static_cast<T>(WaveSize / 2);
}

WAVE_OP(OpType::WaveMultiPrefixSum, (waveMultiPrefixSum(A, WaveSize)));

template <typename T> T waveMultiPrefixSum(T A, UINT) {
return A * static_cast<T>(2u);
}

WAVE_OP(OpType::WavePrefixProduct, (wavePrefixProduct(A, WaveSize)));

template <typename T> T wavePrefixProduct(T A, UINT) {
// We test the the prefix product in the 3rd lane to avoid overflow issues.
// So the result is A * A.
return static_cast<T>(A * A);
return A * A;
}

WAVE_OP(OpType::WaveMultiPrefixProduct, (waveMultiPrefixProduct(A, WaveSize)));

template <typename T> T waveMultiPrefixProduct(T A, UINT) {
// The group mask has 3 lanes.
return A * A;
}

#undef WAVE_OP
Expand Down Expand Up @@ -2343,6 +2410,11 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, int16_t);
HLK_WAVEOP_TEST(WavePrefixSum, int16_t);
HLK_WAVEOP_TEST(WavePrefixProduct, int16_t);
HLK_WAVEOP_TEST(WaveMultiPrefixSum, int16_t);
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, int16_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, int16_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, int16_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, int16_t);
HLK_WAVEOP_TEST(WaveActiveSum, int32_t);
HLK_WAVEOP_TEST(WaveActiveMin, int32_t);
HLK_WAVEOP_TEST(WaveActiveMax, int32_t);
Expand All @@ -2351,7 +2423,12 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneAt, int32_t);
HLK_WAVEOP_TEST(WaveReadLaneFirst, int32_t);
HLK_WAVEOP_TEST(WavePrefixSum, int32_t);
HLK_WAVEOP_TEST(WaveMultiPrefixSum, int32_t);
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, int32_t);
HLK_WAVEOP_TEST(WavePrefixProduct, int32_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, int32_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, int32_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, int32_t);
HLK_WAVEOP_TEST(WaveActiveSum, int64_t);
HLK_WAVEOP_TEST(WaveActiveMin, int64_t);
HLK_WAVEOP_TEST(WaveActiveMax, int64_t);
Expand All @@ -2361,7 +2438,14 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, int64_t);
HLK_WAVEOP_TEST(WavePrefixSum, int64_t);
HLK_WAVEOP_TEST(WavePrefixProduct, int64_t);
HLK_WAVEOP_TEST(WaveMultiPrefixSum, int64_t);
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, int64_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, int64_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, int64_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, int64_t);

// Note: WaveActiveBit* ops don't support uint16_t in HLSL
// But the WaveMultiPrefixBit ops support all int and uint types
HLK_WAVEOP_TEST(WaveActiveSum, uint16_t);
HLK_WAVEOP_TEST(WaveActiveMin, uint16_t);
HLK_WAVEOP_TEST(WaveActiveMax, uint16_t);
Expand All @@ -2371,11 +2455,15 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, uint16_t);
HLK_WAVEOP_TEST(WavePrefixSum, uint16_t);
HLK_WAVEOP_TEST(WavePrefixProduct, uint16_t);
HLK_WAVEOP_TEST(WaveMultiPrefixSum, uint16_t);
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, uint16_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, uint16_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, uint16_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, uint16_t);
HLK_WAVEOP_TEST(WaveActiveSum, uint32_t);
HLK_WAVEOP_TEST(WaveActiveMin, uint32_t);
HLK_WAVEOP_TEST(WaveActiveMax, uint32_t);
HLK_WAVEOP_TEST(WaveActiveProduct, uint32_t);
// Note: WaveActiveBit* ops don't support uint16_t in HLSL
HLK_WAVEOP_TEST(WaveActiveBitAnd, uint32_t);
HLK_WAVEOP_TEST(WaveActiveBitOr, uint32_t);
HLK_WAVEOP_TEST(WaveActiveBitXor, uint32_t);
Expand All @@ -2384,6 +2472,11 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, uint32_t);
HLK_WAVEOP_TEST(WavePrefixSum, uint32_t);
HLK_WAVEOP_TEST(WavePrefixProduct, uint32_t);
HLK_WAVEOP_TEST(WaveMultiPrefixSum, uint32_t);
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, uint32_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, uint32_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, uint32_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, uint32_t);
HLK_WAVEOP_TEST(WaveActiveSum, uint64_t);
HLK_WAVEOP_TEST(WaveActiveMin, uint64_t);
HLK_WAVEOP_TEST(WaveActiveMax, uint64_t);
Expand All @@ -2396,6 +2489,11 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, uint64_t);
HLK_WAVEOP_TEST(WavePrefixSum, uint64_t);
HLK_WAVEOP_TEST(WavePrefixProduct, uint64_t);
HLK_WAVEOP_TEST(WaveMultiPrefixSum, uint64_t);
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, uint64_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, uint64_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, uint64_t);
HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, uint64_t);

HLK_WAVEOP_TEST(WaveActiveSum, HLSLHalf_t);
HLK_WAVEOP_TEST(WaveActiveMin, HLSLHalf_t);
Expand All @@ -2406,6 +2504,8 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, HLSLHalf_t);
HLK_WAVEOP_TEST(WavePrefixSum, HLSLHalf_t);
HLK_WAVEOP_TEST(WavePrefixProduct, HLSLHalf_t);
HLK_WAVEOP_TEST(WaveMultiPrefixSum, HLSLHalf_t);
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, HLSLHalf_t);
HLK_WAVEOP_TEST(WaveActiveSum, float);
HLK_WAVEOP_TEST(WaveActiveMin, float);
HLK_WAVEOP_TEST(WaveActiveMax, float);
Expand All @@ -2415,6 +2515,8 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, float);
HLK_WAVEOP_TEST(WavePrefixSum, float);
HLK_WAVEOP_TEST(WavePrefixProduct, float);
HLK_WAVEOP_TEST(WaveMultiPrefixSum, float);
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, float);
HLK_WAVEOP_TEST(WaveActiveSum, double);
HLK_WAVEOP_TEST(WaveActiveMin, double);
HLK_WAVEOP_TEST(WaveActiveMax, double);
Expand All @@ -2424,6 +2526,8 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, double);
HLK_WAVEOP_TEST(WavePrefixSum, double);
HLK_WAVEOP_TEST(WavePrefixProduct, double);
HLK_WAVEOP_TEST(WaveMultiPrefixSum, double);
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, double);

private:
bool Initialized = false;
Expand Down
Loading