@@ -1300,7 +1300,7 @@ template <typename T> struct ExpectedBuilder<OpType::ModF, T> {
13001300// Wave Ops
13011301//
13021302
1303- #define WAVE_ACTIVE_OP (OP, IMPL ) \
1303+ #define WAVE_OP (OP, IMPL ) \
13041304 template <typename T> struct Op <OP, T, 1 > : DefaultValidation<T> { \
13051305 T operator ()(T A, UINT WaveSize) { return IMPL; } \
13061306 };
@@ -1310,7 +1310,7 @@ template <typename T> T waveActiveSum(T A, UINT WaveSize) {
13101310 return A * WaveSizeT;
13111311}
13121312
1313- WAVE_ACTIVE_OP (OpType::WaveActiveSum, (waveActiveSum(A, WaveSize)));
1313+ WAVE_OP (OpType::WaveActiveSum, (waveActiveSum(A, WaveSize)));
13141314
13151315template <typename T> T waveActiveMin (T A, UINT WaveSize) {
13161316 std::vector<T> Values;
@@ -1320,7 +1320,7 @@ template <typename T> T waveActiveMin(T A, UINT WaveSize) {
13201320 return *std::min_element (Values.begin (), Values.end ());
13211321}
13221322
1323- WAVE_ACTIVE_OP (OpType::WaveActiveMin, (waveActiveMin(A, WaveSize)));
1323+ WAVE_OP (OpType::WaveActiveMin, (waveActiveMin(A, WaveSize)));
13241324
13251325template <typename T> T waveActiveMax (T A, UINT WaveSize) {
13261326 std::vector<T> Values;
@@ -1330,7 +1330,7 @@ template <typename T> T waveActiveMax(T A, UINT WaveSize) {
13301330 return *std::max_element (Values.begin (), Values.end ());
13311331}
13321332
1333- WAVE_ACTIVE_OP (OpType::WaveActiveMax, (waveActiveMax(A, WaveSize)));
1333+ WAVE_OP (OpType::WaveActiveMax, (waveActiveMax(A, WaveSize)));
13341334
13351335template <typename T> T waveActiveProduct (T A, UINT WaveSize) {
13361336 // We want to avoid overflow of a large product. So, the WaveActiveProdFn has
@@ -1339,9 +1339,100 @@ template <typename T> T waveActiveProduct(T A, UINT WaveSize) {
13391339 return A * static_cast <T>(WaveSize - 1 );
13401340}
13411341
1342- WAVE_ACTIVE_OP (OpType::WaveActiveProduct, (waveActiveProduct(A, WaveSize)));
1342+ WAVE_OP (OpType::WaveActiveProduct, (waveActiveProduct(A, WaveSize)));
13431343
1344- #undef WAVE_ACTIVE_OP
1344+ template <typename T> T waveActiveBitAnd (T A, UINT) {
1345+ // We set the LSB to 0 in one of the lanes.
1346+ return static_cast <T>(A & ~static_cast <T>(1 ));
1347+ }
1348+
1349+ WAVE_OP (OpType::WaveActiveBitAnd, (waveActiveBitAnd(A, WaveSize)));
1350+
1351+ template <typename T> T waveActiveBitOr (T A, UINT) {
1352+ // We set the LSB to 0 in one of the lanes.
1353+ return static_cast <T>(A | static_cast <T>(1 ));
1354+ }
1355+
1356+ WAVE_OP (OpType::WaveActiveBitOr, (waveActiveBitOr(A, WaveSize)));
1357+
1358+ template <typename T> T waveActiveBitXor (T A, UINT) {
1359+ // We clear the LSB in every lane except the last lane which sets it to 1.
1360+ return static_cast <T>(A | static_cast <T>(1 ));
1361+ }
1362+
1363+ WAVE_OP (OpType::WaveActiveBitXor, (waveActiveBitXor(A, WaveSize)));
1364+
1365+ template <typename T>
1366+ struct Op <OpType::WaveActiveAllEqual, T, 1 > : StrictValidation {};
1367+
1368+ template <typename T> struct ExpectedBuilder <OpType::WaveActiveAllEqual, T> {
1369+ static std::vector<HLSLBool_t>
1370+ buildExpected (Op<OpType::WaveActiveAllEqual, T, 1 > &,
1371+ const InputSets<T> &Inputs, UINT) {
1372+ DXASSERT_NOMSG (Inputs.size () == 1 );
1373+
1374+ std::vector<HLSLBool_t> Expected;
1375+ const size_t VectorSize = Inputs[0 ].size ();
1376+ Expected.assign (VectorSize - 1 , static_cast <HLSLBool_t>(true ));
1377+ // We set the last element to a different value on a single lane.
1378+ Expected[VectorSize - 1 ] = static_cast <HLSLBool_t>(false );
1379+
1380+ return Expected;
1381+ }
1382+ };
1383+
1384+ template <typename T>
1385+ struct Op <OpType::WaveReadLaneAt, T, 1 > : StrictValidation {};
1386+
1387+ template <typename T> struct ExpectedBuilder <OpType::WaveReadLaneAt, T> {
1388+ static std::vector<T> buildExpected (Op<OpType::WaveReadLaneAt, T, 1 > &,
1389+ const InputSets<T> &Inputs, UINT) {
1390+ DXASSERT_NOMSG (Inputs.size () == 1 );
1391+
1392+ std::vector<T> Expected;
1393+ const size_t VectorSize = Inputs[0 ].size ();
1394+ // Simple test, on the lane that we read we also fill the vector with the
1395+ // value of the first element.
1396+ Expected.assign (VectorSize, Inputs[0 ][0 ]);
1397+
1398+ return Expected;
1399+ }
1400+ };
1401+
1402+ template <typename T>
1403+ struct Op <OpType::WaveReadLaneFirst, T, 1 > : StrictValidation {};
1404+
1405+ template <typename T> struct ExpectedBuilder <OpType::WaveReadLaneFirst, T> {
1406+ static std::vector<T> buildExpected (Op<OpType::WaveReadLaneFirst, T, 1 > &,
1407+ const InputSets<T> &Inputs, UINT) {
1408+ DXASSERT_NOMSG (Inputs.size () == 1 );
1409+
1410+ std::vector<T> Expected;
1411+ const size_t VectorSize = Inputs[0 ].size ();
1412+ // Simple test, on the lane that we read we also fill the vector with the
1413+ // value of the first element.
1414+ Expected.assign (VectorSize, Inputs[0 ][0 ]);
1415+
1416+ return Expected;
1417+ }
1418+ };
1419+
1420+ WAVE_OP (OpType::WavePrefixSum, (wavePrefixSum(A, WaveSize)));
1421+
1422+ template <typename T> T wavePrefixSum (T A, UINT WaveSize) {
1423+ // We test the prefix sume in the 'middle' lane. This choice is arbitrary.
1424+ return static_cast <T>(A * static_cast <T>(WaveSize / 2 ));
1425+ }
1426+
1427+ WAVE_OP (OpType::WavePrefixProduct, (wavePrefixProduct(A, WaveSize)));
1428+
1429+ template <typename T> T wavePrefixProduct (T A, UINT) {
1430+ // We test the the prefix product in the 3rd lane to avoid overflow issues.
1431+ // So the result is A * A.
1432+ return static_cast <T>(A * A);
1433+ }
1434+
1435+ #undef WAVE_OP
13451436
13461437//
13471438// dispatchTest
@@ -1384,9 +1475,6 @@ template <OpType OP, typename T> struct ExpectedBuilder {
13841475
13851476 return Expected;
13861477 }
1387- };
1388-
1389- template <OpType OP, typename T> struct WaveOpExpectedBuilder {
13901478
13911479 static auto buildExpected (Op<OP, T, 1 > Op, const InputSets<T> &Inputs,
13921480 UINT WaveSize) {
@@ -1466,8 +1554,7 @@ void dispatchWaveOpTest(ID3D12Device *D3DDevice, bool VerboseLogging,
14661554 std::vector<std::vector<T>> Inputs =
14671555 buildTestInputs<T>(VectorSize, Operation.InputSets , Operation.Arity );
14681556
1469- auto Expected =
1470- WaveOpExpectedBuilder<OP, T>::buildExpected (Op, Inputs, WaveSize);
1557+ auto Expected = ExpectedBuilder<OP, T>::buildExpected (Op, Inputs, WaveSize);
14711558
14721559 runAndVerify (D3DDevice, VerboseLogging, Operation, Inputs, Expected,
14731560 Op.ValidationConfig , AdditionalCompilerOptions);
@@ -2243,44 +2330,100 @@ class DxilConf_SM69_Vectorized {
22432330 HLK_TEST (LoadAndStore_RD_SB_SRV, double );
22442331 HLK_TEST (LoadAndStore_RD_SB_UAV, double );
22452332
2333+ HLK_WAVEOP_TEST (WaveActiveAllEqual, HLSLBool_t);
2334+ HLK_WAVEOP_TEST (WaveReadLaneAt, HLSLBool_t);
2335+ HLK_WAVEOP_TEST (WaveReadLaneFirst, HLSLBool_t);
2336+
22462337 HLK_WAVEOP_TEST (WaveActiveSum, int16_t );
22472338 HLK_WAVEOP_TEST (WaveActiveMin, int16_t );
22482339 HLK_WAVEOP_TEST (WaveActiveMax, int16_t );
22492340 HLK_WAVEOP_TEST (WaveActiveProduct, int16_t );
2341+ HLK_WAVEOP_TEST (WaveActiveAllEqual, int16_t );
2342+ HLK_WAVEOP_TEST (WaveReadLaneAt, int16_t );
2343+ HLK_WAVEOP_TEST (WaveReadLaneFirst, int16_t );
2344+ HLK_WAVEOP_TEST (WavePrefixSum, int16_t );
2345+ HLK_WAVEOP_TEST (WavePrefixProduct, int16_t );
22502346 HLK_WAVEOP_TEST (WaveActiveSum, int32_t );
22512347 HLK_WAVEOP_TEST (WaveActiveMin, int32_t );
22522348 HLK_WAVEOP_TEST (WaveActiveMax, int32_t );
22532349 HLK_WAVEOP_TEST (WaveActiveProduct, int32_t );
2350+ HLK_WAVEOP_TEST (WaveActiveAllEqual, int32_t );
2351+ HLK_WAVEOP_TEST (WaveReadLaneAt, int32_t );
2352+ HLK_WAVEOP_TEST (WaveReadLaneFirst, int32_t );
2353+ HLK_WAVEOP_TEST (WavePrefixSum, int32_t );
2354+ HLK_WAVEOP_TEST (WavePrefixProduct, int32_t );
22542355 HLK_WAVEOP_TEST (WaveActiveSum, int64_t );
22552356 HLK_WAVEOP_TEST (WaveActiveMin, int64_t );
22562357 HLK_WAVEOP_TEST (WaveActiveMax, int64_t );
22572358 HLK_WAVEOP_TEST (WaveActiveProduct, int64_t );
2359+ HLK_WAVEOP_TEST (WaveActiveAllEqual, int64_t );
2360+ HLK_WAVEOP_TEST (WaveReadLaneAt, int64_t );
2361+ HLK_WAVEOP_TEST (WaveReadLaneFirst, int64_t );
2362+ HLK_WAVEOP_TEST (WavePrefixSum, int64_t );
2363+ HLK_WAVEOP_TEST (WavePrefixProduct, int64_t );
22582364
22592365 HLK_WAVEOP_TEST (WaveActiveSum, uint16_t );
22602366 HLK_WAVEOP_TEST (WaveActiveMin, uint16_t );
22612367 HLK_WAVEOP_TEST (WaveActiveMax, uint16_t );
22622368 HLK_WAVEOP_TEST (WaveActiveProduct, uint16_t );
2369+ HLK_WAVEOP_TEST (WaveActiveAllEqual, uint16_t );
2370+ HLK_WAVEOP_TEST (WaveReadLaneAt, uint16_t );
2371+ HLK_WAVEOP_TEST (WaveReadLaneFirst, uint16_t );
2372+ HLK_WAVEOP_TEST (WavePrefixSum, uint16_t );
2373+ HLK_WAVEOP_TEST (WavePrefixProduct, uint16_t );
22632374 HLK_WAVEOP_TEST (WaveActiveSum, uint32_t );
22642375 HLK_WAVEOP_TEST (WaveActiveMin, uint32_t );
22652376 HLK_WAVEOP_TEST (WaveActiveMax, uint32_t );
22662377 HLK_WAVEOP_TEST (WaveActiveProduct, uint32_t );
2378+ // Note: WaveActiveBit* ops don't support uint16_t in HLSL
2379+ HLK_WAVEOP_TEST (WaveActiveBitAnd, uint32_t );
2380+ HLK_WAVEOP_TEST (WaveActiveBitOr, uint32_t );
2381+ HLK_WAVEOP_TEST (WaveActiveBitXor, uint32_t );
2382+ HLK_WAVEOP_TEST (WaveActiveAllEqual, uint32_t );
2383+ HLK_WAVEOP_TEST (WaveReadLaneAt, uint32_t );
2384+ HLK_WAVEOP_TEST (WaveReadLaneFirst, uint32_t );
2385+ HLK_WAVEOP_TEST (WavePrefixSum, uint32_t );
2386+ HLK_WAVEOP_TEST (WavePrefixProduct, uint32_t );
22672387 HLK_WAVEOP_TEST (WaveActiveSum, uint64_t );
22682388 HLK_WAVEOP_TEST (WaveActiveMin, uint64_t );
22692389 HLK_WAVEOP_TEST (WaveActiveMax, uint64_t );
22702390 HLK_WAVEOP_TEST (WaveActiveProduct, uint64_t );
2391+ HLK_WAVEOP_TEST (WaveActiveBitAnd, uint64_t );
2392+ HLK_WAVEOP_TEST (WaveActiveBitOr, uint64_t );
2393+ HLK_WAVEOP_TEST (WaveActiveBitXor, uint64_t );
2394+ HLK_WAVEOP_TEST (WaveActiveAllEqual, uint64_t );
2395+ HLK_WAVEOP_TEST (WaveReadLaneAt, uint64_t );
2396+ HLK_WAVEOP_TEST (WaveReadLaneFirst, uint64_t );
2397+ HLK_WAVEOP_TEST (WavePrefixSum, uint64_t );
2398+ HLK_WAVEOP_TEST (WavePrefixProduct, uint64_t );
22712399
22722400 HLK_WAVEOP_TEST (WaveActiveSum, HLSLHalf_t);
22732401 HLK_WAVEOP_TEST (WaveActiveMin, HLSLHalf_t);
22742402 HLK_WAVEOP_TEST (WaveActiveMax, HLSLHalf_t);
22752403 HLK_WAVEOP_TEST (WaveActiveProduct, HLSLHalf_t);
2404+ HLK_WAVEOP_TEST (WaveActiveAllEqual, HLSLHalf_t);
2405+ HLK_WAVEOP_TEST (WaveReadLaneAt, HLSLHalf_t);
2406+ HLK_WAVEOP_TEST (WaveReadLaneFirst, HLSLHalf_t);
2407+ HLK_WAVEOP_TEST (WavePrefixSum, HLSLHalf_t);
2408+ HLK_WAVEOP_TEST (WavePrefixProduct, HLSLHalf_t);
22762409 HLK_WAVEOP_TEST (WaveActiveSum, float );
22772410 HLK_WAVEOP_TEST (WaveActiveMin, float );
22782411 HLK_WAVEOP_TEST (WaveActiveMax, float );
22792412 HLK_WAVEOP_TEST (WaveActiveProduct, float );
2413+ HLK_WAVEOP_TEST (WaveActiveAllEqual, float );
2414+ HLK_WAVEOP_TEST (WaveReadLaneAt, float );
2415+ HLK_WAVEOP_TEST (WaveReadLaneFirst, float );
2416+ HLK_WAVEOP_TEST (WavePrefixSum, float );
2417+ HLK_WAVEOP_TEST (WavePrefixProduct, float );
22802418 HLK_WAVEOP_TEST (WaveActiveSum, double );
22812419 HLK_WAVEOP_TEST (WaveActiveMin, double );
22822420 HLK_WAVEOP_TEST (WaveActiveMax, double );
22832421 HLK_WAVEOP_TEST (WaveActiveProduct, double );
2422+ HLK_WAVEOP_TEST (WaveActiveAllEqual, double );
2423+ HLK_WAVEOP_TEST (WaveReadLaneAt, double );
2424+ HLK_WAVEOP_TEST (WaveReadLaneFirst, double );
2425+ HLK_WAVEOP_TEST (WavePrefixSum, double );
2426+ HLK_WAVEOP_TEST (WavePrefixProduct, double );
22842427
22852428private:
22862429 bool Initialized = false ;
0 commit comments