webgpu_compute_reduce.html 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397
  1. <html lang="en">
  2. <head>
  3. <title>three.js webgpu - compute reduction</title>
  4. <meta charset="utf-8">
  5. <meta name="viewport" content="width=device-width, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0">
  6. <meta property="og:title" content="three.js webgpu - compute reduction">
  7. <meta property="og:type" content="website">
  8. <meta property="og:url" content="https://threejs.org/examples/webgpu_compute_reduce.html">
  9. <meta property="og:image" content="https://threejs.org/examples/screenshots/webgpu_compute_reduce.jpg">
  10. <link type="text/css" rel="stylesheet" href="main.css">
  11. </head>
  12. <body>
  13. <style>
  14. #reduction-panel {
  15. background-color: #111;
  16. width: 100%;
  17. display: flex;
  18. position: fixed;
  19. height: auto;
  20. bottom: 0px;
  21. z-index: 99;
  22. flex-direction: column;
  23. justify-content: center;
  24. align-items: center;
  25. border-left: 2px solid #222;
  26. text-align: center;
  27. }
  28. #panel-title {
  29. width: fit-content;
  30. }
  31. .thread-row {
  32. display: flex;
  33. flex-direction: row;
  34. align-items: center;
  35. margin: 4px 0;
  36. position: relative;
  37. }
  38. .thread {
  39. width: 16px;
  40. height: 16px;
  41. background-color: #444;
  42. margin-right: 2px;
  43. transition: background-color 0.5s, transform 0.5s;
  44. }
  45. .stage-display {
  46. display: flex;
  47. flex-direction: column;
  48. justify-content: center;
  49. margin-bottom: 5px;
  50. }
  51. .stage-label {
  52. font-size: 1.2em;
  53. color: #aaa;
  54. font-style: bold;
  55. margin-top: 6px;
  56. margin-bottom: 20px;
  57. }
  58. .thread {
  59. display: flex;
  60. justify-content: center;
  61. align-items: center;
  62. width: 40px;
  63. height: 40px;
  64. margin: 2px;
  65. border: 1px solid rgba(255, 255, 255, 0.2);
  66. border-radius: 4px;
  67. background: linear-gradient(180deg, rgba(255,255,255,0.05), rgba(0,0,0,0.2));
  68. box-shadow: inset 0 0 2px rgba(255,255,255,0.1);
  69. font-family: monospace;
  70. color: white;
  71. }
  72. .thread_data {
  73. display: block;
  74. max-width: 100%;
  75. padding: 0 2px;
  76. white-space: nowrap;
  77. overflow: hidden;
  78. text-overflow: ellipsis;
  79. font-size: clamp(8px, 2vw, 14px);
  80. text-align: center;
  81. }
  82. .subgroup {
  83. display: flex;
  84. position: relative;
  85. margin-left: 10px;
  86. margin-right: 10px;
  87. }
  88. .subgroup::before {
  89. /* label text for each subgroup label */
  90. content: "subgroupAdd()";
  91. position: absolute;
  92. top: -20px;
  93. /* Hide until animation is displayed */
  94. opacity: 0;
  95. z-index: 100;
  96. transition: opacity 0.5s ease;
  97. font-weight: bold;
  98. color: white;
  99. width: 100%;
  100. }
  101. .subgroup::after {
  102. content: attr(data-label);
  103. position: absolute;
  104. bottom: -20px;
  105. opacity: 1;
  106. z-index: 100;
  107. color: gray;
  108. width: 100%;
  109. }
  110. .reduction-stage {
  111. margin-bottom: 20px;
  112. }
  113. @keyframes labelAbsorb {
  114. 0% {
  115. opacity: 0;
  116. transform: translateY(-50%);
  117. }
  118. 40% {
  119. opacity: 1;
  120. transform: translateY(0%);
  121. }
  122. 60% {
  123. opacity: 1;
  124. transform: translateY(0%);
  125. }
  126. 80% {
  127. opacity: 1;
  128. transform: translate(0%, -20%);
  129. }
  130. 100% {
  131. opacity: 0;
  132. transform: translate(0%, 100%);
  133. }
  134. }
  135. .subgroup.anim::before {
  136. opacity: 0;
  137. animation-name: labelAbsorb;
  138. animation-duration: 1.5s;
  139. transition:
  140. transform 0.6s ease-out,
  141. opacity 0.3s ease-in 0.3s;
  142. }
  143. </style>
  144. <div id="info">
  145. <a href="https://threejs.org" target="_blank" rel="noopener">three.js</a>
  146. <br /> This example demonstrates the performance of various simple parallel reduction kernels.
  147. <br /> Reference implementations are translated from the CUDA/WGSL code present in the following books/repos:
  148. <br /> Impl. 0 - 2: <a href="https://www.cambridge.org/core/books/programming-in-parallel-with-cuda/C43652A69033C25AD6933368CDBE084C"><i>Programming in Parallel with CUDA</i></a> by <a href="https://people.bss.phy.cam.ac.uk/~rea1/">Richard Ansorge</a>
  149. <br /> Impl. 3: <a href="https://github.com/frost-beta/betann/blob/main/betann/wgsl/reduce_all.wgsl"><i>betann reduce_all kernel</i></a> by <a href="https://github.com/zcbenz">zcbenz</a>
  150. <br /> Impl. 4: <a href="https://github.com/b0nes164/GPUPrefixSums/blob/main/GPUPrefixSumsWebGPUapis/SharedShaders/rts.wgsl"><i>GPUPrefixSums reduction approach</i></a> by <a href="https://github.com/b0nes164">b0nes164</a>
  151. <div id="left_side_display" style="position: absolute;top: 150px;left: 0;padding: 10px;background: rgba( 0, 0, 0, 0.5 );color: #fff;font-family: monospace;font-size: 12px;line-height: 1.5;pointer-events: none;text-align: left;"></div>
  152. <div id="right_side_display" style="position: absolute;top: 150px;right: 0;padding: 10px;background: rgba( 0, 0, 0, 0.5 );color: #fff;font-family: monospace;font-size: 12px;line-height: 1.5;pointer-events: none;text-align: left;"></div>
  153. </div>
  154. <div id="reduction-panel">
  155. <h3 id="panel-title" style="flex: 0 0 auto;">Subgroup Reduction Explanation</h3>
  156. <div class="reduction-stage" id="subgroup-reduction-stage">
  157. <div class="stage-label">Use subgroupAdd() to capture reduction of each workgroup's subgroups (Hover for animation)</div>
  158. <div class="stage-display">
  159. <div id="workgroup_threads" style="display: flex; justify-content: center; margin-bottom: 20px;"></div>
  160. <div id="subgroup_reduction" style="display: flex; justify-content: center; margin-bottom: 5px;"></div>
  161. </div>
  162. </div>
  163. </div>
  164. <script type="importmap">
  165. {
  166. "imports": {
  167. "three": "../build/three.webgpu.js",
  168. "three/webgpu": "../build/three.webgpu.js",
  169. "three/tsl": "../build/three.tsl.js",
  170. "three/addons/": "./jsm/"
  171. }
  172. }
  173. </script>
  174. <script type="module">
  175. import * as THREE from 'three/webgpu';
  176. import { instancedArray, Loop, If, vec3, dot, clamp, storage, uvec4, subgroupAdd, uniform, uv, uint, float, Fn, vec2, invocationLocalIndex, invocationSubgroupIndex, uvec2, floor, instanceIndex, workgroupId, workgroupBarrier, workgroupArray, subgroupSize, select, countTrailingZeros } from 'three/tsl';
  177. import WebGPU from 'three/addons/capabilities/WebGPU.js';
  178. import { GUI } from 'three/addons/libs/lil-gui.module.min.js';
  179. const timestamps = {
  180. left_side_display: document.getElementById( 'left_side_display' ),
  181. right_side_display: document.getElementById( 'right_side_display' )
  182. };
  183. const divRoundUp = ( size, part_size ) => {
  184. return Math.floor( ( size + part_size - 1 ) / part_size );
  185. };
  186. const cssSubgroupSize = 4;
  187. const cssWorkgroupSize = 16;
  188. const workgroupThreadsContainer = document.getElementById( 'workgroup_threads' );
  189. const subgroupReductionContainer = document.getElementById( 'subgroup_reduction' );
  190. document.getElementById( 'panel-title' ).textContent += ` (Subgroup Size: ${cssSubgroupSize}, Workgroup Size: ${cssWorkgroupSize})`;
  191. const createThreadWithData = ( data ) => {
  192. const threadEle = document.createElement( 'div' );
  193. threadEle.className = 'thread';
  194. const threadData = document.createElement( 'span' );
  195. threadData.textContent = data; // safer than innerHTML for just text
  196. threadData.className = 'thread_data';
  197. threadEle.append( threadData );
  198. return threadEle;
  199. };
  200. // Create thread elements
  201. const workgroupThreads = [];
  202. const initialSubgroups = [];
  203. const initialData = [];
  204. let currentSubgroupDiv = null;
  205. for ( let i = 0; i < cssWorkgroupSize; i ++ ) {
  206. if ( i % cssSubgroupSize === 0 ) {
  207. const currentSubgroupIndex = Math.floor( i / cssSubgroupSize );
  208. const subgroupReductionThread = createThreadWithData( 0 );
  209. subgroupReductionThread.id = `subgroup_reduction_element_${currentSubgroupIndex}`;
  210. subgroupReductionContainer.appendChild( subgroupReductionThread );
  211. currentSubgroupDiv = document.createElement( 'div' );
  212. currentSubgroupDiv.className = 'subgroup';
  213. currentSubgroupDiv.setAttribute( 'data-label', `Threads ${currentSubgroupIndex * cssSubgroupSize}-${( currentSubgroupIndex + 1 ) * cssSubgroupSize - 1}` );
  214. initialSubgroups.push( currentSubgroupDiv );
  215. workgroupThreadsContainer.appendChild( currentSubgroupDiv );
  216. }
  217. const data = Math.floor( Math.random() * 9 ) + 1;
  218. initialData.push( data );
  219. const thread = createThreadWithData( data );
  220. workgroupThreads.push( thread );
  221. currentSubgroupDiv.appendChild( thread );
  222. }
  223. const deactivateLabelAnimation = ( subgroupDiv, idx ) => {
  224. subgroupDiv.classList.remove( 'anim' );
  225. const subgroupReductionBufferElement = document.getElementById( `subgroup_reduction_element_${idx}` ).querySelector( '.thread_data' );
  226. subgroupReductionBufferElement.innerHTML = 0;
  227. };
  228. const activateLabelAnimation = ( subgroupDiv, idx ) => {
  229. const threads = Array.from( subgroupDiv.children );
  230. let total = 0;
  231. for ( let i = idx * cssSubgroupSize; i < idx * cssSubgroupSize + cssSubgroupSize; i ++ ) {
  232. total += initialData[ i ];
  233. }
  234. subgroupDiv.classList.add( 'anim' );
  235. setTimeout( () => {
  236. threads.forEach( t => {
  237. t.querySelector( '.thread_data' ).textContent = total;
  238. } );
  239. const subgroupReductionBufferElement = document.getElementById( `subgroup_reduction_element_${idx}` ).querySelector( '.thread_data' );
  240. subgroupReductionBufferElement.innerHTML = total;
  241. }, 1000 );
  242. // Remove the class after the animation ends so it can be triggered again
  243. setTimeout( () => {
  244. subgroupDiv.classList.remove( 'anim' );
  245. }, 1500 ); // matches animation duration in CSS
  246. };
  247. document.getElementById( 'subgroup-reduction-stage' ).addEventListener( 'mouseenter', () => {
  248. initialSubgroups.forEach( ( subgroupDiv, idx ) => {
  249. activateLabelAnimation( subgroupDiv, idx );
  250. } );
  251. } );
  252. document.getElementById( 'subgroup-reduction-stage' ).addEventListener( 'mouseleave', () => {
  253. initialSubgroups.forEach( ( subgroupDiv, idx ) => {
  254. deactivateLabelAnimation( subgroupDiv, idx );
  255. } );
  256. workgroupThreads.forEach( ( thread, idx ) => {
  257. thread.querySelector( '.thread_data' ).textContent = initialData[ idx ];
  258. } );
  259. } );
  260. if ( WebGPU.isAvailable() === false ) {
  261. document.body.appendChild( WebGPU.getErrorMessage() );
  262. throw new Error( 'No WebGPU support' );
  263. }
  264. // Total number of elements and the dimensions of the display grid.
  265. const size = 262144;
  266. const vecSize = divRoundUp( size, 4 );
  267. // Grid display is gridDim x gridDim
  268. const gridDim = Math.sqrt( size );
  269. let maxWorkgroupSize = 64;
  270. // Algorithm speed increase as you iterate through algorithms array
  271. const algorithms = [
  272. 'Reduce 0 (N/2)',
  273. 'Reduce 1 (Naive Accumulate)',
  274. 'Reduce 2 (Workgroup Reduction)',
  275. 'Reduce 3 (Subgroup Reduce)',
  276. 'Reduce 4 (Subgroup Optimized)',
  277. 'Incorrect Baseline',
  278. ];
  279. // Input Grid: Displays input data in a grid format
  280. // Input Log2: Displays input grid data's logarithmic indices horizontally (1, 2, 4, 8, 16, ..., size)
  281. // Input Element 0: Displays clamped input[0]
  282. const displayModes = [ 'Input Grid', 'Input Log2', 'Input Element 0', 'Workgroup Sum Grid' ];
  283. // Holds uniforms for both displays as well as debug information
  284. const unifiedEffectController = {
  285. // Number of elements in the grid
  286. gridElementWidth: uniform( gridDim ),
  287. gridElementHeight: uniform( gridDim ),
  288. // Number of elements in the grid being displayed
  289. gridDisplayWidth: uniform( gridDim ),
  290. gridDisplayHeight: uniform( gridDim ),
  291. // How to display end result of reduction.
  292. // Ideally this is unique to the reduction method being deployed
  293. 'Display Mode': 'Input Log2',
  294. loggedBuffer: 'Input Buffer',
  295. elementsReduced: size,
  296. };
  297. const leftEffectController = {
  298. // Current reduction algorithm being executed by this side
  299. algo: 'Reduce 0 (N/2)',
  300. // Flag indicating whether to highlight element in validation check
  301. highlight: uniform( 0 ),
  302. // Uniform that corresponds to the index of the current algorithm within the algorithms array
  303. currentAlgo: uniform( 0 ),
  304. // Current state of reduction (Running, validating, resetting)
  305. state: 'Run Algo',
  306. // Current display mode
  307. displayMode: 'Input Log2',
  308. // Reduce 0 specific uniform
  309. numThreadsDispatched: uniform( size / 2 ),
  310. // The subgroup size used by this side's device
  311. };
  312. const rightEffectController = {
  313. algo: 'Reduce 4 (Subgroup Optimized)',
  314. currentAlgo: uniform( 3 ),
  315. highlight: uniform( 0 ),
  316. displayMode: 'Input Element 0',
  317. state: 'Run Algo',
  318. numThreadsDispatched: uniform( size / 2 )
  319. };
  320. const leftMaterial = new THREE.MeshBasicNodeMaterial( { color: 0x00ff00 } );
  321. const rightMaterial = new THREE.MeshBasicNodeMaterial( { color: 0x00ff00 } );
  322. const leftDisplayColorNodes = {};
  323. const rightDisplayColorNodes = {};
  324. const gui = new GUI();
  325. gui.add( leftEffectController, 'algo', algorithms ).onChange( () => {
  326. leftEffectController.currentAlgo.value = algorithms.findIndex( val => val === leftEffectController.algo );
  327. } );
  328. gui.add( rightEffectController, 'algo', algorithms ).onChange( () => {
  329. rightEffectController.currentAlgo.value = algorithms.findIndex( val => val === rightEffectController.algo );
  330. } );
  331. gui.add( leftEffectController, 'displayMode', displayModes ).name( 'Left Display Mode' ).onChange( () => {
  332. leftMaterial.colorNode = leftDisplayColorNodes[ leftEffectController.displayMode ];
  333. leftMaterial.needsUpdate = true;
  334. } );
  335. gui.add( rightEffectController, 'displayMode', displayModes ).name( 'Right Display Mode' ).onChange( () => {
  336. rightMaterial.colorNode = rightDisplayColorNodes[ rightEffectController.displayMode ];
  337. rightMaterial.needsUpdate = true;
  338. } );
  339. const debugFolder = gui.addFolder( 'Debug' );
  340. const elementsReducedController = debugFolder.add( unifiedEffectController, 'elementsReduced' ).name( 'Elements Reduced' );
  341. elementsReducedController.disable();
  342. const stateLeftController = debugFolder.add( leftEffectController, 'state' ).name( 'Left Display State' );
  343. const stateRightController = debugFolder.add( rightEffectController, 'state' ).name( 'Right Display State' );
  344. stateLeftController.disable();
  345. stateRightController.disable();
  346. debugFolder.add( unifiedEffectController, 'loggedBuffer', [ 'Input Buffer', 'Input Vectorized Buffer', 'Workgroup Sums Buffer', 'Debug Buffer' ] ).name( 'Buffer to Log' );
  347. debugFolder.close();
  348. // HELPER FUNCTIONS
  349. const pow2Ceil = Fn( ( [ x ] ) => {
  350. If( x.equal( uint( 0 ) ), () => {
  351. return uint( 1 );
  352. } );
  353. const val = x.sub( 1 ).toVar( 'val' );
  354. val.assign( val.bitOr( val.shiftRight( 1 ) ) );
  355. val.assign( val.bitOr( val.shiftRight( 2 ) ) );
  356. val.assign( val.bitOr( val.shiftRight( 4 ) ) );
  357. val.assign( val.bitOr( val.shiftRight( 8 ) ) );
  358. val.assign( val.bitOr( val.shiftRight( 16 ) ) );
  359. return val.add( 1 );
  360. } ).setLayout( {
  361. name: 'pow2Ceil',
  362. type: 'uint',
  363. inputs: [
  364. { name: 'x', type: 'uint' }
  365. ]
  366. } );
  367. // ALGORITHM CONSTRUCTORS
  368. // REDUCE 1
  369. // Thanks to Sam0oneau of Graphics Programming Discord for the explanation.
  370. // (Graphics Programming Discord Message Link): https://discord.com/channels/318590007881236480/374061825454768129/1391248956171882597
  371. /* Reduce 1 Example (Assume Workgroup Size 256, numElements: 262144) -> Initial currentBuffer State: | 1, 1, 1, 1, ... |
  372. *
  373. * KERNEL 1:
  374. * Executes 256 threads by 256 workgroups. Each thread loops 4 times and accesses elements
  375. * at the indices below.
  376. * Thread 1 Thread 2 Thread 3
  377. * | 0, 65536, ..., n * 65536 | 1, 65537, .... (n * 65536) + 1 | 1, 65538, .... (n * 65536) + 2 | etc
  378. * Buffer Values: | 4, 4, 4, 4, ...|
  379. *
  380. * KERNEL 2:
  381. * Executes 256 threads by one workgroup. Each thread loops 1024 times
  382. * Thread 1 Thread 2 Thread 3
  383. * | 0, 256, ...., n * 256 | 1, 257, ... (n * 256) + 1 | 2, 258, ... (n * 256) + 3 | etc
  384. * Buffer Values: | 1024, 1024, 1024, 1024, ... |
  385. *
  386. * KERNEL 3:
  387. * Executes 1 thread by one workgroup. Single thread loops 256 times
  388. * Thread 1
  389. * | 0, 1, 2, 3, 4, 5, 6 ... etc|
  390. * Buffer Values: [262144, 1024, 1024]
  391. */
  392. const createReduce1Fn = ( createReduce1FnProps ) => {
  393. const { dispatchSize, numElements, inputBuffer, workgroupSize } = createReduce1FnProps;
  394. const fnDef = Fn( () => {
  395. const dispatch = uint( dispatchSize ).toVar( 'dispatchSize' );
  396. const tSum = uint( 0 ).toVar();
  397. const k = instanceIndex.toVar( 'k' );
  398. Loop( k.lessThan( uint( numElements ) ), ( ) => {
  399. tSum.addAssign( inputBuffer.element( k ) );
  400. k.addAssign( uint( dispatch ) );
  401. } );
  402. inputBuffer.element( instanceIndex ).assign( tSum );
  403. } )().compute( dispatchSize, [ workgroupSize ] );
  404. return fnDef;
  405. };
  406. // REDUCE 2
  407. // For non power of 2 # of workgroups
  408. const createReduce2Fn = ( createReduce2FnProps ) => {
  409. const { workgroupSize, dispatchSize, numElements, inputBuffer } = createReduce2FnProps;
  410. const fnDef = Fn( () => {
  411. const tSum = workgroupArray( 'uint', workgroupSize );
  412. const k = instanceIndex.toVar( 'k' );
  413. tSum.element( invocationLocalIndex ).assign( uint( 0 ) );
  414. Loop( k.lessThan( uint( numElements ) ), () => {
  415. tSum.element( invocationLocalIndex ).addAssign( inputBuffer.element( k ) );
  416. k.addAssign( uint( dispatchSize ) );
  417. } );
  418. workgroupBarrier();
  419. // Reset the loop condition (account for numWorkgroups % 2 != 0)
  420. k.assign( pow2Ceil( uint( workgroupSize ) ).div( 2 ) );
  421. Loop( k.greaterThan( 0 ), () => {
  422. If( invocationLocalIndex.lessThan( k ).and( invocationLocalIndex.add( k ).lessThan( workgroupSize ) ), () => {
  423. tSum.element( invocationLocalIndex ).addAssign( tSum.element( invocationLocalIndex.add( k ) ) );
  424. } );
  425. workgroupBarrier();
  426. k.divAssign( 2 );
  427. } );
  428. If( invocationLocalIndex.equal( uint( 0 ) ), () => {
  429. inputBuffer.element( workgroupId.x ).assign( tSum.element( uint( 0 ) ) );
  430. } );
  431. } )().compute( dispatchSize, [ workgroupSize ] );
  432. return fnDef;
  433. };
  434. // REDUCE 3
  435. /* Create array with enough indices for worst-case subgroup size */
  436. const createSubgroupArray = ( type, workgroupSize, minSubgroupSize = 4 ) => {
  437. return workgroupArray( 'uint', workgroupSize / minSubgroupSize );
  438. };
  439. // zcbenz implementation
  440. // https://github.com/frost-beta/betann/blob/8aa2701caf63fb29bd4cd2454e656973342c1588/betann/wgsl/reduce_ops.wgsl#L71
  441. const RowReduce = ( rowReduceProps ) => {
  442. const { workgroupSize, inputBuffer, total, rowOffset, currentRowSize, workPerThread, vectorized } = rowReduceProps;
  443. // Number of unvectorized elements each workgroup can ingest
  444. // At workgroupSize of 256, blockSize will be 1024
  445. const blockSize = uint( workgroupSize ).mul( workPerThread );
  446. const block = uint( 0 ).toVar( 'block' );
  447. // At rowSize of 2048, there will be two blocks
  448. const blockLimiter = currentRowSize.div( blockSize ).toVar( 'blockLimiter' );
  449. Loop( block.lessThan( blockLimiter ), () => {
  450. const blockOffset = block.mul( blockSize );
  451. const startThread = blockOffset.add( invocationLocalIndex.mul( workPerThread ) );
  452. const localThreadOffset = uint( 0 ).toVar( 'localThreadOffset' );
  453. Loop( localThreadOffset.lessThan( workPerThread ), () => {
  454. const inputElement = inputBuffer.element( rowOffset.add( startThread ).addLocal );
  455. if ( vectorized ) {
  456. const value = dot( inputElement, uvec4( 1 ) );
  457. total.addAssign( value );
  458. } else {
  459. const inputElement = inputBuffer.element( rowOffset.add( startThread ).add( localThreadOffset ) );
  460. total.addAssign( inputElement );
  461. }
  462. // Increment up a thread
  463. localThreadOffset.addAssign( 1 );
  464. } );
  465. // Increment up a block
  466. block.addAssign( 1 );
  467. } );
  468. // Ignoring left over check for this example, since we know ahead of time the value of leftover (2048 % 1024 === 0)
  469. };
  470. const WorkgroupReduce = ( workgroupReduceProps ) => {
  471. const { total, workgroupSize } = workgroupReduceProps;
  472. const subgroupSums = createSubgroupArray( 'uint', workgroupSize );
  473. // Assign sum of all values in subgroup to total
  474. total.assign( subgroupAdd( total ) );
  475. const delta = uint( workgroupSize ).div( subgroupSize ).toVar( 'delta' );
  476. const subgroupMetaRank = invocationLocalIndex.div( subgroupSize );
  477. Loop( float( delta ).greaterThan( 1.0 ), () => {
  478. If( invocationSubgroupIndex.equal( 0 ), () => {
  479. // Each subgroup will populate the subgroupSums array
  480. subgroupSums.element( subgroupMetaRank ).assign( total );
  481. } );
  482. // Ensure that all subgroups in the workgroup have populated the workgroup memory array
  483. workgroupBarrier();
  484. // Thread 0 - subgroupsInWorkgroup will assign a value to total
  485. total.assign( select( invocationLocalIndex.lessThan( delta ), subgroupSums.element( invocationLocalIndex ), 0 ).uniformFlow() );
  486. // # of subgroups in workgroup is invariably less than # of threads in subgroup, so subgroupAdd will still sync here
  487. total.assign( subgroupAdd( total ) );
  488. delta.divAssign( subgroupSize );
  489. } );
  490. };
  491. const createReduce3Fn = ( createReduce3FnProps ) => {
  492. const { workgroupSize, workPerThread, inputBuffer, intermediateBuffer, rowSize } = createReduce3FnProps;
  493. const fnDef = Fn( () => {
  494. const inputSize = uint( inputBuffer.bufferCount );
  495. const rowOffset = workgroupId.x.mul( rowSize );
  496. // If the current rows elements exceed the bounds of the input
  497. // Select either 0 or number of elements left,
  498. // otherwise, select existing ROW_SIZE
  499. const currentRowSize = select(
  500. ( rowOffset.add( rowSize ) ).greaterThan( inputSize ),
  501. select( inputSize.greaterThan( rowOffset ), inputSize.sub( rowOffset ), 0 ).uniformFlow(),
  502. rowSize,
  503. ).uniformFlow();
  504. const total = uint( 0 ).toVar( 'total' );
  505. RowReduce( {
  506. inputBuffer: inputBuffer,
  507. total: total,
  508. rowOffset: rowOffset,
  509. currentRowSize: currentRowSize,
  510. workPerThread: workPerThread,
  511. workgroupSize: workgroupSize,
  512. } );
  513. WorkgroupReduce( {
  514. total: total,
  515. workgroupSize: workgroupSize,
  516. } );
  517. // Populate each workgroup with its reduction
  518. If( invocationLocalIndex.equal( 0 ), () => {
  519. intermediateBuffer.element( workgroupId.x ).assign( total );
  520. } );
  521. } )();
  522. return fnDef;
  523. };
  524. // REDUCE 4
  525. // b0nes164 inspired implementation with vec4
  526. const createReduce4Fn = ( props ) => {
  527. // Can't pass in subgroup size since we can't always be certain what size is at runtime
  528. const { size, workPerThread, workgroupSize, inputBuffer, intermediateBuffer } = props;
  529. const ELEMENTS_PER_VEC4 = 4;
  530. // The number of individual elements a single workgroup will access
  531. const partitionSize = workgroupSize * workPerThread * ELEMENTS_PER_VEC4;
  532. const vecSize = divRoundUp( size, ELEMENTS_PER_VEC4 );
  533. // Can also be calculated using divRoundUp( vecSize, workgroupSize * workPerThread );
  534. const numWorkgroups = divRoundUp( size, partitionSize );
  535. // Currently no way to specify dispatch size in increments of workgroups, so we convert to numInvocations
  536. const numInvocations = numWorkgroups * workgroupSize;
  537. const fnDef = Fn( () => {
  538. const perSubgroupReductionArray = createSubgroupArray( 'uint', workgroupSize );
  539. // Get the index of the subgroup within the workgroup
  540. const subgroupMetaRank = invocationLocalIndex.div( subgroupSize );
  541. // Each subgroup block scans across 4 subgroups. So when we move into a new subgroup,
  542. // align that subgroups' accesses to the next 4 subgroups
  543. const subgroupOffset = subgroupMetaRank.mul( subgroupSize ).mul( workPerThread );
  544. subgroupOffset.addAssign( invocationSubgroupIndex );
  545. // Per workgroup, offset by number of vectorized elements scanned per workgroup
  546. const workgroupOffset = workgroupId.x.mul( uint( maxWorkgroupSize ).mul( workPerThread ) );
  547. const startThread = subgroupOffset.add( workgroupOffset );
  548. const subgroupReduction = uint( 0 );
  549. // Each thread will accumulate values from across 'workPerThread' subgroups
  550. If( workgroupId.x.lessThan( uint( numWorkgroups ).sub( 1 ) ), () => {
  551. Loop( {
  552. start: uint( 0 ),
  553. end: workPerThread,
  554. type: 'uint',
  555. condition: '<',
  556. name: 'currentSubgroupInBlock'
  557. }, () => {
  558. // Get vectorized element from input array
  559. const val = inputBuffer.element( startThread );
  560. // Sum values within vec4 together by using result of dot product
  561. subgroupReduction.addAssign( dot( uvec4( 1 ), val ) );
  562. // Increment so thread will scan value in next subgroup
  563. startThread.addAssign( subgroupSize );
  564. } );
  565. } );
  566. // Ensure that the last workgroup does not access out of bounds indices
  567. If( workgroupId.x.equal( uint( numWorkgroups ).sub( 1 ) ), () => {
  568. Loop( {
  569. start: uint( 0 ),
  570. end: workPerThread,
  571. type: 'uint',
  572. condition: '<',
  573. name: 'currentSubgroupInBlock'
  574. }, () => {
  575. // Ensure index is less than number of available vectors in inputBuffer
  576. const val = select( startThread.lessThan( uint( vecSize ) ), inputBuffer.element( startThread ), uvec4( 0 ) ).uniformFlow();
  577. subgroupReduction.addAssign( dot( val, uvec4( 1 ) ) );
  578. startThread.addAssign( subgroupSize );
  579. } );
  580. } );
  581. subgroupReduction.assign( subgroupAdd( subgroupReduction ) );
  582. // Assuming that each element in the input buffer is 1, we generally expect each invocation's subgroupReduction
  583. // value to be ELEMENTS_PER_VEC4 * workPerThread * subgroupSize
  584. // Delegate one thread per subgroup to assign each subgroup's reduction to the workgroup array
  585. If( invocationSubgroupIndex.equal( uint( 0 ) ), () => {
  586. perSubgroupReductionArray.element( subgroupMetaRank ).assign( subgroupReduction );
  587. } );
  588. // Ensure that each workgroup has populated the perSubgroupReductionArray with data
  589. // from each of it's subgroups
  590. workgroupBarrier();
  591. if ( props.debugBuffer ) {
  592. If( invocationLocalIndex.equal( uint( 0 ) ), () => {
  593. props.debugBuffer.element( workgroupId.x ).assign( subgroupReduction );
  594. } );
  595. workgroupBarrier();
  596. }
  597. // WORKGROUP LEVEL REDUCE
  598. // Multiple approaches here
  599. // log2(subgroupSize) -> TSL log2 function
  600. // countTrailingZeros/findLSB(subgroupSize) -> TSL function that counts trailing zeros in number bit representation
  601. // Can technically petition GPU for subgroupSize in shader and calculate logs on CPU at cost of shader being generalizable across devices
  602. // May also break if subgroupSize changes when device is lost or if program is rerun on lower power device
  603. const subgroupSizeLog = countTrailingZeros( subgroupSize ).toVar( 'subgroupSizeLog' );
  604. const spineSize = uint( workgroupSize ).shiftRight( subgroupSizeLog );
  605. const spineSizeLog = countTrailingZeros( spineSize ).toVar( 'spineSizeLog' );
  606. // Align size to powers of subgroupSize
  607. const squaredSubgroupLog = ( spineSizeLog.add( subgroupSizeLog ).sub( 1 ) );
  608. squaredSubgroupLog.divAssign( subgroupSizeLog );
  609. squaredSubgroupLog.mulAssign( subgroupSizeLog );
  610. const alignedSize = ( uint( 1 ).shiftLeft( squaredSubgroupLog ) ).toVar( 'alignedSize' );
  611. // aligned size 2 * 4
  612. const offset = uint( 0 );
  613. // In cases where the number of subgroups in a workgroup is greater than the subgroup size itself,
  614. // we need to iterate over the array again to capture all the data in the workgroup array buffer
  615. Loop( { start: subgroupSize, end: alignedSize, condition: '<=', name: 'j', type: 'uint', update: '<<= subgroupSizeLog' }, () => {
  616. const subgroupIndex = ( ( invocationLocalIndex.add( 1 ) ).shiftLeft( offset ) ).sub( 1 );
  617. const isValidSubgroupIndex = subgroupIndex.lessThan( spineSize ).toVar( 'isValidSubgroupIndex' );
  618. // Reduce values within the local workgroup memory.
  619. // Set toVar to ensure subgroupAdd executes before (not within) the if statement.
  620. const t = subgroupAdd(
  621. select(
  622. isValidSubgroupIndex,
  623. perSubgroupReductionArray.element( subgroupIndex ),
  624. 0
  625. ).uniformFlow()
  626. ).toVar( 't' );
  627. // Can assign back to workgroupArray since all
  628. // subgroup threads work in lockstop for subgroupAdd
  629. If( isValidSubgroupIndex, () => {
  630. perSubgroupReductionArray.element( subgroupIndex ).assign( t );
  631. } );
  632. // Ensure all threads have completed work
  633. workgroupBarrier();
  634. offset.addAssign( subgroupSizeLog );
  635. } );
  636. // Assign single thread from workgroup to assign workgroup reduction
  637. If( invocationLocalIndex.equal( uint( 0 ) ), () => {
  638. const reducedWorkgroupSum = perSubgroupReductionArray.element( uint( spineSize ).sub( 1 ) );
  639. intermediateBuffer.element( workgroupId.x ).assign( reducedWorkgroupSum );
  640. } );
  641. } )().compute( numInvocations, [ maxWorkgroupSize ] );
  642. return fnDef;
  643. };
  644. // INCORRECT BASELINE
  645. const createIncorrectBaselineFn = ( incorrectBaselineProps ) => {
  646. const { inputBuffer } = incorrectBaselineProps;
  647. const fnDef = Fn( () => {
  648. inputBuffer.element( instanceIndex ).assign( 99999 );
  649. } )();
  650. return fnDef;
  651. };
  652. init();
  653. init( false );
  654. async function init( leftSideDisplay = true ) {
  655. const effectController = leftSideDisplay ? leftEffectController : rightEffectController;
  656. const aspect = ( window.innerWidth / 2 ) / window.innerHeight;
  657. const camera = new THREE.OrthographicCamera( - aspect, aspect, 1, - 1, 0, 2 );
  658. camera.position.z = 1;
  659. const scene = new THREE.Scene();
  660. const array = new Uint32Array( Array.from( { length: size }, () => {
  661. return 1;
  662. } ) );
  663. // Represents array of data as uints in compute shader.
  664. const inputStorage = instancedArray( array, 'uint' ).setPBO( true ).setName( `Current_${leftSideDisplay ? 'Left' : 'Right'}` );
  665. // Represents array of data as vec4s in compute shader;
  666. const inputVec4BufferAttribute = new THREE.StorageInstancedBufferAttribute( array, 4 );
  667. const inputVectorizedStorage = storage( inputVec4BufferAttribute, 'uvec4', vecSize ).setPBO( true ).setName( `CurrentVectorized_${leftSideDisplay ? 'Left' : 'Right'}` );
  668. // Reduce 3 Calculations
  669. const workPerThread = 4;
  670. const numRows = workPerThread * 32;
  671. const rowSize = divRoundUp( size, numRows );
  672. const workgroupSumsArray = new Uint32Array( numRows );
  673. const workgroupSumsStorage = instancedArray( workgroupSumsArray, 'uint' ).setPBO( true ).setName( `WorkgroupSums_${leftSideDisplay ? 'Left' : 'Right'}` );
  674. const debugArray = new Uint32Array( 1024 );
  675. const debugStorage = instancedArray( debugArray, 'uint' ).setPBO( true ).setName( `Debug_${leftSideDisplay ? 'Left' : 'Right'}` );
  676. const buffers = {
  677. 'Input Buffer': inputStorage,
  678. 'Input Vectorized Buffer': inputVectorizedStorage,
  679. 'Workgroup Sums Buffer': workgroupSumsStorage,
  680. 'Debug Buffer': debugStorage,
  681. };
  682. const logFunctionName = `Log ${leftSideDisplay ? 'Left' : 'Right'} Side`;
  683. const functionObj = {};
  684. functionObj[ logFunctionName ] = async() => {
  685. const selectedBuffer = buffers[ unifiedEffectController.loggedBuffer ];
  686. const readbackBuffer = new THREE.ReadbackBuffer( selectedBuffer.value );
  687. const result = new Uint32Array( await renderer.getArrayBufferAsync( readbackBuffer ) );
  688. console.log( result );
  689. // Remove GPU/CPU readback buffer from memory
  690. readbackBuffer.dispose();
  691. };
  692. debugFolder.add( functionObj, `Log ${leftSideDisplay ? 'Left' : 'Right'} Side` );
  693. const computeResetBufferFn = Fn( () => {
  694. inputStorage.element( instanceIndex ).assign( 1 );
  695. } );
  696. const computeResetWorkgroupSumsFn = Fn( () => {
  697. workgroupSumsStorage.element( instanceIndex ).assign( 0 );
  698. } );
  699. // Re-initialize compute buffer
  700. const computeResetBuffer = computeResetBufferFn().compute( size );
  701. const computeResetWorkgroupSums = computeResetWorkgroupSumsFn().compute( 256 );
  702. const renderer = new THREE.WebGPURenderer( { antialias: false, trackTimestamp: true } );
  703. renderer.setPixelRatio( window.devicePixelRatio );
  704. renderer.setSize( window.innerWidth / 2, window.innerHeight );
  705. await renderer.init();
  706. // Unfortunately, need to arbitrarily run compute shader to get access to device limits
  707. renderer.compute( computeResetBuffer );
  708. if ( renderer.backend.device !== null ) {
  709. maxWorkgroupSize = renderer.backend.device.limits.maxComputeWorkgroupSizeX;
  710. }
  711. // Create and store dispatches of reduction of certain size. Map each set of dispatches to algorithm name.
  712. const computeReduce0Fn = Fn( () => {
  713. const { numThreadsDispatched } = effectController;
  714. inputStorage.element( instanceIndex ).addAssign( inputStorage.element( instanceIndex.add( numThreadsDispatched ) ) );
  715. } )();
  716. const reduce0Calls = [];
  717. for ( let i = size / 2; i >= 1; i /= 2 ) {
  718. const reduce0 = computeReduce0Fn.compute( i, [ maxWorkgroupSize ] );
  719. reduce0Calls.push( reduce0 );
  720. }
  721. const reduce1Calls = [
  722. // Accumulation
  723. createReduce1Fn( {
  724. dispatchSize: maxWorkgroupSize * maxWorkgroupSize,
  725. workgroupSize: maxWorkgroupSize,
  726. numElements: size,
  727. inputBuffer: inputStorage,
  728. } ),
  729. // 1 Block accumulation
  730. createReduce1Fn( {
  731. dispatchSize: maxWorkgroupSize,
  732. numElements: maxWorkgroupSize * maxWorkgroupSize,
  733. workgroupSize: maxWorkgroupSize,
  734. inputBuffer: inputStorage,
  735. } ),
  736. // Final result
  737. createReduce1Fn( {
  738. dispatchSize: 1,
  739. numElements: maxWorkgroupSize,
  740. workgroupSize: 1,
  741. inputBuffer: inputStorage
  742. } ),
  743. ];
  744. const reduce2Calls = [
  745. // Accumulate within workgroups
  746. createReduce2Fn( {
  747. workgroupSize: maxWorkgroupSize,
  748. dispatchSize: maxWorkgroupSize * maxWorkgroupSize,
  749. numElements: size,
  750. inputBuffer: inputStorage,
  751. } ),
  752. // 1 Block accumulation
  753. createReduce2Fn( {
  754. workgroupSize: maxWorkgroupSize,
  755. dispatchSize: maxWorkgroupSize,
  756. numElements: maxWorkgroupSize,
  757. inputBuffer: inputStorage,
  758. } ),
  759. ];
  760. const reduce3Calls = [
  761. createReduce3Fn( {
  762. inputBuffer: inputStorage,
  763. intermediateBuffer: workgroupSumsStorage,
  764. workgroupSize: maxWorkgroupSize,
  765. workPerThread: 4,
  766. rowSize: rowSize,
  767. vectorized: false,
  768. } ).compute( maxWorkgroupSize * numRows, [ maxWorkgroupSize ] ),
  769. createReduce3Fn( {
  770. inputBuffer: workgroupSumsStorage,
  771. intermediateBuffer: inputStorage,
  772. workgroupSize: 32,
  773. workPerThread: 4,
  774. rowSize: rowSize,
  775. vectorized: false
  776. } ).compute( 32, [ 32 ] )
  777. ];
  778. const reduce4Calls = [
  779. createReduce4Fn( {
  780. size: size,
  781. inputBuffer: inputVectorizedStorage,
  782. intermediateBuffer: workgroupSumsStorage,
  783. workgroupSize: maxWorkgroupSize,
  784. workPerThread: 4,
  785. } ),
  786. createReduce3Fn( {
  787. inputBuffer: workgroupSumsStorage,
  788. intermediateBuffer: inputStorage,
  789. workgroupSize: 32,
  790. workPerThread: 4,
  791. rowSize: rowSize,
  792. vectorized: false
  793. } ).compute( 32, [ 32 ] )
  794. ];
  795. const incorrectBaselineCalls = [
  796. createIncorrectBaselineFn( {
  797. inputBuffer: inputStorage,
  798. } ).compute( size ),
  799. ];
  800. const calls = {
  801. 'Reduce 0 (N/2)': reduce0Calls,
  802. 'Reduce 1 (Naive Accumulate)': reduce1Calls,
  803. 'Reduce 2 (Workgroup Reduction)': reduce2Calls,
  804. 'Reduce 3 (Subgroup Reduce)': reduce3Calls,
  805. 'Reduce 4 (Subgroup Optimized)': reduce4Calls,
  806. 'Incorrect Baseline': incorrectBaselineCalls
  807. };
  808. const getColor = ( bufferToCheck, colorChanger, width, height ) => {
  809. const subtracter = float( colorChanger ).div( width.mul( height ) );
  810. const color = vec3( subtracter.oneMinus() ).toVar();
  811. const { highlight } = effectController;
  812. // Validate that element 0 is equal to expected result of reduction
  813. If( highlight.equal( 1 ), () => {
  814. If( ( bufferToCheck.element( 0 ) ).equal( size ), () => {
  815. color.assign( vec3( 0.0, subtracter.oneMinus(), 0.0 ) );
  816. } ).Else( () => {
  817. color.assign( vec3( subtracter.oneMinus(), 0.0, 0.0 ) );
  818. } );
  819. } );
  820. return color;
  821. };
  822. const displayNodes = leftSideDisplay ? leftDisplayColorNodes : rightDisplayColorNodes;
  823. displayNodes[ 'Input Grid' ] = Fn( () => {
  824. const { gridElementWidth, gridElementHeight, gridDisplayWidth, gridDisplayHeight } = unifiedEffectController;
  825. const newUV = uv().mul( vec2( gridDisplayWidth, gridDisplayHeight ) );
  826. const pixel = uvec2( uint( floor( newUV.x ) ), uint( floor( newUV.y ) ) );
  827. const elementIndex = uint( gridDisplayWidth ).mul( pixel.y ).add( pixel.x );
  828. const colorChanger = uint( 0 ).toVar();
  829. const color = vec3( 0 ).toVar( 'color' );
  830. colorChanger.assign( inputStorage.element( elementIndex ) );
  831. color.assign( getColor( inputStorage, colorChanger, gridElementWidth, gridElementHeight ) );
  832. return color;
  833. } )();
  834. displayNodes[ 'Input Log2' ] = Fn( () => {
  835. const { gridElementWidth, gridElementHeight } = unifiedEffectController;
  836. const newUV = uv().mul( vec2( Math.log2( size ) ), 1 );
  837. const colorChanger = uint( 0 ).toVar();
  838. const color = vec3( 0 ).toVar( 'color' );
  839. colorChanger.assign( inputStorage.element( uint( 1 ).shiftLeft( newUV.x ) ) );
  840. color.assign( getColor( inputStorage, colorChanger, gridElementWidth, gridElementHeight ) );
  841. return color;
  842. } )();
  843. displayNodes[ 'Input Element 0' ] = Fn( () => {
  844. const { gridElementWidth, gridElementHeight } = unifiedEffectController;
  845. const colorChanger = uint( 0 ).toVar();
  846. const color = vec3( 0 ).toVar( 'color' );
  847. // Clamp display of single element to shade where green is still readable
  848. colorChanger.assign( clamp( inputStorage.element( 0 ), 0, size / 2 ) );
  849. color.assign( getColor( inputStorage, colorChanger, gridElementWidth, gridElementHeight ) );
  850. return color;
  851. } )();
  852. displayNodes[ 'Workgroup Sum Grid' ] = Fn( () => {
  853. const width = uint( 8 );
  854. const height = uint( 16 );
  855. const newUV = uv().mul( vec2( width, height ) );
  856. const pixel = uvec2( uint( floor( newUV.x ) ), uint( floor( newUV.y ) ) );
  857. const elementIndex = uint( width ).mul( pixel.y ).add( pixel.x );
  858. const colorChanger = uint( 0 ).toVar();
  859. const color = vec3( 0 ).toVar( 'color' );
  860. colorChanger.assign( workgroupSumsStorage.element( elementIndex ) );
  861. color.assign( getColor( inputStorage, colorChanger, width, height ) );
  862. return color;
  863. } )();
  864. ( leftSideDisplay ? leftMaterial : rightMaterial ).colorNode = displayNodes[ effectController.displayMode ];
  865. ( leftSideDisplay ? leftMaterial : rightMaterial ).needsUpdate = true;
  866. const plane = new THREE.Mesh( new THREE.PlaneGeometry( 1, 1 ), ( leftSideDisplay ? leftMaterial : rightMaterial ) );
  867. scene.add( plane );
  868. const animate = () => {
  869. renderer.render( scene, camera );
  870. };
  871. renderer.setAnimationLoop( animate );
  872. document.body.appendChild( renderer.domElement );
  873. renderer.domElement.style.position = 'absolute';
  874. renderer.domElement.style.top = '0';
  875. renderer.domElement.style.left = '0';
  876. renderer.domElement.style.width = '50%';
  877. renderer.domElement.style.height = '100%';
  878. if ( ! leftSideDisplay ) {
  879. renderer.domElement.style.left = '50%';
  880. scene.background = new THREE.Color( 0x212121 );
  881. } else {
  882. scene.background = new THREE.Color( 0x313131 );
  883. }
  884. renderer.info.autoReset = false;
  885. const stepAnimation = async function () {
  886. const currentAlgorithm = effectController.algo;
  887. const state = effectController.state;
  888. const stateController = leftSideDisplay ? stateLeftController : stateRightController;
  889. if ( state === 'Reset' ) {
  890. renderer.compute( computeResetBuffer );
  891. renderer.compute( computeResetWorkgroupSums );
  892. } else if ( state === 'Run Algo' ) {
  893. renderer.info.reset();
  894. const cpuTime = 0;
  895. switch ( currentAlgorithm ) {
  896. case 'Reduce 0 (N/2)': {
  897. let m = size / 2;
  898. for ( let i = 0; i < reduce0Calls.length; i ++ ) {
  899. effectController.numThreadsDispatched.value = m;
  900. const reduce0 = reduce0Calls[ i ];
  901. // Do a reduction step
  902. renderer.compute( reduce0 );
  903. renderer.resolveTimestampsAsync( THREE.TimestampQuery.COMPUTE );
  904. m /= 2;
  905. }
  906. break;
  907. }
  908. default: {
  909. const currentAlgoCalls = calls[ currentAlgorithm ];
  910. for ( let i = 0; i < currentAlgoCalls.length; i ++ ) {
  911. renderer.compute( currentAlgoCalls[ i ] );
  912. renderer.resolveTimestampsAsync( THREE.TimestampQuery.COMPUTE );
  913. }
  914. break;
  915. }
  916. }
  917. // DEBUG: const reductionResult = new Uint32Array( await renderer.getArrayBufferAsync( currentBuffer ) )[0];
  918. let passInfoString = '';
  919. if ( effectController.algo.substring( 0, 3 ) === 'CPU' ) {
  920. passInfoString = `Ran in ${cpuTime}ms<br>`;
  921. } else {
  922. passInfoString = `${renderer.info.compute.frameCalls} pass in ${renderer.info.compute.timestamp.toFixed( 6 )}ms<br>`;
  923. }
  924. timestamps[ leftSideDisplay ? 'left_side_display' : 'right_side_display' ].innerHTML = `
  925. Compute ${effectController.algo}: ${passInfoString}`;
  926. }
  927. renderer.render( scene, camera );
  928. renderer.resolveTimestampsAsync( THREE.TimestampQuery.RENDER );
  929. // Validate next state
  930. if ( state === 'Run Algo' ) {
  931. stateController.setValue( 'Validate' );
  932. effectController.highlight.value = 1;
  933. } else if ( state === 'Validate' ) {
  934. stateController.setValue( 'Reset' );
  935. effectController.highlight.value = 0;
  936. } else if ( state === 'Reset' ) {
  937. stateController.setValue( 'Run Algo' );
  938. }
  939. setTimeout( stepAnimation, 1000 );
  940. };
  941. window.addEventListener( 'resize', onWindowResize );
  942. function onWindowResize() {
  943. renderer.setSize( window.innerWidth / 2, window.innerHeight );
  944. const aspect = ( window.innerWidth / 2 ) / window.innerHeight;
  945. const frustumHeight = camera.top - camera.bottom;
  946. camera.left = - frustumHeight * aspect / 2;
  947. camera.right = frustumHeight * aspect / 2;
  948. camera.updateProjectionMatrix();
  949. renderer.render( scene, camera );
  950. }
  951. setTimeout( stepAnimation, 1000 );
  952. }
  953. </script>
  954. </body>
  955. </html>
粤ICP备19079148号