||
- <html lang="en">
- <head>
- <title>three.js webgpu - compute reduction</title>
- <meta charset="utf-8">
- <meta name="viewport" content="width=device-width, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0">
- <link type="text/css" rel="stylesheet" href="main.css">
- </head>
- <body>
- <style>
- #reduction-panel {
- background-color: #111;
- width: 100%;
- display: flex;
- position: fixed;
- height: auto;
- bottom: 0px;
- z-index: 99;
- flex-direction: column;
- justify-content: center;
- align-items: center;
- border-left: 2px solid #222;
- text-align: center;
- }
- #panel-title {
- width: fit-content;
- }
- .thread-row {
- display: flex;
- flex-direction: row;
- align-items: center;
- margin: 4px 0;
- position: relative;
- }
- .thread {
- width: 16px;
- height: 16px;
- background-color: #444;
- margin-right: 2px;
- transition: background-color 0.5s, transform 0.5s;
- }
- .stage-display {
- display: flex;
- flex-direction: column;
- justify-content: center;
- margin-bottom: 5px;
- }
- .stage-label {
- font-size: 1.2em;
- color: #aaa;
- font-style: bold;
- margin-top: 6px;
- margin-bottom: 20px;
- }
- .thread {
- display: flex;
- justify-content: center;
- align-items: center;
- width: 40px;
- height: 40px;
- margin: 2px;
- border: 1px solid rgba(255, 255, 255, 0.2);
- border-radius: 4px;
- background: linear-gradient(180deg, rgba(255,255,255,0.05), rgba(0,0,0,0.2));
- box-shadow: inset 0 0 2px rgba(255,255,255,0.1);
- font-family: monospace;
- color: white;
- }
- .thread_data {
- display: block;
- max-width: 100%;
- padding: 0 2px;
- white-space: nowrap;
- overflow: hidden;
- text-overflow: ellipsis;
- font-size: clamp(8px, 2vw, 14px);
- text-align: center;
- }
- .subgroup {
- display: flex;
- position: relative;
- margin-left: 10px;
- margin-right: 10px;
- }
- .subgroup::before {
- /* label text for each subgroup label */
- content: "subgroupAdd()";
- position: absolute;
- top: -20px;
- /* Hide until animation is displayed */
- opacity: 0;
- z-index: 100;
- transition: opacity 0.5s ease;
- font-weight: bold;
- color: white;
- width: 100%;
- }
- .subgroup::after {
- content: attr(data-label);
- position: absolute;
- bottom: -20px;
- opacity: 1;
- z-index: 100;
- color: gray;
- width: 100%;
- }
- .reduction-stage {
- margin-bottom: 20px;
- }
- @keyframes labelAbsorb {
- 0% {
- opacity: 0;
- transform: translateY(-50%);
- }
- 40% {
- opacity: 1;
- transform: translateY(0%);
- }
- 60% {
- opacity: 1;
- transform: translateY(0%);
- }
- 80% {
- opacity: 1;
- transform: translate(0%, -20%);
- }
- 100% {
- opacity: 0;
- transform: translate(0%, 100%);
- }
- }
- .subgroup.anim::before {
- opacity: 0;
- animation-name: labelAbsorb;
- animation-duration: 1.5s;
- transition:
- transform 0.6s ease-out,
- opacity 0.3s ease-in 0.3s;
- }
- </style>
- <div id="info">
- <a href="https://threejs.org" target="_blank" rel="noopener">three.js</a>
- <br /> This example demonstrates the performance of various simple parallel reduction kernels.
- <br /> Reference implementations are translated from the CUDA/WGSL code present in the following books/repos:
- <br /> Impl. 0 - 2: <a href="https://www.cambridge.org/core/books/programming-in-parallel-with-cuda/C43652A69033C25AD6933368CDBE084C"><i>Programming in Parallel with CUDA</i></a> by <a href="https://people.bss.phy.cam.ac.uk/~rea1/">Richard Ansorge</a>
- <br /> Impl. 3: <a href="https://github.com/frost-beta/betann/blob/main/betann/wgsl/reduce_all.wgsl"><i>betann reduce_all kernel</i></a> by <a href="https://github.com/zcbenz">zcbenz</a>
- <br /> Impl. 4: <a href="https://github.com/b0nes164/GPUPrefixSums/blob/main/GPUPrefixSumsWebGPUapis/SharedShaders/rts.wgsl"><i>GPUPrefixSums reduction approach</i></a> by <a href="https://github.com/b0nes164">b0nes164</a>
- <div id="left_side_display" style="position: absolute;top: 150px;left: 0;padding: 10px;background: rgba( 0, 0, 0, 0.5 );color: #fff;font-family: monospace;font-size: 12px;line-height: 1.5;pointer-events: none;text-align: left;"></div>
- <div id="right_side_display" style="position: absolute;top: 150px;right: 0;padding: 10px;background: rgba( 0, 0, 0, 0.5 );color: #fff;font-family: monospace;font-size: 12px;line-height: 1.5;pointer-events: none;text-align: left;"></div>
- </div>
- <div id="reduction-panel">
- <h3 id="panel-title" style="flex: 0 0 auto;">Subgroup Reduction Explanation</h3>
- <div class="reduction-stage" id="subgroup-reduction-stage">
- <div class="stage-label">Use subgroupAdd() to capture reduction of each workgroup's subgroups (Hover for animation)</div>
- <div class="stage-display">
- <div id="workgroup_threads" style="display: flex; justify-content: center; margin-bottom: 20px;"></div>
- <div id="subgroup_reduction" style="display: flex; justify-content: center; margin-bottom: 5px;"></div>
- </div>
- </div>
- </div>
- <script type="importmap">
- {
- "imports": {
- "three": "../build/three.webgpu.js",
- "three/webgpu": "../build/three.webgpu.js",
- "three/tsl": "../build/three.tsl.js",
- "three/addons/": "./jsm/"
- }
- }
- </script>
- <script type="module">
- import * as THREE from 'three/webgpu';
- import { instancedArray, Loop, If, vec3, dot, clamp, storage, uvec4, subgroupAdd, uniform, uv, uint, float, Fn, vec2, invocationLocalIndex, invocationSubgroupIndex, uvec2, floor, instanceIndex, workgroupId, workgroupBarrier, workgroupArray, subgroupSize, select, countTrailingZeros } from 'three/tsl';
- import WebGPU from 'three/addons/capabilities/WebGPU.js';
- import { GUI } from 'three/addons/libs/lil-gui.module.min.js';
- const timestamps = {
- left_side_display: document.getElementById( 'left_side_display' ),
- right_side_display: document.getElementById( 'right_side_display' )
- };
- const divRoundUp = ( size, part_size ) => {
- return Math.floor( ( size + part_size - 1 ) / part_size );
- };
- const cssSubgroupSize = 4;
- const cssWorkgroupSize = 16;
- const workgroupThreadsContainer = document.getElementById( 'workgroup_threads' );
- const subgroupReductionContainer = document.getElementById( 'subgroup_reduction' );
- document.getElementById( 'panel-title' ).textContent += ` (Subgroup Size: ${cssSubgroupSize}, Workgroup Size: ${cssWorkgroupSize})`;
- const createThreadWithData = ( data ) => {
- const threadEle = document.createElement( 'div' );
- threadEle.className = 'thread';
- const threadData = document.createElement( 'span' );
- threadData.textContent = data; // safer than innerHTML for just text
- threadData.className = 'thread_data';
- threadEle.append( threadData );
- return threadEle;
- };
- // Create thread elements
- const workgroupThreads = [];
- const initialSubgroups = [];
- const initialData = [];
- let currentSubgroupDiv = null;
- for ( let i = 0; i < cssWorkgroupSize; i ++ ) {
- if ( i % cssSubgroupSize === 0 ) {
- const currentSubgroupIndex = Math.floor( i / cssSubgroupSize );
- const subgroupReductionThread = createThreadWithData( 0 );
- subgroupReductionThread.id = `subgroup_reduction_element_${currentSubgroupIndex}`;
- subgroupReductionContainer.appendChild( subgroupReductionThread );
- currentSubgroupDiv = document.createElement( 'div' );
- currentSubgroupDiv.className = 'subgroup';
- currentSubgroupDiv.setAttribute( 'data-label', `Threads ${currentSubgroupIndex * cssSubgroupSize}-${( currentSubgroupIndex + 1 ) * cssSubgroupSize - 1}` );
- initialSubgroups.push( currentSubgroupDiv );
- workgroupThreadsContainer.appendChild( currentSubgroupDiv );
- }
- const data = Math.floor( Math.random() * 9 ) + 1;
- initialData.push( data );
- const thread = createThreadWithData( data );
- workgroupThreads.push( thread );
- currentSubgroupDiv.appendChild( thread );
- }
- const deactivateLabelAnimation = ( subgroupDiv, idx ) => {
- subgroupDiv.classList.remove( 'anim' );
- const subgroupReductionBufferElement = document.getElementById( `subgroup_reduction_element_${idx}` ).querySelector( '.thread_data' );
- subgroupReductionBufferElement.innerHTML = 0;
- };
- const activateLabelAnimation = ( subgroupDiv, idx ) => {
- const threads = Array.from( subgroupDiv.children );
- let total = 0;
- for ( let i = idx * cssSubgroupSize; i < idx * cssSubgroupSize + cssSubgroupSize; i ++ ) {
- total += initialData[ i ];
- }
- subgroupDiv.classList.add( 'anim' );
- setTimeout( () => {
- threads.forEach( t => {
- t.querySelector( '.thread_data' ).textContent = total;
- } );
- const subgroupReductionBufferElement = document.getElementById( `subgroup_reduction_element_${idx}` ).querySelector( '.thread_data' );
- subgroupReductionBufferElement.innerHTML = total;
- }, 1000 );
- // Remove the class after the animation ends so it can be triggered again
- setTimeout( () => {
- subgroupDiv.classList.remove( 'anim' );
- }, 1500 ); // matches animation duration in CSS
- };
- document.getElementById( 'subgroup-reduction-stage' ).addEventListener( 'mouseenter', () => {
- initialSubgroups.forEach( ( subgroupDiv, idx ) => {
- activateLabelAnimation( subgroupDiv, idx );
- } );
- } );
- document.getElementById( 'subgroup-reduction-stage' ).addEventListener( 'mouseleave', () => {
- initialSubgroups.forEach( ( subgroupDiv, idx ) => {
- deactivateLabelAnimation( subgroupDiv, idx );
- } );
- workgroupThreads.forEach( ( thread, idx ) => {
- thread.querySelector( '.thread_data' ).textContent = initialData[ idx ];
- } );
- } );
- if ( WebGPU.isAvailable() === false ) {
- document.body.appendChild( WebGPU.getErrorMessage() );
- throw new Error( 'No WebGPU support' );
- }
- // Total number of elements and the dimensions of the display grid.
- const size = 262144;
- const vecSize = divRoundUp( size, 4 );
- // Grid display is gridDim x gridDim
- const gridDim = Math.sqrt( size );
- let maxWorkgroupSize = 64;
- // Algorithm speed increase as you iterate through algorithms array
- const algorithms = [
- 'Reduce 0 (N/2)',
- 'Reduce 1 (Naive Accumulate)',
- 'Reduce 2 (Workgroup Reduction)',
- 'Reduce 3 (Subgroup Reduce)',
- 'Reduce 4 (Subgroup Optimized)',
- 'Incorrect Baseline',
- ];
- // Input Grid: Displays input data in a grid format
- // Input Log2: Displays input grid data's logarithmic indices horizontally (1, 2, 4, 8, 16, ..., size)
- // Input Element 0: Displays clamped input[0]
- const displayModes = [ 'Input Grid', 'Input Log2', 'Input Element 0', 'Workgroup Sum Grid' ];
- // Holds uniforms for both displays as well as debug information
- const unifiedEffectController = {
- // Number of elements in the grid
- gridElementWidth: uniform( gridDim ),
- gridElementHeight: uniform( gridDim ),
- // Number of elements in the grid being displayed
- gridDisplayWidth: uniform( gridDim ),
- gridDisplayHeight: uniform( gridDim ),
- // How to display end result of reduction.
- // Ideally this is unique to the reduction method being deployed
- 'Display Mode': 'Input Log2',
- loggedBuffer: 'Input Buffer',
- elementsReduced: size,
- };
- const leftEffectController = {
- // Current reduction algorithm being executed by this side
- algo: 'Reduce 0 (N/2)',
- // Flag indicating whether to highlight element in validation check
- highlight: uniform( 0 ),
- // Uniform that corresponds to the index of the current algorithm within the algorithms array
- currentAlgo: uniform( 0 ),
- // Current state of reduction (Running, validating, resetting)
- state: 'Run Algo',
- // Current display mode
- displayMode: 'Input Log2',
- // Reduce 0 specific uniform
- numThreadsDispatched: uniform( size / 2 ),
- // The subgroup size used by this side's device
- };
- const rightEffectController = {
- algo: 'Reduce 4 (Subgroup Optimized)',
- currentAlgo: uniform( 3 ),
- highlight: uniform( 0 ),
- displayMode: 'Input Element 0',
- state: 'Run Algo',
- numThreadsDispatched: uniform( size / 2 )
- };
- const leftMaterial = new THREE.MeshBasicNodeMaterial( { color: 0x00ff00 } );
- const rightMaterial = new THREE.MeshBasicNodeMaterial( { color: 0x00ff00 } );
- const leftDisplayColorNodes = {};
- const rightDisplayColorNodes = {};
- const gui = new GUI();
- gui.add( leftEffectController, 'algo', algorithms ).onChange( () => {
- leftEffectController.currentAlgo.value = algorithms.findIndex( val => val === leftEffectController.algo );
- } );
- gui.add( rightEffectController, 'algo', algorithms ).onChange( () => {
- rightEffectController.currentAlgo.value = algorithms.findIndex( val => val === rightEffectController.algo );
- } );
- gui.add( leftEffectController, 'displayMode', displayModes ).name( 'Left Display Mode' ).onChange( () => {
- leftMaterial.colorNode = leftDisplayColorNodes[ leftEffectController.displayMode ];
- leftMaterial.needsUpdate = true;
- } );
- gui.add( rightEffectController, 'displayMode', displayModes ).name( 'Right Display Mode' ).onChange( () => {
- rightMaterial.colorNode = rightDisplayColorNodes[ rightEffectController.displayMode ];
- rightMaterial.needsUpdate = true;
- } );
- const debugFolder = gui.addFolder( 'Debug' );
- const elementsReducedController = debugFolder.add( unifiedEffectController, 'elementsReduced' ).name( 'Elements Reduced' );
- elementsReducedController.disable();
- const stateLeftController = debugFolder.add( leftEffectController, 'state' ).name( 'Left Display State' );
- const stateRightController = debugFolder.add( rightEffectController, 'state' ).name( 'Right Display State' );
- stateLeftController.disable();
- stateRightController.disable();
- debugFolder.add( unifiedEffectController, 'loggedBuffer', [ 'Input Buffer', 'Input Vectorized Buffer', 'Workgroup Sums Buffer', 'Debug Buffer' ] ).name( 'Buffer to Log' );
- debugFolder.close();
- // HELPER FUNCTIONS
- const pow2Ceil = Fn( ( [ x ] ) => {
- If( x.equal( uint( 0 ) ), () => {
- return uint( 1 );
- } );
- const val = x.sub( 1 ).toVar( 'val' );
- val.assign( val.bitOr( val.shiftRight( 1 ) ) );
- val.assign( val.bitOr( val.shiftRight( 2 ) ) );
- val.assign( val.bitOr( val.shiftRight( 4 ) ) );
- val.assign( val.bitOr( val.shiftRight( 8 ) ) );
- val.assign( val.bitOr( val.shiftRight( 16 ) ) );
- return val.add( 1 );
- } ).setLayout( {
- name: 'pow2Ceil',
- type: 'uint',
- inputs: [
- { name: 'x', type: 'uint' }
- ]
- } );
- // ALGORITHM CONSTRUCTORS
- // REDUCE 1
- // Thanks to Sam0oneau of Graphics Programming Discord for the explanation.
- // (Graphics Programming Discord Message Link): https://discord.com/channels/318590007881236480/374061825454768129/1391248956171882597
- /* Reduce 1 Example (Assume Workgroup Size 256, numElements: 262144) -> Initial currentBuffer State: | 1, 1, 1, 1, ... |
- *
- * KERNEL 1:
- * Executes 256 threads by 256 workgroups. Each thread loops 4 times and accesses elements
- * at the indices below.
- * Thread 1 Thread 2 Thread 3
- * | 0, 65536, ..., n * 65536 | 1, 65537, .... (n * 65536) + 1 | 1, 65538, .... (n * 65536) + 2 | etc
- * Buffer Values: | 4, 4, 4, 4, ...|
- *
- * KERNEL 2:
- * Executes 256 threads by one workgroup. Each thread loops 1024 times
- * Thread 1 Thread 2 Thread 3
- * | 0, 256, ...., n * 256 | 1, 257, ... (n * 256) + 1 | 2, 258, ... (n * 256) + 3 | etc
- * Buffer Values: | 1024, 1024, 1024, 1024, ... |
- *
- * KERNEL 3:
- * Executes 1 thread by one workgroup. Single thread loops 256 times
- * Thread 1
- * | 0, 1, 2, 3, 4, 5, 6 ... etc|
- * Buffer Values: [262144, 1024, 1024]
- */
- const createReduce1Fn = ( createReduce1FnProps ) => {
- const { dispatchSize, numElements, inputBuffer, workgroupSize } = createReduce1FnProps;
- const fnDef = Fn( () => {
- const dispatch = uint( dispatchSize ).toVar( 'dispatchSize' );
- const tSum = uint( 0 ).toVar();
- const k = instanceIndex.toVar( 'k' );
- Loop( k.lessThan( uint( numElements ) ), ( ) => {
- tSum.addAssign( inputBuffer.element( k ) );
- k.addAssign( uint( dispatch ) );
- } );
- inputBuffer.element( instanceIndex ).assign( tSum );
- } )().compute( dispatchSize, [ workgroupSize ] );
- return fnDef;
- };
- // REDUCE 2
- // For non power of 2 # of workgroups
- const createReduce2Fn = ( createReduce2FnProps ) => {
- const { workgroupSize, dispatchSize, numElements, inputBuffer } = createReduce2FnProps;
- const fnDef = Fn( () => {
- const tSum = workgroupArray( 'uint', workgroupSize );
- const k = instanceIndex.toVar( 'k' );
- tSum.element( invocationLocalIndex ).assign( uint( 0 ) );
- Loop( k.lessThan( uint( numElements ) ), () => {
- tSum.element( invocationLocalIndex ).addAssign( inputBuffer.element( k ) );
- k.addAssign( uint( dispatchSize ) );
- } );
- workgroupBarrier();
- // Reset the loop condition (account for numWorkgroups % 2 != 0)
- k.assign( pow2Ceil( uint( workgroupSize ) ).div( 2 ) );
- Loop( k.greaterThan( 0 ), () => {
- If( invocationLocalIndex.lessThan( k ).and( invocationLocalIndex.add( k ).lessThan( workgroupSize ) ), () => {
- tSum.element( invocationLocalIndex ).addAssign( tSum.element( invocationLocalIndex.add( k ) ) );
- } );
- workgroupBarrier();
- k.divAssign( 2 );
- } );
- If( invocationLocalIndex.equal( uint( 0 ) ), () => {
- inputBuffer.element( workgroupId.x ).assign( tSum.element( uint( 0 ) ) );
- } );
- } )().compute( dispatchSize, [ workgroupSize ] );
- return fnDef;
- };
- // REDUCE 3
- /* Create array with enough indices for worst-case subgroup size */
- const createSubgroupArray = ( type, workgroupSize, minSubgroupSize = 4 ) => {
- return workgroupArray( 'uint', workgroupSize / minSubgroupSize );
- };
- // zcbenz implementation
- // https://github.com/frost-beta/betann/blob/8aa2701caf63fb29bd4cd2454e656973342c1588/betann/wgsl/reduce_ops.wgsl#L71
- const RowReduce = ( rowReduceProps ) => {
- const { workgroupSize, inputBuffer, total, rowOffset, currentRowSize, workPerThread, vectorized } = rowReduceProps;
- // Number of unvectorized elements each workgroup can ingest
- // At workgroupSize of 256, blockSize will be 1024
- const blockSize = uint( workgroupSize ).mul( workPerThread );
- const block = uint( 0 ).toVar( 'block' );
- // At rowSize of 2048, there will be two blocks
- const blockLimiter = currentRowSize.div( blockSize ).toVar( 'blockLimiter' );
- Loop( block.lessThan( blockLimiter ), () => {
- const blockOffset = block.mul( blockSize );
- const startThread = blockOffset.add( invocationLocalIndex.mul( workPerThread ) );
- const localThreadOffset = uint( 0 ).toVar( 'localThreadOffset' );
- Loop( localThreadOffset.lessThan( workPerThread ), () => {
- const inputElement = inputBuffer.element( rowOffset.add( startThread ).addLocal );
- if ( vectorized ) {
- const value = dot( inputElement, uvec4( 1 ) );
- total.addAssign( value );
- } else {
- const inputElement = inputBuffer.element( rowOffset.add( startThread ).add( localThreadOffset ) );
- total.addAssign( inputElement );
- }
- // Increment up a thread
- localThreadOffset.addAssign( 1 );
- } );
- // Increment up a block
- block.addAssign( 1 );
- } );
- // Ignoring left over check for this example, since we know ahead of time the value of leftover (2048 % 1024 === 0)
- };
- const WorkgroupReduce = ( workgroupReduceProps ) => {
- const { total, workgroupSize } = workgroupReduceProps;
- const subgroupSums = createSubgroupArray( 'uint', workgroupSize );
- // Assign sum of all values in subgroup to total
- total.assign( subgroupAdd( total ) );
- const delta = uint( workgroupSize ).div( subgroupSize ).toVar( 'delta' );
- const subgroupMetaRank = invocationLocalIndex.div( subgroupSize );
- Loop( float( delta ).greaterThan( 1.0 ), () => {
- If( invocationSubgroupIndex.equal( 0 ), () => {
- // Each subgroup will populate the subgroupSums array
- subgroupSums.element( subgroupMetaRank ).assign( total );
- } );
- // Ensure that all subgroups in the workgroup have populated the workgroup memory array
- workgroupBarrier();
- // Thread 0 - subgroupsInWorkgroup will assign a value to total
- total.assign( select( invocationLocalIndex.lessThan( delta ), subgroupSums.element( invocationLocalIndex ), 0 ).uniformFlow() );
- // # of subgroups in workgroup is invariably less than # of threads in subgroup, so subgroupAdd will still sync here
- total.assign( subgroupAdd( total ) );
- delta.divAssign( subgroupSize );
- } );
- };
- const createReduce3Fn = ( createReduce3FnProps ) => {
- const { workgroupSize, workPerThread, inputBuffer, intermediateBuffer, rowSize } = createReduce3FnProps;
- const fnDef = Fn( () => {
- const inputSize = uint( inputBuffer.bufferCount.length );
- const rowOffset = workgroupId.x.mul( rowSize );
- // If the current rows elements exceed the bounds of the input
- // Select either 0 or number of elements left,
- // otherwise, select existing ROW_SIZE
- const currentRowSize = select(
- ( rowOffset.add( rowSize ) ).greaterThan( inputSize ),
- select( inputSize.greaterThan( rowOffset ), inputSize.sub( rowOffset ), 0 ).uniformFlow(),
- rowSize,
- ).uniformFlow();
- const total = uint( 0 ).toVar( 'total' );
- RowReduce( {
- inputBuffer: inputBuffer,
- total: total,
- rowOffset: rowOffset,
- currentRowSize: currentRowSize,
- workPerThread: workPerThread,
- workgroupSize: workgroupSize,
- } );
- WorkgroupReduce( {
- total: total,
- workgroupSize: workgroupSize,
- } );
- // Populate each workgroup with its reduction
- If( invocationLocalIndex.equal( 0 ), () => {
- intermediateBuffer.element( workgroupId.x ).assign( total );
- } );
- } )();
- return fnDef;
- };
- // REDUCE 4
- // b0nes164 inspired implementation with vec4
- const createReduce4Fn = ( props ) => {
- // Can't pass in subgroup size since we can't always be certain what size is at runtime
- const { size, workPerThread, workgroupSize, inputBuffer, intermediateBuffer } = props;
- const ELEMENTS_PER_VEC4 = 4;
- // The number of individual elements a single workgroup will access
- const partitionSize = workgroupSize * workPerThread * ELEMENTS_PER_VEC4;
- const vecSize = divRoundUp( size, ELEMENTS_PER_VEC4 );
- // Can also be calculated using divRoundUp( vecSize, workgroupSize * workPerThread );
- const numWorkgroups = divRoundUp( size, partitionSize );
- // Currently no way to specify dispatch size in increments of workgroups, so we convert to numInvocations
- const numInvocations = numWorkgroups * workgroupSize;
- const fnDef = Fn( () => {
- const perSubgroupReductionArray = createSubgroupArray( 'uint', workgroupSize );
- // Get the index of the subgroup within the workgroup
- const subgroupMetaRank = invocationLocalIndex.div( subgroupSize );
- // Each subgroup block scans across 4 subgroups. So when we move into a new subgroup,
- // align that subgroups' accesses to the next 4 subgroups
- const subgroupOffset = subgroupMetaRank.mul( subgroupSize ).mul( workPerThread );
- subgroupOffset.addAssign( invocationSubgroupIndex );
- // Per workgroup, offset by number of vectorized elements scanned per workgroup
- const workgroupOffset = workgroupId.x.mul( uint( maxWorkgroupSize ).mul( workPerThread ) );
- const startThread = subgroupOffset.add( workgroupOffset );
- const subgroupReduction = uint( 0 );
- // Each thread will accumulate values from across 'workPerThread' subgroups
- If( workgroupId.x.lessThan( uint( numWorkgroups ).sub( 1 ) ), () => {
- Loop( {
- start: uint( 0 ),
- end: workPerThread,
- type: 'uint',
- condition: '<',
- name: 'currentSubgroupInBlock'
- }, () => {
- // Get vectorized element from input array
- const val = inputBuffer.element( startThread );
- // Sum values within vec4 together by using result of dot product
- subgroupReduction.addAssign( dot( uvec4( 1 ), val ) );
- // Increment so thread will scan value in next subgroup
- startThread.addAssign( subgroupSize );
- } );
- } );
- // Ensure that the last workgroup does not access out of bounds indices
- If( workgroupId.x.equal( uint( numWorkgroups ).sub( 1 ) ), () => {
- Loop( {
- start: uint( 0 ),
- end: workPerThread,
- type: 'uint',
- condition: '<',
- name: 'currentSubgroupInBlock'
- }, () => {
- // Ensure index is less than number of available vectors in inputBuffer
- const val = select( startThread.lessThan( uint( vecSize ) ), inputBuffer.element( startThread ), uvec4( 0 ) ).uniformFlow();
- subgroupReduction.addAssign( dot( val, uvec4( 1 ) ) );
- startThread.addAssign( subgroupSize );
- } );
- } );
- subgroupReduction.assign( subgroupAdd( subgroupReduction ) );
- // Assuming that each element in the input buffer is 1, we generally expect each invocation's subgroupReduction
- // value to be ELEMENTS_PER_VEC4 * workPerThread * subgroupSize
- // Delegate one thread per subgroup to assign each subgroup's reduction to the workgroup array
- If( invocationSubgroupIndex.equal( uint( 0 ) ), () => {
- perSubgroupReductionArray.element( subgroupMetaRank ).assign( subgroupReduction );
- } );
- // Ensure that each workgroup has populated the perSubgroupReductionArray with data
- // from each of it's subgroups
- workgroupBarrier();
- if ( props.debugBuffer ) {
- If( invocationLocalIndex.equal( uint( 0 ) ), () => {
- props.debugBuffer.element( workgroupId.x ).assign( subgroupReduction );
- } );
- workgroupBarrier();
- }
- // WORKGROUP LEVEL REDUCE
- // Multiple approaches here
- // log2(subgroupSize) -> TSL log2 function
- // countTrailingZeros/findLSB(subgroupSize) -> TSL function that counts trailing zeros in number bit representation
- // Can technically petition GPU for subgroupSize in shader and calculate logs on CPU at cost of shader being generalizable across devices
- // May also break if subgroupSize changes when device is lost or if program is rerun on lower power device
- const subgroupSizeLog = countTrailingZeros( subgroupSize ).toVar( 'subgroupSizeLog' );
- const spineSize = uint( workgroupSize ).shiftRight( subgroupSizeLog );
- const spineSizeLog = countTrailingZeros( spineSize ).toVar( 'spineSizeLog' );
- // Align size to powers of subgroupSize
- const squaredSubgroupLog = ( spineSizeLog.add( subgroupSizeLog ).sub( 1 ) );
- squaredSubgroupLog.divAssign( subgroupSizeLog );
- squaredSubgroupLog.mulAssign( subgroupSizeLog );
- const alignedSize = ( uint( 1 ).shiftLeft( squaredSubgroupLog ) ).toVar( 'alignedSize' );
- // aligned size 2 * 4
- const offset = uint( 0 );
- // In cases where the number of subgroups in a workgroup is greater than the subgroup size itself,
- // we need to iterate over the array again to capture all the data in the workgroup array buffer
- Loop( { start: subgroupSize, end: alignedSize, condition: '<=', name: 'j', type: 'uint', update: '<<= subgroupSizeLog' }, () => {
- const subgroupIndex = ( ( invocationLocalIndex.add( 1 ) ).shiftLeft( offset ) ).sub( 1 );
- const isValidSubgroupIndex = subgroupIndex.lessThan( spineSize ).toVar( 'isValidSubgroupIndex' );
- // Reduce values within the local workgroup memory.
- // Set toVar to ensure subgroupAdd executes before (not within) the if statement.
- const t = subgroupAdd(
- select(
- isValidSubgroupIndex,
- perSubgroupReductionArray.element( subgroupIndex ),
- 0
- ).uniformFlow()
- ).toVar( 't' );
- // Can assign back to workgroupArray since all
- // subgroup threads work in lockstop for subgroupAdd
- If( isValidSubgroupIndex, () => {
- perSubgroupReductionArray.element( subgroupIndex ).assign( t );
- } );
- // Ensure all threads have completed work
- workgroupBarrier();
- offset.addAssign( subgroupSizeLog );
- } );
- // Assign single thread from workgroup to assign workgroup reduction
- If( invocationLocalIndex.equal( uint( 0 ) ), () => {
- const reducedWorkgroupSum = perSubgroupReductionArray.element( uint( spineSize ).sub( 1 ) );
- intermediateBuffer.element( workgroupId.x ).assign( reducedWorkgroupSum );
- } );
- } )().compute( numInvocations, [ maxWorkgroupSize ] );
- return fnDef;
- };
- // INCORRECT BASELINE
- const createIncorrectBaselineFn = ( incorrectBaselineProps ) => {
- const { inputBuffer } = incorrectBaselineProps;
- const fnDef = Fn( () => {
- inputBuffer.element( instanceIndex ).assign( 99999 );
- } )();
- return fnDef;
- };
- init();
- init( false );
- async function init( leftSideDisplay = true ) {
- const effectController = leftSideDisplay ? leftEffectController : rightEffectController;
- const aspect = ( window.innerWidth / 2 ) / window.innerHeight;
- const camera = new THREE.OrthographicCamera( - aspect, aspect, 1, - 1, 0, 2 );
- camera.position.z = 1;
- const scene = new THREE.Scene();
- const array = new Uint32Array( Array.from( { length: size }, () => {
- return 1;
- } ) );
- // Represents array of data as uints in compute shader.
- const inputStorage = instancedArray( array, 'uint' ).setPBO( true ).setName( `Current_${leftSideDisplay ? 'Left' : 'Right'}` );
- // Represents array of data as vec4s in compute shader;
- const inputVec4BufferAttribute = new THREE.StorageInstancedBufferAttribute( array, 4 );
- const inputVectorizedStorage = storage( inputVec4BufferAttribute, 'uvec4', vecSize ).setPBO( true ).setName( `CurrentVectorized_${leftSideDisplay ? 'Left' : 'Right'}` );
- // Reduce 3 Calculations
- const workPerThread = 4;
- const numRows = workPerThread * 32;
- const rowSize = divRoundUp( size, numRows );
- const workgroupSumsArray = new Uint32Array( numRows );
- const workgroupSumsStorage = instancedArray( workgroupSumsArray, 'uint' ).setPBO( true ).setName( `WorkgroupSums_${leftSideDisplay ? 'Left' : 'Right'}` );
- const debugArray = new Uint32Array( 1024 );
- const debugStorage = instancedArray( debugArray, 'uint' ).setPBO( true ).setName( `Debug_${leftSideDisplay ? 'Left' : 'Right'}` );
- const buffers = {
- 'Input Buffer': inputStorage,
- 'Input Vectorized Buffer': inputVectorizedStorage,
- 'Workgroup Sums Buffer': workgroupSumsStorage,
- 'Debug Buffer': debugStorage,
- };
- const logFunctionName = `Log ${leftSideDisplay ? 'Left' : 'Right'} Side`;
- const functionObj = {};
- functionObj[ logFunctionName ] = async() => {
- const selectedBuffer = buffers[ unifiedEffectController.loggedBuffer ];
- console.log( new Uint32Array( await renderer.getArrayBufferAsync( selectedBuffer.value ) ) );
- };
- debugFolder.add( functionObj, `Log ${leftSideDisplay ? 'Left' : 'Right'} Side` );
- const computeResetBufferFn = Fn( () => {
- inputStorage.element( instanceIndex ).assign( 1 );
- } );
- const computeResetWorkgroupSumsFn = Fn( () => {
- workgroupSumsStorage.element( instanceIndex ).assign( 0 );
- } );
- // Re-initialize compute buffer
- const computeResetBuffer = computeResetBufferFn().compute( size );
- const computeResetWorkgroupSums = computeResetWorkgroupSumsFn().compute( 256 );
- const renderer = new THREE.WebGPURenderer( { antialias: false, trackTimestamp: true } );
- renderer.setPixelRatio( window.devicePixelRatio );
- renderer.setSize( window.innerWidth / 2, window.innerHeight );
- await renderer.init();
- // Unfortunately, need to arbitrarily run compute shader to get access to device limits
- renderer.compute( computeResetBuffer );
- if ( renderer.backend.device !== null ) {
- maxWorkgroupSize = renderer.backend.device.limits.maxComputeWorkgroupSizeX;
- }
- // Create and store dispatches of reduction of certain size. Map each set of dispatches to algorithm name.
- const computeReduce0Fn = Fn( () => {
- const { numThreadsDispatched } = effectController;
- inputStorage.element( instanceIndex ).addAssign( inputStorage.element( instanceIndex.add( numThreadsDispatched ) ) );
- } )();
- const reduce0Calls = [];
- for ( let i = size / 2; i >= 1; i /= 2 ) {
- const reduce0 = computeReduce0Fn.compute( i, [ maxWorkgroupSize ] );
- reduce0Calls.push( reduce0 );
- }
- const reduce1Calls = [
- // Accumulation
- createReduce1Fn( {
- dispatchSize: maxWorkgroupSize * maxWorkgroupSize,
- workgroupSize: maxWorkgroupSize,
- numElements: size,
- inputBuffer: inputStorage,
- } ),
- // 1 Block accumulation
- createReduce1Fn( {
- dispatchSize: maxWorkgroupSize,
- numElements: maxWorkgroupSize * maxWorkgroupSize,
- workgroupSize: maxWorkgroupSize,
- inputBuffer: inputStorage,
- } ),
- // Final result
- createReduce1Fn( {
- dispatchSize: 1,
- numElements: maxWorkgroupSize,
- workgroupSize: 1,
- inputBuffer: inputStorage
- } ),
- ];
- const reduce2Calls = [
- // Accumulate within workgroups
- createReduce2Fn( {
- workgroupSize: maxWorkgroupSize,
- dispatchSize: maxWorkgroupSize * maxWorkgroupSize,
- numElements: size,
- inputBuffer: inputStorage,
- } ),
- // 1 Block accumulation
- createReduce2Fn( {
- workgroupSize: maxWorkgroupSize,
- dispatchSize: maxWorkgroupSize,
- numElements: maxWorkgroupSize,
- inputBuffer: inputStorage,
- } ),
- ];
- const reduce3Calls = [
- createReduce3Fn( {
- inputBuffer: inputStorage,
- intermediateBuffer: workgroupSumsStorage,
- workgroupSize: maxWorkgroupSize,
- workPerThread: 4,
- rowSize: rowSize,
- vectorized: false,
- } ).compute( maxWorkgroupSize * numRows, [ maxWorkgroupSize ] ),
- createReduce3Fn( {
- inputBuffer: workgroupSumsStorage,
- intermediateBuffer: inputStorage,
- workgroupSize: 32,
- workPerThread: 4,
- rowSize: rowSize,
- vectorized: false
- } ).compute( 32, [ 32 ] )
- ];
- const reduce4Calls = [
- createReduce4Fn( {
- size: size,
- inputBuffer: inputVectorizedStorage,
- intermediateBuffer: workgroupSumsStorage,
- workgroupSize: maxWorkgroupSize,
- workPerThread: 4,
- } ),
- createReduce3Fn( {
- inputBuffer: workgroupSumsStorage,
- intermediateBuffer: inputStorage,
- workgroupSize: 32,
- workPerThread: 4,
- rowSize: rowSize,
- vectorized: false
- } ).compute( 32, [ 32 ] )
- ];
- const incorrectBaselineCalls = [
- createIncorrectBaselineFn( {
- inputBuffer: inputStorage,
- } ).compute( size ),
- ];
- const calls = {
- 'Reduce 0 (N/2)': reduce0Calls,
- 'Reduce 1 (Naive Accumulate)': reduce1Calls,
- 'Reduce 2 (Workgroup Reduction)': reduce2Calls,
- 'Reduce 3 (Subgroup Reduce)': reduce3Calls,
- 'Reduce 4 (Subgroup Optimized)': reduce4Calls,
- 'Incorrect Baseline': incorrectBaselineCalls
- };
- const getColor = ( bufferToCheck, colorChanger, width, height ) => {
- const subtracter = float( colorChanger ).div( width.mul( height ) );
- const color = vec3( subtracter.oneMinus() ).toVar();
- const { highlight } = effectController;
- // Validate that element 0 is equal to expected result of reduction
- If( highlight.equal( 1 ), () => {
- If( ( bufferToCheck.element( 0 ) ).equal( size ), () => {
- color.assign( vec3( 0.0, subtracter.oneMinus(), 0.0 ) );
- } ).Else( () => {
- color.assign( vec3( subtracter.oneMinus(), 0.0, 0.0 ) );
- } );
- } );
- return color;
- };
- const displayNodes = leftSideDisplay ? leftDisplayColorNodes : rightDisplayColorNodes;
- displayNodes[ 'Input Grid' ] = Fn( () => {
- const { gridElementWidth, gridElementHeight, gridDisplayWidth, gridDisplayHeight } = unifiedEffectController;
- const newUV = uv().mul( vec2( gridDisplayWidth, gridDisplayHeight ) );
- const pixel = uvec2( uint( floor( newUV.x ) ), uint( floor( newUV.y ) ) );
- const elementIndex = uint( gridDisplayWidth ).mul( pixel.y ).add( pixel.x );
- const colorChanger = uint( 0 ).toVar();
- const color = vec3( 0 ).toVar( 'color' );
- colorChanger.assign( inputStorage.element( elementIndex ) );
- color.assign( getColor( inputStorage, colorChanger, gridElementWidth, gridElementHeight ) );
- return color;
- } )();
- displayNodes[ 'Input Log2' ] = Fn( () => {
- const { gridElementWidth, gridElementHeight } = unifiedEffectController;
- const newUV = uv().mul( vec2( Math.log2( size ) ), 1 );
- const colorChanger = uint( 0 ).toVar();
- const color = vec3( 0 ).toVar( 'color' );
- colorChanger.assign( inputStorage.element( uint( 1 ).shiftLeft( newUV.x ) ) );
- color.assign( getColor( inputStorage, colorChanger, gridElementWidth, gridElementHeight ) );
- return color;
- } )();
- displayNodes[ 'Input Element 0' ] = Fn( () => {
- const { gridElementWidth, gridElementHeight } = unifiedEffectController;
- const colorChanger = uint( 0 ).toVar();
- const color = vec3( 0 ).toVar( 'color' );
- // Clamp display of single element to shade where green is still readable
- colorChanger.assign( clamp( inputStorage.element( 0 ), 0, size / 2 ) );
- color.assign( getColor( inputStorage, colorChanger, gridElementWidth, gridElementHeight ) );
- return color;
- } )();
- displayNodes[ 'Workgroup Sum Grid' ] = Fn( () => {
- const width = uint( 8 );
- const height = uint( 16 );
- const newUV = uv().mul( vec2( width, height ) );
- const pixel = uvec2( uint( floor( newUV.x ) ), uint( floor( newUV.y ) ) );
- const elementIndex = uint( width ).mul( pixel.y ).add( pixel.x );
- const colorChanger = uint( 0 ).toVar();
- const color = vec3( 0 ).toVar( 'color' );
- colorChanger.assign( workgroupSumsStorage.element( elementIndex ) );
- color.assign( getColor( inputStorage, colorChanger, width, height ) );
- return color;
- } )();
- ( leftSideDisplay ? leftMaterial : rightMaterial ).colorNode = displayNodes[ effectController.displayMode ];
- ( leftSideDisplay ? leftMaterial : rightMaterial ).needsUpdate = true;
- const plane = new THREE.Mesh( new THREE.PlaneGeometry( 1, 1 ), ( leftSideDisplay ? leftMaterial : rightMaterial ) );
- scene.add( plane );
- const animate = () => {
- renderer.render( scene, camera );
- };
- renderer.setAnimationLoop( animate );
- document.body.appendChild( renderer.domElement );
- renderer.domElement.style.position = 'absolute';
- renderer.domElement.style.top = '0';
- renderer.domElement.style.left = '0';
- renderer.domElement.style.width = '50%';
- renderer.domElement.style.height = '100%';
- if ( ! leftSideDisplay ) {
- renderer.domElement.style.left = '50%';
- scene.background = new THREE.Color( 0x212121 );
- } else {
- scene.background = new THREE.Color( 0x313131 );
- }
- renderer.info.autoReset = false;
- const stepAnimation = async function () {
- const currentAlgorithm = effectController.algo;
- const state = effectController.state;
- const stateController = leftSideDisplay ? stateLeftController : stateRightController;
- if ( state === 'Reset' ) {
- renderer.compute( computeResetBuffer );
- renderer.compute( computeResetWorkgroupSums );
- } else if ( state === 'Run Algo' ) {
- renderer.info.reset();
- const cpuTime = 0;
- switch ( currentAlgorithm ) {
- case 'Reduce 0 (N/2)': {
- let m = size / 2;
- for ( let i = 0; i < reduce0Calls.length; i ++ ) {
- effectController.numThreadsDispatched.value = m;
- const reduce0 = reduce0Calls[ i ];
- // Do a reduction step
- renderer.compute( reduce0 );
- renderer.resolveTimestampsAsync( THREE.TimestampQuery.COMPUTE );
- m /= 2;
- }
- break;
- }
- default: {
- const currentAlgoCalls = calls[ currentAlgorithm ];
- for ( let i = 0; i < currentAlgoCalls.length; i ++ ) {
- renderer.compute( currentAlgoCalls[ i ] );
- renderer.resolveTimestampsAsync( THREE.TimestampQuery.COMPUTE );
- }
- break;
- }
- }
- // DEBUG: const reductionResult = new Uint32Array( await renderer.getArrayBufferAsync( currentBuffer ) )[0];
- let passInfoString = '';
- if ( effectController.algo.substring( 0, 3 ) === 'CPU' ) {
- passInfoString = `Ran in ${cpuTime}ms<br>`;
- } else {
- passInfoString = `${renderer.info.compute.frameCalls} pass in ${renderer.info.compute.timestamp.toFixed( 6 )}ms<br>`;
- }
- timestamps[ leftSideDisplay ? 'left_side_display' : 'right_side_display' ].innerHTML = `
- Compute ${effectController.algo}: ${passInfoString}`;
- }
- renderer.render( scene, camera );
- renderer.resolveTimestampsAsync( THREE.TimestampQuery.RENDER );
- // Validate next state
- if ( state === 'Run Algo' ) {
- stateController.setValue( 'Validate' );
- effectController.highlight.value = 1;
- } else if ( state === 'Validate' ) {
- stateController.setValue( 'Reset' );
- effectController.highlight.value = 0;
- } else if ( state === 'Reset' ) {
- stateController.setValue( 'Run Algo' );
- }
- setTimeout( stepAnimation, 1000 );
- };
- window.addEventListener( 'resize', onWindowResize );
- function onWindowResize() {
- renderer.setSize( window.innerWidth / 2, window.innerHeight );
- const aspect = ( window.innerWidth / 2 ) / window.innerHeight;
- const frustumHeight = camera.top - camera.bottom;
- camera.left = - frustumHeight * aspect / 2;
- camera.right = frustumHeight * aspect / 2;
- camera.updateProjectionMatrix();
- renderer.render( scene, camera );
- }
- setTimeout( stepAnimation, 1000 );
- }
- </script>
- </body>
- </html>
|