Christian Helgeson 2 месяцев назад
Родитель
Сommit
6a84a807bc

+ 3 - 19
examples/jsm/tsl/display/SSGINode.js

@@ -1,5 +1,5 @@
 import { RenderTarget, Vector2, TempNode, QuadMesh, NodeMaterial, RendererUtils, MathUtils } from 'three/webgpu';
-import { clamp, normalize, reference, nodeObject, Fn, NodeUpdateType, uniform, vec4, passTexture, uv, logarithmicDepthToViewZ, viewZToPerspectiveDepth, getViewPosition, screenCoordinate, float, sub, fract, dot, vec2, rand, vec3, Loop, mul, PI, cos, sin, uint, cross, acos, sign, pow, luminance, If, max, abs, Break, sqrt, HALF_PI, div, ceil, shiftRight, convertToTexture, bool, getNormalFromDepth, interleavedGradientNoise } from 'three/tsl';
+import { clamp, normalize, reference, nodeObject, Fn, NodeUpdateType, uniform, vec4, passTexture, uv, logarithmicDepthToViewZ, viewZToPerspectiveDepth, getViewPosition, screenCoordinate, float, sub, fract, dot, vec2, rand, vec3, Loop, mul, PI, cos, sin, uint, cross, acos, sign, pow, luminance, If, max, abs, Break, sqrt, HALF_PI, div, ceil, shiftRight, convertToTexture, bool, getNormalFromDepth, countOneBits, interleavedGradientNoise } from 'three/tsl';
 
 const _quadMesh = /*@__PURE__*/ new QuadMesh();
 const _size = /*@__PURE__*/ new Vector2();
@@ -435,22 +435,6 @@ class SSGINode extends TempNode {
 			]
 		} );
 
-		const bitCount = Fn( ( [ value ] ) => {
-
-			const v = uint( value );
-			v.assign( v.sub( v.shiftRight( uint( 1 ) ).bitAnd( uint( 0x55555555 ) ) ) );
-			v.assign( v.bitAnd( uint( 0x33333333 ) ).add( v.shiftRight( uint( 2 ) ).bitAnd( uint( 0x33333333 ) ) ) );
-
-			return v.add( v.shiftRight( uint( 4 ) ) ).bitAnd( uint( 0xF0F0F0F ) ).mul( uint( 0x1010101 ) ).shiftRight( uint( 24 ) );
-
-		} ).setLayout( {
-			name: 'bitCount',
-			type: 'uint',
-			inputs: [
-				{ name: 'value', type: 'uint' }
-			]
-		} );
-
 		const horizonSampling = Fn( ( [ directionIsRight, RADIUS, viewPosition, slideDirTexelSize, initialRayStep, uvNode, viewDir, viewNormal, n ] ) => {
 
 			const STEP_COUNT = this.stepCount.toConst();
@@ -513,7 +497,7 @@ class SSGINode extends TempNode {
 				currentOccludedBitfield = currentOccludedBitfield.bitAnd( globalOccludedBitfield.bitNot() );
 
 				globalOccludedBitfield.assign( globalOccludedBitfield.bitOr( currentOccludedBitfield ) );
-				const numOccludedZones = bitCount( currentOccludedBitfield );
+				const numOccludedZones = countOneBits( currentOccludedBitfield );
 
 				//
 
@@ -597,7 +581,7 @@ class SSGINode extends TempNode {
 				color.addAssign( horizonSampling( bool( true ), RADIUS, viewPosition, slideDirTexelSize, initialRayStep, uvNode, viewDir, viewNormal, n ) );
 				color.addAssign( horizonSampling( bool( false ), RADIUS, viewPosition, slideDirTexelSize, initialRayStep, uvNode, viewDir, viewNormal, n ) );
 
-				ao.addAssign( float( bitCount( globalOccludedBitfield ) ).div( float( MAX_RAY ) ) );
+				ao.addAssign( float( countOneBits( globalOccludedBitfield ) ).div( float( MAX_RAY ) ) );
 
 			} );
 

+ 4 - 4
examples/webgpu_compute_reduce.html

@@ -190,7 +190,7 @@
 		<script type="module">
 
 			import * as THREE from 'three/webgpu';
-			import { instancedArray, Loop, If, vec3, dot, clamp, storage, uvec4, subgroupAdd, uniform, uv, uint, float, Fn, vec2, invocationLocalIndex, invocationSubgroupIndex, uvec2, floor, instanceIndex, workgroupId, workgroupBarrier, workgroupArray, subgroupSize, select, log2 } from 'three/tsl';
+			import { instancedArray, Loop, If, vec3, dot, clamp, storage, uvec4, subgroupAdd, uniform, uv, uint, float, Fn, vec2, invocationLocalIndex, invocationSubgroupIndex, uvec2, floor, instanceIndex, workgroupId, workgroupBarrier, workgroupArray, subgroupSize, select, countTrailingZeros } from 'three/tsl';
 
 			import WebGPU from 'three/addons/capabilities/WebGPU.js';
 
@@ -831,12 +831,12 @@
 
 					// Multiple approaches here
 					// log2(subgroupSize) -> TSL log2 function
-					// countTrailingZeros/findLSB(subgroupSize) -> Currently unsupported function in TSL that counts trailing zeros in number bit representation
+					// countTrailingZeros/findLSB(subgroupSize) -> TSL function that counts trailing zeros in number bit representation
 					// Can technically petition GPU for subgroupSize in shader and calculate logs on CPU at cost of shader being generalizable across devices
 					// May also break if subgroupSize changes when device is lost or if program is rerun on lower power device
-					const subgroupSizeLog = uint( log2( float( subgroupSize ) ) ).toVar( 'subgroupSizeLog' );
+					const subgroupSizeLog = countTrailingZeros( subgroupSize ).toVar( 'subgroupSizeLog' );
 					const spineSize = uint( workgroupSize ).shiftRight( subgroupSizeLog );
-					const spineSizeLog = uint( log2( float( spineSize ) ) ).toVar( 'spineSizeLog' );
+					const spineSizeLog = countTrailingZeros( spineSize ).toVar( 'spineSizeLog' );
 
 
 					// Align size to powers of subgroupSize

+ 3 - 0
src/Three.TSL.js

@@ -136,6 +136,9 @@ export const context = TSL.context;
 export const convert = TSL.convert;
 export const convertColorSpace = TSL.convertColorSpace;
 export const convertToTexture = TSL.convertToTexture;
+export const countLeadingZeros = TSL.countLeadingZeros;
+export const countOneBits = TSL.countOneBits;
+export const countTrailingZeros = TSL.countTrailingZeros;
 export const cos = TSL.cos;
 export const cross = TSL.cross;
 export const cubeTexture = TSL.cubeTexture;

+ 1 - 0
src/nodes/TSL.js

@@ -20,6 +20,7 @@ export * from './core/MRTNode.js';
 
 // math
 export * from './math/BitcastNode.js';
+export * from './math/BitcountNode.js';
 export * from './math/Hash.js';
 export * from './math/MathUtils.js';
 export * from './math/TriNoise3D.js';

+ 433 - 0
src/nodes/math/BitcountNode.js

@@ -0,0 +1,433 @@
+import { float, Fn, If, nodeProxyIntent, uint, int, uvec2, uvec3, uvec4, ivec2, ivec3, ivec4 } from '../tsl/TSLCore.js';
+import { bitcast, floatBitsToUint } from './BitcastNode.js';
+import MathNode, { negate } from './MathNode.js';
+
+const registeredBitcountFunctions = {};
+
+/**
+ * This node represents an operation that counts the bits of a piece of shader data.
+ *
+ * @augments MathNode
+ */
+class BitcountNode extends MathNode {
+
+	static get type() {
+
+		return 'BitcountNode';
+
+	}
+
+	/**
+	 * Constructs a new math node.
+	 *
+	 * @param {'countTrailingZeros'|'countLeadingZeros'|'countOneBits'} method - The method name.
+	 * @param {Node} aNode - The first input.
+	 */
+	constructor( method, aNode ) {
+
+		super( method, aNode );
+
+		/**
+		 * This flag can be used for type testing.
+		 *
+		 * @type {boolean}
+		 * @readonly
+		 * @default true
+		 */
+		this.isBitcountNode = true;
+
+	}
+
+	/**
+	 * Casts the input value of the function to an integer if necessary.
+	 *
+	 * @private
+	 * @param {Node<uint>|Node<int>} inputNode - The input value.
+	 * @param {Node<uint>} outputNode - The output value.
+	 * @param {string} elementType - The type of the input value.
+	 */
+	_resolveElementType( inputNode, outputNode, elementType ) {
+
+		if ( elementType === 'int' ) {
+
+			outputNode.assign( bitcast( inputNode, 'uint' ) );
+
+		} else {
+
+			outputNode.assign( inputNode );
+
+		}
+
+	}
+
+	_returnDataNode( inputType ) {
+
+		switch ( inputType ) {
+
+			case 'uint': {
+
+				return uint;
+
+			}
+
+			case 'int': {
+
+				return int;
+
+			}
+
+			case 'uvec2': {
+
+				return uvec2;
+
+			}
+
+			case 'uvec3': {
+
+				return uvec3;
+
+			}
+
+			case 'uvec4': {
+
+				return uvec4;
+
+			}
+
+			case 'ivec2': {
+
+				return ivec2;
+
+			}
+
+			case 'ivec3': {
+
+				return ivec3;
+
+			}
+
+			case 'ivec4': {
+
+				return ivec4;
+
+			}
+
+		}
+
+	}
+
+	/**
+	 * Creates and registers a reusable GLSL function that emulates the behavior of countTrailingZeros.
+	 *
+	 * @private
+	 * @param {string} method - The name of the function to create.
+	 * @param {string} elementType - The type of the input value.
+	 * @returns {Function} - The generated function
+	 */
+	_createTrailingZerosBaseLayout( method, elementType ) {
+
+		const outputConvertNode = this._returnDataNode( elementType );
+
+		const fnDef = Fn( ( [ value ] ) => {
+
+			const v = uint( 0.0 );
+
+			this._resolveElementType( value, v, elementType );
+
+			const f = float( v.bitAnd( negate( v ) ) );
+			const uintBits = floatBitsToUint( f );
+
+			const numTrailingZeros = ( uintBits.shiftRight( 23 ) ).sub( 127 );
+
+			return outputConvertNode( numTrailingZeros );
+
+		} ).setLayout( {
+			name: method,
+			type: elementType,
+			inputs: [
+				{ name: 'value', type: elementType }
+			]
+		} );
+
+		return fnDef;
+
+	}
+
+	/**
+	 * Creates and registers a reusable GLSL function that emulates the behavior of countLeadingZeros.
+	 *
+	 * @private
+	 * @param {string} method - The name of the function to create.
+	 * @param {string} elementType - The type of the input value.
+	 * @returns {Function} - The generated function
+	 */
+	_createLeadingZerosBaseLayout( method, elementType ) {
+
+		const outputConvertNode = this._returnDataNode( elementType );
+
+		const fnDef = Fn( ( [ value ] ) => {
+
+			If( value.equal( uint( 0 ) ), () => {
+
+				return uint( 32 );
+
+			} );
+
+			const v = uint( 0 );
+			const n = uint( 0 );
+			this._resolveElementType( value, v, elementType );
+
+			If( v.shiftRight( 16 ).equal( 0 ), () => {
+
+				n.addAssign( 16 );
+				v.shiftLeftAssign( 16 );
+
+			} );
+
+			If( v.shiftRight( 24 ).equal( 0 ), () => {
+
+				n.addAssign( 8 );
+				v.shiftLeftAssign( 8 );
+
+			} );
+
+			If( v.shiftRight( 28 ).equal( 0 ), () => {
+
+				n.addAssign( 4 );
+				v.shiftLeftAssign( 4 );
+
+			} );
+
+			If( v.shiftRight( 30 ).equal( 0 ), () => {
+
+				n.addAssign( 2 );
+				v.shiftLeftAssign( 2 );
+
+			} );
+
+			If( v.shiftRight( 31 ).equal( 0 ), () => {
+
+				n.addAssign( 1 );
+
+			} );
+
+			return outputConvertNode( n );
+
+		} ).setLayout( {
+			name: method,
+			type: elementType,
+			inputs: [
+				{ name: 'value', type: elementType }
+			]
+		} );
+
+		return fnDef;
+
+	}
+
+	/**
+	 * Creates and registers a reusable GLSL function that emulates the behavior of countOneBits.
+	 *
+	 * @private
+	 * @param {string} method - The name of the function to create.
+	 * @param {string} elementType - The type of the input value.
+	 * @returns {Function} - The generated function
+	 */
+	_createOneBitsBaseLayout( method, elementType ) {
+
+		const outputConvertNode = this._returnDataNode( elementType );
+
+		const fnDef = Fn( ( [ value ] ) => {
+
+			const v = uint( 0.0 );
+
+			this._resolveElementType( value, v, elementType );
+
+			v.assign( v.sub( v.shiftRight( uint( 1 ) ).bitAnd( uint( 0x55555555 ) ) ) );
+			v.assign( v.bitAnd( uint( 0x33333333 ) ).add( v.shiftRight( uint( 2 ) ).bitAnd( uint( 0x33333333 ) ) ) );
+
+			const numBits = v.add( v.shiftRight( uint( 4 ) ) ).bitAnd( uint( 0xF0F0F0F ) ).mul( uint( 0x1010101 ) ).shiftRight( uint( 24 ) );
+
+			return outputConvertNode( numBits );
+
+		} ).setLayout( {
+			name: method,
+			type: elementType,
+			inputs: [
+				{ name: 'value', type: elementType }
+			]
+		} );
+
+		return fnDef;
+
+	}
+
+	/**
+	 * Creates and registers a reusable GLSL function that emulates the behavior of the specified bitcount function.
+	 * including considerations for component-wise bitcounts on vector type inputs.
+	 *
+	 * @private
+	 * @param {string} method - The name of the function to create.
+	 * @param {string} inputType - The type of the input value.
+	 * @param {number} typeLength - The vec length of the input value.
+	 * @param {Function} baseFn - The base function that operates on an individual component of the vector.
+	 * @returns {Function} - The alias function for the specified bitcount method.
+	 */
+	_createMainLayout( method, inputType, typeLength, baseFn ) {
+
+		const outputConvertNode = this._returnDataNode( inputType );
+
+		const fnDef = Fn( ( [ value ] ) => {
+
+			if ( typeLength === 1 ) {
+
+				return outputConvertNode( baseFn( value ) );
+
+			} else {
+
+				const vec = outputConvertNode( 0 );
+
+				const components = [ 'x', 'y', 'z', 'w' ];
+				for ( let i = 0; i < typeLength; i ++ ) {
+
+					const component = components[ i ];
+
+					vec[ component ].assign( baseFn( value[ component ] ) );
+
+				}
+
+				return vec;
+
+			}
+
+		} ).setLayout( {
+			name: method,
+			type: inputType,
+			inputs: [
+				{ name: 'value', type: inputType }
+			]
+		} );
+
+		return fnDef;
+
+	}
+
+	setup( builder ) {
+
+		const { method, aNode } = this;
+
+		const { renderer } = builder;
+
+		if ( renderer.backend.isWebGPUBackend ) {
+
+			// use built-in WGSL functions for WebGPU
+
+			return super.setup( builder );
+
+		}
+
+		const inputType = this.getInputType( builder );
+		const elementType = builder.getElementType( inputType );
+
+		const typeLength = builder.getTypeLength( inputType );
+
+		const baseMethod = `${method}_base_${elementType}`;
+		const newMethod = `${method}_${inputType}`;
+
+		let baseFn = registeredBitcountFunctions[ baseMethod ];
+
+		if ( baseFn === undefined ) {
+
+			switch ( method ) {
+
+				case BitcountNode.COUNT_LEADING_ZEROS: {
+
+					baseFn = this._createLeadingZerosBaseLayout( baseMethod, elementType );
+					break;
+
+				}
+
+				case BitcountNode.COUNT_TRAILING_ZEROS: {
+
+					baseFn = this._createTrailingZerosBaseLayout( baseMethod, elementType );
+					break;
+
+				}
+
+				case BitcountNode.COUNT_ONE_BITS: {
+
+					baseFn = this._createOneBitsBaseLayout( baseMethod, elementType );
+					break;
+
+				}
+
+			}
+
+			registeredBitcountFunctions[ baseMethod ] = baseFn;
+
+		}
+
+		let fn = registeredBitcountFunctions[ newMethod ];
+
+		if ( fn === undefined ) {
+
+			fn = this._createMainLayout( newMethod, inputType, typeLength, baseFn );
+			registeredBitcountFunctions[ newMethod ] = fn;
+
+		}
+
+		const output = Fn( () => {
+
+			return fn(
+				aNode,
+			);
+
+		} );
+
+		return output();
+
+	}
+
+}
+
+export default BitcountNode;
+
+BitcountNode.COUNT_TRAILING_ZEROS = 'countTrailingZeros';
+BitcountNode.COUNT_LEADING_ZEROS = 'countLeadingZeros';
+BitcountNode.COUNT_ONE_BITS = 'countOneBits';
+
+/**
+ * Finds the number of consecutive 0 bits from the least significant bit of the input value,
+ * which is also the index of the least significant bit of the input value.
+ *
+ * Can only be used with {@link WebGPURenderer} and a WebGPU backend.
+ *
+ * @tsl
+ * @function
+ * @param {Node | number} x - The input value.
+ * @returns {Node}
+ */
+export const countTrailingZeros = /*@__PURE__*/ nodeProxyIntent( BitcountNode, BitcountNode.COUNT_TRAILING_ZEROS ).setParameterLength( 1 );
+
+/**
+ * Finds the number of consecutive 0 bits starting from the most significant bit of the input value.
+ *
+ * Can only be used with {@link WebGPURenderer} and a WebGPU backend.
+ *
+ * @tsl
+ * @function
+ * @param {Node | number} x - The input value.
+ * @returns {Node}
+ */
+export const countLeadingZeros = /*@__PURE__*/ nodeProxyIntent( BitcountNode, BitcountNode.COUNT_LEADING_ZEROS ).setParameterLength( 1 );
+
+/**
+ * Finds the number of '1' bits set in the input value
+ *
+ * Can only be used with {@link WebGPURenderer} and a WebGPU backend.
+ *
+ * @tsl
+ * @function
+ * @returns {Node}
+ */
+export const countOneBits = /*@__PURE__*/ nodeProxyIntent( BitcountNode, BitcountNode.COUNT_ONE_BITS ).setParameterLength( 1 );

+ 2 - 2
src/renderers/webgl-fallback/nodes/GLSLNodeBuilder.js

@@ -10,8 +10,8 @@ import { DataTexture } from '../../../textures/DataTexture.js';
 import { error } from '../../../utils.js';
 
 const glslPolyfills = {
-	bitcast_int_uint: new CodeNode( /* glsl */'uint tsl_bitcast_uint_to_int ( int x ) { return floatBitsToInt( uintBitsToFloat( x ) ); }' ),
-	bitcast_uint_int: new CodeNode( /* glsl */'uint tsl_bitcast_int_to_uint ( int x ) { return floatBitsToUint( intBitsToFloat ( x ) ); }' )
+	bitcast_int_uint: new CodeNode( /* glsl */'uint tsl_bitcast_int_to_uint ( int x ) { return floatBitsToUint( intBitsToFloat ( x ) ); }' ),
+	bitcast_uint_int: new CodeNode( /* glsl */'uint tsl_bitcast_uint_to_int ( uint x ) { return floatBitsToInt( uintBitsToFloat ( x ) ); }' )
 };
 
 const glslMethods = {

粤ICP备19079148号