1- #include "nbl/builtin/hlsl/cpp_compat.hlsl"
2- #include <nbl/builtin/hlsl/colorspace/encodeCIEXYZ.hlsl>
1+ #include <nbl/builtin/hlsl/rwmc/rwmc.hlsl>
32#include "resolve_common.hlsl"
43#include "rwmc_global_settings_common.hlsl"
4+ #ifdef PERSISTENT_WORKGROUPS
5+ #include "nbl/builtin/hlsl/math/morton.hlsl"
6+ #endif
57
68[[vk::push_constant]] ResolvePushConstants pc;
79[[vk::image_format ("rgba16f" )]] [[vk::binding (0 , 0 )]] RWTexture2D <float32_t4> outImage;
@@ -14,174 +16,33 @@ NBL_CONSTEXPR uint32_t WorkgroupSize = 512;
1416NBL_CONSTEXPR uint32_t MAX_DEPTH_LOG2 = 4 ;
1517NBL_CONSTEXPR uint32_t MAX_SAMPLES_LOG2 = 10 ;
1618
17- struct RWMCReweightingParameters
18- {
19- uint32_t lastCascadeIndex;
20- float initialEmin; // a minimum image brightness that we always consider reliable
21- float reciprocalBase;
22- float reciprocalN;
23- float reciprocalKappa;
24- float colorReliabilityFactor;
25- float NOverKappa;
26- };
27-
28- RWMCReweightingParameters computeReweightingParameters (float base, uint32_t sampleCount, float minReliableLuma, float kappa)
29- {
30- RWMCReweightingParameters retval;
31- retval.lastCascadeIndex = CascadeSize - 1u;
32- retval.initialEmin = minReliableLuma;
33- retval.reciprocalBase = 1.f / base;
34- const float N = float (sampleCount);
35- retval.reciprocalN = 1.f / N;
36- retval.reciprocalKappa = 1.f / kappa;
37- // if not interested in exact expected value estimation (kappa!=1.f), can usually accept a bit more variance relative to the image brightness we already have
38- // allow up to ~<cascadeBase> more energy in one sample to lessen bias in some cases
39- retval.colorReliabilityFactor = base + (1.f - base) * retval.reciprocalKappa;
40- retval.NOverKappa = N * retval.reciprocalKappa;
41-
42- return retval;
43- }
44-
45- struct RWMCCascadeSample
46- {
47- float32_t3 centerValue;
48- float normalizedCenterLuma;
49- float normalizedNeighbourhoodAverageLuma;
50- };
51-
52- // TODO: figure out what values should pixels outside have, 0.0f is incorrect
53- float32_t3 RWMCsampleCascadeTexel (int32_t2 currentCoord, int32_t2 offset, uint32_t cascadeIndex)
54- {
55- const int32_t2 texelCoord = currentCoord + offset;
56- if (any (texelCoord < int32_t2 (0 , 0 )))
57- return float32_t3 (0.0f , 0.0f , 0.0f );
58-
59- float32_t4 output = cascade.Load (int32_t3 (texelCoord, int32_t (cascadeIndex)));
60- return float32_t3 (output.r, output.g, output.b);
61- }
62-
63- float32_t calcLuma (in float32_t3 col)
64- {
65- return hlsl::dot<float32_t3>(hlsl::transpose (colorspace::scRGBtoXYZ)[1 ], col);
66- }
67-
68- RWMCCascadeSample RWMCSampleCascade (in int32_t2 coord, in uint cascadeIndex, in float reciprocalBaseI)
69- {
70- float32_t3 neighbourhood[9 ];
71- neighbourhood[0 ] = RWMCsampleCascadeTexel (coord, int32_t2 (-1 , -1 ), cascadeIndex);
72- neighbourhood[1 ] = RWMCsampleCascadeTexel (coord, int32_t2 (0 , -1 ), cascadeIndex);
73- neighbourhood[2 ] = RWMCsampleCascadeTexel (coord, int32_t2 (1 , -1 ), cascadeIndex);
74- neighbourhood[3 ] = RWMCsampleCascadeTexel (coord, int32_t2 (-1 , 0 ), cascadeIndex);
75- neighbourhood[4 ] = RWMCsampleCascadeTexel (coord, int32_t2 (0 , 0 ), cascadeIndex);
76- neighbourhood[5 ] = RWMCsampleCascadeTexel (coord, int32_t2 (1 , 0 ), cascadeIndex);
77- neighbourhood[6 ] = RWMCsampleCascadeTexel (coord, int32_t2 (-1 , 1 ), cascadeIndex);
78- neighbourhood[7 ] = RWMCsampleCascadeTexel (coord, int32_t2 (0 , 1 ), cascadeIndex);
79- neighbourhood[8 ] = RWMCsampleCascadeTexel (coord, int32_t2 (1 , 1 ), cascadeIndex);
80-
81- // numerical robustness
82- float32_t3 excl_hood_sum = ((neighbourhood[0 ] + neighbourhood[1 ]) + (neighbourhood[2 ] + neighbourhood[3 ])) +
83- ((neighbourhood[5 ] + neighbourhood[6 ]) + (neighbourhood[7 ] + neighbourhood[8 ]));
84-
85- RWMCCascadeSample retval;
86- retval.centerValue = neighbourhood[4 ];
87- retval.normalizedNeighbourhoodAverageLuma = retval.normalizedCenterLuma = calcLuma (neighbourhood[4 ]) * reciprocalBaseI;
88- retval.normalizedNeighbourhoodAverageLuma = (calcLuma (excl_hood_sum) * reciprocalBaseI + retval.normalizedNeighbourhoodAverageLuma) / 9.f ;
89- return retval;
90- }
91-
92- float32_t3 RWMCReweight (in RWMCReweightingParameters params, in int32_t2 coord)
93- {
94- float reciprocalBaseI = 1.f ;
95- RWMCCascadeSample curr = RWMCSampleCascade (coord, 0u, reciprocalBaseI);
96-
97- float32_t3 accumulation = float32_t3 (0.0f , 0.0f , 0.0f );
98- float Emin = params.initialEmin;
99-
100- float prevNormalizedCenterLuma, prevNormalizedNeighbourhoodAverageLuma;
101- for (uint i = 0u; i <= params.lastCascadeIndex; i++)
102- {
103- const bool notFirstCascade = i != 0u;
104- const bool notLastCascade = i != params.lastCascadeIndex;
105-
106- RWMCCascadeSample next;
107- if (notLastCascade)
108- {
109- reciprocalBaseI *= params.reciprocalBase;
110- next = RWMCSampleCascade (coord, i + 1u, reciprocalBaseI);
111- }
112-
113- float reliability = 1.f ;
114- // sample counting-based reliability estimation
115- if (params.reciprocalKappa <= 1.f )
116- {
117- float localReliability = curr.normalizedCenterLuma;
118- // reliability in 3x3 pixel block (see robustness)
119- float globalReliability = curr.normalizedNeighbourhoodAverageLuma;
120- if (notFirstCascade)
121- {
122- localReliability += prevNormalizedCenterLuma;
123- globalReliability += prevNormalizedNeighbourhoodAverageLuma;
124- }
125- if (notLastCascade)
126- {
127- localReliability += next.normalizedCenterLuma;
128- globalReliability += next.normalizedNeighbourhoodAverageLuma;
129- }
130- // check if above minimum sampling threshold (avg 9 sample occurences in 3x3 neighbourhood), then use per-pixel reliability (NOTE: tertiary op is in reverse)
131- reliability = globalReliability < params.reciprocalN ? globalReliability : localReliability;
132- {
133- const float accumLuma = calcLuma (accumulation);
134- if (accumLuma > Emin)
135- Emin = accumLuma;
136-
137- const float colorReliability = Emin * reciprocalBaseI * params.colorReliabilityFactor;
138-
139- reliability += colorReliability;
140- reliability *= params.NOverKappa;
141- reliability -= params.reciprocalKappa;
142- reliability = clamp (reliability * 0.5f , 0.f , 1.f );
143- }
144- }
145- accumulation += curr.centerValue * reliability;
146-
147- prevNormalizedCenterLuma = curr.normalizedCenterLuma;
148- prevNormalizedNeighbourhoodAverageLuma = curr.normalizedNeighbourhoodAverageLuma;
149- curr = next;
150- }
151-
152- return accumulation;
153- }
154-
15519int32_t2 getCoordinates ()
15620{
15721 uint32_t width, height;
15822 outImage.GetDimensions (width, height);
15923 return int32_t2 (glsl::gl_GlobalInvocationID ().x % width, glsl::gl_GlobalInvocationID ().x / width);
16024}
16125
162- // this function is for testing purpose
163- // simply adds every cascade buffer, output shoud be nearly the same as output of default accumulator (RWMC off)
164- float32_t3 sumCascade (in const int32_t2 coords)
165- {
166- float32_t3 accumulation = float32_t3 (0.0f , 0.0f , 0.0f );
167-
168- for (int i = 0 ; i < 6 ; ++i)
169- {
170- float32_t4 cascadeLevel = cascade.Load (uint3 (coords, i));
171- accumulation += float32_t3 (cascadeLevel.r, cascadeLevel.g, cascadeLevel.b);
172- }
173-
174- return accumulation;
175- }
176-
17726[numthreads (WorkgroupSize, 1 , 1 )]
17827void main (uint32_t3 threadID : SV_DispatchThreadID )
17928{
29+ #ifdef PERSISTENT_WORKGROUPS
30+ uint32_t virtualThreadIndex;
31+ [loop]
32+ for (uint32_t virtualThreadBase = glsl::gl_WorkGroupID ().x * WorkgroupSize; virtualThreadBase < 1920 * 1080 ; virtualThreadBase += glsl::gl_NumWorkGroups ().x * WorkgroupSize) // not sure why 1280*720 doesn't cover draw surface
33+ {
34+ virtualThreadIndex = virtualThreadBase + glsl::gl_LocalInvocationIndex ().x;
35+ const int32_t2 coords = (int32_t2)math::Morton<uint32_t>::decode2d (virtualThreadIndex);
36+ #else
18037 const int32_t2 coords = getCoordinates ();
181- //float32_t3 color = sumCascade(coords);
38+ #endif
18239
183- RWMCReweightingParameters reweightingParameters = computeReweightingParameters (pc.base, pc.sampleCount, pc.minReliableLuma, pc.kappa);
184- float32_t3 color = RWMCReweight (reweightingParameters, coords);
40+ rwmc::ReweightingParameters reweightingParameters = rwmc:: computeReweightingParameters (pc.base, pc.sampleCount, pc.minReliableLuma, pc.kappa, CascadeSize );
41+ float32_t3 color = rwmc:: reweight (reweightingParameters, cascade , coords);
18542
186- outImage[coords] = float32_t4 (color, 1.0f );
43+ outImage[coords] = float32_t4 (color, 1.0f );
44+
45+ #ifdef PERSISTENT_WORKGROUPS
46+ }
47+ #endif
18748}
0 commit comments