Skip to content

Commit 32c2d19

Browse files
authored
TRAANode: Improve motion factor and disocclusion (#32296)
1 parent 889a39d commit 32c2d19

File tree

1 file changed

+87
-67
lines changed

1 file changed

+87
-67
lines changed

examples/jsm/tsl/display/TRAANode.js

Lines changed: 87 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { HalfFloatType, Vector2, RenderTarget, RendererUtils, QuadMesh, NodeMaterial, TempNode, NodeUpdateType, Matrix4, DepthTexture } from 'three/webgpu';
2-
import { add, float, If, Loop, int, Fn, min, max, clamp, nodeObject, texture, uniform, uv, vec2, vec4, luminance, convertToTexture, passTexture, velocity, getViewPosition, length } from 'three/tsl';
2+
import { add, float, If, Loop, int, Fn, min, max, clamp, nodeObject, texture, uniform, uv, vec2, vec4, luminance, convertToTexture, passTexture, velocity, getViewPosition, viewZToPerspectiveDepth } from 'three/tsl';
33

44
const _quadMesh = /*@__PURE__*/ new QuadMesh();
55
const _size = /*@__PURE__*/ new Vector2();
@@ -84,53 +84,36 @@ class TRAANode extends TempNode {
8484
this.camera = camera;
8585

8686
/**
87-
* The jitter index selects the current camera offset value.
87+
* When the difference between the current and previous depth goes above
88+
* this threshold, the history is considered invalid.
8889
*
89-
* @private
9090
* @type {number}
91-
* @default 0
92-
*/
93-
this._jitterIndex = 0;
94-
95-
/**
96-
* A uniform node holding the inverse resolution value.
97-
*
98-
* @private
99-
* @type {UniformNode<vec2>}
100-
*/
101-
this._invSize = uniform( new Vector2() );
102-
103-
/**
104-
* A uniform node holding the camera world matrix.
105-
*
106-
* @private
107-
* @type {UniformNode<mat4>}
10891
*/
109-
this._cameraWorldMatrix = uniform( new Matrix4() );
92+
this.depthThreshold = 0.0001;
11093

11194
/**
112-
* A uniform node holding the camera projection matrix inverse.
95+
* The depth difference within the 3×3 neighborhood to consider a pixel as an edge.
11396
*
114-
* @private
115-
* @type {UniformNode<mat4>}
97+
* @type {number}
11698
*/
117-
this._cameraProjectionMatrixInverse = uniform( new Matrix4() );
99+
this.edgeDepthDiff = 0.0001;
118100

119101
/**
120-
* A uniform node holding the previous frame's view matrix.
102+
* The jitter index selects the current camera offset value.
121103
*
122104
* @private
123-
* @type {UniformNode<mat4>}
105+
* @type {number}
106+
* @default 0
124107
*/
125-
this._previousCameraWorldMatrix = uniform( new Matrix4() );
108+
this._jitterIndex = 0;
126109

127110
/**
128-
* A uniform node holding the previous frame's projection matrix inverse.
111+
* A uniform node holding the inverse resolution value.
129112
*
130113
* @private
131-
* @type {UniformNode<mat4>}
114+
* @type {UniformNode<vec2>}
132115
*/
133-
this._previousCameraProjectionMatrixInverse = uniform( new Matrix4() );
116+
this._invSize = uniform( new Vector2() );
134117

135118
/**
136119
* The render target that represents the history of frame data.
@@ -175,6 +158,54 @@ class TRAANode extends TempNode {
175158
*/
176159
this._originalProjectionMatrix = new Matrix4();
177160

161+
/**
162+
* A uniform node holding the camera's near and far.
163+
*
164+
* @private
165+
* @type {UniformNode<vec2>}
166+
*/
167+
this._cameraNearFar = uniform( new Vector2() );
168+
169+
/**
170+
* A uniform node holding the camera world matrix.
171+
*
172+
* @private
173+
* @type {UniformNode<mat4>}
174+
*/
175+
this._cameraWorldMatrix = uniform( new Matrix4() );
176+
177+
/**
178+
* A uniform node holding the camera world matrix inverse.
179+
*
180+
* @private
181+
* @type {UniformNode<mat4>}
182+
*/
183+
this._cameraWorldMatrixInverse = uniform( new Matrix4() );
184+
185+
/**
186+
* A uniform node holding the camera projection matrix inverse.
187+
*
188+
* @private
189+
* @type {UniformNode<mat4>}
190+
*/
191+
this._cameraProjectionMatrixInverse = uniform( new Matrix4() );
192+
193+
/**
194+
* A uniform node holding the previous frame's view matrix.
195+
*
196+
* @private
197+
* @type {UniformNode<mat4>}
198+
*/
199+
this._previousCameraWorldMatrix = uniform( new Matrix4() );
200+
201+
/**
202+
* A uniform node holding the previous frame's projection matrix inverse.
203+
*
204+
* @private
205+
* @type {UniformNode<mat4>}
206+
*/
207+
this._previousCameraProjectionMatrixInverse = uniform( new Matrix4() );
208+
178209
/**
179210
* A texture node for the previous depth buffer.
180211
*
@@ -293,7 +324,9 @@ class TRAANode extends TempNode {
293324

294325
// update camera matrices uniforms
295326

327+
this._cameraNearFar.value.set( this.camera.near, this.camera.far );
296328
this._cameraWorldMatrix.value.copy( this.camera.matrixWorld );
329+
this._cameraWorldMatrixInverse.value.copy( this.camera.matrixWorldInverse );
297330
this._cameraProjectionMatrixInverse.value.copy( this.camera.projectionMatrixInverse );
298331

299332
// keep the TRAA in sync with the dimensions of the beauty node
@@ -408,14 +441,24 @@ class TRAANode extends TempNode {
408441
const depthTexture = this.depthNode;
409442
const velocityTexture = this.velocityNode;
410443

444+
const samplePreviousDepth = ( uv ) => {
445+
446+
const depth = this._previousDepthNode.sample( uv ).r;
447+
const positionView = getViewPosition( uv, depth, this._previousCameraProjectionMatrixInverse );
448+
const positionWorld = this._previousCameraWorldMatrix.mul( vec4( positionView, 1 ) ).xyz;
449+
const viewZ = this._cameraWorldMatrixInverse.mul( vec4( positionWorld, 1 ) ).z;
450+
return viewZToPerspectiveDepth( viewZ, this._cameraNearFar.x, this._cameraNearFar.y );
451+
452+
};
453+
411454
const resolve = Fn( () => {
412455

413456
const uvNode = uv();
414457

415458
const minColor = vec4( 10000 ).toVar();
416459
const maxColor = vec4( - 10000 ).toVar();
417-
const closestDepth = float( 1 ).toVar();
418-
const farthestDepth = float( 0 ).toVar();
460+
const closestDepth = float( 2 ).toVar();
461+
const farthestDepth = float( - 1 ).toVar();
419462
const closestDepthPixelPosition = vec2( 0 ).toVar();
420463

421464
// sample a 3x3 neighborhood to create a box in color space
@@ -465,47 +508,24 @@ class TRAANode extends TempNode {
465508

466509
const clampedHistoryColor = clamp( historyColor, minColor, maxColor );
467510

468-
// calculate current frame world position
511+
// sample the current and previous depths
469512

470513
const currentDepth = depthTexture.sample( uvNode ).r;
471-
const currentViewPosition = getViewPosition( uvNode, currentDepth, this._cameraProjectionMatrixInverse );
472-
const currentWorldPosition = this._cameraWorldMatrix.mul( vec4( currentViewPosition, 1.0 ) ).xyz;
473-
474-
// calculate previous frame world position from history UV and previous depth
475-
476514
const historyUV = uvNode.sub( offset );
477-
const previousDepth = this._previousDepthNode.sample( historyUV ).r;
478-
const previousViewPosition = getViewPosition( historyUV, previousDepth, this._previousCameraProjectionMatrixInverse );
479-
const previousWorldPosition = this._previousCameraWorldMatrix.mul( vec4( previousViewPosition, 1.0 ) ).xyz;
480-
481-
// calculate difference in world positions
515+
const previousDepth = samplePreviousDepth( historyUV );
482516

483-
const worldPositionDifference = length( currentWorldPosition.sub( previousWorldPosition ) ).toVar();
484-
worldPositionDifference.assign( min( max( worldPositionDifference.sub( 1.0 ), 0.0 ), 1.0 ) );
517+
// disocclusion except on edges
485518

486-
// Adaptive blend weights based on velocity magnitude suggested by CLAUDE in #32133
487-
// Higher velocity or position difference = more weight on current frame to reduce ghosting
519+
const isEdge = farthestDepth.sub( closestDepth ).greaterThan( this.edgeDepthDiff );
520+
const isDisocclusion = currentDepth.sub( previousDepth ).greaterThan( this.depthThreshold ).and( isEdge.not() );
488521

489-
const velocityMagnitude = length( offset ).toConst();
490-
const motionFactor = max( worldPositionDifference.mul( 0.5 ), velocityMagnitude.mul( 10.0 ) ).toVar();
491-
motionFactor.assign( min( motionFactor, 1.0 ) );
522+
// higher velocity = more weight on current frame
523+
// zero out history weight where disocclusion
492524

493-
const currentWeight = float( 0.05 ).add( motionFactor.mul( 0.25 ) ).toVar();
525+
const motionFactor = uvNode.sub( historyUV ).length().mul( 10 );
526+
const currentWeight = isDisocclusion.select( 1, float( 0.05 ).add( motionFactor ).saturate() ).toVar();
494527
const historyWeight = currentWeight.oneMinus().toVar();
495528

496-
// zero out history weight if world positions are different (indicating motion) except on edges.
497-
// note that the constants 0.00001 and 0.5 were suggested by CLAUDE in #32133
498-
499-
const isEdge = farthestDepth.sub( closestDepth ).greaterThan( 0.00001 );
500-
const strongDisocclusion = worldPositionDifference.greaterThan( 0.5 ).and( isEdge.not() );
501-
502-
If( strongDisocclusion, () => {
503-
504-
currentWeight.assign( 1.0 );
505-
historyWeight.assign( 0.0 );
506-
507-
} );
508-
509529
// flicker reduction based on luminance weighing
510530

511531
const compressedCurrent = currentColor.mul( float( 1 ).div( ( max( currentColor.r, currentColor.g, currentColor.b ).add( 1.0 ) ) ) );
@@ -514,8 +534,8 @@ class TRAANode extends TempNode {
514534
const luminanceCurrent = luminance( compressedCurrent.rgb );
515535
const luminanceHistory = luminance( compressedHistory.rgb );
516536

517-
currentWeight.mulAssign( float( 1.0 ).div( luminanceCurrent.add( 1 ) ) );
518-
historyWeight.mulAssign( float( 1.0 ).div( luminanceHistory.add( 1 ) ) );
537+
currentWeight.mulAssign( float( 1 ).div( luminanceCurrent.add( 1 ) ) );
538+
historyWeight.mulAssign( float( 1 ).div( luminanceHistory.add( 1 ) ) );
519539

520540
const smoothedOutput = add( currentColor.mul( currentWeight ), clampedHistoryColor.mul( historyWeight ) ).div( max( currentWeight.add( historyWeight ), 0.00001 ) ).toVar();
521541

0 commit comments

Comments
 (0)