Browse Source

WebGPURenderer: Add compute shader bounds check (#33186)

Co-authored-by: sunag <sunagbrasil@gmail.com>
Marco Fugaro 3 weeks ago
parent
commit
82fa2369df

+ 1 - 17
examples/webgpu_compute_cloth.html

@@ -33,7 +33,7 @@
 
 			import * as THREE from 'three/webgpu';
 
-			import { Fn, If, Return, instancedArray, instanceIndex, uniform, select, attribute, uint, Loop, float, transformNormalToView, cross, triNoise3D, time } from 'three/tsl';
+			import { Fn, If, Return, instancedArray, instanceIndex, uniform, select, attribute, Loop, float, transformNormalToView, cross, triNoise3D, time } from 'three/tsl';
 
 			import { Inspector } from 'three/addons/inspector/Inspector.js';
 
@@ -307,14 +307,6 @@
 				// This shader computes a force for each spring, depending on the distance between the two vertices connected by that spring and the targeted rest length
 				computeSpringForces = Fn( () => {
 
-					If( instanceIndex.greaterThanEqual( uint( springCount ) ), () => {
-
-						// compute Shaders are executed in groups of 64, so instanceIndex might be bigger than the amount of springs.
-						// in that case, return.
-						Return();
-
-					} );
-
 					const vertexIds = springVertexIdBuffer.element( instanceIndex );
 					const restLength = springRestLengthBuffer.element( instanceIndex );
 
@@ -335,14 +327,6 @@
 				// In the end it adds the force to the vertex' position.
 				computeVertexForces = Fn( () => {
 
-					If( instanceIndex.greaterThanEqual( uint( vertexCount ) ), () => {
-
-						// compute Shaders are executed in groups of 64, so instanceIndex might be bigger than the amount of vertices.
-						// in that case, return.
-						Return();
-
-					} );
-
 					const params = vertexParamsBuffer.element( instanceIndex ).toVar();
 					const isFixed = params.x;
 					const springCount = params.y;

+ 4 - 27
examples/webgpu_compute_particles_fluid.html

@@ -33,7 +33,7 @@
 
 			import * as THREE from 'three/webgpu';
 
-			import { Fn, If, Return, instancedArray, instanceIndex, uniform, attribute, uint, float, clamp, struct, atomicStore, int, ivec3, array, vec3, atomicAdd, Loop, atomicLoad, max, pow, mat3, vec4, cross, step, storage } from 'three/tsl';
+			import { Fn, If, Return, instancedArray, instanceIndex, uniform, attribute, float, clamp, struct, atomicStore, int, ivec3, array, vec3, atomicAdd, Loop, atomicLoad, max, pow, mat3, vec4, cross, step, storage } from 'three/tsl';
 
 			import { Inspector } from 'three/addons/inspector/Inspector.js';
 
@@ -132,6 +132,9 @@
 				gui.add( params, 'particleCount', 4096, maxParticles, 4096 ).onChange( value => {
 
 					particleMesh.count = value;
+					p2g1Kernel.count = value;
+					p2g2Kernel.count = value;
+					g2pKernel.count = value;
 					particleCountUniform.value = value;
 
 				} );
@@ -219,12 +222,6 @@
 				const cellCount = gridSize.x * gridSize.y * gridSize.z;
 				clearGridKernel = Fn( () => {
 
-					If( instanceIndex.greaterThanEqual( uint( cellCount ) ), () => {
-
-						Return();
-			
-					} );
-
 					atomicStore( cellBuffer.element( instanceIndex ).get( 'x' ), 0 );
 					atomicStore( cellBuffer.element( instanceIndex ).get( 'y' ), 0 );
 					atomicStore( cellBuffer.element( instanceIndex ).get( 'z' ), 0 );
@@ -234,11 +231,6 @@
 
 				p2g1Kernel = Fn( () => {
 
-					If( instanceIndex.greaterThanEqual( particleCountUniform ), () => {
-
-						Return();
-			
-					} );
 					const particlePosition = particleBuffer.element( instanceIndex ).get( 'position' ).toConst( 'particlePosition' );
 					const particleVelocity = particleBuffer.element( instanceIndex ).get( 'velocity' ).toConst( 'particleVelocity' );
 					const C = particleBuffer.element( instanceIndex ).get( 'C' ).toConst( 'C' );
@@ -282,11 +274,6 @@
 
 				p2g2Kernel = Fn( () => {
 
-					If( instanceIndex.greaterThanEqual( particleCountUniform ), () => {
-
-						Return();
-			
-					} );
 					const particlePosition = particleBuffer.element( instanceIndex ).get( 'position' ).toConst( 'particlePosition' );
 					const gridPosition = particlePosition.mul( gridSizeUniform ).toVar();
 
@@ -353,11 +340,6 @@
 
 				updateGridKernel = Fn( () => {
 
-					If( instanceIndex.greaterThanEqual( uint( cellCount ) ), () => {
-
-						Return();
-			
-					} );
 					const cell = cellBuffer.element( instanceIndex );
 					const mass = decodeFixedPoint( atomicLoad( cell.get( 'mass' ) ) ).toConst();
 					If( mass.lessThanEqual( 0 ), () => {
@@ -412,11 +394,6 @@
 
 				g2pKernel = Fn( () => {
 
-					If( instanceIndex.greaterThanEqual( particleCountUniform ), () => {
-
-						Return();
-			
-					} );
 					const particlePosition = particleBuffer.element( instanceIndex ).get( 'position' ).toVar( 'particlePosition' );
 					const gridPosition = particlePosition.mul( gridSizeUniform ).toVar();
 					const particleVelocity = vec3( 0 ).toVar();

+ 2 - 1
src/nodes/core/IndexNode.js

@@ -1,5 +1,6 @@
 import Node from './Node.js';
-import { nodeImmutable, varying } from '../tsl/TSLBase.js';
+import { nodeImmutable } from '../tsl/TSLCore.js';
+import { varying } from './VaryingNode.js';
 
 /**
  * This class represents shader indices of different types. The following predefined node

+ 4 - 0
src/nodes/gpgpu/BarrierNode.js

@@ -21,6 +21,8 @@ class BarrierNode extends Node {
 
 		this.scope = scope;
 
+		this.isBarrierNode = true;
+
 	}
 
 	generate( builder ) {
@@ -28,6 +30,8 @@ class BarrierNode extends Node {
 		const { scope } = this;
 		const { renderer } = builder;
 
+		builder.allowEarlyReturns = false;
+
 		if ( renderer.backend.isWebGLBackend === true ) {
 
 			builder.addFlowCode( `\t// ${scope}Barrier \n` );

+ 69 - 44
src/nodes/gpgpu/ComputeNode.js

@@ -1,11 +1,13 @@
 import Node from '../core/Node.js';
+import { instanceIndex } from '../core/IndexNode.js';
 import StackTrace from '../core/StackTrace.js';
+import { uniform } from '../core/UniformNode.js';
 import { NodeUpdateType } from '../core/constants.js';
 import { addMethodChaining, nodeObject } from '../tsl/TSLCore.js';
 import { warn, error } from '../../utils.js';
 
 /**
- * TODO
+ * Represents a compute shader node.
  *
  * @augments Node
  */
@@ -20,8 +22,8 @@ class ComputeNode extends Node {
 	/**
 	 * Constructs a new compute node.
 	 *
-	 * @param {Node} computeNode - TODO
-	 * @param {Array<number>} workgroupSize - TODO.
+	 * @param {Node} computeNode - The node that defines the compute shader logic.
+	 * @param {Array<number>} workgroupSize - An array defining the X, Y, and Z dimensions of the workgroup for compute shader execution.
 	 */
 	constructor( computeNode, workgroupSize ) {
 
@@ -37,15 +39,14 @@ class ComputeNode extends Node {
 		this.isComputeNode = true;
 
 		/**
-		 * TODO
+		 * The node that defines the compute shader logic.
 		 *
 		 * @type {Node}
 		 */
 		this.computeNode = computeNode;
 
-
 		/**
-		 * TODO
+		 * An array defining the X, Y, and Z dimensions of the workgroup for compute shader execution.
 		 *
 		 * @type {Array<number>}
 		 * @default [ 64 ]
@@ -53,14 +54,23 @@ class ComputeNode extends Node {
 		this.workgroupSize = workgroupSize;
 
 		/**
-		 * TODO
+		 * The total number of threads (invocations) to execute. If it is a number, it will be used
+		 * to automatically generate bounds checking against `instanceIndex`.
 		 *
 		 * @type {number|Array<number>}
 		 */
 		this.count = null;
 
 		/**
-		 * TODO
+		 * The dispatch size for workgroups on X, Y, and Z axes.
+		 * Used directly if `count` is not provided.
+		 *
+		 * @type {number|Array<number>}
+		 */
+		this.dispatchSize = null;
+
+		/**
+		 * The version of the node.
 		 *
 		 * @type {number}
 		 */
@@ -84,36 +94,19 @@ class ComputeNode extends Node {
 		this.updateBeforeType = NodeUpdateType.OBJECT;
 
 		/**
-		 * TODO
+		 * A callback executed when the compute node finishes initialization.
 		 *
 		 * @type {?Function}
 		 */
 		this.onInitFunction = null;
 
-	}
-
-	/**
-	 * TODO
-	 *
-	 * @param {number|Array<number>} count - Array with [ x, y, z ] values for dispatch or a single number for the count
-	 * @return {ComputeNode}
-	 */
-	setCount( count ) {
-
-		this.count = count;
-
-		return this;
-
-	}
-
-	/**
-	 * TODO
-	 *
-	 * @return {number|Array<number>}
-	 */
-	getCount() {
-
-		return this.count;
+		/**
+		 * A uniform node holding the dispatch count for bounds checking.
+		 * Created automatically when `count` is a number.
+		 *
+		 * @type {?UniformNode}
+		 */
+		this.countNode = null;
 
 	}
 
@@ -156,9 +149,9 @@ class ComputeNode extends Node {
 	}
 
 	/**
-	 * TODO
+	 * Sets the callback to run during initialization.
 	 *
-	 * @param {Function} callback - TODO.
+	 * @param {Function} callback - The callback function.
 	 * @return {ComputeNode} A reference to this node.
 	 */
 	onInit( callback ) {
@@ -182,6 +175,12 @@ class ComputeNode extends Node {
 
 	setup( builder ) {
 
+		if ( this.count !== null && this.countNode === null ) {
+
+			this.countNode = uniform( this.count, 'uint' ).onObjectUpdate( () => this.count );
+
+		}
+
 		const result = this.computeNode.build( builder );
 
 		if ( result ) {
@@ -211,6 +210,16 @@ class ComputeNode extends Node {
 
 			}
 
+			if ( this.count !== null && builder.allowEarlyReturns === true ) {
+
+				const countSnippet = this.countNode.build( builder, 'uint' );
+				const indexSnippet = instanceIndex.build( builder, 'uint' );
+
+				builder.flow.code = `${ builder.tab }if ( ${ indexSnippet } >= ${ countSnippet } ) { return; }\n\n${ builder.flow.code }`;
+
+			}
+
+
 		} else {
 
 			const properties = builder.getNodeProperties( this );
@@ -235,9 +244,9 @@ export default ComputeNode;
  *
  * @tsl
  * @function
- * @param {Node} node - TODO
- * @param {Array<number>} [workgroupSize=[64]] - TODO.
- * @returns {AtomicFunctionNode}
+ * @param {Node} node - The TSL logic for the compute shader.
+ * @param {Array<number>} [workgroupSize=[64]] - The workgroup size.
+ * @returns {ComputeNode}
  */
 export const computeKernel = ( node, workgroupSize = [ 64 ] ) => {
 
@@ -274,12 +283,28 @@ export const computeKernel = ( node, workgroupSize = [ 64 ] ) => {
  *
  * @tsl
  * @function
- * @param {Node} node - TODO
- * @param {number|Array<number>} count - TODO.
- * @param {Array<number>} [workgroupSize=[64]] - TODO.
- * @returns {AtomicFunctionNode}
- */
-export const compute = ( node, count, workgroupSize ) => computeKernel( node, workgroupSize ).setCount( count );
+ * @param {Node} node - The TSL logic for the compute shader.
+ * @param {number|Array<number>} count - The compute count or dispatch size.
+ * @param {Array<number>} [workgroupSize=[64]] - The workgroup size.
+ * @returns {ComputeNode}
+,  */
+export const compute = ( node, count, workgroupSize ) => {
+
+	const computeNode = computeKernel( node, workgroupSize );
+
+	if ( typeof count === 'number' ) {
+
+		computeNode.count = count;
+
+	} else {
+
+		computeNode.dispatchSize = count;
+
+	}
+
+	return computeNode;
+
+};
 
 addMethodChaining( 'compute', compute );
 addMethodChaining( 'computeKernel', computeKernel );

+ 1 - 1
src/renderers/common/ComputePipeline.js

@@ -9,7 +9,7 @@ import Pipeline from './Pipeline.js';
 class ComputePipeline extends Pipeline {
 
 	/**
-	 * Constructs a new render pipeline.
+	 * Constructs a new compute pipeline.
 	 *
 	 * @param {string} cacheKey - The pipeline's cache key.
 	 * @param {ProgrammableStage} computeProgram - The pipeline's compute shader.

+ 2 - 2
src/renderers/webgpu/WebGPUBackend.js

@@ -1425,13 +1425,13 @@ class WebGPUBackend extends Backend {
 
 		if ( dispatchSize === null ) {
 
-			dispatchSize = computeNode.count;
+			dispatchSize = computeNode.dispatchSize || computeNode.count;
 
 		}
 
 		// When the dispatchSize is set with a StorageBuffer from the GPU.
 
-		if ( dispatchSize && typeof dispatchSize === 'object' && dispatchSize.isIndirectStorageBufferAttribute ) {
+		if ( dispatchSize && dispatchSize.isIndirectStorageBufferAttribute ) {
 
 			const dispatchBuffer = this.get( dispatchSize ).buffer;
 

+ 8 - 0
src/renderers/webgpu/nodes/WGSLNodeBuilder.js

@@ -230,6 +230,14 @@ class WGSLNodeBuilder extends NodeBuilder {
 		 */
 		this.scopedArrays = new Map();
 
+		/**
+		 * A flag that indicates that early returns are allowed.
+		 *
+		 * @type {boolean}
+		 * @default true
+		 */
+		this.allowEarlyReturns = true;
+
 	}
 
 	/**

粤ICP备19079148号