BitonicSort.js 17 KB


  1. import { Fn, uvec2, If, instancedArray, instanceIndex, invocationLocalIndex, Loop, workgroupArray, workgroupBarrier, workgroupId, uint, select, min, max } from 'three/tsl';
  2. const StepType = {
  3. NONE: 0,
  4. // Swap all values within the local range of workgroupSize * 2
  5. SWAP_LOCAL: 1,
  6. DISPERSE_LOCAL: 2,
  7. // Swap values within global data buffer.
  8. FLIP_GLOBAL: 3,
  9. DISPERSE_GLOBAL: 4,
  10. };
  11. /**
  12. * Returns the indices that will be compared in a bitonic flip operation.
  13. *
  14. * @tsl
  15. * @private
  16. * @param {Node<uint>} index - The compute thread's invocation id.
  17. * @param {Node<uint>} blockHeight - The height of the block within which elements are being swapped.
  18. * @returns {Node<uvec2>} The indices of the elements in the data buffer being compared.
  19. */
  20. export const getBitonicFlipIndices = /*@__PURE__*/ Fn( ( [ index, blockHeight ] ) => {
  21. const blockOffset = ( index.mul( 2 ).div( blockHeight ) ).mul( blockHeight );
  22. const halfHeight = blockHeight.div( 2 );
  23. const idx = uvec2(
  24. index.mod( halfHeight ),
  25. blockHeight.sub( index.mod( halfHeight ) ).sub( 1 )
  26. );
  27. idx.x.addAssign( blockOffset );
  28. idx.y.addAssign( blockOffset );
  29. return idx;
  30. } ).setLayout( {
  31. name: 'getBitonicFlipIndices',
  32. type: 'uvec2',
  33. inputs: [
  34. { name: 'index', type: 'uint' },
  35. { name: 'blockHeight', type: 'uint' }
  36. ]
  37. } );
  38. /**
  39. * Returns the indices that will be compared in a bitonic sort's disperse operation.
  40. *
  41. * @tsl
  42. * @private
  43. * @param {Node<uint>} index - The compute thread's invocation id.
  44. * @param {Node<uint>} swapSpan - The maximum span over which elements are being swapped.
  45. * @returns {Node<uvec2>} The indices of the elements in the data buffer being compared.
  46. */
  47. export const getBitonicDisperseIndices = /*@__PURE__*/ Fn( ( [ index, swapSpan ] ) => {
  48. const blockOffset = ( ( index.mul( 2 ) ).div( swapSpan ) ).mul( swapSpan );
  49. const halfHeight = swapSpan.div( 2 );
  50. const idx = uvec2(
  51. index.mod( halfHeight ),
  52. ( index.mod( halfHeight ) ).add( halfHeight )
  53. );
  54. idx.x.addAssign( blockOffset );
  55. idx.y.addAssign( blockOffset );
  56. return idx;
  57. } ).setLayout( {
  58. name: 'getBitonicDisperseIndices',
  59. type: 'uvec2',
  60. inputs: [
  61. { name: 'index', type: 'uint' },
  62. { name: 'blockHeight', type: 'uint' }
  63. ]
  64. } );
  65. // TODO: Add parameters for computing a buffer larger than vec4
  66. export class BitonicSort {
  67. /**
  68. * Constructs a new light probe helper.
  69. *
  70. * @param {Renderer} renderer - The current scene's renderer.
  71. * @param {StorageBufferNode} [size=1] - The size of the helper.
  72. * @param {Object} [options={}] - The size of the helper.
  73. */
  74. constructor( renderer, dataBuffer, options = {} ) {
  75. /**
  76. * A reference to the renderer.
  77. *
  78. * @type {Renderer}
  79. */
  80. this.renderer = renderer;
  81. /**
  82. * A reference to the StorageBufferNode holding the data that will be sorted .
  83. *
  84. * @type {StorageBufferNode}
  85. */
  86. this.dataBuffer = dataBuffer;
  87. /**
  88. * The size of the data.
  89. *
  90. * @type {StorageBufferNode}
  91. */
  92. this.count = dataBuffer.value.count;
  93. /**
  94. *
  95. * The size of each compute dispatch.
  96. * @type {number}
  97. */
  98. this.dispatchSize = this.count / 2;
  99. /**
  100. * The workgroup size of the compute shaders executed during the sort.
  101. *
  102. * @type {StorageBufferNode}
  103. */
  104. this.workgroupSize = options.workgroupSize ? Math.min( this.dispatchSize, options.workgroupSize ) : Math.min( this.dispatchSize, 64 );
  105. /**
  106. * A node representing a workgroup scoped buffer that holds locally sorted elements.
  107. *
  108. * @type {WorkgroupInfoNode}
  109. */
  110. this.localStorage = workgroupArray( dataBuffer.nodeType, this.workgroupSize * 2 );
  111. this._tempArray = new Uint32Array( this.count );
  112. for ( let i = 0; i < this.count; i ++ ) {
  113. this._tempArray[ i ] = 0;
  114. }
  115. /**
  116. * A node representing a storage buffer used for transfering the result of the global sort back to the original data buffer.
  117. *
  118. * @type {StorageBufferNode}
  119. */
  120. this.tempBuffer = instancedArray( this.count, dataBuffer.nodeType ).setName( 'TempStorage' );
  121. /**
  122. * A node containing the current algorithm type, the current swap span, and the highest swap span.
  123. *
  124. * @type {StorageBufferNode}
  125. */
  126. this.infoStorage = instancedArray( new Uint32Array( [ 1, 2, 2 ] ), 'uint' ).setName( 'BitonicSortInfo' );
  127. /**
  128. * The number of distinct swap operations ('flips' and 'disperses') executed in an in-place
  129. * bitonic sort of the current data buffer.
  130. *
  131. * @type {number}
  132. */
  133. this.swapOpCount = this._getSwapOpCount();
  134. /**
  135. * The number of steps (i.e prepping and/or executing a swap) needed to fully execute an in-place bitonic sort of the current data buffer.
  136. *
  137. * @type {number}
  138. */
  139. this.stepCount = this._getStepCount();
  140. /**
  141. * A compute shader that executes a 'flip' swap within a global address space on elements in the data buffer.
  142. *
  143. * @type {ComputeNode}
  144. */
  145. this.flipGlobalFn = this._getFlipGlobal();
  146. /**
  147. * A compute shader that executes a 'disperse' swap within a global address space on elements in the data buffer.
  148. *
  149. * @type {ComputeNode}
  150. */
  151. this.disperseGlobalFn = this._getDisperseGlobal();
  152. /**
  153. * A compute shader that executes a sequence of flip and disperse swaps within a local address space on elements in the data buffer.
  154. *
  155. * @type {ComputeNode}
  156. */
  157. this.swapLocalFn = this._getSwapLocal();
  158. /**
  159. * A compute shader that executes a sequence of disperse swaps within a local address space on elements in the data buffer.
  160. *
  161. * @type {ComputeNode}
  162. */
  163. this.disperseLocalFn = this._getDisperseLocal();
  164. // Utility functions
  165. /**
  166. * A compute shader that sets up the algorithm and the swap span for the next swap operation.
  167. *
  168. * @type {ComputeNode}
  169. */
  170. this.setAlgoFn = this._getSetAlgoFn();
  171. /**
  172. * A compute shader that aligns the result of the global swap operation with the current buffer.
  173. *
  174. * @type {ComputeNode}
  175. */
  176. this.alignFn = this._getAlignFn();
  177. /**
  178. * A compute shader that resets the algorithm and swap span information.
  179. *
  180. * @type {ComputeNode}
  181. */
  182. this.resetFn = this._getResetFn();
  183. /**
  184. * The current compute shader dispatch within the list of dispatches needed to complete the sort.
  185. *
  186. * @type {number}
  187. */
  188. this.currentDispatch = 0;
  189. /**
  190. * The number of global swap operations that must be executed before the sort
  191. * can swap in local address space.
  192. *
  193. * @type {number}
  194. */
  195. this.globalOpsRemaining = 0;
  196. /**
  197. * The total number of global operations needed to sort elements within the current swap span.
  198. *
  199. * @type {number}
  200. */
  201. this.globalOpsInSpan = 0;
  202. }
  203. /**
  204. * Get total number of distinct swaps that occur in a bitonic sort.
  205. *
  206. * @private
  207. */
  208. _getSwapOpCount() {
  209. const n = Math.log2( this.count );
  210. return ( n * ( n + 1 ) ) / 2;
  211. }
  212. /**
  213. * Get the number of steps it takes to execute a complete bitonic sort.
  214. *
  215. * @private
  216. */
  217. _getStepCount() {
  218. const logElements = Math.log2( this.count );
  219. const logSwapSpan = Math.log2( this.workgroupSize * 2 );
  220. const numGlobalFlips = logElements - logSwapSpan;
  221. // Start with 1 for initial sort over all local elements
  222. let numSteps = 1;
  223. let numGlobalDisperses = 0;
  224. for ( let i = 1; i <= numGlobalFlips; i ++ ) {
  225. // Increment by the global flip that starts each global block
  226. numSteps += 1;
  227. // Increment by number of global disperses following the global flip
  228. numSteps += numGlobalDisperses;
  229. // Increment by local disperse that occurs after all global swaps are finished
  230. numSteps += 1;
  231. // Number of global disperse increases as swapSpan increases by factor of 2
  232. numGlobalDisperses += 1;
  233. }
  234. return numSteps;
  235. }
  236. /**
  237. * Compares and swaps two data points in the data buffer within the global address space.
  238. *
  239. * @private
  240. */
  241. _globalCompareAndSwapTSL( idxBefore, idxAfter, dataBuffer, tempBuffer ) {
  242. const data1 = dataBuffer.element( idxBefore );
  243. const data2 = dataBuffer.element( idxAfter );
  244. tempBuffer.element( idxBefore ).assign( min( data1, data2 ) );
  245. tempBuffer.element( idxAfter ).assign( max( data1, data2 ) );
  246. }
  247. /**
  248. * Compares and swaps two data points in the data buffer within the local address space.
  249. *
  250. * @private
  251. */
  252. _localCompareAndSwapTSL( idxBefore, idxAfter ) {
  253. const { localStorage } = this;
  254. const data1 = localStorage.element( idxBefore ).toVar();
  255. const data2 = localStorage.element( idxAfter ).toVar();
  256. localStorage.element( idxBefore ).assign( min( data1, data2 ) );
  257. localStorage.element( idxAfter ).assign( max( data1, data2 ) );
  258. }
  259. /**
  260. * Create the compute shader that performs a global disperse swap on the data buffer.
  261. *
  262. * @private
  263. */
  264. _getDisperseGlobal() {
  265. const { infoStorage, tempBuffer, dataBuffer } = this;
  266. const currentSwapSpan = infoStorage.element( 1 );
  267. const fnDef = Fn( () => {
  268. const idx = getBitonicDisperseIndices( instanceIndex, currentSwapSpan );
  269. this._globalCompareAndSwapTSL( idx.x, idx.y, dataBuffer, tempBuffer );
  270. } )().compute( this.dispatchSize, [ this.workgroupSize ] );
  271. return fnDef;
  272. }
  273. /**
  274. * Create the compute shader that performs a global flip swap on the data buffer.
  275. *
  276. * @private
  277. */
  278. _getFlipGlobal() {
  279. const { infoStorage, tempBuffer, dataBuffer } = this;
  280. const currentSwapSpan = infoStorage.element( 1 );
  281. const fnDef = Fn( () => {
  282. const idx = getBitonicFlipIndices( instanceIndex, currentSwapSpan );
  283. this._globalCompareAndSwapTSL( idx.x, idx.y, dataBuffer, tempBuffer );
  284. } )().compute( this.dispatchSize, [ this.workgroupSize ] );
  285. return fnDef;
  286. }
  287. /**
  288. * Create the compute shader that performs a complete local swap on the data buffer.
  289. *
  290. * @private
  291. */
  292. _getSwapLocal() {
  293. const { localStorage, dataBuffer, workgroupSize } = this;
  294. const fnDef = Fn( () => {
  295. // Get ids of indices needed to populate workgroup local buffer.
  296. // Use .toVar() to prevent these values from being recalculated multiple times.
  297. const localOffset = uint( workgroupSize ).mul( 2 ).mul( workgroupId.x ).toVar();
  298. const localID1 = invocationLocalIndex.mul( 2 );
  299. const localID2 = invocationLocalIndex.mul( 2 ).add( 1 );
  300. localStorage.element( localID1 ).assign( dataBuffer.element( localOffset.add( localID1 ) ) );
  301. localStorage.element( localID2 ).assign( dataBuffer.element( localOffset.add( localID2 ) ) );
  302. // Ensure that all local data has been populated
  303. workgroupBarrier();
  304. // Perform a chunk of the sort in a single pass that operates entirely in workgroup local space
  305. // SWAP_LOCAL will always be first pass, so we start with known block height of 2
  306. const flipBlockHeight = uint( 2 );
  307. Loop( { start: uint( 2 ), end: uint( workgroupSize * 2 ), type: 'uint', condition: '<=', update: '<<= 1' }, () => {
  308. // Ensure that last dispatch block executed
  309. workgroupBarrier();
  310. const flipIdx = getBitonicFlipIndices( invocationLocalIndex, flipBlockHeight );
  311. this._localCompareAndSwapTSL( flipIdx.x, flipIdx.y );
  312. const localBlockHeight = flipBlockHeight.div( 2 );
  313. Loop( { start: localBlockHeight, end: uint( 1 ), type: 'uint', condition: '>', update: '>>= 1' }, () => {
  314. // Ensure that last dispatch op executed
  315. workgroupBarrier();
  316. const disperseIdx = getBitonicDisperseIndices( invocationLocalIndex, localBlockHeight );
  317. this._localCompareAndSwapTSL( disperseIdx.x, disperseIdx.y );
  318. localBlockHeight.divAssign( 2 );
  319. } );
  320. // flipBlockHeight *= 2;
  321. flipBlockHeight.shiftLeftAssign( 1 );
  322. } );
  323. // Ensure that all invocations have swapped their own regions of data
  324. workgroupBarrier();
  325. dataBuffer.element( localOffset.add( localID1 ) ).assign( localStorage.element( localID1 ) );
  326. dataBuffer.element( localOffset.add( localID2 ) ).assign( localStorage.element( localID2 ) );
  327. } )().compute( this.dispatchSize, [ this.workgroupSize ] );
  328. return fnDef;
  329. }
  330. /**
  331. * Create the compute shader that performs a local disperse swap on the data buffer.
  332. *
  333. * @private
  334. */
  335. _getDisperseLocal() {
  336. const { localStorage, dataBuffer, workgroupSize } = this;
  337. const fnDef = Fn( () => {
  338. // Get ids of indices needed to populate workgroup local buffer.
  339. // Use .toVar() to prevent these values from being recalculated multiple times.
  340. const localOffset = uint( workgroupSize ).mul( 2 ).mul( workgroupId.x ).toVar();
  341. const localID1 = invocationLocalIndex.mul( 2 );
  342. const localID2 = invocationLocalIndex.mul( 2 ).add( 1 );
  343. localStorage.element( localID1 ).assign( dataBuffer.element( localOffset.add( localID1 ) ) );
  344. localStorage.element( localID2 ).assign( dataBuffer.element( localOffset.add( localID2 ) ) );
  345. // Ensure that all local data has been populated
  346. workgroupBarrier();
  347. const localBlockHeight = uint( workgroupSize * 2 );
  348. Loop( { start: localBlockHeight, end: uint( 1 ), type: 'uint', condition: '>', update: '>>= 1' }, () => {
  349. // Ensure that last dispatch op executed
  350. workgroupBarrier();
  351. const disperseIdx = getBitonicDisperseIndices( invocationLocalIndex, localBlockHeight );
  352. this._localCompareAndSwapTSL( disperseIdx.x, disperseIdx.y );
  353. localBlockHeight.divAssign( 2 );
  354. } );
  355. // Ensure that all invocations have swapped their own regions of data
  356. workgroupBarrier();
  357. dataBuffer.element( localOffset.add( localID1 ) ).assign( localStorage.element( localID1 ) );
  358. dataBuffer.element( localOffset.add( localID2 ) ).assign( localStorage.element( localID2 ) );
  359. } )().compute( this.dispatchSize, [ this.workgroupSize ] );
  360. return fnDef;
  361. }
  362. /**
  363. * Create the compute shader that resets the sort's algorithm information.
  364. *
  365. * @private
  366. */
  367. _getResetFn() {
  368. const fnDef = Fn( () => {
  369. const { infoStorage } = this;
  370. const currentAlgo = infoStorage.element( 0 );
  371. const currentSwapSpan = infoStorage.element( 1 );
  372. const maxSwapSpan = infoStorage.element( 2 );
  373. currentAlgo.assign( StepType.SWAP_LOCAL );
  374. currentSwapSpan.assign( 2 );
  375. maxSwapSpan.assign( 2 );
  376. } )().compute( 1 );
  377. return fnDef;
  378. }
  379. /**
  380. * Create the compute shader that copies the state of the global swap to the data buffer.
  381. *
  382. * @private
  383. */
  384. _getAlignFn() {
  385. const { dataBuffer, tempBuffer } = this;
  386. // TODO: Only do this in certain instances by ping-ponging which buffer gets sorted
  387. // And only aligning if numDispatches % 2 === 1
  388. const fnDef = Fn( () => {
  389. dataBuffer.element( instanceIndex ).assign( tempBuffer.element( instanceIndex ) );
  390. } )().compute( this.count, [ this.workgroupSize ] );
  391. return fnDef;
  392. }
  393. /**
  394. * Create the compute shader that sets the algorithm's information.
  395. *
  396. * @private
  397. */
  398. _getSetAlgoFn() {
  399. const fnDef = Fn( () => {
  400. const { infoStorage, workgroupSize } = this;
  401. const currentAlgo = infoStorage.element( 0 );
  402. const currentSwapSpan = infoStorage.element( 1 );
  403. const maxSwapSpan = infoStorage.element( 2 );
  404. If( currentAlgo.equal( StepType.SWAP_LOCAL ), () => {
  405. const nextHighestSwapSpan = uint( workgroupSize * 4 );
  406. currentAlgo.assign( StepType.FLIP_GLOBAL );
  407. currentSwapSpan.assign( nextHighestSwapSpan );
  408. maxSwapSpan.assign( nextHighestSwapSpan );
  409. } ).ElseIf( currentAlgo.equal( StepType.DISPERSE_LOCAL ), () => {
  410. currentAlgo.assign( StepType.FLIP_GLOBAL );
  411. const nextHighestSwapSpan = maxSwapSpan.mul( 2 );
  412. currentSwapSpan.assign( nextHighestSwapSpan );
  413. maxSwapSpan.assign( nextHighestSwapSpan );
  414. } ).Else( () => {
  415. const nextSwapSpan = currentSwapSpan.div( 2 );
  416. currentAlgo.assign(
  417. select(
  418. nextSwapSpan.lessThanEqual( uint( workgroupSize * 2 ) ),
  419. StepType.DISPERSE_LOCAL,
  420. StepType.DISPERSE_GLOBAL
  421. ).uniformFlow()
  422. );
  423. currentSwapSpan.assign( nextSwapSpan );
  424. } );
  425. } )().compute( 1 );
  426. return fnDef;
  427. }
  428. /**
  429. * Executes a step of the bitonic sort operation.
  430. *
  431. * @param {Renderer} renderer - The current scene's renderer.
  432. */
  433. async computeStep( renderer ) {
  434. // Swap local only runs once
  435. if ( this.currentDispatch === 0 ) {
  436. await renderer.computeAsync( this.swapLocalFn );
  437. this.globalOpsRemaining = 1;
  438. this.globalOpsInSpan = 1;
  439. } else if ( this.globalOpsRemaining > 0 ) {
  440. const swapType = this.globalOpsRemaining === this.globalOpsInSpan ? 'Flip' : 'Disperse';
  441. await renderer.computeAsync( swapType === 'Flip' ? this.flipGlobalFn : this.disperseGlobalFn );
  442. await renderer.computeAsync( this.alignFn );
  443. this.globalOpsRemaining -= 1;
  444. } else {
  445. // Then run local disperses when we've finished all global swaps
  446. await renderer.computeAsync( this.disperseLocalFn );
  447. const nextSpanGlobalOps = this.globalOpsInSpan + 1;
  448. this.globalOpsInSpan = nextSpanGlobalOps;
  449. this.globalOpsRemaining = nextSpanGlobalOps;
  450. }
  451. this.currentDispatch += 1;
  452. if ( this.currentDispatch === this.stepCount ) {
  453. // Just reset the algorithm information
  454. await renderer.computeAsync( this.resetFn );
  455. this.currentDispatch = 0;
  456. this.globalOpsRemaining = 0;
  457. this.globalOpsInSpan = 0;
  458. } else {
  459. // Otherwise, determine what next swap span is
  460. await renderer.computeAsync( this.setAlgoFn );
  461. }
  462. }
  463. /**
  464. * Executes a complete bitonic sort on the data buffer.
  465. *
  466. * @param {Renderer} renderer - The current scene's renderer.
  467. */
  468. async compute( renderer ) {
  469. this.globalOpsRemaining = 0;
  470. this.globalOpsInSpan = 0;
  471. this.currentDispatch = 0;
  472. for ( let i = 0; i < this.stepCount; i ++ ) {
  473. await this.computeStep( renderer );
  474. }
  475. }
  476. }
粤ICP备19079148号