Просмотр исходного кода

TSL: Introduce `computeKernel()` (#31402)

* Allows the specification of workgroups for compute shaders and the specification of the dispatchSize

* Allows the specification of workgroups for compute shaders and the specification of the dispatchSize

* Allows the specification of workgroups for compute shaders and the specification of the dispatchSize

* Trigger GitHub to update PR

* rebase branch

* rebase branch

* rebase branch

* rebase branch

* update

* update

* Extend compute with flexible workgroupSize and dispatchSize validation

* Extend compute with flexible workgroupSize and dispatchSize validation

* Extend compute with flexible workgroupSize and dispatchSize validation

* introduce `computeKernel`

* add dispatchSize for `computeAsync()`

* optional

* description

* cleanup

---------

Co-authored-by: Attila Schroeder <attila-schroeder.79@gmail.com>
Co-authored-by: sunag <sunagbrasil@gmail.com>
Spiri0 5 месяцев назад
Родитель
Сommit
f3769ebafb

+ 1 - 0
src/Three.TSL.js

@@ -118,6 +118,7 @@ export const color = TSL.color;
 export const colorSpaceToWorking = TSL.colorSpaceToWorking;
 export const colorToDirection = TSL.colorToDirection;
 export const compute = TSL.compute;
+export const computeKernel = TSL.computeKernel;
 export const computeSkinning = TSL.computeSkinning;
 export const cond = TSL.cond;
 export const Const = TSL.Const;

+ 58 - 29
src/nodes/gpgpu/ComputeNode.js

@@ -19,10 +19,9 @@ class ComputeNode extends Node {
 	 * Constructs a new compute node.
 	 *
 	 * @param {Node} computeNode - TODO
-	 * @param {number} count - TODO.
-	 * @param {Array<number>} [workgroupSize=[64]] - TODO.
+	 * @param {Array<number>} workgroupSize - TODO.
 	 */
-	constructor( computeNode, count, workgroupSize = [ 64 ] ) {
+	constructor( computeNode, workgroupSize ) {
 
 		super( 'void' );
 
@@ -42,18 +41,12 @@ class ComputeNode extends Node {
 		 */
 		this.computeNode = computeNode;
 
-		/**
-		 * TODO
-		 *
-		 * @type {number}
-		 */
-		this.count = count;
 
 		/**
 		 * TODO
 		 *
 		 * @type {Array<number>}
-		 * @default [64]
+		 * @default [ 64 ]
 		 */
 		this.workgroupSize = workgroupSize;
 
@@ -62,7 +55,7 @@ class ComputeNode extends Node {
 		 *
 		 * @type {number}
 		 */
-		this.dispatchCount = 0;
+		this.count = null;
 
 		/**
 		 * TODO
@@ -95,7 +88,19 @@ class ComputeNode extends Node {
 		 */
 		this.onInitFunction = null;
 
-		this.updateDispatchCount();
+	}
+
+	setCount( count ) {
+
+		this.count = count;
+
+		return this;
+
+	}
+
+	getCount() {
+
+		return this.count;
 
 	}
 
@@ -122,22 +127,6 @@ class ComputeNode extends Node {
 
 	}
 
-	/**
-	 * TODO
-	 */
-	updateDispatchCount() {
-
-		const { count, workgroupSize } = this;
-
-		let size = workgroupSize[ 0 ];
-
-		for ( let i = 1; i < workgroupSize.length; i ++ )
-			size *= workgroupSize[ i ];
-
-		this.dispatchCount = Math.ceil( count / size );
-
-	}
-
 	/**
 	 * TODO
 	 *
@@ -213,6 +202,45 @@ class ComputeNode extends Node {
 
 export default ComputeNode;
 
+/**
+ * TSL function for creating a compute kernel node.
+ *
+ * @tsl
+ * @function
+ * @param {Node} node - TODO
+ * @param {Array<number>} [workgroupSize=[64]] - TODO.
+ * @returns {AtomicFunctionNode}
+ */
+export const computeKernel = ( node, workgroupSize = [ 64 ] ) => {
+
+	if ( workgroupSize.length === 0 || workgroupSize.length > 3 ) {
+
+		console.error( 'THREE.TSL: compute() workgroupSize must have 1, 2, or 3 elements' );
+
+	}
+
+	for ( let i = 0; i < workgroupSize.length; i ++ ) {
+
+		const val = workgroupSize[ i ];
+
+		if ( typeof val !== 'number' || val <= 0 || ! Number.isInteger( val ) ) {
+
+			console.error( `THREE.TSL: compute() workgroupSize element at index [ ${ i } ] must be a positive integer` );
+
+		}
+
+	}
+
+	// Implicit fill-up to [ x, y, z ] with 1s, just like WGSL treats @workgroup_size when fewer dimensions are specified
+
+	while ( workgroupSize.length < 3 ) workgroupSize.push( 1 );
+
+	//
+
+	return nodeObject( new ComputeNode( nodeObject( node ), workgroupSize ) );
+
+};
+
 /**
  * TSL function for creating a compute node.
  *
@@ -223,6 +251,7 @@ export default ComputeNode;
  * @param {Array<number>} [workgroupSize=[64]] - TODO.
  * @returns {AtomicFunctionNode}
  */
-export const compute = ( node, count, workgroupSize ) => nodeObject( new ComputeNode( nodeObject( node ), count, workgroupSize ) );
+export const compute = ( node, count, workgroupSize ) => computeKernel( node, workgroupSize ).setCount( count );
 
 addMethodChaining( 'compute', compute );
+addMethodChaining( 'computeKernel', computeKernel );

+ 6 - 4
src/renderers/common/Renderer.js

@@ -2308,9 +2308,10 @@ class Renderer {
 	 * if the renderer has been initialized.
 	 *
 	 * @param {Node|Array<Node>} computeNodes - The compute node(s).
+	 * @param {Array<number>|number} [dispatchSizeOrCount=null] - Array with [ x, y, z ] values for dispatch or a single number for the count.
 	 * @return {Promise|undefined} A Promise that resolve when the compute has finished. Only returned when the renderer has not been initialized.
 	 */
-	compute( computeNodes ) {
+	compute( computeNodes, dispatchSizeOrCount = null ) {
 
 		if ( this._isDeviceLost === true ) return;
 
@@ -2389,7 +2390,7 @@ class Renderer {
 			const computeBindings = bindings.getForCompute( computeNode );
 			const computePipeline = pipelines.getForCompute( computeNode, computeBindings );
 
-			backend.compute( computeNodes, computeNode, computeBindings, computePipeline );
+			backend.compute( computeNodes, computeNode, computeBindings, computePipeline, dispatchSizeOrCount );
 
 		}
 
@@ -2406,13 +2407,14 @@ class Renderer {
 	 *
 	 * @async
 	 * @param {Node|Array<Node>} computeNodes - The compute node(s).
+	 * @param {Array<number>|number} [dispatchSizeOrCount=null] - Array with [ x, y, z ] values for dispatch or a single number for the count.
 	 * @return {Promise} A Promise that resolve when the compute has finished.
 	 */
-	async computeAsync( computeNodes ) {
+	async computeAsync( computeNodes, dispatchSizeOrCount = null ) {
 
 		if ( this._initialized === false ) await this.init();
 
-		this.compute( computeNodes );
+		this.compute( computeNodes, dispatchSizeOrCount );
 
 	}
 

+ 14 - 3
src/renderers/webgl-fallback/WebGLBackend.js

@@ -915,8 +915,9 @@ class WebGLBackend extends Backend {
 	 * @param {Node} computeNode - The compute node.
 	 * @param {Array<BindGroup>} bindings - The bindings.
 	 * @param {ComputePipeline} pipeline - The compute pipeline.
+	 * @param {number|null} [count=null] - The count of compute invocations. If `null`, the count is determined by the compute node.
 	 */
-	compute( computeGroup, computeNode, bindings, pipeline ) {
+	compute( computeGroup, computeNode, bindings, pipeline, count = null ) {
 
 		const { state, gl } = this;
 
@@ -953,13 +954,23 @@ class WebGLBackend extends Backend {
 		gl.bindTransformFeedback( gl.TRANSFORM_FEEDBACK, transformFeedbackGPU );
 		gl.beginTransformFeedback( gl.POINTS );
 
+		count = ( count !== null ) ? count : computeNode.count;
+
+		if ( Array.isArray( count ) ) {
+
+			warnOnce( 'WebGLBackend.compute(): The count parameter must be a single number, not an array.' );
+
+			count = count[ 0 ];
+
+		}
+
 		if ( attributes[ 0 ].isStorageInstancedBufferAttribute ) {
 
-			gl.drawArraysInstanced( gl.POINTS, 0, 1, computeNode.count );
+			gl.drawArraysInstanced( gl.POINTS, 0, 1, count );
 
 		} else {
 
-			gl.drawArrays( gl.POINTS, 0, computeNode.count );
+			gl.drawArrays( gl.POINTS, 0, count );
 
 		}
 

+ 52 - 13
src/renderers/webgpu/WebGPUBackend.js

@@ -1298,7 +1298,6 @@ class WebGPUBackend extends Backend {
 
 		const groupGPU = this.get( computeGroup );
 
-
 		const descriptor = {
 			label: 'computeGroup_' + computeGroup.id
 		};
@@ -1318,9 +1317,11 @@ class WebGPUBackend extends Backend {
 	 * @param {Node} computeNode - The compute node.
 	 * @param {Array<BindGroup>} bindings - The bindings.
 	 * @param {ComputePipeline} pipeline - The compute pipeline.
+	 * @param {Array<number>|number} [dispatchSizeOrCount=null] - Array with [ x, y, z ] values for dispatch or a single number for the count.
 	 */
-	compute( computeGroup, computeNode, bindings, pipeline ) {
+	compute( computeGroup, computeNode, bindings, pipeline, dispatchSizeOrCount = null ) {
 
+		const computeNodeData = this.get( computeNode );
 		const { passEncoderGPU } = this.get( computeGroup );
 
 		// pipeline
@@ -1340,29 +1341,67 @@ class WebGPUBackend extends Backend {
 
 		}
 
-		const maxComputeWorkgroupsPerDimension = this.device.limits.maxComputeWorkgroupsPerDimension;
+		let dispatchSize;
 
-		const computeNodeData = this.get( computeNode );
+		if ( dispatchSizeOrCount === null ) {
+
+			dispatchSizeOrCount = computeNode.count;
+
+		}
+
+		if ( typeof dispatchSizeOrCount === 'number' ) {
+
+			// If a single number is given, we calculate the dispatch size based on the workgroup size
+
+			const count = dispatchSizeOrCount;
+
+			if ( computeNodeData.dispatchSize === undefined || computeNodeData.count !== count ) {
+
+				// cache dispatch size to avoid recalculating it every time
+
+				computeNodeData.dispatchSize = [ 0, 1, 1 ];
+				computeNodeData.count = count;
+
+				const workgroupSize = computeNode.workgroupSize;
 
-		if ( computeNodeData.dispatchSize === undefined ) computeNodeData.dispatchSize = { x: 0, y: 1, z: 1 };
+				let size = workgroupSize[ 0 ];
 
-		const { dispatchSize } = computeNodeData;
+				for ( let i = 1; i < workgroupSize.length; i ++ )
+					size *= workgroupSize[ i ];
 
-		if ( computeNode.dispatchCount > maxComputeWorkgroupsPerDimension ) {
+				const dispatchCount = Math.ceil( count / size );
 
-			dispatchSize.x = Math.min( computeNode.dispatchCount, maxComputeWorkgroupsPerDimension );
-			dispatchSize.y = Math.ceil( computeNode.dispatchCount / maxComputeWorkgroupsPerDimension );
+				//
+
+				const maxComputeWorkgroupsPerDimension = this.device.limits.maxComputeWorkgroupsPerDimension;
+
+				dispatchSize = [ dispatchCount, 1, 1 ];
+
+				if ( dispatchCount > maxComputeWorkgroupsPerDimension ) {
+
+					dispatchSize[ 0 ] = Math.min( dispatchCount, maxComputeWorkgroupsPerDimension );
+					dispatchSize[ 1 ] = Math.ceil( dispatchCount / maxComputeWorkgroupsPerDimension );
+
+				}
+
+				computeNodeData.dispatchSize = dispatchSize;
+
+			}
+
+			dispatchSize = computeNodeData.dispatchSize;
 
 		} else {
 
-			dispatchSize.x = computeNode.dispatchCount;
+			dispatchSize = dispatchSizeOrCount;
 
 		}
 
+		//
+
 		passEncoderGPU.dispatchWorkgroups(
-			dispatchSize.x,
-			dispatchSize.y,
-			dispatchSize.z
+			dispatchSize[ 0 ],
+			dispatchSize[ 1 ] || 1,
+			dispatchSize[ 2 ] || 1
 		);
 
 	}

+ 19 - 11
src/renderers/webgpu/nodes/WGSLNodeBuilder.js

@@ -1851,7 +1851,11 @@ ${ flowData.code }
 
 		} else {
 
-			this.computeShader = this._getWGSLComputeCode( shadersData.compute, ( this.object.workgroupSize || [ 64 ] ).join( ', ' ) );
+			// Early strictly validated in computeNode
+
+			const workgroupSize = this.object.workgroupSize;
+
+			this.computeShader = this._getWGSLComputeCode( shadersData.compute, workgroupSize );
 
 		}
 
@@ -2056,36 +2060,40 @@ fn main( ${shaderData.varyings} ) -> ${shaderData.returnType} {
 	 */
 	_getWGSLComputeCode( shaderData, workgroupSize ) {
 
+		const [ workgroupSizeX, workgroupSizeY, workgroupSizeZ ] = workgroupSize;
+
 		return `${ this.getSignature() }
 // directives
-${shaderData.directives}
+${ shaderData.directives }
 
 // system
 var<private> instanceIndex : u32;
 
 // locals
-${shaderData.scopedArrays}
+${ shaderData.scopedArrays }
 
 // structs
-${shaderData.structs}
+${ shaderData.structs }
 
 // uniforms
-${shaderData.uniforms}
+${ shaderData.uniforms }
 
 // codes
-${shaderData.codes}
+${ shaderData.codes }
 
-@compute @workgroup_size( ${workgroupSize} )
-fn main( ${shaderData.attributes} ) {
+@compute @workgroup_size( ${ workgroupSizeX }, ${ workgroupSizeY }, ${ workgroupSizeZ } )
+fn main( ${ shaderData.attributes } ) {
 
 	// system
-	instanceIndex = globalId.x + globalId.y * numWorkgroups.x * u32(${workgroupSize}) + globalId.z * numWorkgroups.x * numWorkgroups.y * u32(${workgroupSize});
+	instanceIndex = globalId.x
+    	+ globalId.y * ( ${ workgroupSizeX } * numWorkgroups.x )
+    	+ globalId.z * ( ${ workgroupSizeX } * numWorkgroups.x ) * ( ${ workgroupSizeY } * numWorkgroups.y );
 
 	// vars
-	${shaderData.vars}
+	${ shaderData.vars }
 
 	// flow
-	${shaderData.flow}
+	${ shaderData.flow }
 
 }
 `;

粤ICP备19079148号