PoissonDenoiseShader.js 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. import {
  2. Matrix4,
  3. Vector2,
  4. Vector3,
  5. } from 'three';
  6. /**
  7. * @module PoissonDenoiseShader
  8. * @three_import import { PoissonDenoiseShader } from 'three/addons/shaders/PoissonDenoiseShader.js';
  9. */
  10. /**
  11. * Poisson Denoise Shader.
  12. *
  13. * References:
  14. * - [Self-Supervised Poisson-Gaussian Denoising](https://openaccess.thecvf.com/content/WACV2021/papers/Khademi_Self-Supervised_Poisson-Gaussian_Denoising_WACV_2021_paper.pdf).
  15. * - [Poisson2Sparse: Self-Supervised Poisson Denoising From a Single Image](https://arxiv.org/pdf/2206.01856.pdf)
  16. *
  17. * @constant
  18. * @type {ShaderMaterial~Shader}
  19. */
  20. const PoissonDenoiseShader = {
  21. name: 'PoissonDenoiseShader',
  22. defines: {
  23. 'SAMPLES': 16,
  24. 'SAMPLE_VECTORS': generatePdSamplePointInitializer( 16, 2, 1 ),
  25. 'NORMAL_VECTOR_TYPE': 1,
  26. 'DEPTH_VALUE_SOURCE': 0,
  27. },
  28. uniforms: {
  29. 'tDiffuse': { value: null },
  30. 'tNormal': { value: null },
  31. 'tDepth': { value: null },
  32. 'tNoise': { value: null },
  33. 'resolution': { value: new Vector2() },
  34. 'cameraProjectionMatrixInverse': { value: new Matrix4() },
  35. 'lumaPhi': { value: 5. },
  36. 'depthPhi': { value: 5. },
  37. 'normalPhi': { value: 5. },
  38. 'radius': { value: 4. },
  39. 'index': { value: 0 }
  40. },
  41. vertexShader: /* glsl */`
  42. varying vec2 vUv;
  43. void main() {
  44. vUv = uv;
  45. gl_Position = projectionMatrix * modelViewMatrix * vec4( position, 1.0 );
  46. }`,
  47. fragmentShader: /* glsl */`
  48. varying vec2 vUv;
  49. uniform sampler2D tDiffuse;
  50. uniform sampler2D tNormal;
  51. uniform sampler2D tDepth;
  52. uniform sampler2D tNoise;
  53. uniform vec2 resolution;
  54. uniform mat4 cameraProjectionMatrixInverse;
  55. uniform float lumaPhi;
  56. uniform float depthPhi;
  57. uniform float normalPhi;
  58. uniform float radius;
  59. uniform int index;
  60. #include <common>
  61. #include <packing>
  62. #ifndef SAMPLE_LUMINANCE
  63. #define SAMPLE_LUMINANCE dot(vec3(0.2125, 0.7154, 0.0721), a)
  64. #endif
  65. #ifndef FRAGMENT_OUTPUT
  66. #define FRAGMENT_OUTPUT vec4(denoised, 1.)
  67. #endif
  68. float getLuminance(const in vec3 a) {
  69. return SAMPLE_LUMINANCE;
  70. }
  71. const vec3 poissonDisk[SAMPLES] = SAMPLE_VECTORS;
  72. vec3 getViewPosition( const in vec2 screenPosition, const in float depth ) {
  73. #ifdef USE_REVERSED_DEPTH_BUFFER
  74. vec4 clipSpacePosition = vec4( vec2( screenPosition ) * 2.0 - 1.0, depth, 1.0 );
  75. #else
  76. vec4 clipSpacePosition = vec4( vec3( screenPosition, depth ) * 2.0 - 1.0, 1.0 );
  77. #endif
  78. vec4 viewSpacePosition = cameraProjectionMatrixInverse * clipSpacePosition;
  79. return viewSpacePosition.xyz / viewSpacePosition.w;
  80. }
  81. float getDepth(const vec2 uv) {
  82. #if DEPTH_VALUE_SOURCE == 1
  83. return textureLod(tDepth, uv.xy, 0.0).a;
  84. #else
  85. return textureLod(tDepth, uv.xy, 0.0).r;
  86. #endif
  87. }
  88. float fetchDepth(const ivec2 uv) {
  89. #if DEPTH_VALUE_SOURCE == 1
  90. return texelFetch(tDepth, uv.xy, 0).a;
  91. #else
  92. return texelFetch(tDepth, uv.xy, 0).r;
  93. #endif
  94. }
  95. vec3 computeNormalFromDepth(const vec2 uv) {
  96. vec2 size = vec2(textureSize(tDepth, 0));
  97. ivec2 p = ivec2(uv * size);
  98. float c0 = fetchDepth(p);
  99. float l2 = fetchDepth(p - ivec2(2, 0));
  100. float l1 = fetchDepth(p - ivec2(1, 0));
  101. float r1 = fetchDepth(p + ivec2(1, 0));
  102. float r2 = fetchDepth(p + ivec2(2, 0));
  103. float b2 = fetchDepth(p - ivec2(0, 2));
  104. float b1 = fetchDepth(p - ivec2(0, 1));
  105. float t1 = fetchDepth(p + ivec2(0, 1));
  106. float t2 = fetchDepth(p + ivec2(0, 2));
  107. float dl = abs((2.0 * l1 - l2) - c0);
  108. float dr = abs((2.0 * r1 - r2) - c0);
  109. float db = abs((2.0 * b1 - b2) - c0);
  110. float dt = abs((2.0 * t1 - t2) - c0);
  111. vec3 ce = getViewPosition(uv, c0).xyz;
  112. vec3 dpdx = (dl < dr) ? ce - getViewPosition((uv - vec2(1.0 / size.x, 0.0)), l1).xyz
  113. : -ce + getViewPosition((uv + vec2(1.0 / size.x, 0.0)), r1).xyz;
  114. vec3 dpdy = (db < dt) ? ce - getViewPosition((uv - vec2(0.0, 1.0 / size.y)), b1).xyz
  115. : -ce + getViewPosition((uv + vec2(0.0, 1.0 / size.y)), t1).xyz;
  116. return normalize(cross(dpdx, dpdy));
  117. }
  118. vec3 getViewNormal(const vec2 uv) {
  119. #if NORMAL_VECTOR_TYPE == 2
  120. return normalize(textureLod(tNormal, uv, 0.).rgb);
  121. #elif NORMAL_VECTOR_TYPE == 1
  122. return unpackRGBToNormal(textureLod(tNormal, uv, 0.).rgb);
  123. #else
  124. return computeNormalFromDepth(uv);
  125. #endif
  126. }
  127. void denoiseSample(in vec3 center, in vec3 viewNormal, in vec3 viewPos, in vec2 sampleUv, inout vec3 denoised, inout float totalWeight) {
  128. vec4 sampleTexel = textureLod(tDiffuse, sampleUv, 0.0);
  129. float sampleDepth = getDepth(sampleUv);
  130. vec3 sampleNormal = getViewNormal(sampleUv);
  131. vec3 neighborColor = sampleTexel.rgb;
  132. vec3 viewPosSample = getViewPosition(sampleUv, sampleDepth);
  133. float normalDiff = dot(viewNormal, sampleNormal);
  134. float normalSimilarity = pow(max(normalDiff, 0.), normalPhi);
  135. float lumaDiff = abs(getLuminance(neighborColor) - getLuminance(center));
  136. float lumaSimilarity = max(1.0 - lumaDiff / lumaPhi, 0.0);
  137. float depthDiff = abs(dot(viewPos - viewPosSample, viewNormal));
  138. float depthSimilarity = max(1. - depthDiff / depthPhi, 0.);
  139. float w = lumaSimilarity * depthSimilarity * normalSimilarity;
  140. denoised += w * neighborColor;
  141. totalWeight += w;
  142. }
  143. void main() {
  144. float depth = getDepth(vUv.xy);
  145. vec3 viewNormal = getViewNormal(vUv);
  146. if (depth == 1. || dot(viewNormal, viewNormal) == 0.) {
  147. discard;
  148. return;
  149. }
  150. vec4 texel = textureLod(tDiffuse, vUv, 0.0);
  151. vec3 center = texel.rgb;
  152. vec3 viewPos = getViewPosition(vUv, depth);
  153. vec2 noiseResolution = vec2(textureSize(tNoise, 0));
  154. vec2 noiseUv = vUv * resolution / noiseResolution;
  155. vec4 noiseTexel = textureLod(tNoise, noiseUv, 0.0);
  156. vec2 noiseVec = vec2(sin(noiseTexel[index % 4] * 2. * PI), cos(noiseTexel[index % 4] * 2. * PI));
  157. mat2 rotationMatrix = mat2(noiseVec.x, -noiseVec.y, noiseVec.x, noiseVec.y);
  158. float totalWeight = 1.0;
  159. vec3 denoised = texel.rgb;
  160. for (int i = 0; i < SAMPLES; i++) {
  161. vec3 sampleDir = poissonDisk[i];
  162. vec2 offset = rotationMatrix * (sampleDir.xy * (1. + sampleDir.z * (radius - 1.)) / resolution);
  163. vec2 sampleUv = vUv + offset;
  164. denoiseSample(center, viewNormal, viewPos, sampleUv, denoised, totalWeight);
  165. }
  166. if (totalWeight > 0.) {
  167. denoised /= totalWeight;
  168. }
  169. gl_FragColor = FRAGMENT_OUTPUT;
  170. }`
  171. };
  172. function generatePdSamplePointInitializer( samples, rings, radiusExponent ) {
  173. const poissonDisk = generateDenoiseSamples(
  174. samples,
  175. rings,
  176. radiusExponent,
  177. );
  178. let glslCode = 'vec3[SAMPLES](';
  179. for ( let i = 0; i < samples; i ++ ) {
  180. const sample = poissonDisk[ i ];
  181. glslCode += `vec3(${sample.x}, ${sample.y}, ${sample.z})${( i < samples - 1 ) ? ',' : ')'}`;
  182. }
  183. return glslCode;
  184. }
  185. function generateDenoiseSamples( numSamples, numRings, radiusExponent ) {
  186. const samples = [];
  187. for ( let i = 0; i < numSamples; i ++ ) {
  188. const angle = 2 * Math.PI * numRings * i / numSamples;
  189. const radius = Math.pow( i / ( numSamples - 1 ), radiusExponent );
  190. samples.push( new Vector3( Math.cos( angle ), Math.sin( angle ), radius ) );
  191. }
  192. return samples;
  193. }
  194. export { generatePdSamplePointInitializer, PoissonDenoiseShader };
粤ICP备19079148号