11import { HalfFloatType , Vector2 , RenderTarget , RendererUtils , QuadMesh , NodeMaterial , TempNode , NodeUpdateType , Matrix4 , DepthTexture } from 'three/webgpu' ;
2- import { add , float , If , Loop , int , Fn , min , max , clamp , nodeObject , texture , uniform , uv , vec2 , vec4 , luminance , convertToTexture , passTexture , velocity , getViewPosition , viewZToPerspectiveDepth } from 'three/tsl' ;
2+ import { add , float , If , Fn , max , nodeObject , texture , uniform , uv , vec2 , vec4 , luminance , convertToTexture , passTexture , velocity , getViewPosition , viewZToPerspectiveDepth , struct , ivec2 , mix } from 'three/tsl' ;
33
44const _quadMesh = /*@__PURE__ */ new QuadMesh ( ) ;
55const _size = /*@__PURE__ */ new Vector2 ( ) ;
@@ -77,26 +77,45 @@ class TRAANode extends TempNode {
7777 this . velocityNode = velocityNode ;
7878
7979 /**
80- * The camera the scene is rendered with.
80+ * The camera the scene is rendered with.
8181 *
8282 * @type {Camera }
8383 */
8484 this . camera = camera ;
8585
8686 /**
87- * When the difference between the current and previous depth goes above
88- * this threshold, the history is considered invalid.
87+ * When the difference between the current and previous depth goes above this threshold,
88+ * the history is considered invalid.
8989 *
9090 * @type {number }
91+ * @default 0.0005
9192 */
92- this . depthThreshold = 0.0001 ;
93+ this . depthThreshold = 0.0005 ;
9394
9495 /**
9596 * The depth difference within the 3×3 neighborhood to consider a pixel as an edge.
9697 *
9798 * @type {number }
99+ * @default 0.001
98100 */
99- this . edgeDepthDiff = 0.0001 ;
101+ this . edgeDepthDiff = 0.001 ;
102+
103+ /**
104+ * The history becomes invalid as the pixel length of the velocity approaches this value.
105+ *
106+ * @type {number }
107+ * @default 128
108+ */
109+ this . maxVelocityLength = 128 ;
110+
111+ /**
112+ * Whether to decrease the weight on the current frame when the velocity is more subpixel.
113+ * This reduces blurriness under motion, but can introduce a square pattern artifact.
114+ *
115+ * @type {boolean }
116+ * @default true
117+ */
118+ this . useSubpixelCorrection = true ;
100119
101120 /**
102121 * The jitter index selects the current camera offset value.
@@ -436,11 +455,50 @@ class TRAANode extends TempNode {
436455
437456 }
438457
439- const historyTexture = texture ( this . _historyRenderTarget . texture ) ;
440- const sampleTexture = this . beautyNode ;
441- const depthTexture = this . depthNode ;
442- const velocityTexture = this . velocityNode ;
458+ const currentDepthStruct = struct ( {
459+
460+ closestDepth : 'float' ,
461+ closestPositionTexel : 'vec2' ,
462+ farthestDepth : 'float' ,
463+
464+ } ) ;
465+
466+ // Samples 3×3 neighborhood pixels and returns the closest and farthest depths.
467+ const sampleCurrentDepth = Fn ( ( [ positionTexel ] ) => {
468+
469+ const closestDepth = float ( 2 ) . toVar ( ) ;
470+ const closestPositionTexel = vec2 ( 0 ) . toVar ( ) ;
471+ const farthestDepth = float ( - 1 ) . toVar ( ) ;
472+
473+ for ( let x = - 1 ; x <= 1 ; ++ x ) {
474+
475+ for ( let y = - 1 ; y <= 1 ; ++ y ) {
476+
477+ const neighbor = positionTexel . add ( vec2 ( x , y ) ) . toVar ( ) ;
478+ const depth = this . depthNode . load ( neighbor ) . r . toVar ( ) ;
479+
480+ If ( depth . lessThan ( closestDepth ) , ( ) => {
481+
482+ closestDepth . assign ( depth ) ;
483+ closestPositionTexel . assign ( neighbor ) ;
484+
485+ } ) ;
486+
487+ If ( depth . greaterThan ( farthestDepth ) , ( ) => {
488+
489+ farthestDepth . assign ( depth ) ;
490+
491+ } ) ;
492+
493+ }
494+
495+ }
443496
497+ return currentDepthStruct ( closestDepth , closestPositionTexel , farthestDepth ) ;
498+
499+ } ) ;
500+
501+ // Samples a previous depth and reproject it using the current camera matrices.
444502 const samplePreviousDepth = ( uv ) => {
445503
446504 const depth = this . _previousDepthNode . sample ( uv ) . r ;
@@ -451,95 +509,163 @@ class TRAANode extends TempNode {
451509
452510 } ;
453511
454- const resolve = Fn ( ( ) => {
512+ // Optimized version of AABB clipping.
513+ // Reference: https://github.com/playdeadgames/temporal
514+ const clipAABB = Fn ( ( [ currentColor , historyColor , minColor , maxColor ] ) => {
515+
516+ const pClip = maxColor . rgb . add ( minColor . rgb ) . mul ( 0.5 ) ;
517+ const eClip = maxColor . rgb . sub ( minColor . rgb ) . mul ( 0.5 ) . add ( 1e-7 ) ;
518+ const vClip = historyColor . sub ( vec4 ( pClip , currentColor . a ) ) ;
519+ const vUnit = vClip . xyz . div ( eClip ) ;
520+ const absUnit = vUnit . abs ( ) ;
521+ const maxUnit = max ( absUnit . x , absUnit . y , absUnit . z ) ;
522+ return maxUnit . greaterThan ( 1 ) . select (
523+ vec4 ( pClip , currentColor . a ) . add ( vClip . div ( maxUnit ) ) ,
524+ historyColor
525+ ) ;
526+
527+ } ) . setLayout ( {
528+ name : 'clipAABB' ,
529+ type : 'vec4' ,
530+ inputs : [
531+ { name : 'currentColor' , type : 'vec4' } ,
532+ { name : 'historyColor' , type : 'vec4' } ,
533+ { name : 'minColor' , type : 'vec4' } ,
534+ { name : 'maxColor' , type : 'vec4' }
535+ ]
536+ } ) ;
455537
456- const uvNode = uv ( ) ;
538+ // Performs variance clipping.
539+ // See: https://developer.download.nvidia.com/gameworks/events/GDC2016/msalvi_temporal_supersampling.pdf
540+ const varianceClipping = Fn ( ( [ positionTexel , currentColor , historyColor , gamma ] ) => {
457541
458- const minColor = vec4 ( 10000 ) . toVar ( ) ;
459- const maxColor = vec4 ( - 10000 ) . toVar ( ) ;
460- const closestDepth = float ( 2 ) . toVar ( ) ;
461- const farthestDepth = float ( - 1 ) . toVar ( ) ;
462- const closestDepthPixelPosition = vec2 ( 0 ) . toVar ( ) ;
542+ const offsets = [
543+ [ - 1 , - 1 ] ,
544+ [ - 1 , 1 ] ,
545+ [ 1 , - 1 ] ,
546+ [ 1 , 1 ] ,
547+ [ 1 , 0 ] ,
548+ [ 0 , - 1 ] ,
549+ [ 0 , 1 ] ,
550+ [ - 1 , 0 ]
551+ ] ;
463552
464- // sample a 3x3 neighborhood to create a box in color space
465- // clamping the history color with the resulting min/max colors mitigates ghosting
553+ const moment1 = currentColor . toVar ( ) ;
554+ const moment2 = currentColor . pow2 ( ) . toVar ( ) ;
466555
467- Loop ( { start : int ( - 1 ) , end : int ( 1 ) , type : 'int' , condition : '<=' , name : 'x' } , ( { x } ) => {
556+ for ( const [ x , y ] of offsets ) {
468557
469- Loop ( { start : int ( - 1 ) , end : int ( 1 ) , type : 'int' , condition : '<=' , name : 'y' } , ( { y } ) => {
558+ // Use max() to prevent NaN values from propagating.
559+ const neighbor = this . beautyNode . offset ( ivec2 ( x , y ) ) . load ( positionTexel ) . max ( 0 ) ;
560+ moment1 . addAssign ( neighbor ) ;
561+ moment2 . addAssign ( neighbor . pow2 ( ) ) ;
470562
471- const uvNeighbor = uvNode . add ( vec2 ( float ( x ) , float ( y ) ) . mul ( this . _invSize ) ) . toVar ( ) ;
472- const colorNeighbor = max ( vec4 ( 0 ) , sampleTexture . sample ( uvNeighbor ) ) . toVar ( ) ; // use max() to avoid propagate garbage values
563+ }
473564
474- minColor . assign ( min ( minColor , colorNeighbor ) ) ;
475- maxColor . assign ( max ( maxColor , colorNeighbor ) ) ;
565+ const N = float ( offsets . length + 1 ) ;
566+ const mean = moment1 . div ( N ) ;
567+ const variance = moment2 . div ( N ) . sub ( mean . pow2 ( ) ) . max ( 0 ) . sqrt ( ) . mul ( gamma ) ;
568+ const minColor = mean . sub ( variance ) ;
569+ const maxColor = mean . add ( variance ) ;
476570
477- const currentDepth = depthTexture . sample ( uvNeighbor ) . r . toVar ( ) ;
571+ return clipAABB ( mean . clamp ( minColor , maxColor ) , historyColor , minColor , maxColor ) ;
478572
479- // find the sample position of the closest depth in the neighborhood (used for velocity)
573+ } ) ;
480574
481- If ( currentDepth . lessThan ( closestDepth ) , ( ) => {
575+ // Returns the amount of subpixel (expressed within [0, 1]) in the velocity.
576+ const subpixelCorrection = Fn ( ( [ velocityUV , textureSize ] ) => {
577+
578+ const velocityTexel = velocityUV . mul ( textureSize ) ;
579+ const phase = velocityTexel . fract ( ) . abs ( ) ;
580+ const weight = max ( phase , phase . oneMinus ( ) ) ;
581+ return weight . x . mul ( weight . y ) . oneMinus ( ) . div ( 0.75 ) ;
582+
583+ } ) . setLayout ( {
584+ name : 'subpixelCorrection' ,
585+ type : 'float' ,
586+ inputs : [
587+ { name : 'velocityUV' , type : 'vec2' } ,
588+ { name : 'textureSize' , type : 'ivec2' }
589+ ]
590+ } ) ;
482591
483- closestDepth . assign ( currentDepth ) ;
484- closestDepthPixelPosition . assign ( uvNeighbor ) ;
592+ // Flicker reduction based on luminance weighing.
593+ const flickerReduction = Fn ( ( [ currentColor , historyColor , currentWeight ] ) => {
485594
486- } ) ;
595+ const historyWeight = currentWeight . oneMinus ( ) ;
596+ const compressedCurrent = currentColor . mul ( float ( 1 ) . div ( ( max ( currentColor . r , currentColor . g , currentColor . b ) . add ( 1 ) ) ) ) ;
597+ const compressedHistory = historyColor . mul ( float ( 1 ) . div ( ( max ( historyColor . r , historyColor . g , historyColor . b ) . add ( 1 ) ) ) ) ;
487598
488- // find the farthest depth in the neighborhood (used to preserve edge anti-aliasing)
599+ const luminanceCurrent = luminance ( compressedCurrent . rgb ) ;
600+ const luminanceHistory = luminance ( compressedHistory . rgb ) ;
489601
490- If ( currentDepth . greaterThan ( farthestDepth ) , ( ) => {
602+ currentWeight . mulAssign ( float ( 1 ) . div ( luminanceCurrent . add ( 1 ) ) ) ;
603+ historyWeight . mulAssign ( float ( 1 ) . div ( luminanceHistory . add ( 1 ) ) ) ;
491604
492- farthestDepth . assign ( currentDepth ) ;
605+ return add ( currentColor . mul ( currentWeight ) , historyColor . mul ( historyWeight ) ) . div ( max ( currentWeight . add ( historyWeight ) , 0.00001 ) ) . toVar ( ) ;
493606
494- } ) ;
607+ } ) ;
495608
496- } ) ;
609+ const historyNode = texture ( this . _historyRenderTarget . texture ) ;
497610
498- } ) ;
611+ const resolve = Fn ( ( ) => {
499612
500- // sampling/reprojection
613+ const uvNode = uv ( ) ;
614+ const textureSize = this . beautyNode . size ( ) ; // Assumes all the buffers share the same size.
615+ const positionTexel = uvNode . mul ( textureSize ) ;
501616
502- const offset = velocityTexture . sample ( closestDepthPixelPosition ) . xy . mul ( vec2 ( 0.5 , - 0.5 ) ) ; // NDC to uv offset
617+ // sample the closest and farthest depths in the current buffer
503618
504- const currentColor = sampleTexture . sample ( uvNode ) ;
505- const historyColor = historyTexture . sample ( uvNode . sub ( offset ) ) ;
619+ const currentDepth = sampleCurrentDepth ( positionTexel ) ;
620+ const closestDepth = currentDepth . get ( 'closestDepth' ) ;
621+ const closestPositionTexel = currentDepth . get ( 'closestPositionTexel' ) ;
622+ const farthestDepth = currentDepth . get ( 'farthestDepth' ) ;
506623
507- // clamping
624+ // convert the NDC offset to UV offset
508625
509- const clampedHistoryColor = clamp ( historyColor , minColor , maxColor ) ;
626+ const offsetUV = this . velocityNode . load ( closestPositionTexel ) . xy . mul ( vec2 ( 0.5 , - 0.5 ) ) ;
510627
511- // sample the current and previous depths
628+ // sample the previous depth
512629
513- const currentDepth = depthTexture . sample ( uvNode ) . r ;
514- const historyUV = uvNode . sub ( offset ) ;
630+ const historyUV = uvNode . sub ( offsetUV ) ;
515631 const previousDepth = samplePreviousDepth ( historyUV ) ;
516632
517- // disocclusion except on edges
633+ // history is considered valid when the UV is in range and there's no disocclusion except on edges
518634
635+ const isValidUV = historyUV . greaterThanEqual ( 0 ) . all ( ) . and ( historyUV . lessThanEqual ( 1 ) . all ( ) ) ;
519636 const isEdge = farthestDepth . sub ( closestDepth ) . greaterThan ( this . edgeDepthDiff ) ;
520- const isDisocclusion = currentDepth . sub ( previousDepth ) . greaterThan ( this . depthThreshold ) . and ( isEdge . not ( ) ) ;
637+ const isDisocclusion = closestDepth . sub ( previousDepth ) . greaterThan ( this . depthThreshold ) ;
638+ const hasValidHistory = isValidUV . and ( isEdge . or ( isDisocclusion . not ( ) ) ) ;
521639
522- // higher velocity = more weight on current frame
523- // zero out history weight where disocclusion
640+ // sample the current and previous colors
524641
525- const motionFactor = uvNode . sub ( historyUV ) . length ( ) . mul ( 10 ) ;
526- const currentWeight = isDisocclusion . select ( 1 , float ( 0.05 ) . add ( motionFactor ) . saturate ( ) ) . toVar ( ) ;
527- const historyWeight = currentWeight . oneMinus ( ) . toVar ( ) ;
642+ const currentColor = this . beautyNode . sample ( uvNode ) ;
643+ const historyColor = historyNode . sample ( uvNode . sub ( offsetUV ) ) ;
528644
529- // flicker reduction based on luminance weighing
645+ // increase the weight towards the current frame under motion
530646
531- const compressedCurrent = currentColor . mul ( float ( 1 ) . div ( ( max ( currentColor . r , currentColor . g , currentColor . b ) . add ( 1.0 ) ) ) ) ;
532- const compressedHistory = clampedHistoryColor . mul ( float ( 1 ) . div ( ( max ( clampedHistoryColor . r , clampedHistoryColor . g , clampedHistoryColor . b ) . add ( 1.0 ) ) ) ) ;
647+ const motionFactor = uvNode . sub ( historyUV ) . mul ( textureSize ) . length ( ) . div ( this . maxVelocityLength ) . saturate ( ) ;
648+ const currentWeight = float ( 0.05 ) . toVar ( ) ; // A minimum weight
533649
534- const luminanceCurrent = luminance ( compressedCurrent . rgb ) ;
535- const luminanceHistory = luminance ( compressedHistory . rgb ) ;
650+ if ( this . useSubpixelCorrection ) {
536651
537- currentWeight . mulAssign ( float ( 1 ) . div ( luminanceCurrent . add ( 1 ) ) ) ;
538- historyWeight . mulAssign ( float ( 1 ) . div ( luminanceHistory . add ( 1 ) ) ) ;
652+ // Increase the minimum weight towards the current frame when the velocity is more subpixel.
653+ currentWeight . addAssign ( subpixelCorrection ( offsetUV , textureSize ) . mul ( 0.25 ) ) ;
654+
655+ }
656+
657+ currentWeight . assign ( hasValidHistory . select ( currentWeight . add ( motionFactor ) . saturate ( ) , 1 ) ) ;
658+
659+ // Perform neighborhood clipping/clamping. We use variance clipping here.
660+
661+ const varianceGamma = mix ( 0.5 , 1 , motionFactor . oneMinus ( ) . pow2 ( ) ) ; // Reasonable gamma range is [0.75, 2]
662+ const clippedHistoryColor = varianceClipping ( positionTexel , currentColor , historyColor , varianceGamma ) ;
663+
664+ // flicker reduction based on luminance weighing
539665
540- const smoothedOutput = add ( currentColor . mul ( currentWeight ) , clampedHistoryColor . mul ( historyWeight ) ) . div ( max ( currentWeight . add ( historyWeight ) , 0.00001 ) ) . toVar ( ) ;
666+ const output = flickerReduction ( currentColor , clippedHistoryColor , currentWeight ) ;
541667
542- return smoothedOutput ;
668+ return output ;
543669
544670 } ) ;
545671
0 commit comments