webgpu_compute_sort_bitonic.html 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. <html lang="en">
  2. <head>
  3. <title>three.js webgpu - storage pbo external element</title>
  4. <meta charset="utf-8">
  5. <meta name="viewport" content="width=device-width, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0">
  6. <link type="text/css" rel="stylesheet" href="main.css">
  7. </head>
  8. <body>
  9. <div id="info">
  10. <a href="https://threejs.org" target="_blank" rel="noopener">three.js</a>
  11. <br /> This example demonstrates a bitonic sort running step by step in a compute shader.
  12. <br /> The left canvas swaps values within workgroup local arrays. The right swaps values within storage buffers.
  13. <br /> Reference implementation by <a href="https://poniesandlight.co.uk/reflect/bitonic_merge_sort/">Tim Gfrerer</a>
  14. <br />
  15. <div id="local_swap" style="
  16. position: absolute;
  17. top: 150px;
  18. left: 0;
  19. padding: 10px;
  20. background: rgba( 0, 0, 0, 0.5 );
  21. color: #fff;
  22. font-family: monospace;
  23. font-size: 12px;
  24. line-height: 1.5;
  25. pointer-events: none;
  26. text-align: left;
  27. "></div>
  28. <div id="global_swap" style="
  29. position: absolute;
  30. top: 150px;
  31. right: 0;
  32. padding: 10px;
  33. background: rgba( 0, 0, 0, 0.5 );
  34. color: #fff;
  35. font-family: monospace;
  36. font-size: 12px;
  37. line-height: 1.5;
  38. pointer-events: none;
  39. text-align: left;
  40. "></div>
  41. </div>
  42. <script type="importmap">
  43. {
  44. "imports": {
  45. "three": "../build/three.webgpu.js",
  46. "three/webgpu": "../build/three.webgpu.js",
  47. "three/tsl": "../build/three.tsl.js",
  48. "three/addons/": "./jsm/"
  49. }
  50. }
  51. </script>
  52. <script type="module">
  53. import * as THREE from 'three';
  54. import { storage, If, vec3, not, uniform, uv, uint, float, Fn, vec2, abs, int, invocationLocalIndex, workgroupArray, uvec2, floor, instanceIndex, workgroupBarrier, atomicAdd, atomicStore, workgroupId } from 'three/tsl';
  55. import WebGPU from 'three/addons/capabilities/WebGPU.js';
  56. import { GUI } from 'three/addons/libs/lil-gui.module.min.js';
  57. const StepType = {
  58. NONE: 0,
  59. // Swap values within workgroup local buffer.
  60. FLIP_LOCAL: 1,
  61. DISPERSE_LOCAL: 2,
  62. // Swap values within global data buffer.
  63. FLIP_GLOBAL: 3,
  64. DISPERSE_GLOBAL: 4,
  65. };
  66. const timestamps = {
  67. local_swap: document.getElementById( 'local_swap' ),
  68. global_swap: document.getElementById( 'global_swap' )
  69. };
  70. const localColors = [ 'rgb(203, 64, 203)', 'rgb(0, 215, 215)' ];
  71. const globalColors = [ 'rgb(1, 150, 1)', 'red' ];
  72. // Total number of elements and the dimensions of the display grid.
  73. const size = 16384;
  74. const gridDim = Math.sqrt( size );
  75. const getNumSteps = () => {
  76. const n = Math.log2( size );
  77. return ( n * ( n + 1 ) ) / 2;
  78. };
  79. // Total number of steps in a bitonic sort with 'size' elements.
  80. const MAX_STEPS = getNumSteps();
  81. const WORKGROUP_SIZE = [ 64 ];
  82. const effectController = {
  83. // Sqr root of 16834
  84. gridWidth: uniform( gridDim ),
  85. gridHeight: uniform( gridDim ),
  86. highlight: uniform( 1 ),
  87. 'Display Mode': 'Swap Zone Highlight'
  88. };
  89. const gui = new GUI();
  90. gui.add( effectController, 'Display Mode', [ 'Elements', 'Swap Zone Highlight' ] ).onChange( () => {
  91. if ( effectController[ 'Display Mode' ] === 'Elements' ) {
  92. effectController.highlight.value = 0;
  93. } else {
  94. effectController.highlight.value = 1;
  95. }
  96. } );
  97. if ( WebGPU.isAvailable() === false ) {
  98. document.body.appendChild( WebGPU.getErrorMessage() );
  99. throw new Error( 'No WebGPU support' );
  100. }
  101. // Allow Workgroup Array Swaps
  102. init();
  103. // Global Swaps Only
  104. init( true );
  105. // When forceGlobalSwap is true, force all valid local swaps to be global swaps.
  106. async function init( forceGlobalSwap = false ) {
  107. let currentStep = 0;
  108. let nextStepGlobal = false;
  109. const aspect = ( window.innerWidth / 2 ) / window.innerHeight;
  110. const camera = new THREE.OrthographicCamera( - aspect, aspect, 1, - 1, 0, 2 );
  111. camera.position.z = 1;
  112. const scene = new THREE.Scene();
  113. const nextAlgoBuffer = new THREE.StorageInstancedBufferAttribute( new Uint32Array( 1 ).fill( forceGlobalSwap ? StepType.FLIP_GLOBAL : StepType.FLIP_LOCAL ), 1 );
  114. const nextAlgoStorage = storage( nextAlgoBuffer, 'uint', nextAlgoBuffer.count ).setPBO( true ).label( 'NextAlgo' );
  115. const nextBlockHeightBuffer = new THREE.StorageInstancedBufferAttribute( new Uint32Array( 1 ).fill( 2 ), 1 );
  116. const nextBlockHeightStorage = storage( nextBlockHeightBuffer, 'uint', nextBlockHeightBuffer.count ).setPBO( true ).label( 'NextBlockHeight' );
  117. const nextBlockHeightRead = storage( nextBlockHeightBuffer, 'uint', nextBlockHeightBuffer.count ).setPBO( true ).label( 'NextBlockHeight' ).toReadOnly();
  118. const highestBlockHeightBuffer = new THREE.StorageInstancedBufferAttribute( new Uint32Array( 1 ).fill( 2 ), 1 );
  119. const highestBlockHeightStorage = storage( highestBlockHeightBuffer, 'uint', highestBlockHeightBuffer.count ).setPBO( true ).label( 'HighestBlockHeight' );
  120. const counterBuffer = new THREE.StorageBufferAttribute( 1, 1 );
  121. const counterStorage = storage( counterBuffer, 'uint', counterBuffer.count ).setPBO( true ).toAtomic().label( 'Counter' );
  122. const array = new Uint32Array( Array.from( { length: size }, ( _, i ) => {
  123. return i;
  124. } ) );
  125. const randomizeDataArray = () => {
  126. let currentIndex = array.length;
  127. while ( currentIndex !== 0 ) {
  128. const randomIndex = Math.floor( Math.random() * currentIndex );
  129. currentIndex -= 1;
  130. [ array[ currentIndex ], array[ randomIndex ] ] = [
  131. array[ randomIndex ],
  132. array[ currentIndex ],
  133. ];
  134. }
  135. };
  136. randomizeDataArray();
  137. const currentElementsBuffer = new THREE.StorageInstancedBufferAttribute( array, 1 );
  138. const currentElementsStorage = storage( currentElementsBuffer, 'uint', size ).setPBO( true ).label( 'Elements' );
  139. const tempBuffer = new THREE.StorageInstancedBufferAttribute( array, 1 );
  140. const tempStorage = storage( tempBuffer, 'uint', size ).setPBO( true ).label( 'Temp' );
  141. const randomizedElementsBuffer = new THREE.StorageInstancedBufferAttribute( size, 1 );
  142. const randomizedElementsStorage = storage( randomizedElementsBuffer, 'uint', size ).setPBO( true ).label( 'RandomizedElements' );
  143. const getFlipIndices = ( index, blockHeight ) => {
  144. const blockOffset = ( index.mul( 2 ).div( blockHeight ) ).mul( blockHeight );
  145. const halfHeight = blockHeight.div( 2 );
  146. const idx = uvec2(
  147. index.mod( halfHeight ),
  148. blockHeight.sub( index.mod( halfHeight ) ).sub( 1 )
  149. );
  150. idx.x.addAssign( blockOffset );
  151. idx.y.addAssign( blockOffset );
  152. return idx;
  153. };
  154. const getDisperseIndices = ( index, blockHeight ) => {
  155. const blockOffset = ( ( index.mul( 2 ) ).div( blockHeight ) ).mul( blockHeight );
  156. const halfHeight = blockHeight.div( 2 );
  157. const idx = uvec2(
  158. index.mod( halfHeight ),
  159. ( index.mod( halfHeight ) ).add( halfHeight )
  160. );
  161. idx.x.addAssign( blockOffset );
  162. idx.y.addAssign( blockOffset );
  163. return idx;
  164. };
  165. const localStorage = workgroupArray( 'uint', 64 * 2 );
  166. // Swap the elements in local storage
  167. const localCompareAndSwap = ( idxBefore, idxAfter ) => {
  168. If( localStorage.element( idxAfter ).lessThan( localStorage.element( idxBefore ) ), () => {
  169. atomicAdd( counterStorage.element( 0 ), 1 );
  170. const temp = localStorage.element( idxBefore ).toVar();
  171. localStorage.element( idxBefore ).assign( localStorage.element( idxAfter ) );
  172. localStorage.element( idxAfter ).assign( temp );
  173. } );
  174. };
  175. const globalCompareAndSwap = ( idxBefore, idxAfter ) => {
  176. // If the later element is less than the current element
  177. If( currentElementsStorage.element( idxAfter ).lessThan( currentElementsStorage.element( idxBefore ) ), () => {
  178. // Apply the swapped values to temporary storage.
  179. atomicAdd( counterStorage.element( 0 ), 1 );
  180. tempStorage.element( idxBefore ).assign( currentElementsStorage.element( idxAfter ) );
  181. tempStorage.element( idxAfter ).assign( currentElementsStorage.element( idxBefore ) );
  182. } ).Else( () => {
  183. // Otherwise apply the existing values to temporary storage.
  184. tempStorage.element( idxBefore ).assign( currentElementsStorage.element( idxBefore ) );
  185. tempStorage.element( idxAfter ).assign( currentElementsStorage.element( idxAfter ) );
  186. } );
  187. };
  188. const computeInitFn = Fn( () => {
  189. randomizedElementsStorage.element( instanceIndex ).assign( currentElementsStorage.element( instanceIndex ) );
  190. } );
  191. const computeBitonicStepFn = Fn( () => {
  192. const nextBlockHeight = nextBlockHeightStorage.element( 0 ).toVar();
  193. const nextAlgo = nextAlgoStorage.element( 0 ).toVar();
  194. // Get ids of indices needed to populate workgroup local buffer.
  195. // Use .toVar() to prevent these values from being recalculated multiple times.
  196. const localOffset = uint( WORKGROUP_SIZE[ 0 ] ).mul( 2 ).mul( workgroupId.x ).toVar();
  197. const localID1 = invocationLocalIndex.mul( 2 );
  198. const localID2 = invocationLocalIndex.mul( 2 ).add( 1 );
  199. // If we will perform a local swap, then populate the local data
  200. If( nextAlgo.lessThanEqual( uint( StepType.DISPERSE_LOCAL ) ), () => {
  201. localStorage.element( localID1 ).assign( currentElementsStorage.element( localOffset.add( localID1 ) ) );
  202. localStorage.element( localID2 ).assign( currentElementsStorage.element( localOffset.add( localID2 ) ) );
  203. } );
  204. workgroupBarrier();
  205. // TODO: Convert to switch block.
  206. If( nextAlgo.equal( uint( StepType.FLIP_LOCAL ) ), () => {
  207. const idx = getFlipIndices( invocationLocalIndex, nextBlockHeight );
  208. localCompareAndSwap( idx.x, idx.y );
  209. } ).ElseIf( nextAlgo.equal( uint( StepType.DISPERSE_LOCAL ) ), () => {
  210. const idx = getDisperseIndices( invocationLocalIndex, nextBlockHeight );
  211. localCompareAndSwap( idx.x, idx.y );
  212. } ).ElseIf( nextAlgo.equal( uint( StepType.FLIP_GLOBAL ) ), () => {
  213. const idx = getFlipIndices( instanceIndex, nextBlockHeight );
  214. globalCompareAndSwap( idx.x, idx.y );
  215. } ).ElseIf( nextAlgo.equal( uint( StepType.DISPERSE_GLOBAL ) ), () => {
  216. const idx = getDisperseIndices( instanceIndex, nextBlockHeight );
  217. globalCompareAndSwap( idx.x, idx.y );
  218. } );
  219. // Ensure that all invocations have swapped their own regions of data
  220. workgroupBarrier();
  221. // Populate output data with the results from our swaps
  222. If( nextAlgo.lessThanEqual( uint( StepType.DISPERSE_LOCAL ) ), () => {
  223. currentElementsStorage.element( localOffset.add( localID1 ) ).assign( localStorage.element( localID1 ) );
  224. currentElementsStorage.element( localOffset.add( localID2 ) ).assign( localStorage.element( localID2 ) );
  225. } );
  226. // If the previous algorithm was global, we execute an additional compute step to sync the current buffer with the output buffer.
  227. } );
  228. const computeSetAlgoFn = Fn( () => {
  229. const nextBlockHeight = nextBlockHeightStorage.element( 0 ).toVar();
  230. const nextAlgo = nextAlgoStorage.element( 0 );
  231. const highestBlockHeight = highestBlockHeightStorage.element( 0 ).toVar();
  232. nextBlockHeight.divAssign( 2 );
  233. If( nextBlockHeight.equal( 1 ), () => {
  234. highestBlockHeight.mulAssign( 2 );
  235. if ( forceGlobalSwap ) {
  236. If( highestBlockHeight.equal( size * 2 ), () => {
  237. nextAlgo.assign( StepType.NONE );
  238. nextBlockHeight.assign( 0 );
  239. } ).Else( () => {
  240. nextAlgo.assign( StepType.FLIP_GLOBAL );
  241. nextBlockHeight.assign( highestBlockHeight );
  242. } );
  243. } else {
  244. If( highestBlockHeight.equal( size * 2 ), () => {
  245. nextAlgo.assign( StepType.NONE );
  246. nextBlockHeight.assign( 0 );
  247. } ).ElseIf( highestBlockHeight.greaterThan( WORKGROUP_SIZE[ 0 ] * 2 ), () => {
  248. nextAlgo.assign( StepType.FLIP_GLOBAL );
  249. nextBlockHeight.assign( highestBlockHeight );
  250. } ).Else( () => {
  251. nextAlgo.assign( forceGlobalSwap ? StepType.FLIP_GLOBAL : StepType.FLIP_LOCAL );
  252. nextBlockHeight.assign( highestBlockHeight );
  253. } );
  254. }
  255. } ).Else( () => {
  256. if ( forceGlobalSwap ) {
  257. nextAlgo.assign( StepType.DISPERSE_GLOBAL );
  258. } else {
  259. nextAlgo.assign( nextBlockHeight.greaterThan( WORKGROUP_SIZE[ 0 ] * 2 ).select( StepType.DISPERSE_GLOBAL, StepType.DISPERSE_LOCAL ) );
  260. }
  261. } );
  262. nextBlockHeightStorage.element( 0 ).assign( nextBlockHeight );
  263. highestBlockHeightStorage.element( 0 ).assign( highestBlockHeight );
  264. } );
  265. const computeAlignCurrentFn = Fn( () => {
  266. currentElementsStorage.element( instanceIndex ).assign( tempStorage.element( instanceIndex ) );
  267. } );
  268. const computeResetBuffersFn = Fn( () => {
  269. currentElementsStorage.element( instanceIndex ).assign( randomizedElementsStorage.element( instanceIndex ) );
  270. } );
  271. const computeResetAlgoFn = Fn( () => {
  272. nextAlgoStorage.element( 0 ).assign( forceGlobalSwap ? StepType.FLIP_GLOBAL : StepType.FLIP_LOCAL );
  273. nextBlockHeightStorage.element( 0 ).assign( 2 );
  274. highestBlockHeightStorage.element( 0 ).assign( 2 );
  275. atomicStore( counterStorage.element( 0 ), 0 );
  276. } );
  277. // Initialize each value in the elements buffer.
  278. const computeInit = computeInitFn().compute( size );
  279. // Swap a pair of elements in the elements buffer.
  280. const computeBitonicStep = computeBitonicStepFn().compute( size / 2 );
  281. // Set the conditions for the next swap.
  282. const computeSetAlgo = computeSetAlgoFn().compute( 1 );
  283. // Align the current buffer with the temp buffer if the previous sort was executed in a global scope.
  284. const computeAlignCurrent = computeAlignCurrentFn().compute( size );
  285. // Reset the buffers and algorithm information after a full bitonic sort has been completed.
  286. const computeResetBuffers = computeResetBuffersFn().compute( size );
  287. const computeResetAlgo = computeResetAlgoFn().compute( 1 );
  288. const material = new THREE.MeshBasicNodeMaterial( { color: 0x00ff00 } );
  289. const display = Fn( () => {
  290. const { gridWidth, gridHeight, highlight } = effectController;
  291. const newUV = uv().mul( vec2( gridWidth, gridHeight ) );
  292. const pixel = uvec2( uint( floor( newUV.x ) ), uint( floor( newUV.y ) ) );
  293. const elementIndex = uint( gridWidth ).mul( pixel.y ).add( pixel.x );
  294. const colorChanger = currentElementsStorage.element( elementIndex );
  295. const subtracter = float( colorChanger ).div( gridWidth.mul( gridHeight ) );
  296. const color = vec3( subtracter.oneMinus() ).toVar();
  297. If( highlight.equal( 1 ).and( not( nextAlgoStorage.element( 0 ).equal( StepType.NONE ) ) ), () => {
  298. const boolCheck = int( elementIndex.mod( nextBlockHeightRead.element( 0 ) ).lessThan( nextBlockHeightRead.element( 0 ).div( 2 ) ) );
  299. color.z.assign( nextAlgoStorage.element( 0 ).lessThanEqual( StepType.DISPERSE_LOCAL ) );
  300. color.x.mulAssign( boolCheck );
  301. color.y.mulAssign( abs( boolCheck.sub( 1 ) ) );
  302. } );
  303. return color;
  304. } );
  305. material.colorNode = display();
  306. const plane = new THREE.Mesh( new THREE.PlaneGeometry( 1, 1 ), material );
  307. scene.add( plane );
  308. const renderer = new THREE.WebGPURenderer( { antialias: false, trackTimestamp: true } );
  309. renderer.setPixelRatio( window.devicePixelRatio );
  310. renderer.setSize( window.innerWidth / 2, window.innerHeight );
  311. const animate = () => {
  312. renderer.render( scene, camera );
  313. };
  314. renderer.setAnimationLoop( animate );
  315. document.body.appendChild( renderer.domElement );
  316. renderer.domElement.style.position = 'absolute';
  317. renderer.domElement.style.top = '0';
  318. renderer.domElement.style.left = '0';
  319. renderer.domElement.style.width = '50%';
  320. renderer.domElement.style.height = '100%';
  321. if ( forceGlobalSwap ) {
  322. renderer.domElement.style.left = '50%';
  323. scene.background = new THREE.Color( 0x212121 );
  324. } else {
  325. scene.background = new THREE.Color( 0x313131 );
  326. }
  327. await renderer.computeAsync( computeInit );
  328. renderer.info.autoReset = false;
  329. const stepAnimation = async function () {
  330. renderer.info.reset();
  331. if ( currentStep !== MAX_STEPS ) {
  332. renderer.compute( computeBitonicStep );
  333. if ( nextStepGlobal ) {
  334. renderer.compute( computeAlignCurrent );
  335. }
  336. renderer.compute( computeSetAlgo );
  337. currentStep ++;
  338. } else {
  339. renderer.compute( computeResetBuffers );
  340. renderer.compute( computeResetAlgo );
  341. currentStep = 0;
  342. }
  343. renderer.resolveTimestampsAsync( THREE.TimestampQuery.COMPUTE );
  344. const algo = new Uint32Array( await renderer.getArrayBufferAsync( nextAlgoBuffer ) );
  345. algo > StepType.DISPERSE_LOCAL ? ( nextStepGlobal = true ) : ( nextStepGlobal = false );
  346. const totalSwaps = new Uint32Array( await renderer.getArrayBufferAsync( counterBuffer ) );
  347. renderer.render( scene, camera );
  348. renderer.resolveTimestampsAsync( THREE.TimestampQuery.RENDER );
  349. timestamps[ forceGlobalSwap ? 'global_swap' : 'local_swap' ].innerHTML = `
  350. Compute ${forceGlobalSwap ? 'Global' : 'Local'}: ${renderer.info.compute.frameCalls} pass in ${renderer.info.compute.timestamp.toFixed( 6 )}ms<br>
  351. Total Swaps: ${totalSwaps}<br>
  352. <div style="display: flex; flex-direction:row; justify-content: center; align-items: center;">
  353. ${forceGlobalSwap ? 'Global Swaps' : 'Local Swaps'} Compare Region&nbsp;
  354. <div style="background-color: ${ forceGlobalSwap ? globalColors[ 0 ] : localColors[ 0 ]}; width:12.5px; height: 1em; border-radius: 20%;"></div>
  355. &nbsp;to Region&nbsp;
  356. <div style="background-color: ${ forceGlobalSwap ? globalColors[ 1 ] : localColors[ 1 ]}; width:12.5px; height: 1em; border-radius: 20%;"></div>
  357. </div>`;
  358. if ( currentStep === MAX_STEPS ) {
  359. setTimeout( stepAnimation, 1000 );
  360. } else {
  361. setTimeout( stepAnimation, 50 );
  362. }
  363. };
  364. stepAnimation();
  365. window.addEventListener( 'resize', onWindowResize );
  366. function onWindowResize() {
  367. renderer.setSize( window.innerWidth / 2, window.innerHeight );
  368. const aspect = ( window.innerWidth / 2 ) / window.innerHeight;
  369. const frustumHeight = camera.top - camera.bottom;
  370. camera.left = - frustumHeight * aspect / 2;
  371. camera.right = frustumHeight * aspect / 2;
  372. camera.updateProjectionMatrix();
  373. renderer.render( scene, camera );
  374. }
  375. }
  376. </script>
  377. </body>
  378. </html>
粤ICP备19079148号