@@ -3,20 +3,26 @@ RWStructuredBuffer<uint> Out : register(u0);
33
44[numthreads(8,1,1)]
55void main(uint3 TID : SV_GroupThreadID) {
6+ // Use `WaveActiveSum()` to count wave size, since
7+ // `WaveGetLaneCount` is not implemented yet in clang compiler.
8+ float maxActiveLaneCount = min(WaveActiveSum(1), 8);
9+ float result = 0;
10+
611 // First non-uniform branch
7- if (TID.x < 4 ) {
12+ if (TID.x % 2 == 0 ) {
813 // Second non-uniform branch
9- if (TID.x % 2 == 0) {
10- Out[TID.x] = WaveActiveSum(TID.x );
14+ if (TID.x % 4 == 0) {
15+ result + = WaveActiveSum(4.0/maxActiveLaneCount );
1116 } else {
12- Out[TID.x] = WaveActiveMax(TID.x );
17+ result += WaveActiveSum(8.0/maxActiveLaneCount );
1318 }
1419 // Must reconverge here with maximal reconvergence
15- Out[TID.x] += WaveActiveMax(TID.x);
16- } else {
17- Out[4] = WaveActiveMax(TID.x);
20+ result += WaveActiveSum(2.0/maxActiveLaneCount);
1821 }
19- Out[TID.x] += WaveActiveMax(TID.x);
22+
23+ // Must reconverge here with maximal reconvergence
24+ result += WaveActiveSum(1.0/maxActiveLaneCount);
25+ Out[TID.x] = uint(result);
2026}
2127
2228//--- pipeline.yaml
@@ -56,4 +62,4 @@ DescriptorSets:
5662
5763# CHECK: Name: Out
5864# CHECK: Format: UInt32
59- # CHECK: Data: [ 12, 13, 12, 13, 14, 7, 7, 7 ]
65+ # CHECK: Data: [ 3, 1, 4, 1, 3, 1, 4, 1 ]
0 commit comments