Skip to content

Commit 67c9849

Browse files
authored
Execution Tests: Long Vectors - Finish the basic WaveOp tests (#7913)
This PR is the third and final PR to resolve #7472. Test cases validated against a private build of WARP with bug fixes.
1 parent 450bbe5 commit 67c9849

File tree

3 files changed

+276
-19
lines changed

3 files changed

+276
-19
lines changed

tools/clang/unittests/HLSLExec/LongVectorOps.def

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,14 @@ OP_LOAD_AND_STORE_SB(LoadAndStore_RD_SB_SRV, "RootDescriptor_SRV")
198198
OP_DEFAULT(Wave, WaveActiveSum, 1, "WaveActiveSum", "")
199199
OP_DEFAULT_DEFINES(Wave, WaveActiveMin, 1, "TestWaveActiveMin", "", " -DFUNC_WAVE_ACTIVE_MIN=1")
200200
OP_DEFAULT_DEFINES(Wave, WaveActiveMax, 1, "TestWaveActiveMax", "", " -DFUNC_WAVE_ACTIVE_MAX=1")
201-
OP(Wave, WaveActiveProduct, 1, "TestWaveActiveProduct", "", " -DFUNC_WAVE_ACTIVE_PRODUCT=1", "LongVectorOp",
202-
AllOnes, Default2, Default3)
201+
OP(Wave, WaveActiveProduct, 1, "TestWaveActiveProduct", "", " -DFUNC_WAVE_ACTIVE_PRODUCT=1", "LongVectorOp", AllOnes, Default2, Default3)
202+
OP_DEFAULT_DEFINES(Wave, WaveActiveBitAnd, 1, "TestWaveActiveBitAnd", "", " -DFUNC_WAVE_ACTIVE_BIT_AND=1")
203+
OP_DEFAULT_DEFINES(Wave, WaveActiveBitOr, 1, "TestWaveActiveBitOr", "", " -DFUNC_WAVE_ACTIVE_BIT_OR=1")
204+
OP_DEFAULT_DEFINES(Wave, WaveActiveBitXor, 1, "TestWaveActiveBitXor", "", " -DFUNC_WAVE_ACTIVE_BIT_XOR=1")
205+
OP_DEFAULT_DEFINES(Wave, WaveActiveAllEqual, 1, "TestWaveActiveAllEqual", "", " -DFUNC_WAVE_ACTIVE_ALL_EQUAL=1")
206+
OP_DEFAULT_DEFINES(Wave, WaveReadLaneAt, 1, "TestWaveReadLaneAt", "", " -DFUNC_WAVE_READ_LANE_AT=1")
207+
OP_DEFAULT_DEFINES(Wave, WaveReadLaneFirst, 1, "TestWaveReadLaneFirst", "", " -DFUNC_WAVE_READ_LANE_FIRST=1")
208+
OP_DEFAULT_DEFINES(Wave, WavePrefixSum, 1, "TestWavePrefixSum", "", " -DFUNC_WAVE_PREFIX_SUM=1 -DIS_WAVE_PREFIX_OP=1")
209+
OP_DEFAULT_DEFINES(Wave, WavePrefixProduct, 1, "TestWavePrefixProduct", "", " -DFUNC_WAVE_PREFIX_PRODUCT=1 -DIS_WAVE_PREFIX_OP=1")
203210

204211
#undef OP

tools/clang/unittests/HLSLExec/LongVectors.cpp

Lines changed: 154 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,7 +1300,7 @@ template <typename T> struct ExpectedBuilder<OpType::ModF, T> {
13001300
// Wave Ops
13011301
//
13021302

1303-
#define WAVE_ACTIVE_OP(OP, IMPL) \
1303+
#define WAVE_OP(OP, IMPL) \
13041304
template <typename T> struct Op<OP, T, 1> : DefaultValidation<T> { \
13051305
T operator()(T A, UINT WaveSize) { return IMPL; } \
13061306
};
@@ -1310,7 +1310,7 @@ template <typename T> T waveActiveSum(T A, UINT WaveSize) {
13101310
return A * WaveSizeT;
13111311
}
13121312

1313-
WAVE_ACTIVE_OP(OpType::WaveActiveSum, (waveActiveSum(A, WaveSize)));
1313+
WAVE_OP(OpType::WaveActiveSum, (waveActiveSum(A, WaveSize)));
13141314

13151315
template <typename T> T waveActiveMin(T A, UINT WaveSize) {
13161316
std::vector<T> Values;
@@ -1320,7 +1320,7 @@ template <typename T> T waveActiveMin(T A, UINT WaveSize) {
13201320
return *std::min_element(Values.begin(), Values.end());
13211321
}
13221322

1323-
WAVE_ACTIVE_OP(OpType::WaveActiveMin, (waveActiveMin(A, WaveSize)));
1323+
WAVE_OP(OpType::WaveActiveMin, (waveActiveMin(A, WaveSize)));
13241324

13251325
template <typename T> T waveActiveMax(T A, UINT WaveSize) {
13261326
std::vector<T> Values;
@@ -1330,7 +1330,7 @@ template <typename T> T waveActiveMax(T A, UINT WaveSize) {
13301330
return *std::max_element(Values.begin(), Values.end());
13311331
}
13321332

1333-
WAVE_ACTIVE_OP(OpType::WaveActiveMax, (waveActiveMax(A, WaveSize)));
1333+
WAVE_OP(OpType::WaveActiveMax, (waveActiveMax(A, WaveSize)));
13341334

13351335
template <typename T> T waveActiveProduct(T A, UINT WaveSize) {
13361336
// We want to avoid overflow of a large product. So, the WaveActiveProdFn has
@@ -1339,9 +1339,100 @@ template <typename T> T waveActiveProduct(T A, UINT WaveSize) {
13391339
return A * static_cast<T>(WaveSize - 1);
13401340
}
13411341

1342-
WAVE_ACTIVE_OP(OpType::WaveActiveProduct, (waveActiveProduct(A, WaveSize)));
1342+
WAVE_OP(OpType::WaveActiveProduct, (waveActiveProduct(A, WaveSize)));
13431343

1344-
#undef WAVE_ACTIVE_OP
1344+
template <typename T> T waveActiveBitAnd(T A, UINT) {
1345+
// We set the LSB to 0 in one of the lanes.
1346+
return static_cast<T>(A & ~static_cast<T>(1));
1347+
}
1348+
1349+
WAVE_OP(OpType::WaveActiveBitAnd, (waveActiveBitAnd(A, WaveSize)));
1350+
1351+
template <typename T> T waveActiveBitOr(T A, UINT) {
1352+
// We set the LSB to 0 in one of the lanes.
1353+
return static_cast<T>(A | static_cast<T>(1));
1354+
}
1355+
1356+
WAVE_OP(OpType::WaveActiveBitOr, (waveActiveBitOr(A, WaveSize)));
1357+
1358+
template <typename T> T waveActiveBitXor(T A, UINT) {
1359+
// We clear the LSB in every lane except the last lane which sets it to 1.
1360+
return static_cast<T>(A | static_cast<T>(1));
1361+
}
1362+
1363+
WAVE_OP(OpType::WaveActiveBitXor, (waveActiveBitXor(A, WaveSize)));
1364+
1365+
template <typename T>
1366+
struct Op<OpType::WaveActiveAllEqual, T, 1> : StrictValidation {};
1367+
1368+
template <typename T> struct ExpectedBuilder<OpType::WaveActiveAllEqual, T> {
1369+
static std::vector<HLSLBool_t>
1370+
buildExpected(Op<OpType::WaveActiveAllEqual, T, 1> &,
1371+
const InputSets<T> &Inputs, UINT) {
1372+
DXASSERT_NOMSG(Inputs.size() == 1);
1373+
1374+
std::vector<HLSLBool_t> Expected;
1375+
const size_t VectorSize = Inputs[0].size();
1376+
Expected.assign(VectorSize - 1, static_cast<HLSLBool_t>(true));
1377+
// We set the last element to a different value on a single lane.
1378+
Expected[VectorSize - 1] = static_cast<HLSLBool_t>(false);
1379+
1380+
return Expected;
1381+
}
1382+
};
1383+
1384+
template <typename T>
1385+
struct Op<OpType::WaveReadLaneAt, T, 1> : StrictValidation {};
1386+
1387+
template <typename T> struct ExpectedBuilder<OpType::WaveReadLaneAt, T> {
1388+
static std::vector<T> buildExpected(Op<OpType::WaveReadLaneAt, T, 1> &,
1389+
const InputSets<T> &Inputs, UINT) {
1390+
DXASSERT_NOMSG(Inputs.size() == 1);
1391+
1392+
std::vector<T> Expected;
1393+
const size_t VectorSize = Inputs[0].size();
1394+
// Simple test, on the lane that we read we also fill the vector with the
1395+
// value of the first element.
1396+
Expected.assign(VectorSize, Inputs[0][0]);
1397+
1398+
return Expected;
1399+
}
1400+
};
1401+
1402+
template <typename T>
1403+
struct Op<OpType::WaveReadLaneFirst, T, 1> : StrictValidation {};
1404+
1405+
template <typename T> struct ExpectedBuilder<OpType::WaveReadLaneFirst, T> {
1406+
static std::vector<T> buildExpected(Op<OpType::WaveReadLaneFirst, T, 1> &,
1407+
const InputSets<T> &Inputs, UINT) {
1408+
DXASSERT_NOMSG(Inputs.size() == 1);
1409+
1410+
std::vector<T> Expected;
1411+
const size_t VectorSize = Inputs[0].size();
1412+
// Simple test, on the lane that we read we also fill the vector with the
1413+
// value of the first element.
1414+
Expected.assign(VectorSize, Inputs[0][0]);
1415+
1416+
return Expected;
1417+
}
1418+
};
1419+
1420+
WAVE_OP(OpType::WavePrefixSum, (wavePrefixSum(A, WaveSize)));
1421+
1422+
template <typename T> T wavePrefixSum(T A, UINT WaveSize) {
1423+
// We test the prefix sume in the 'middle' lane. This choice is arbitrary.
1424+
return static_cast<T>(A * static_cast<T>(WaveSize / 2));
1425+
}
1426+
1427+
WAVE_OP(OpType::WavePrefixProduct, (wavePrefixProduct(A, WaveSize)));
1428+
1429+
template <typename T> T wavePrefixProduct(T A, UINT) {
1430+
// We test the the prefix product in the 3rd lane to avoid overflow issues.
1431+
// So the result is A * A.
1432+
return static_cast<T>(A * A);
1433+
}
1434+
1435+
#undef WAVE_OP
13451436

13461437
//
13471438
// dispatchTest
@@ -1384,9 +1475,6 @@ template <OpType OP, typename T> struct ExpectedBuilder {
13841475

13851476
return Expected;
13861477
}
1387-
};
1388-
1389-
template <OpType OP, typename T> struct WaveOpExpectedBuilder {
13901478

13911479
static auto buildExpected(Op<OP, T, 1> Op, const InputSets<T> &Inputs,
13921480
UINT WaveSize) {
@@ -1466,8 +1554,7 @@ void dispatchWaveOpTest(ID3D12Device *D3DDevice, bool VerboseLogging,
14661554
std::vector<std::vector<T>> Inputs =
14671555
buildTestInputs<T>(VectorSize, Operation.InputSets, Operation.Arity);
14681556

1469-
auto Expected =
1470-
WaveOpExpectedBuilder<OP, T>::buildExpected(Op, Inputs, WaveSize);
1557+
auto Expected = ExpectedBuilder<OP, T>::buildExpected(Op, Inputs, WaveSize);
14711558

14721559
runAndVerify(D3DDevice, VerboseLogging, Operation, Inputs, Expected,
14731560
Op.ValidationConfig, AdditionalCompilerOptions);
@@ -2243,44 +2330,100 @@ class DxilConf_SM69_Vectorized {
22432330
HLK_TEST(LoadAndStore_RD_SB_SRV, double);
22442331
HLK_TEST(LoadAndStore_RD_SB_UAV, double);
22452332

2333+
HLK_WAVEOP_TEST(WaveActiveAllEqual, HLSLBool_t);
2334+
HLK_WAVEOP_TEST(WaveReadLaneAt, HLSLBool_t);
2335+
HLK_WAVEOP_TEST(WaveReadLaneFirst, HLSLBool_t);
2336+
22462337
HLK_WAVEOP_TEST(WaveActiveSum, int16_t);
22472338
HLK_WAVEOP_TEST(WaveActiveMin, int16_t);
22482339
HLK_WAVEOP_TEST(WaveActiveMax, int16_t);
22492340
HLK_WAVEOP_TEST(WaveActiveProduct, int16_t);
2341+
HLK_WAVEOP_TEST(WaveActiveAllEqual, int16_t);
2342+
HLK_WAVEOP_TEST(WaveReadLaneAt, int16_t);
2343+
HLK_WAVEOP_TEST(WaveReadLaneFirst, int16_t);
2344+
HLK_WAVEOP_TEST(WavePrefixSum, int16_t);
2345+
HLK_WAVEOP_TEST(WavePrefixProduct, int16_t);
22502346
HLK_WAVEOP_TEST(WaveActiveSum, int32_t);
22512347
HLK_WAVEOP_TEST(WaveActiveMin, int32_t);
22522348
HLK_WAVEOP_TEST(WaveActiveMax, int32_t);
22532349
HLK_WAVEOP_TEST(WaveActiveProduct, int32_t);
2350+
HLK_WAVEOP_TEST(WaveActiveAllEqual, int32_t);
2351+
HLK_WAVEOP_TEST(WaveReadLaneAt, int32_t);
2352+
HLK_WAVEOP_TEST(WaveReadLaneFirst, int32_t);
2353+
HLK_WAVEOP_TEST(WavePrefixSum, int32_t);
2354+
HLK_WAVEOP_TEST(WavePrefixProduct, int32_t);
22542355
HLK_WAVEOP_TEST(WaveActiveSum, int64_t);
22552356
HLK_WAVEOP_TEST(WaveActiveMin, int64_t);
22562357
HLK_WAVEOP_TEST(WaveActiveMax, int64_t);
22572358
HLK_WAVEOP_TEST(WaveActiveProduct, int64_t);
2359+
HLK_WAVEOP_TEST(WaveActiveAllEqual, int64_t);
2360+
HLK_WAVEOP_TEST(WaveReadLaneAt, int64_t);
2361+
HLK_WAVEOP_TEST(WaveReadLaneFirst, int64_t);
2362+
HLK_WAVEOP_TEST(WavePrefixSum, int64_t);
2363+
HLK_WAVEOP_TEST(WavePrefixProduct, int64_t);
22582364

22592365
HLK_WAVEOP_TEST(WaveActiveSum, uint16_t);
22602366
HLK_WAVEOP_TEST(WaveActiveMin, uint16_t);
22612367
HLK_WAVEOP_TEST(WaveActiveMax, uint16_t);
22622368
HLK_WAVEOP_TEST(WaveActiveProduct, uint16_t);
2369+
HLK_WAVEOP_TEST(WaveActiveAllEqual, uint16_t);
2370+
HLK_WAVEOP_TEST(WaveReadLaneAt, uint16_t);
2371+
HLK_WAVEOP_TEST(WaveReadLaneFirst, uint16_t);
2372+
HLK_WAVEOP_TEST(WavePrefixSum, uint16_t);
2373+
HLK_WAVEOP_TEST(WavePrefixProduct, uint16_t);
22632374
HLK_WAVEOP_TEST(WaveActiveSum, uint32_t);
22642375
HLK_WAVEOP_TEST(WaveActiveMin, uint32_t);
22652376
HLK_WAVEOP_TEST(WaveActiveMax, uint32_t);
22662377
HLK_WAVEOP_TEST(WaveActiveProduct, uint32_t);
2378+
// Note: WaveActiveBit* ops don't support uint16_t in HLSL
2379+
HLK_WAVEOP_TEST(WaveActiveBitAnd, uint32_t);
2380+
HLK_WAVEOP_TEST(WaveActiveBitOr, uint32_t);
2381+
HLK_WAVEOP_TEST(WaveActiveBitXor, uint32_t);
2382+
HLK_WAVEOP_TEST(WaveActiveAllEqual, uint32_t);
2383+
HLK_WAVEOP_TEST(WaveReadLaneAt, uint32_t);
2384+
HLK_WAVEOP_TEST(WaveReadLaneFirst, uint32_t);
2385+
HLK_WAVEOP_TEST(WavePrefixSum, uint32_t);
2386+
HLK_WAVEOP_TEST(WavePrefixProduct, uint32_t);
22672387
HLK_WAVEOP_TEST(WaveActiveSum, uint64_t);
22682388
HLK_WAVEOP_TEST(WaveActiveMin, uint64_t);
22692389
HLK_WAVEOP_TEST(WaveActiveMax, uint64_t);
22702390
HLK_WAVEOP_TEST(WaveActiveProduct, uint64_t);
2391+
HLK_WAVEOP_TEST(WaveActiveBitAnd, uint64_t);
2392+
HLK_WAVEOP_TEST(WaveActiveBitOr, uint64_t);
2393+
HLK_WAVEOP_TEST(WaveActiveBitXor, uint64_t);
2394+
HLK_WAVEOP_TEST(WaveActiveAllEqual, uint64_t);
2395+
HLK_WAVEOP_TEST(WaveReadLaneAt, uint64_t);
2396+
HLK_WAVEOP_TEST(WaveReadLaneFirst, uint64_t);
2397+
HLK_WAVEOP_TEST(WavePrefixSum, uint64_t);
2398+
HLK_WAVEOP_TEST(WavePrefixProduct, uint64_t);
22712399

22722400
HLK_WAVEOP_TEST(WaveActiveSum, HLSLHalf_t);
22732401
HLK_WAVEOP_TEST(WaveActiveMin, HLSLHalf_t);
22742402
HLK_WAVEOP_TEST(WaveActiveMax, HLSLHalf_t);
22752403
HLK_WAVEOP_TEST(WaveActiveProduct, HLSLHalf_t);
2404+
HLK_WAVEOP_TEST(WaveActiveAllEqual, HLSLHalf_t);
2405+
HLK_WAVEOP_TEST(WaveReadLaneAt, HLSLHalf_t);
2406+
HLK_WAVEOP_TEST(WaveReadLaneFirst, HLSLHalf_t);
2407+
HLK_WAVEOP_TEST(WavePrefixSum, HLSLHalf_t);
2408+
HLK_WAVEOP_TEST(WavePrefixProduct, HLSLHalf_t);
22762409
HLK_WAVEOP_TEST(WaveActiveSum, float);
22772410
HLK_WAVEOP_TEST(WaveActiveMin, float);
22782411
HLK_WAVEOP_TEST(WaveActiveMax, float);
22792412
HLK_WAVEOP_TEST(WaveActiveProduct, float);
2413+
HLK_WAVEOP_TEST(WaveActiveAllEqual, float);
2414+
HLK_WAVEOP_TEST(WaveReadLaneAt, float);
2415+
HLK_WAVEOP_TEST(WaveReadLaneFirst, float);
2416+
HLK_WAVEOP_TEST(WavePrefixSum, float);
2417+
HLK_WAVEOP_TEST(WavePrefixProduct, float);
22802418
HLK_WAVEOP_TEST(WaveActiveSum, double);
22812419
HLK_WAVEOP_TEST(WaveActiveMin, double);
22822420
HLK_WAVEOP_TEST(WaveActiveMax, double);
22832421
HLK_WAVEOP_TEST(WaveActiveProduct, double);
2422+
HLK_WAVEOP_TEST(WaveActiveAllEqual, double);
2423+
HLK_WAVEOP_TEST(WaveReadLaneAt, double);
2424+
HLK_WAVEOP_TEST(WaveReadLaneFirst, double);
2425+
HLK_WAVEOP_TEST(WavePrefixSum, double);
2426+
HLK_WAVEOP_TEST(WavePrefixProduct, double);
22842427

22852428
private:
22862429
bool Initialized = false;

0 commit comments

Comments
 (0)