Browse Source

Nodes: Add `AtomicFunctionNode` (#29385)

* add atomic operations

* add storeNode

* cleanup

---------
Christian Helgeson 1 year ago
parent
commit
38fd5e9659

+ 9 - 1
examples/webgpu_compute_sort_bitonic.html

@@ -54,7 +54,7 @@
 		<script type="module">
 
 			import * as THREE from 'three';
-			import { storageObject, If, vec3, not, uniform, uv, uint, float, Fn, vec2, abs, int, invocationLocalIndex, workgroupArray, uvec2, floor, instanceIndex, workgroupBarrier } from 'three/tsl';
+			import { storageObject, If, vec3, not, uniform, uv, uint, float, Fn, vec2, abs, int, invocationLocalIndex, workgroupArray, uvec2, floor, instanceIndex, workgroupBarrier, atomicAdd, atomicStore } from 'three/tsl';
 
 			import { GUI } from 'three/addons/libs/lil-gui.module.min.js';
 
@@ -149,6 +149,9 @@
 				const highestBlockHeightBuffer = new THREE.StorageInstancedBufferAttribute( new Uint32Array( 1 ).fill( 2 ), 1 );
 				const highestBlockHeightStorage = storageObject( highestBlockHeightBuffer, 'uint', highestBlockHeightBuffer.count ).label( 'HighestBlockHeight' );
 
+				const counterBuffer = new THREE.StorageBufferAttribute( 1, 1 );
+				const counterStorage = storageObject( counterBuffer, 'uint', counterBuffer.count ).toAtomic().label( 'Counter' );
+
 				const array = new Uint32Array( Array.from( { length: size }, ( _, i ) => {
 
 					return i;
@@ -219,6 +222,7 @@
 
 					If( localStorage.element( idxAfter ).lessThan( localStorage.element( idxBefore ) ), () => {
 
+						atomicAdd( counterStorage.element( 0 ), 1 );
 						const temp = localStorage.element( idxBefore ).toVar();
 						localStorage.element( idxBefore ).assign( localStorage.element( idxAfter ) );
 						localStorage.element( idxAfter ).assign( temp );
@@ -233,6 +237,7 @@
 					If( currentElementsStorage.element( idxAfter ).lessThan( currentElementsStorage.element( idxBefore ) ), () => {
 
 						// Apply the swapped values to temporary storage.
+						atomicAdd( counterStorage.element( 0 ), 1 );
 						tempStorage.element( idxBefore ).assign( currentElementsStorage.element( idxAfter ) );
 						tempStorage.element( idxAfter ).assign( currentElementsStorage.element( idxBefore ) );
 
@@ -396,6 +401,7 @@
 					nextAlgoStorage.element( 0 ).assign( forceGlobalSwap ? StepType.FLIP_GLOBAL : StepType.FLIP_LOCAL );
 					nextBlockHeightStorage.element( 0 ).assign( 2 );
 					highestBlockHeightStorage.element( 0 ).assign( 2 );
+					atomicStore( counterStorage.element( 0 ), 0 );
 
 				} );
 
@@ -511,12 +517,14 @@
 
 					const algo = new Uint32Array( await renderer.getArrayBufferAsync( nextAlgoBuffer ) );
 					algo > StepType.DISPERSE_LOCAL ? ( nextStepGlobal = true ) : ( nextStepGlobal = false );
+					const totalSwaps = new Uint32Array( await renderer.getArrayBufferAsync( counterBuffer ) );
 			
 					renderer.render( scene, camera );
 
 					timestamps[ forceGlobalSwap ? 'global_swap' : 'local_swap' ].innerHTML = `
 
 							Compute ${forceGlobalSwap ? 'Global' : 'Local'}: ${renderer.info.compute.frameCalls} pass in ${renderer.info.compute.timestamp.toFixed( 6 )}ms<br>
+							Total Swaps: ${totalSwaps}<br>
 								<div style="display: flex; flex-direction:row; justify-content: center; align-items: center;">
 									${forceGlobalSwap ? 'Global Swaps' : 'Local Swaps'} Compare Region&nbsp;
 									<div style="background-color: ${ forceGlobalSwap ? globalColors[ 0 ] : localColors[ 0 ]}; width:12.5px; height: 1em; border-radius: 20%;"></div>

+ 1 - 0
src/nodes/TSL.js

@@ -144,6 +144,7 @@ export * from './geometry/RangeNode.js';
 export * from './gpgpu/ComputeNode.js';
 export * from './gpgpu/BarrierNode.js';
 export * from './gpgpu/WorkgroupInfoNode.js';
+export * from './gpgpu/AtomicFunctionNode.js';
 
 // lighting
 export * from './accessors/Lights.js';

+ 15 - 0
src/nodes/accessors/StorageBufferNode.js

@@ -19,6 +19,7 @@ class StorageBufferNode extends BufferNode {
 		this.isStorageBufferNode = true;
 
 		this.access = GPUBufferBindingType.Storage;
+		this.isAtomic = false;
 
 		this.bufferObject = false;
 		this.bufferCount = bufferCount;
@@ -97,6 +98,20 @@ class StorageBufferNode extends BufferNode {
 
 	}
 
+	setAtomic( value ) {
+
+		this.isAtomic = value;
+
+		return this;
+
+	}
+
+	toAtomic() {
+
+		return this.setAtomic( true );
+
+	}
+
 	generate( builder ) {
 
 		if ( builder.isAvailable( 'storageBuffer' ) ) {

+ 99 - 0
src/nodes/gpgpu/AtomicFunctionNode.js

@@ -0,0 +1,99 @@
+import TempNode from '../core/TempNode.js';
+import { nodeProxy } from '../tsl/TSLCore.js';
+
+class AtomicFunctionNode extends TempNode {
+
+	static get type() {
+
+		return 'AtomicFunctionNode';
+
+	}
+
+	constructor( method, pointerNode, valueNode, storeNode = null ) {
+
+		super( 'uint' );
+
+		this.method = method;
+
+		this.pointerNode = pointerNode;
+		this.valueNode = valueNode;
+		this.storeNode = storeNode;
+
+	}
+
+	getInputType( builder ) {
+
+		return this.pointerNode.getNodeType( builder );
+
+	}
+
+	getNodeType( builder ) {
+
+		return this.getInputType( builder );
+
+	}
+
+	generate( builder ) {
+
+		const method = this.method;
+
+		const type = this.getNodeType( builder );
+		const inputType = this.getInputType( builder );
+
+		const a = this.pointerNode;
+		const b = this.valueNode;
+
+		const params = [];
+
+		params.push( `&${ a.build( builder, inputType ) }` );
+		params.push( b.build( builder, inputType ) );
+
+		const methodSnippet = `${ builder.getMethod( method, type ) }( ${params.join( ', ' )} )`;
+
+		if ( this.storeNode !== null ) {
+
+			const varSnippet = this.storeNode.build( builder, inputType );
+
+			builder.addLineFlowCode( `${varSnippet} = ${methodSnippet}` );
+
+		} else {
+
+			builder.addLineFlowCode( methodSnippet );
+
+		}
+
+	}
+
+}
+
+AtomicFunctionNode.ATOMIC_LOAD = 'atomicLoad';
+AtomicFunctionNode.ATOMIC_STORE = 'atomicStore';
+AtomicFunctionNode.ATOMIC_ADD = 'atomicAdd';
+AtomicFunctionNode.ATOMIC_SUB = 'atomicSub';
+AtomicFunctionNode.ATOMIC_MAX = 'atomicMax';
+AtomicFunctionNode.ATOMIC_MIN = 'atomicMin';
+AtomicFunctionNode.ATOMIC_AND = 'atomicAnd';
+AtomicFunctionNode.ATOMIC_OR = 'atomicOr';
+AtomicFunctionNode.ATOMIC_XOR = 'atomicXor';
+
+export default AtomicFunctionNode;
+
+const atomicNode = nodeProxy( AtomicFunctionNode );
+
+export const atomicFunc = ( method, pointerNode, valueNode, storeNode ) => {
+
+	const node = atomicNode( method, pointerNode, valueNode, storeNode );
+	node.append();
+
+	return node;
+
+};
+
+export const atomicStore = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_STORE, pointerNode, valueNode, storeNode );
+export const atomicAdd = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_ADD, pointerNode, valueNode, storeNode );
+export const atomicSub = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_SUB, pointerNode, valueNode, storeNode );
+export const atomicMax = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_MAX, pointerNode, valueNode, storeNode );
+export const atomicMin = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_MIN, pointerNode, valueNode, storeNode );
+export const atomicAnd = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_AND, pointerNode, valueNode, storeNode );
+export const atomicOr = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_OR, pointerNode, valueNode, storeNode );
+export const atomicXor = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_XOR, pointerNode, valueNode, storeNode );

+ 2 - 1
src/renderers/webgpu/nodes/WGSLNodeBuilder.js

@@ -1052,7 +1052,8 @@ ${ flowData.code }
 				const bufferCount = bufferNode.bufferCount;
 
 				const bufferCountSnippet = bufferCount > 0 ? ', ' + bufferCount : '';
-				const bufferSnippet = `\t${ uniform.name } : array< ${ bufferType }${ bufferCountSnippet } >\n`;
+				const bufferTypeSnippet = bufferNode.isAtomic ? `atomic<${bufferType}>` : `${bufferType}`;
+				const bufferSnippet = `\t${ uniform.name } : array< ${ bufferTypeSnippet }${ bufferCountSnippet } >\n`;
 				const bufferAccessMode = bufferNode.isStorageBufferNode ? `storage, ${ this.getStorageAccess( bufferNode ) }` : 'uniform';
 
 				bufferSnippets.push( this._getWGSLStructBinding( 'NodeBuffer_' + bufferNode.id, bufferSnippet, bufferAccessMode, uniformIndexes.binding ++, uniformIndexes.group ) );

粤ICP备19079148号