ClusteredLightsNode.js 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613
  1. import { DataTexture, FloatType, RGBAFormat, Vector2, Vector3, LightsNode, NodeUpdateType } from 'three/webgpu';
  2. import {
  3. attributeArray, nodeProxy, int, float, vec3, vec4, ivec2, ivec4, uniform, Break, Loop, positionView,
  4. Fn, If, Return, textureLoad, instanceIndex, screenCoordinate, directPointLight,
  5. renderGroup,
  6. min, max, pow, log, clamp, dot
  7. } from 'three/tsl';
  8. const _vector3 = /*@__PURE__*/ new Vector3();
  9. const _size = /*@__PURE__*/ new Vector2();
  10. /**
  11. * A custom version of `LightsNode` implementing Forward+ clustered shading:
  12. * the view frustum is subdivided into a 3D grid of clusters (X × Y screen tiles
  13. * times an exponentially-spaced set of Z depth slices), and each cluster holds
  14. * only the point lights whose spheres intersect it. At shading time each fragment
  15. * looks up its cluster and loops over just that cluster's lights. Unlike 2D tiled
  16. * lighting, clustered shading culls lights that share screen pixels but lie at
  17. * different depths — suitable for 3D scenes with real depth complexity.
  18. *
  19. * @augments LightsNode
  20. * @three_import import { clusteredLights } from 'three/addons/tsl/lighting/ClusteredLightsNode.js';
  21. */
  22. class ClusteredLightsNode extends LightsNode {
  23. static get type() {
  24. return 'ClusteredLightsNode';
  25. }
  26. /**
  27. * Constructs a new clustered lights node.
  28. *
  29. * @param {number} [maxLights=1024] - Maximum number of point lights.
  30. * @param {number} [tileSize=32] - Screen tile size in pixels (cluster XY size).
  31. * @param {number} [zSlices=24] - Number of exponential depth slices.
  32. * @param {number} [maxLightsPerCluster=64] - Per-cluster light-list capacity.
  33. */
  34. constructor( maxLights = 1024, tileSize = 32, zSlices = 24, maxLightsPerCluster = 64 ) {
  35. super();
  36. this.materialLights = [];
  37. this.clusteredLights = [];
  38. this.maxLights = maxLights;
  39. this.tileSize = tileSize;
  40. this.zSlices = zSlices;
  41. this.maxLightsPerCluster = maxLightsPerCluster;
  42. this._chunksPerCluster = Math.ceil( maxLightsPerCluster / 4 );
  43. this._bufferSize = null;
  44. this._lightIndexes = null;
  45. this._screenClusterIndex = null;
  46. this._compute = null;
  47. this._lightsTexture = null;
  48. this._zSliceRangesTexture = null;
  49. this._zSliceRangesData = null;
  50. this._lightViewZ = new Float32Array( maxLights );
  51. this._lightSortOrder = [];
  52. this._lightsCount = uniform( 0, 'int' );
  53. // Render-group uniforms: shared between compute and fragment passes,
  54. // updated manually each frame in updateBefore (compute lacks a camera context).
  55. this._cameraNear = uniform( 0 ).setName( 'clusteredCameraNear' ).setGroup( renderGroup );
  56. this._cameraFar = uniform( 0 ).setName( 'clusteredCameraFar' ).setGroup( renderGroup );
  57. this._cameraViewMatrix = uniform( 'mat4' ).setName( 'clusteredCameraViewMatrix' ).setGroup( renderGroup );
  58. this._cameraProjectionMatrix = uniform( 'mat4' ).setName( 'clusteredCameraProjectionMatrix' ).setGroup( renderGroup );
  59. this._gridDimensions = uniform( new Vector2() );
  60. this.updateBeforeType = NodeUpdateType.RENDER;
  61. }
  62. customCacheKey() {
  63. return ( this._compute ? this._compute.getCacheKey() : 0 ) + super.customCacheKey();
  64. }
  65. updateLightsTexture( camera ) {
  66. const { _lightsTexture: lightsTexture, clusteredLights } = this;
  67. const data = lightsTexture.image.data;
  68. const lineSize = lightsTexture.image.width * 4;
  69. const count = clusteredLights.length;
  70. this._lightsCount.value = count;
  71. // Sort lights by view-space depth for Z-culling
  72. const viewZ = this._lightViewZ;
  73. const order = this._lightSortOrder;
  74. for ( let i = 0; i < count; i ++ ) {
  75. _vector3.setFromMatrixPosition( clusteredLights[ i ].matrixWorld );
  76. _vector3.applyMatrix4( camera.matrixWorldInverse );
  77. viewZ[ i ] = _vector3.z;
  78. order[ i ] = i;
  79. }
  80. order.length = count;
  81. order.sort( ( a, b ) => viewZ[ a ] - viewZ[ b ] );
  82. // Write sorted lights to texture
  83. for ( let i = 0; i < count; i ++ ) {
  84. const light = clusteredLights[ order[ i ] ];
  85. _vector3.setFromMatrixPosition( light.matrixWorld );
  86. const offset = i * 4;
  87. data[ offset + 0 ] = _vector3.x;
  88. data[ offset + 1 ] = _vector3.y;
  89. data[ offset + 2 ] = _vector3.z;
  90. data[ offset + 3 ] = light.distance;
  91. data[ lineSize + offset + 0 ] = light.color.r * light.intensity;
  92. data[ lineSize + offset + 1 ] = light.color.g * light.intensity;
  93. data[ lineSize + offset + 2 ] = light.color.b * light.intensity;
  94. data[ lineSize + offset + 3 ] = light.decay;
  95. }
  96. lightsTexture.needsUpdate = true;
  97. // Compute per Z-slice light ranges
  98. const zRanges = this._zSliceRangesData;
  99. if ( zRanges === null ) return;
  100. const near = camera.near;
  101. const far = camera.far;
  102. const NZ = this.zSlices;
  103. for ( let z = 0; z < NZ; z ++ ) {
  104. // Exponential Z-slice bounds (view-space, negative values)
  105. const sliceNear = - ( near * Math.pow( far / near, z / NZ ) );
  106. const sliceFar = - ( near * Math.pow( far / near, ( z + 1 ) / NZ ) );
  107. let rangeStart = count;
  108. let rangeEnd = 0;
  109. for ( let i = 0; i < count; i ++ ) {
  110. const vz = viewZ[ order[ i ] ];
  111. const r = clusteredLights[ order[ i ] ].distance;
  112. const radius = r > 0 ? r : far;
  113. // Light sphere Z: [vz - radius, vz + radius]
  114. // Slice Z: [sliceFar, sliceNear] (both negative, sliceFar < sliceNear)
  115. if ( vz + radius >= sliceFar && vz - radius <= sliceNear ) {
  116. if ( i < rangeStart ) rangeStart = i;
  117. if ( i + 1 > rangeEnd ) rangeEnd = i + 1;
  118. }
  119. }
  120. if ( rangeStart >= count ) {
  121. rangeStart = 0;
  122. rangeEnd = 0;
  123. }
  124. zRanges[ z * 4 ] = rangeStart;
  125. zRanges[ z * 4 + 1 ] = rangeEnd;
  126. }
  127. this._zSliceRangesTexture.needsUpdate = true;
  128. }
  129. updateBefore( frame ) {
  130. const { renderer, camera } = frame;
  131. this.updateProgram( renderer );
  132. this.updateLightsTexture( camera );
  133. this._cameraNear.value = camera.near;
  134. this._cameraFar.value = camera.far;
  135. this._cameraViewMatrix.value = camera.matrixWorldInverse;
  136. this._cameraProjectionMatrix.value = camera.projectionMatrix;
  137. renderer.compute( this._compute );
  138. }
  139. setLights( lights ) {
  140. const { clusteredLights, materialLights } = this;
  141. let materialIndex = 0;
  142. let clusteredIndex = 0;
  143. for ( const light of lights ) {
  144. if ( light.isPointLight === true ) {
  145. clusteredLights[ clusteredIndex ++ ] = light;
  146. } else {
  147. materialLights[ materialIndex ++ ] = light;
  148. }
  149. }
  150. materialLights.length = materialIndex;
  151. clusteredLights.length = clusteredIndex;
  152. return super.setLights( materialLights );
  153. }
  154. getBlock() {
  155. return this._lightIndexes.element( this._screenClusterIndex.mul( int( this._chunksPerCluster ) ) );
  156. }
  157. getTile( element ) {
  158. element = int( element );
  159. const stride = int( 4 );
  160. const chunkOffset = element.div( stride );
  161. const idx = this._screenClusterIndex.mul( int( this._chunksPerCluster ) ).add( chunkOffset );
  162. return this._lightIndexes.element( idx ).element( element.mod( stride ) );
  163. }
  164. getClusterLightCount( zSliceNode ) {
  165. const getCount = Fn( ( [ zSliceNode ] ) => {
  166. const count = int( 0 ).toVar();
  167. const debugClusterIndex = this._screenClusterIndex.toVar();
  168. If( zSliceNode.greaterThanEqual( int( 0 ) ), () => {
  169. const tileSize = int( this.tileSize );
  170. const screenTile = screenCoordinate.div( tileSize ).floor();
  171. const NX = int( this._gridDimensions.x );
  172. const NY = int( this._gridDimensions.y );
  173. debugClusterIndex.assign(
  174. int( screenTile.x )
  175. .add( int( screenTile.y ).mul( NX ) )
  176. .add( zSliceNode.mul( NX.mul( NY ) ) )
  177. );
  178. } );
  179. Loop( this.maxLightsPerCluster, ( { i } ) => {
  180. const element = int( i );
  181. const stride = int( 4 );
  182. const chunkOffset = element.div( stride );
  183. const idx = debugClusterIndex.mul( int( this._chunksPerCluster ) ).add( chunkOffset );
  184. const lightIndex = this._lightIndexes.element( idx ).element( element.mod( stride ) );
  185. If( lightIndex.equal( int( 0 ) ), () => {
  186. Break();
  187. } );
  188. count.addAssign( int( 1 ) );
  189. } );
  190. return count;
  191. } );
  192. return getCount( zSliceNode );
  193. }
  194. getLightData( index ) {
  195. index = int( index );
  196. const dataA = textureLoad( this._lightsTexture, ivec2( index, 0 ) );
  197. const dataB = textureLoad( this._lightsTexture, ivec2( index, 1 ) );
  198. const position = dataA.xyz;
  199. const viewPosition = this._cameraViewMatrix.mul( vec4( position, 1.0 ) ).xyz;
  200. const distance = dataA.w;
  201. const color = dataB.rgb;
  202. const decay = dataB.w;
  203. return {
  204. position,
  205. viewPosition,
  206. distance,
  207. color,
  208. decay
  209. };
  210. }
  211. setupLights( builder, lightNodes ) {
  212. this.updateProgram( builder.renderer );
  213. //
  214. const lightingModel = builder.context.reflectedLight;
  215. lightingModel.directDiffuse.toStack();
  216. lightingModel.directSpecular.toStack();
  217. super.setupLights( builder, lightNodes );
  218. Fn( () => {
  219. Loop( this.maxLightsPerCluster, ( { i } ) => {
  220. const lightIndex = this.getTile( i );
  221. If( lightIndex.equal( int( 0 ) ), () => {
  222. Break();
  223. } );
  224. const { color, decay, viewPosition, distance } = this.getLightData( lightIndex.sub( 1 ) );
  225. const lightVector = viewPosition.sub( positionView );
  226. // Early-out: skip full BRDF if fragment is beyond the light's cutoff
  227. If( distance.equal( 0 ).or( dot( lightVector, lightVector ).lessThanEqual( distance.mul( distance ) ) ), () => {
  228. builder.lightsNode.setupDirectLight( builder, this, directPointLight( {
  229. color,
  230. lightVector,
  231. cutoffDistance: distance,
  232. decayExponent: decay
  233. } ) );
  234. } );
  235. } );
  236. }, 'void' )();
  237. }
  238. getBufferFitSize( value ) {
  239. const multiple = this.tileSize;
  240. return Math.ceil( value / multiple ) * multiple;
  241. }
  242. setSize( width, height ) {
  243. width = this.getBufferFitSize( width );
  244. height = this.getBufferFitSize( height );
  245. if ( ! this._bufferSize || this._bufferSize.width !== width || this._bufferSize.height !== height ) {
  246. this.create( width, height );
  247. }
  248. return this;
  249. }
  250. updateProgram( renderer ) {
  251. renderer.getDrawingBufferSize( _size );
  252. const width = this.getBufferFitSize( _size.width );
  253. const height = this.getBufferFitSize( _size.height );
  254. if ( this._bufferSize === null ) {
  255. this.create( width, height );
  256. } else if ( this._bufferSize.width !== width || this._bufferSize.height !== height ) {
  257. this.create( width, height );
  258. }
  259. }
  260. create( width, height ) {
  261. const { tileSize, maxLights, zSlices, maxLightsPerCluster, _chunksPerCluster: chunksPerCluster } = this;
  262. const bufferSize = new Vector2( width, height );
  263. const NX = Math.floor( bufferSize.width / tileSize );
  264. const NY = Math.floor( bufferSize.height / tileSize );
  265. const NZ = zSlices;
  266. const clusterCount = NX * NY * NZ;
  267. this._gridDimensions.value.set( NX, NY );
  268. // Lights data texture (same layout as TiledLightsNode)
  269. const lightsData = new Float32Array( maxLights * 4 * 2 );
  270. const lightsTexture = new DataTexture( lightsData, lightsData.length / 8, 2, RGBAFormat, FloatType );
  271. // Per Z-slice light range for Z-culling (CPU-sorted, uploaded each frame)
  272. const zSliceRangesData = new Float32Array( NZ * 4 );
  273. const zSliceRangesTexture = new DataTexture( zSliceRangesData, NZ, 1, RGBAFormat, FloatType );
  274. // Per-cluster light-index storage (ivec4 chunks)
  275. const lightIndexesArray = new Int32Array( clusterCount * chunksPerCluster * 4 );
  276. const lightIndexes = attributeArray( lightIndexesArray, 'ivec4' ).setName( 'lightIndexes' );
  277. // compute-side accessors (use instanceIndex)
  278. const getClusterChunk = ( chunkIdx ) => {
  279. const idx = instanceIndex.mul( int( chunksPerCluster ) ).add( int( chunkIdx ) );
  280. return lightIndexes.element( idx );
  281. };
  282. const getClusterSlot = ( slotIdx ) => {
  283. slotIdx = int( slotIdx );
  284. const stride = int( 4 );
  285. const chunkOffset = slotIdx.div( stride );
  286. const idx = instanceIndex.mul( int( chunksPerCluster ) ).add( chunkOffset );
  287. return lightIndexes.element( idx ).element( slotIdx.mod( stride ) );
  288. };
  289. // compute: one thread per cluster
  290. const compute = Fn( () => {
  291. // view-space scale factors derived from the projection matrix:
  292. // view_x = ndc_x * (-view_z) / focal_x = ndc_x * (-view_z) * invFocalX
  293. // view_y = ndc_y * (-view_z) / focal_y = ndc_y * (-view_z) * invFocalY
  294. // where focal_x = projMatrix[0][0] and focal_y = projMatrix[1][1].
  295. const invFocalX = float( 1 ).div( this._cameraProjectionMatrix.element( 0 ).element( 0 ) );
  296. const invFocalY = float( 1 ).div( this._cameraProjectionMatrix.element( 1 ).element( 1 ) );
  297. // 3D cluster coordinates from instanceIndex
  298. const cx = instanceIndex.mod( NX );
  299. const cy = instanceIndex.div( NX ).mod( NY );
  300. const cz = instanceIndex.div( NX * NY );
  301. // NDC X/Y bounds of the cluster.
  302. // Y is flipped: cy=0 is the top screen row (fragment y=0), which is NDC y=+1.
  303. const ndcXmin = float( cx ).mul( 2.0 / NX ).sub( 1.0 );
  304. const ndcXmax = float( cx.add( int( 1 ) ) ).mul( 2.0 / NX ).sub( 1.0 );
  305. const ndcYmax = float( 1 ).sub( float( cy ).mul( 2.0 / NY ) );
  306. const ndcYmin = float( 1 ).sub( float( cy.add( int( 1 ) ) ).mul( 2.0 / NY ) );
  307. // View-space Z bounds (negative, exponential slicing)
  308. const farOverNear = this._cameraFar.div( this._cameraNear );
  309. const zNearCluster = this._cameraNear.mul( pow( farOverNear, float( cz ).mul( 1.0 / NZ ) ) ).negate();
  310. const zFarCluster = this._cameraNear.mul( pow( farOverNear, float( cz.add( int( 1 ) ) ).mul( 1.0 / NZ ) ) ).negate();
  311. const scaleNearX = zNearCluster.negate().mul( invFocalX );
  312. const scaleFarX = zFarCluster.negate().mul( invFocalX );
  313. const scaleNearY = zNearCluster.negate().mul( invFocalY );
  314. const scaleFarY = zFarCluster.negate().mul( invFocalY );
  315. const xMinNear = ndcXmin.mul( scaleNearX );
  316. const xMaxNear = ndcXmax.mul( scaleNearX );
  317. const xMinFar = ndcXmin.mul( scaleFarX );
  318. const xMaxFar = ndcXmax.mul( scaleFarX );
  319. const yMinNear = ndcYmin.mul( scaleNearY );
  320. const yMaxNear = ndcYmax.mul( scaleNearY );
  321. const yMinFar = ndcYmin.mul( scaleFarY );
  322. const yMaxFar = ndcYmax.mul( scaleFarY );
  323. // AABB of the 8 view-space corners (tile boundaries can straddle the view axis)
  324. const aabbMinX = min( xMinNear, xMinFar );
  325. const aabbMaxX = max( xMaxNear, xMaxFar );
  326. const aabbMinY = min( yMinNear, yMinFar );
  327. const aabbMaxY = max( yMaxNear, yMaxFar );
  328. const aabbMin = vec3( aabbMinX, aabbMinY, zFarCluster );
  329. const aabbMax = vec3( aabbMaxX, aabbMaxY, zNearCluster );
  330. // clear stale data from previous frame
  331. Loop( chunksPerCluster, ( { i } ) => {
  332. getClusterChunk( i ).assign( ivec4( 0 ) );
  333. } );
  334. const index = int( 0 ).toVar();
  335. // Z-culling: only test lights that can reach this cluster's Z-slice
  336. const zRange = textureLoad( zSliceRangesTexture, ivec2( cz, 0 ) );
  337. const rangeStart = int( zRange.x );
  338. const rangeEnd = int( zRange.y );
  339. Loop( this.maxLights, ( { i } ) => {
  340. const lightIdx = rangeStart.add( i );
  341. If( index.greaterThanEqual( int( maxLightsPerCluster ) ).or( lightIdx.greaterThanEqual( rangeEnd ) ), () => {
  342. Return();
  343. } );
  344. const { viewPosition, distance } = this.getLightData( lightIdx );
  345. // sphere-AABB intersection in view space
  346. const pos = viewPosition.xyz;
  347. const closest = max( aabbMin, min( pos, aabbMax ) );
  348. const diff = pos.sub( closest );
  349. const distSq = dot( diff, diff );
  350. If( distSq.lessThanEqual( distance.mul( distance ) ), () => {
  351. getClusterSlot( index ).assign( lightIdx.add( int( 1 ) ) );
  352. index.addAssign( int( 1 ) );
  353. } );
  354. } );
  355. } )().compute( clusterCount ).setName( 'Update Clustered Lights' );
  356. // shading-side: fragment → cluster index
  357. const getScreenClusterIndex = Fn( () => {
  358. const screenTile = screenCoordinate.div( tileSize ).floor();
  359. // view-space depth from positionView (negative in front); take magnitude
  360. const viewDepth = positionView.z.negate();
  361. // exponential Z slice: tz = floor( log(depth/near) / log(far/near) * NZ )
  362. const invLogFarOverNear = float( 1 ).div( log( this._cameraFar.div( this._cameraNear ) ) );
  363. const sliceFloat = log( viewDepth.div( this._cameraNear ) ).mul( invLogFarOverNear ).mul( float( NZ ) );
  364. const zSlice = clamp( sliceFloat.floor(), float( 0 ), float( NZ - 1 ) );
  365. return int( screenTile.x )
  366. .add( int( screenTile.y ).mul( int( NX ) ) )
  367. .add( int( zSlice ).mul( int( NX * NY ) ) );
  368. } );
  369. const screenClusterIndex = getScreenClusterIndex().toVar();
  370. // assigns
  371. this._bufferSize = bufferSize;
  372. this._lightIndexes = lightIndexes;
  373. this._screenClusterIndex = screenClusterIndex;
  374. this._compute = compute;
  375. this._lightsTexture = lightsTexture;
  376. this._zSliceRangesTexture = zSliceRangesTexture;
  377. this._zSliceRangesData = zSliceRangesData;
  378. }
  379. get hasLights() {
  380. return super.hasLights || this.clusteredLights.length > 0;
  381. }
  382. }
  383. export default ClusteredLightsNode;
  384. /**
  385. * TSL function that creates a clustered lights node.
  386. *
  387. * @tsl
  388. * @function
  389. * @param {number} [maxLights=1024] - Maximum number of point lights.
  390. * @param {number} [tileSize=32] - Screen tile size in pixels.
  391. * @param {number} [zSlices=24] - Depth slice count.
  392. * @param {number} [maxLightsPerCluster=64] - Per-cluster light-list capacity.
  393. * @return {ClusteredLightsNode} The clustered lights node.
  394. */
  395. export const clusteredLights = /*@__PURE__*/ nodeProxy( ClusteredLightsNode );
粤ICP备19079148号