11import { HalfFloatType , Vector2 , RenderTarget , RendererUtils , QuadMesh , NodeMaterial , TempNode , NodeUpdateType , Matrix4 , DepthTexture } from 'three/webgpu' ;
2- import { add , float , If , Loop , int , Fn , min , max , clamp , nodeObject , texture , uniform , uv , vec2 , vec4 , luminance , convertToTexture , passTexture , velocity , getViewPosition , length } from 'three/tsl' ;
2+ import { add , float , If , Loop , int , Fn , min , max , clamp , nodeObject , texture , uniform , uv , vec2 , vec4 , luminance , convertToTexture , passTexture , velocity , getViewPosition , viewZToPerspectiveDepth } from 'three/tsl' ;
33
44const _quadMesh = /*@__PURE__ */ new QuadMesh ( ) ;
55const _size = /*@__PURE__ */ new Vector2 ( ) ;
@@ -84,53 +84,36 @@ class TRAANode extends TempNode {
8484 this . camera = camera ;
8585
8686 /**
87- * The jitter index selects the current camera offset value.
87+ * When the difference between the current and previous depth goes above
88+ * this threshold, the history is considered invalid.
8889 *
89- * @private
9090 * @type {number }
91- * @default 0
92- */
93- this . _jitterIndex = 0 ;
94-
95- /**
96- * A uniform node holding the inverse resolution value.
97- *
98- * @private
99- * @type {UniformNode<vec2> }
100- */
101- this . _invSize = uniform ( new Vector2 ( ) ) ;
102-
103- /**
104- * A uniform node holding the camera world matrix.
105- *
106- * @private
107- * @type {UniformNode<mat4> }
10891 */
109- this . _cameraWorldMatrix = uniform ( new Matrix4 ( ) ) ;
92+ this . depthThreshold = 0.0001 ;
11093
11194 /**
112- * A uniform node holding the camera projection matrix inverse .
95+ * The depth difference within the 3×3 neighborhood to consider a pixel as an edge .
11396 *
114- * @private
115- * @type {UniformNode<mat4> }
97+ * @type {number }
11698 */
117- this . _cameraProjectionMatrixInverse = uniform ( new Matrix4 ( ) ) ;
99+ this . edgeDepthDiff = 0.0001 ;
118100
119101 /**
120- * A uniform node holding the previous frame's view matrix .
102+ * The jitter index selects the current camera offset value .
121103 *
122104 * @private
123- * @type {UniformNode<mat4> }
105+ * @type {number }
106+ * @default 0
124107 */
125- this . _previousCameraWorldMatrix = uniform ( new Matrix4 ( ) ) ;
108+ this . _jitterIndex = 0 ;
126109
127110 /**
128- * A uniform node holding the previous frame's projection matrix inverse .
111+ * A uniform node holding the inverse resolution value .
129112 *
130113 * @private
131- * @type {UniformNode<mat4 > }
114+ * @type {UniformNode<vec2 > }
132115 */
133- this . _previousCameraProjectionMatrixInverse = uniform ( new Matrix4 ( ) ) ;
116+ this . _invSize = uniform ( new Vector2 ( ) ) ;
134117
135118 /**
136119 * The render target that represents the history of frame data.
@@ -175,6 +158,54 @@ class TRAANode extends TempNode {
175158 */
176159 this . _originalProjectionMatrix = new Matrix4 ( ) ;
177160
161+ /**
162+ * A uniform node holding the camera's near and far.
163+ *
164+ * @private
165+ * @type {UniformNode<vec2> }
166+ */
167+ this . _cameraNearFar = uniform ( new Vector2 ( ) ) ;
168+
169+ /**
170+ * A uniform node holding the camera world matrix.
171+ *
172+ * @private
173+ * @type {UniformNode<mat4> }
174+ */
175+ this . _cameraWorldMatrix = uniform ( new Matrix4 ( ) ) ;
176+
177+ /**
178+ * A uniform node holding the camera world matrix inverse.
179+ *
180+ * @private
181+ * @type {UniformNode<mat4> }
182+ */
183+ this . _cameraWorldMatrixInverse = uniform ( new Matrix4 ( ) ) ;
184+
185+ /**
186+ * A uniform node holding the camera projection matrix inverse.
187+ *
188+ * @private
189+ * @type {UniformNode<mat4> }
190+ */
191+ this . _cameraProjectionMatrixInverse = uniform ( new Matrix4 ( ) ) ;
192+
193+ /**
194+ * A uniform node holding the previous frame's view matrix.
195+ *
196+ * @private
197+ * @type {UniformNode<mat4> }
198+ */
199+ this . _previousCameraWorldMatrix = uniform ( new Matrix4 ( ) ) ;
200+
201+ /**
202+ * A uniform node holding the previous frame's projection matrix inverse.
203+ *
204+ * @private
205+ * @type {UniformNode<mat4> }
206+ */
207+ this . _previousCameraProjectionMatrixInverse = uniform ( new Matrix4 ( ) ) ;
208+
178209 /**
179210 * A texture node for the previous depth buffer.
180211 *
@@ -293,7 +324,9 @@ class TRAANode extends TempNode {
293324
294325 // update camera matrices uniforms
295326
327+ this . _cameraNearFar . value . set ( this . camera . near , this . camera . far ) ;
296328 this . _cameraWorldMatrix . value . copy ( this . camera . matrixWorld ) ;
329+ this . _cameraWorldMatrixInverse . value . copy ( this . camera . matrixWorldInverse ) ;
297330 this . _cameraProjectionMatrixInverse . value . copy ( this . camera . projectionMatrixInverse ) ;
298331
299332 // keep the TRAA in sync with the dimensions of the beauty node
@@ -408,14 +441,24 @@ class TRAANode extends TempNode {
408441 const depthTexture = this . depthNode ;
409442 const velocityTexture = this . velocityNode ;
410443
444+ const samplePreviousDepth = ( uv ) => {
445+
446+ const depth = this . _previousDepthNode . sample ( uv ) . r ;
447+ const positionView = getViewPosition ( uv , depth , this . _previousCameraProjectionMatrixInverse ) ;
448+ const positionWorld = this . _previousCameraWorldMatrix . mul ( vec4 ( positionView , 1 ) ) . xyz ;
449+ const viewZ = this . _cameraWorldMatrixInverse . mul ( vec4 ( positionWorld , 1 ) ) . z ;
450+ return viewZToPerspectiveDepth ( viewZ , this . _cameraNearFar . x , this . _cameraNearFar . y ) ;
451+
452+ } ;
453+
411454 const resolve = Fn ( ( ) => {
412455
413456 const uvNode = uv ( ) ;
414457
415458 const minColor = vec4 ( 10000 ) . toVar ( ) ;
416459 const maxColor = vec4 ( - 10000 ) . toVar ( ) ;
417- const closestDepth = float ( 1 ) . toVar ( ) ;
418- const farthestDepth = float ( 0 ) . toVar ( ) ;
460+ const closestDepth = float ( 2 ) . toVar ( ) ;
461+ const farthestDepth = float ( - 1 ) . toVar ( ) ;
419462 const closestDepthPixelPosition = vec2 ( 0 ) . toVar ( ) ;
420463
421464 // sample a 3x3 neighborhood to create a box in color space
@@ -465,47 +508,24 @@ class TRAANode extends TempNode {
465508
466509 const clampedHistoryColor = clamp ( historyColor , minColor , maxColor ) ;
467510
468- // calculate current frame world position
511+ // sample the current and previous depths
469512
470513 const currentDepth = depthTexture . sample ( uvNode ) . r ;
471- const currentViewPosition = getViewPosition ( uvNode , currentDepth , this . _cameraProjectionMatrixInverse ) ;
472- const currentWorldPosition = this . _cameraWorldMatrix . mul ( vec4 ( currentViewPosition , 1.0 ) ) . xyz ;
473-
474- // calculate previous frame world position from history UV and previous depth
475-
476514 const historyUV = uvNode . sub ( offset ) ;
477- const previousDepth = this . _previousDepthNode . sample ( historyUV ) . r ;
478- const previousViewPosition = getViewPosition ( historyUV , previousDepth , this . _previousCameraProjectionMatrixInverse ) ;
479- const previousWorldPosition = this . _previousCameraWorldMatrix . mul ( vec4 ( previousViewPosition , 1.0 ) ) . xyz ;
480-
481- // calculate difference in world positions
515+ const previousDepth = samplePreviousDepth ( historyUV ) ;
482516
483- const worldPositionDifference = length ( currentWorldPosition . sub ( previousWorldPosition ) ) . toVar ( ) ;
484- worldPositionDifference . assign ( min ( max ( worldPositionDifference . sub ( 1.0 ) , 0.0 ) , 1.0 ) ) ;
517+ // disocclusion except on edges
485518
486- // Adaptive blend weights based on velocity magnitude suggested by CLAUDE in #32133
487- // Higher velocity or position difference = more weight on current frame to reduce ghosting
519+ const isEdge = farthestDepth . sub ( closestDepth ) . greaterThan ( this . edgeDepthDiff ) ;
520+ const isDisocclusion = currentDepth . sub ( previousDepth ) . greaterThan ( this . depthThreshold ) . and ( isEdge . not ( ) ) ;
488521
489- const velocityMagnitude = length ( offset ) . toConst ( ) ;
490- const motionFactor = max ( worldPositionDifference . mul ( 0.5 ) , velocityMagnitude . mul ( 10.0 ) ) . toVar ( ) ;
491- motionFactor . assign ( min ( motionFactor , 1.0 ) ) ;
522+ // higher velocity = more weight on current frame
523+ // zero out history weight where disocclusion
492524
493- const currentWeight = float ( 0.05 ) . add ( motionFactor . mul ( 0.25 ) ) . toVar ( ) ;
525+ const motionFactor = uvNode . sub ( historyUV ) . length ( ) . mul ( 10 ) ;
526+ const currentWeight = isDisocclusion . select ( 1 , float ( 0.05 ) . add ( motionFactor ) . saturate ( ) ) . toVar ( ) ;
494527 const historyWeight = currentWeight . oneMinus ( ) . toVar ( ) ;
495528
496- // zero out history weight if world positions are different (indicating motion) except on edges.
497- // note that the constants 0.00001 and 0.5 were suggested by CLAUDE in #32133
498-
499- const isEdge = farthestDepth . sub ( closestDepth ) . greaterThan ( 0.00001 ) ;
500- const strongDisocclusion = worldPositionDifference . greaterThan ( 0.5 ) . and ( isEdge . not ( ) ) ;
501-
502- If ( strongDisocclusion , ( ) => {
503-
504- currentWeight . assign ( 1.0 ) ;
505- historyWeight . assign ( 0.0 ) ;
506-
507- } ) ;
508-
509529 // flicker reduction based on luminance weighing
510530
511531 const compressedCurrent = currentColor . mul ( float ( 1 ) . div ( ( max ( currentColor . r , currentColor . g , currentColor . b ) . add ( 1.0 ) ) ) ) ;
@@ -514,8 +534,8 @@ class TRAANode extends TempNode {
514534 const luminanceCurrent = luminance ( compressedCurrent . rgb ) ;
515535 const luminanceHistory = luminance ( compressedHistory . rgb ) ;
516536
517- currentWeight . mulAssign ( float ( 1.0 ) . div ( luminanceCurrent . add ( 1 ) ) ) ;
518- historyWeight . mulAssign ( float ( 1.0 ) . div ( luminanceHistory . add ( 1 ) ) ) ;
537+ currentWeight . mulAssign ( float ( 1 ) . div ( luminanceCurrent . add ( 1 ) ) ) ;
538+ historyWeight . mulAssign ( float ( 1 ) . div ( luminanceHistory . add ( 1 ) ) ) ;
519539
520540 const smoothedOutput = add ( currentColor . mul ( currentWeight ) , clampedHistoryColor . mul ( historyWeight ) ) . div ( max ( currentWeight . add ( historyWeight ) , 0.00001 ) ) . toVar ( ) ;
521541
0 commit comments