|
|
@@ -1,11 +1,13 @@
|
|
|
import Node from '../core/Node.js';
|
|
|
+import { instanceIndex } from '../core/IndexNode.js';
|
|
|
import StackTrace from '../core/StackTrace.js';
|
|
|
+import { uniform } from '../core/UniformNode.js';
|
|
|
import { NodeUpdateType } from '../core/constants.js';
|
|
|
import { addMethodChaining, nodeObject } from '../tsl/TSLCore.js';
|
|
|
import { warn, error } from '../../utils.js';
|
|
|
|
|
|
/**
|
|
|
- * TODO
|
|
|
+ * Represents a compute shader node.
|
|
|
*
|
|
|
* @augments Node
|
|
|
*/
|
|
|
@@ -20,8 +22,8 @@ class ComputeNode extends Node {
|
|
|
/**
|
|
|
* Constructs a new compute node.
|
|
|
*
|
|
|
- * @param {Node} computeNode - TODO
|
|
|
- * @param {Array<number>} workgroupSize - TODO.
|
|
|
+ * @param {Node} computeNode - The node that defines the compute shader logic.
|
|
|
+ * @param {Array<number>} workgroupSize - An array defining the X, Y, and Z dimensions of the workgroup for compute shader execution.
|
|
|
*/
|
|
|
constructor( computeNode, workgroupSize ) {
|
|
|
|
|
|
@@ -37,15 +39,14 @@ class ComputeNode extends Node {
|
|
|
this.isComputeNode = true;
|
|
|
|
|
|
/**
|
|
|
- * TODO
|
|
|
+ * The node that defines the compute shader logic.
|
|
|
*
|
|
|
* @type {Node}
|
|
|
*/
|
|
|
this.computeNode = computeNode;
|
|
|
|
|
|
-
|
|
|
/**
|
|
|
- * TODO
|
|
|
+ * An array defining the X, Y, and Z dimensions of the workgroup for compute shader execution.
|
|
|
*
|
|
|
* @type {Array<number>}
|
|
|
* @default [ 64 ]
|
|
|
@@ -53,14 +54,23 @@ class ComputeNode extends Node {
|
|
|
this.workgroupSize = workgroupSize;
|
|
|
|
|
|
/**
|
|
|
- * TODO
|
|
|
+ * The total number of threads (invocations) to execute. If it is a number, it will be used
|
|
|
+ * to automatically generate bounds checking against `instanceIndex`.
|
|
|
*
|
|
|
* @type {number|Array<number>}
|
|
|
*/
|
|
|
this.count = null;
|
|
|
|
|
|
/**
|
|
|
- * TODO
|
|
|
+ * The dispatch size for workgroups on X, Y, and Z axes.
|
|
|
+ * Used directly if `count` is not provided.
|
|
|
+ *
|
|
|
+ * @type {number|Array<number>}
|
|
|
+ */
|
|
|
+ this.dispatchSize = null;
|
|
|
+
|
|
|
+ /**
|
|
|
+ * The version of the node.
|
|
|
*
|
|
|
* @type {number}
|
|
|
*/
|
|
|
@@ -84,36 +94,19 @@ class ComputeNode extends Node {
|
|
|
this.updateBeforeType = NodeUpdateType.OBJECT;
|
|
|
|
|
|
/**
|
|
|
- * TODO
|
|
|
+ * A callback executed when the compute node finishes initialization.
|
|
|
*
|
|
|
* @type {?Function}
|
|
|
*/
|
|
|
this.onInitFunction = null;
|
|
|
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * TODO
|
|
|
- *
|
|
|
- * @param {number|Array<number>} count - Array with [ x, y, z ] values for dispatch or a single number for the count
|
|
|
- * @return {ComputeNode}
|
|
|
- */
|
|
|
- setCount( count ) {
|
|
|
-
|
|
|
- this.count = count;
|
|
|
-
|
|
|
- return this;
|
|
|
-
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * TODO
|
|
|
- *
|
|
|
- * @return {number|Array<number>}
|
|
|
- */
|
|
|
- getCount() {
|
|
|
-
|
|
|
- return this.count;
|
|
|
+ /**
|
|
|
+ * A uniform node holding the dispatch count for bounds checking.
|
|
|
+ * Created automatically when `count` is a number.
|
|
|
+ *
|
|
|
+ * @type {?UniformNode}
|
|
|
+ */
|
|
|
+ this.countNode = null;
|
|
|
|
|
|
}
|
|
|
|
|
|
@@ -156,9 +149,9 @@ class ComputeNode extends Node {
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * TODO
|
|
|
+ * Sets the callback to run during initialization.
|
|
|
*
|
|
|
- * @param {Function} callback - TODO.
|
|
|
+ * @param {Function} callback - The callback function.
|
|
|
* @return {ComputeNode} A reference to this node.
|
|
|
*/
|
|
|
onInit( callback ) {
|
|
|
@@ -182,6 +175,12 @@ class ComputeNode extends Node {
|
|
|
|
|
|
setup( builder ) {
|
|
|
|
|
|
+ if ( this.count !== null && this.countNode === null ) {
|
|
|
+
|
|
|
+ this.countNode = uniform( this.count, 'uint' ).onObjectUpdate( () => this.count );
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
const result = this.computeNode.build( builder );
|
|
|
|
|
|
if ( result ) {
|
|
|
@@ -211,6 +210,16 @@ class ComputeNode extends Node {
|
|
|
|
|
|
}
|
|
|
|
|
|
+ if ( this.count !== null && builder.allowEarlyReturns === true ) {
|
|
|
+
|
|
|
+ const countSnippet = this.countNode.build( builder, 'uint' );
|
|
|
+ const indexSnippet = instanceIndex.build( builder, 'uint' );
|
|
|
+
|
|
|
+ builder.flow.code = `${ builder.tab }if ( ${ indexSnippet } >= ${ countSnippet } ) { return; }\n\n${ builder.flow.code }`;
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
} else {
|
|
|
|
|
|
const properties = builder.getNodeProperties( this );
|
|
|
@@ -235,9 +244,9 @@ export default ComputeNode;
|
|
|
*
|
|
|
* @tsl
|
|
|
* @function
|
|
|
- * @param {Node} node - TODO
|
|
|
- * @param {Array<number>} [workgroupSize=[64]] - TODO.
|
|
|
- * @returns {AtomicFunctionNode}
|
|
|
+ * @param {Node} node - The TSL logic for the compute shader.
|
|
|
+ * @param {Array<number>} [workgroupSize=[64]] - The workgroup size.
|
|
|
+ * @returns {ComputeNode}
|
|
|
*/
|
|
|
export const computeKernel = ( node, workgroupSize = [ 64 ] ) => {
|
|
|
|
|
|
@@ -274,12 +283,28 @@ export const computeKernel = ( node, workgroupSize = [ 64 ] ) => {
|
|
|
*
|
|
|
* @tsl
|
|
|
* @function
|
|
|
- * @param {Node} node - TODO
|
|
|
- * @param {number|Array<number>} count - TODO.
|
|
|
- * @param {Array<number>} [workgroupSize=[64]] - TODO.
|
|
|
- * @returns {AtomicFunctionNode}
|
|
|
- */
|
|
|
-export const compute = ( node, count, workgroupSize ) => computeKernel( node, workgroupSize ).setCount( count );
|
|
|
+ * @param {Node} node - The TSL logic for the compute shader.
|
|
|
+ * @param {number|Array<number>} count - The compute count or dispatch size.
|
|
|
+ * @param {Array<number>} [workgroupSize=[64]] - The workgroup size.
|
|
|
+ * @returns {ComputeNode}
|
|
|
+, */
|
|
|
+export const compute = ( node, count, workgroupSize ) => {
|
|
|
+
|
|
|
+ const computeNode = computeKernel( node, workgroupSize );
|
|
|
+
|
|
|
+ if ( typeof count === 'number' ) {
|
|
|
+
|
|
|
+ computeNode.count = count;
|
|
|
+
|
|
|
+ } else {
|
|
|
+
|
|
|
+ computeNode.dispatchSize = count;
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ return computeNode;
|
|
|
+
|
|
|
+};
|
|
|
|
|
|
addMethodChaining( 'compute', compute );
|
|
|
addMethodChaining( 'computeKernel', computeKernel );
|