From 45f35f2d162d4d727222eca6039e1b0219536fda Mon Sep 17 00:00:00 2001 From: Martin Valigursky Date: Fri, 17 Apr 2026 11:17:42 +0100 Subject: [PATCH] perf: add radix sort for GSplat tile sorting on NVIDIA GPUs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a per-tile radix sort pass as a faster alternative to bitonic sort for tiles with ≤1976 entries on NVIDIA GPUs. Uses a 5-pass, 4-bit radix sort with subgroupBallot-based stable scatter, operating entirely in 16KB of workgroup shared memory. - Add workgroup-local radix sort shader using subgroupBallot with .x-only optimization (correct for sgSize <= 32) - Three-tier tile classification when radix is enabled: radix (<=1976), bitonic (1977-4096), large (>4096 via bucket+chunk) - Gate radix sort to NVIDIA only -- Apple shows perf regression vs bitonic; AMD wave64 (sgSize 64) would produce incorrect results - Add GPU vendor detection via capsDefines (VENDOR_NVIDIA, etc.) by overriding initCapsDefines() on WebgpuGraphicsDevice - Rename bitonic sort profiler labels for clarity --- src/platform/graphics/graphics-device.js | 9 + .../graphics/webgpu/webgpu-graphics-device.js | 10 + .../gsplat-compute-local-renderer.js | 84 +++++++- .../gsplat-local-dispatch-set.js | 11 +- .../gsplat/compute-gsplat-local-classify.js | 53 ++++-- .../gsplat/compute-gsplat-local-radix-sort.js | 180 ++++++++++++++++++ .../compute-gsplat-local-tile-radix-sort.js | 33 ++++ 7 files changed, 358 insertions(+), 22 deletions(-) create mode 100644 src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-radix-sort.js create mode 100644 src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-tile-radix-sort.js diff --git a/src/platform/graphics/graphics-device.js b/src/platform/graphics/graphics-device.js index fa8e4f352e6..a565b5477ba 100644 --- a/src/platform/graphics/graphics-device.js +++ b/src/platform/graphics/graphics-device.js @@ -290,6 +290,15 @@ class GraphicsDevice extends EventHandler { */ supportsSubgroupId = false; + /** + * Maximum subgroup (warp/wavefront) size supported by the device. Zero if subgroups are + * not supported. Used internally to gate algorithms that assume a specific subgroup size. + * + * @type {number} + * @ignore + */ + maxSubgroupSize = 0; + /** * Currently active render target. * diff --git a/src/platform/graphics/webgpu/webgpu-graphics-device.js b/src/platform/graphics/webgpu/webgpu-graphics-device.js index 042f652051f..ae1ae581df5 100644 --- a/src/platform/graphics/webgpu/webgpu-graphics-device.js +++ b/src/platform/graphics/webgpu/webgpu-graphics-device.js @@ -257,6 +257,15 @@ class WebgpuGraphicsDevice extends GraphicsDevice { this.initCapsDefines(); } + initCapsDefines() { + super.initCapsDefines(); + + const vendor = this.gpuAdapter?.info?.vendor; + if (vendor) { + this.capsDefines.set(`VENDOR_${vendor.toUpperCase()}`, ''); + } + } + async initWebGpu(glslangUrl, twgslUrl) { if (!window.navigator.gpu) { @@ -333,6 +342,7 @@ class WebgpuGraphicsDevice extends GraphicsDevice { this.supportsTextureFormatTier1 ||= this.supportsTextureFormatTier2; this.supportsPrimitiveIndex = requireFeature('primitive-index'); this.supportsSubgroups = requireFeature('subgroups'); + this.maxSubgroupSize = this.supportsSubgroups ? (this.gpuAdapter?.limits?.maxSubgroupSize ?? 0) : 0; Debug.log(`WEBGPU features [${bare ? 'bare' : 'full'}]: ${requiredFeatures.join(', ') || 'none'}`); // copy all adapter limits to the requiredLimits object (skipped for bare mode to use spec defaults) diff --git a/src/scene/gsplat-unified/gsplat-compute-local-renderer.js b/src/scene/gsplat-unified/gsplat-compute-local-renderer.js index 70a96faed0c..ee2faf36a41 100644 --- a/src/scene/gsplat-unified/gsplat-compute-local-renderer.js +++ b/src/scene/gsplat-unified/gsplat-compute-local-renderer.js @@ -32,6 +32,8 @@ import { computeGsplatLocalBucketSortSource } from '../shader-lib/wgsl/chunks/gs import { computeGsplatLocalChunkSortSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-chunk-sort.js'; import { computeGsplatLocalCopySource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-copy.js'; import { computeGsplatLocalBitonicSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-bitonic.js'; +import { computeGsplatLocalRadixSortSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-radix-sort.js'; +import { computeGsplatLocalTileRadixSortSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-tile-radix-sort.js'; import { computeGsplatCommonSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-common.js'; import { computeGsplatTileIntersectSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-tile-intersect.js'; import { GSplatTileComposite } from './gsplat-tile-composite.js'; @@ -289,6 +291,15 @@ class GSplatComputeLocalRenderer extends GSplatRenderer { /** @type {BindGroupFormat} */ _chunkSortBindGroupFormat; + /** @type {boolean} */ + _useRadixSort = false; + + /** @type {Shader|undefined} */ + _radixSortShader; + + /** @type {BindGroupFormat|undefined} */ + _radixSortBindGroupFormat; + /** * @param {GraphicsDevice} device - The graphics device. * @param {GraphNode} node - The graph node. @@ -747,7 +758,8 @@ class GSplatComputeLocalRenderer extends GSplatRenderer { set._totalChunksBuffer.clear(); set._chunkSortIndirectBuffer.clear(); - const indirectSlot = device.getIndirectDispatchSlot(3); + const numIndirectSlots = this._useRadixSort ? 4 : 3; + const indirectSlot = device.getIndirectDispatchSlot(numIndirectSlots); const drawSlot = device.getIndirectDrawSlot(1); set.classifyCompute.setParameter('tileSplatCounts', set._tileSplatCountsBuffer); @@ -758,6 +770,9 @@ class GSplatComputeLocalRenderer extends GSplatRenderer { set.classifyCompute.setParameter('indirectDispatchArgs', device.indirectDispatchBuffer); set.classifyCompute.setParameter('largeTileOverflowBases', set._largeTileOverflowBasesBuffer); set.classifyCompute.setParameter('indirectDrawArgs', device.indirectDrawBuffer); + if (this._useRadixSort) { + set.classifyCompute.setParameter('radixTileList', set._radixTileListBuffer); + } set.classifyCompute.setParameter('numTiles', numTiles); set.classifyCompute.setParameter('dispatchSlotOffset', indirectSlot * 3); set.classifyCompute.setParameter('bufferCapacity', maxEntries); @@ -791,7 +806,19 @@ class GSplatComputeLocalRenderer extends GSplatRenderer { set.copyCompute.setupDispatch(1, 1, 1); device.computeDispatch([set.copyCompute], pickMode ? 'GSplatPickCopy' : 'GSplatLocalCopy'); - // --- Pass 4a: Small tile sort --- + // --- Pass 4a-radix: Radix tile sort (≤1976 entries, when enabled) --- + if (this._useRadixSort) { + set.radixSortCompute.setParameter('tileEntries', this._tileEntriesBuffer); + set.radixSortCompute.setParameter('tileSplatCounts', set._tileSplatCountsBuffer); + set.radixSortCompute.setParameter('depthBuffer', this._depthBuffer); + set.radixSortCompute.setParameter('radixTileList', set._radixTileListBuffer); + set.radixSortCompute.setParameter('tileListCounts', set._tileListCountsBuffer); + + set.radixSortCompute.setupIndirectDispatch(indirectSlot + 3); + device.computeDispatch([set.radixSortCompute], pickMode ? 'GSplatPickTileRadixSort' : 'GSplatLocalTileRadixSort'); + } + + // --- Pass 4a-bitonic: Bitonic tile sort --- set.sortCompute.setParameter('tileEntries', this._tileEntriesBuffer); set.sortCompute.setParameter('tileSplatCounts', set._tileSplatCountsBuffer); set.sortCompute.setParameter('depthBuffer', this._depthBuffer); @@ -799,7 +826,7 @@ class GSplatComputeLocalRenderer extends GSplatRenderer { set.sortCompute.setParameter('tileListCounts', set._tileListCountsBuffer); set.sortCompute.setupIndirectDispatch(indirectSlot); - device.computeDispatch([set.sortCompute], pickMode ? 'GSplatPickTileSort' : 'GSplatLocalTileSort'); + device.computeDispatch([set.sortCompute], pickMode ? 'GSplatPickTileBitonicSort' : 'GSplatLocalTileBitonicSort'); // --- Pass 4c: Chunk sort --- set.chunkSortCompute.setParameter('tileEntries', this._tileEntriesBuffer); @@ -809,7 +836,7 @@ class GSplatComputeLocalRenderer extends GSplatRenderer { set.chunkSortCompute.setParameter('maxChunks', numTiles * MAX_CHUNKS_PER_TILE); set.chunkSortCompute.setupIndirectDispatch(0, set._chunkSortIndirectBuffer); - device.computeDispatch([set.chunkSortCompute], pickMode ? 'GSplatPickChunkSort' : 'GSplatLocalChunkSort'); + device.computeDispatch([set.chunkSortCompute], pickMode ? 'GSplatPickChunkBitonicSort' : 'GSplatLocalChunkBitonicSort'); // --- Pass 5: Rasterize --- // Select the shader variant based on pick mode and depth availability. Depth testing @@ -1057,6 +1084,13 @@ class GSplatComputeLocalRenderer extends GSplatRenderer { }); // --- Classify --- + // Radix sort requires subgroups and uses a subgroupBallot .x-only optimization that + // is only correct when sgSize <= 32. Currently restricted to NVIDIA only: + // - NVIDIA: sgSize always 32, benchmarked faster than bitonic for larger tiles + // - Apple: sgSize 32 but benchmarks show a performance regression vs bitonic + // - AMD: sgSize can be 64 (wave64), .x-only ballot would produce incorrect results + // - Intel/Qualcomm/others: untested, excluded for safety + this._useRadixSort = device.supportsSubgroups && device.capsDefines.has('VENDOR_NVIDIA'); { const ubf = new UniformBufferFormat(device, [ new UniformFormat('numTiles', UNIFORMTYPE_UINT), @@ -1065,7 +1099,7 @@ class GSplatComputeLocalRenderer extends GSplatRenderer { new UniformFormat('maxWorkgroupsPerDim', UNIFORMTYPE_UINT), new UniformFormat('drawSlot', UNIFORMTYPE_UINT) ]); - this._classifyBindGroupFormat = new BindGroupFormat(device, [ + const classifyBindings = [ new BindStorageBufferFormat('tileSplatCounts', SHADERSTAGE_COMPUTE, true), new BindStorageBufferFormat('smallTileList', SHADERSTAGE_COMPUTE), new BindStorageBufferFormat('largeTileList', SHADERSTAGE_COMPUTE), @@ -1075,11 +1109,18 @@ class GSplatComputeLocalRenderer extends GSplatRenderer { new BindStorageBufferFormat('largeTileOverflowBases', SHADERSTAGE_COMPUTE), new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE), new BindStorageBufferFormat('indirectDrawArgs', SHADERSTAGE_COMPUTE) - ]); + ]; + if (this._useRadixSort) { + classifyBindings.push(new BindStorageBufferFormat('radixTileList', SHADERSTAGE_COMPUTE)); + } + this._classifyBindGroupFormat = new BindGroupFormat(device, classifyBindings); + + const cdefines = this._useRadixSort ? new Map([['USE_RADIX_SORT', '']]) : undefined; this._classifyShader = new Shader(device, { name: 'GSplatLocalClassify', shaderLanguage: SHADERLANGUAGE_WGSL, cshader: computeGsplatLocalClassifySource, + cdefines: cdefines, computeBindGroupFormat: this._classifyBindGroupFormat, computeUniformBufferFormats: { uniforms: ubf } }); @@ -1094,13 +1135,31 @@ class GSplatComputeLocalRenderer extends GSplatRenderer { new BindStorageBufferFormat('tileListCounts', SHADERSTAGE_COMPUTE, true) ]); this._sortShader = new Shader(device, { - name: 'GSplatLocalTileSort', + name: 'GSplatLocalTileBitonicSort', shaderLanguage: SHADERLANGUAGE_WGSL, cshader: computeGsplatLocalTileSortSource, cincludes: this._createBitonicIncludes(), computeBindGroupFormat: this._sortBindGroupFormat }); + // --- Radix Sort (for tiles with ≤1976 entries, NVIDIA only) --- + if (this._useRadixSort) { + this._radixSortBindGroupFormat = new BindGroupFormat(device, [ + new BindStorageBufferFormat('tileEntries', SHADERSTAGE_COMPUTE), + new BindStorageBufferFormat('tileSplatCounts', SHADERSTAGE_COMPUTE, true), + new BindStorageBufferFormat('depthBuffer', SHADERSTAGE_COMPUTE, true), + new BindStorageBufferFormat('radixTileList', SHADERSTAGE_COMPUTE, true), + new BindStorageBufferFormat('tileListCounts', SHADERSTAGE_COMPUTE, true) + ]); + this._radixSortShader = new Shader(device, { + name: 'GSplatLocalTileRadixSort', + shaderLanguage: SHADERLANGUAGE_WGSL, + cshader: computeGsplatLocalTileRadixSortSource, + cincludes: new Map([['gsplatLocalRadixSortCS', computeGsplatLocalRadixSortSource]]), + computeBindGroupFormat: this._radixSortBindGroupFormat + }); + } + // --- BucketSort --- { const ubf = new UniformBufferFormat(device, [ @@ -1160,7 +1219,7 @@ class GSplatComputeLocalRenderer extends GSplatRenderer { new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE) ]); this._chunkSortShader = new Shader(device, { - name: 'GSplatLocalChunkSort', + name: 'GSplatLocalChunkBitonicSort', shaderLanguage: SHADERLANGUAGE_WGSL, cshader: computeGsplatLocalChunkSortSource, cincludes: this._createBitonicIncludes(), @@ -1287,7 +1346,12 @@ class GSplatComputeLocalRenderer extends GSplatRenderer { set.classifyCompute = new Compute(device, this._classifyShader, pickMode ? 'GSplatPickClassify' : 'GSplatLocalClassify'); // Sort: shared shader - set.sortCompute = new Compute(device, this._sortShader, pickMode ? 'GSplatPickTileSort' : 'GSplatLocalTileSort'); + set.sortCompute = new Compute(device, this._sortShader, pickMode ? 'GSplatPickTileBitonicSort' : 'GSplatLocalTileBitonicSort'); + + // RadixSort: shared shader (when enabled) + if (this._useRadixSort) { + set.radixSortCompute = new Compute(device, this._radixSortShader, pickMode ? 'GSplatPickTileRadixSort' : 'GSplatLocalTileRadixSort'); + } // BucketSort: shared shader set.bucketSortCompute = new Compute(device, this._bucketSortShader, pickMode ? 'GSplatPickBucketSort' : 'GSplatLocalBucketSort'); @@ -1296,7 +1360,7 @@ class GSplatComputeLocalRenderer extends GSplatRenderer { set.copyCompute = new Compute(device, this._copyShader, pickMode ? 'GSplatPickCopy' : 'GSplatLocalCopy'); // ChunkSort: shared shader - set.chunkSortCompute = new Compute(device, this._chunkSortShader, pickMode ? 'GSplatPickChunkSort' : 'GSplatLocalChunkSort'); + set.chunkSortCompute = new Compute(device, this._chunkSortShader, pickMode ? 'GSplatPickChunkBitonicSort' : 'GSplatLocalChunkBitonicSort'); return set; } diff --git a/src/scene/gsplat-unified/gsplat-local-dispatch-set.js b/src/scene/gsplat-unified/gsplat-local-dispatch-set.js index 7afdf5bfa09..62aae94e6b4 100644 --- a/src/scene/gsplat-unified/gsplat-local-dispatch-set.js +++ b/src/scene/gsplat-unified/gsplat-local-dispatch-set.js @@ -86,6 +86,9 @@ class GSplatLocalDispatchSet { /** @type {Compute} */ chunkSortCompute; + /** @type {Compute|null} */ + radixSortCompute = null; + /** @type {Map} */ _rasterizeVariants = new Map(); @@ -104,6 +107,9 @@ class GSplatLocalDispatchSet { /** @type {StorageBuffer|null} */ _largeTileOverflowBasesBuffer = null; + /** @type {StorageBuffer|null} */ + _radixTileListBuffer = null; + /** @type {StorageBuffer|null} */ _rasterizeTileListBuffer = null; @@ -203,6 +209,7 @@ class GSplatLocalDispatchSet { this._smallTileListBuffer?.destroy(); this._largeTileListBuffer?.destroy(); this._largeTileOverflowBasesBuffer?.destroy(); + this._radixTileListBuffer?.destroy(); this._rasterizeTileListBuffer?.destroy(); this._tileListCountsBuffer?.destroy(); this._chunkRangesBuffer?.destroy(); @@ -214,8 +221,9 @@ class GSplatLocalDispatchSet { this._smallTileListBuffer = new StorageBuffer(this.device, numTiles * 4); this._largeTileListBuffer = new StorageBuffer(this.device, numTiles * 4); this._largeTileOverflowBasesBuffer = new StorageBuffer(this.device, numTiles * 4); + this._radixTileListBuffer = new StorageBuffer(this.device, numTiles * 4); this._rasterizeTileListBuffer = new StorageBuffer(this.device, numTiles * 4); - this._tileListCountsBuffer = new StorageBuffer(this.device, 4 * 4, BUFFERUSAGE_COPY_DST | BUFFERUSAGE_COPY_SRC); + this._tileListCountsBuffer = new StorageBuffer(this.device, 5 * 4, BUFFERUSAGE_COPY_DST | BUFFERUSAGE_COPY_SRC); const maxChunks = numTiles * MAX_CHUNKS_PER_TILE; this._chunkRangesBuffer = new StorageBuffer(this.device, maxChunks * 8); @@ -397,6 +405,7 @@ class GSplatLocalDispatchSet { this._smallTileListBuffer?.destroy(); this._largeTileListBuffer?.destroy(); this._largeTileOverflowBasesBuffer?.destroy(); + this._radixTileListBuffer?.destroy(); this._rasterizeTileListBuffer?.destroy(); this._tileListCountsBuffer?.destroy(); this._chunkRangesBuffer?.destroy(); diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-classify.js b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-classify.js index 32c5f7e0feb..013c47de319 100644 --- a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-classify.js +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-classify.js @@ -1,9 +1,18 @@ -// Tile classification: scans prefix-summed tile counts, builds small/large/rasterize +// Tile classification: scans prefix-summed tile counts, builds radix/bitonic/large/rasterize // tile lists, writes indirect dispatch args for subsequent passes, and writes indirect // draw args for the tile-based composite. -// For large tiles (>4096 entries), assigns compact overflow scratch offsets within -// the shared tileEntries buffer (overflow region starts at totalEntries). +// When USE_RADIX_SORT is defined: three sort tiers — radix (≤1976), bitonic (1977..4096), +// large (>4096 via bucket+chunk). Otherwise: two tiers — bitonic (≤4096), large (>4096). +// For large tiles, assigns compact overflow scratch offsets within the shared tileEntries +// buffer (overflow region starts at totalEntries). // Single workgroup (256 threads) — each thread processes ceil(numTiles/256) tiles. +// +// tileListCounts layout: +// [0] = bitonic tile count (smallTileList) +// [1] = large tile count (largeTileList) +// [2] = rasterize tile count (rasterizeTileList) +// [3] = large tile overflow entries claimed +// [4] = radix tile count (radixTileList) — only used when USE_RADIX_SORT is defined import indirectCoreCS from '../common/comp/indirect-core.js'; import dispatchCoreCS from '../common/comp/dispatch-core.js'; @@ -13,7 +22,10 @@ export const computeGsplatLocalClassifySource = /* wgsl */` ${indirectCoreCS} ${dispatchCoreCS} -const MAX_TILE_ENTRIES: u32 = 4096u; +#ifdef USE_RADIX_SORT +const RADIX_MAX_ENTRIES: u32 = 1976u; +#endif +const BITONIC_MAX_ENTRIES: u32 = 4096u; const CLASSIFY_WORKGROUP: u32 = 256u; @group(0) @binding(0) var tileSplatCounts: array; @@ -24,6 +36,9 @@ const CLASSIFY_WORKGROUP: u32 = 256u; @group(0) @binding(5) var indirectDispatchArgs: array; @group(0) @binding(6) var largeTileOverflowBases: array; @group(0) @binding(8) var indirectDrawArgs: array; +#ifdef USE_RADIX_SORT +@group(0) @binding(9) var radixTileList: array; +#endif struct Uniforms { numTiles: u32, @@ -52,7 +67,14 @@ fn main(@builtin(local_invocation_index) localIdx: u32) { let rIdx = atomicAdd(&tileListCounts[2], 1u); rasterizeTileList[rIdx] = i; - if (count <= MAX_TILE_ENTRIES) { +#ifdef USE_RADIX_SORT + if (count <= RADIX_MAX_ENTRIES) { + let rxIdx = atomicAdd(&tileListCounts[4], 1u); + radixTileList[rxIdx] = i; + } else if (count <= BITONIC_MAX_ENTRIES) { +#else + if (count <= BITONIC_MAX_ENTRIES) { +#endif let sIdx = atomicAdd(&tileListCounts[0], 1u); smallTileList[sIdx] = i; } else { @@ -68,20 +90,20 @@ fn main(@builtin(local_invocation_index) localIdx: u32) { workgroupBarrier(); - // Thread 0 writes indirect dispatch args for passes 4a (small sort), 4b (bucket), 5 (rasterize). + // Thread 0 writes indirect dispatch args for sort and rasterize passes. // Uses balanced 2D dispatch to stay within maxComputeWorkgroupsPerDimension with minimal waste: // y = ceil(count / maxDim), x = ceil(count / y). Waste is at most y-1 workgroups (typically 0-1). if (localIdx == 0u) { - let smallCount = atomicLoad(&tileListCounts[0]); + let bitonicCount = atomicLoad(&tileListCounts[0]); let largeCount = atomicLoad(&tileListCounts[1]); let rasterizeCount = atomicLoad(&tileListCounts[2]); let off = uniforms.dispatchSlotOffset; let maxDim = uniforms.maxWorkgroupsPerDim; - // Slot 0: small tile sort — 1 workgroup per tile - let smallDim = calcDispatch2D(smallCount, maxDim); - indirectDispatchArgs[off + 0u] = smallDim.x; - indirectDispatchArgs[off + 1u] = smallDim.y; + // Slot 0: bitonic tile sort + let bitonicDim = calcDispatch2D(bitonicCount, maxDim); + indirectDispatchArgs[off + 0u] = bitonicDim.x; + indirectDispatchArgs[off + 1u] = bitonicDim.y; indirectDispatchArgs[off + 2u] = 1u; // Slot 1: bucket pre-sort — 1 workgroup per large tile @@ -96,6 +118,15 @@ fn main(@builtin(local_invocation_index) localIdx: u32) { indirectDispatchArgs[off + 7u] = rasterDim.y; indirectDispatchArgs[off + 8u] = 1u; +#ifdef USE_RADIX_SORT + // Slot 3: radix tile sort (≤1976 entries) + let radixCount = atomicLoad(&tileListCounts[4]); + let radixDim = calcDispatch2D(radixCount, maxDim); + indirectDispatchArgs[off + 9u] = radixDim.x; + indirectDispatchArgs[off + 10u] = radixDim.y; + indirectDispatchArgs[off + 11u] = 1u; +#endif + // Indirect draw args for tile-based composite: 6 vertices per tile quad indirectDrawArgs[uniforms.drawSlot] = DrawIndirectArgs(rasterizeCount * 6u, 1u, 0u, 0u, 0u); } diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-radix-sort.js b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-radix-sort.js new file mode 100644 index 00000000000..d34b25419da --- /dev/null +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-radix-sort.js @@ -0,0 +1,180 @@ +// Per-tile radix sort for tiles with up to 1976 entries. +// Uses 5×4-bit radix sort (16 buckets per pass) in 16KB shared memory, +// with ping-pong buffers. 1 element per thread per step. +// Requires subgroup support for stable scatter via subgroupBallot. +// Ballot processing uses only .x component — correct only for sgSize <= 32 +// (NVIDIA, Intel, ARM Mali). Not safe for AMD wave64 or Qualcomm Adreno (sgSize 64+). +// +// Shared memory layout (16,384 bytes): +// sA[1976] = 7,904 bytes (ping buffer) +// sB[1976] = 7,904 bytes (pong buffer) +// histogram[16] = 64 bytes (atomic, reused for min/max, histogram, prefix sums) +// warpCounts[8×16=128] = 512 bytes (per-subgroup per-digit counts) +export const computeGsplatLocalRadixSortSource = /* wgsl */` + +const RADIX_MAX_ENTRIES: u32 = 1976u; +const RADIX_WG_SIZE: u32 = 256u; +const RADIX_BITS: u32 = 4u; +const NUM_BUCKETS: u32 = 16u; +const BUCKET_MASK: u32 = 0xFu; +const NUM_PASSES: u32 = 5u; +const MAX_SUBGROUPS: u32 = 8u; +const INDEX_BITS: u32 = 11u; +const INDEX_MASK: u32 = 0x7FFu; +const DEPTH_LEVELS: f32 = 1048575.0; + +var sA: array; +var sB: array; +var histogram: array, 16>; +var warpCounts: array; + +fn radixSortRange(localIdx: u32, sgInvId: u32, sgSize: u32, tStart: u32, count: u32) { + let clampedCount = min(count, RADIX_MAX_ENTRIES); + + if (clampedCount <= 1u) { + return; + } + + let sgId = localIdx / sgSize; + let numSgs = RADIX_WG_SIZE / sgSize; + let sgInvMask = (1u << sgInvId) - 1u; + + // Phase 1: Load f32 depths into sA + if (localIdx == 0u) { + atomicStore(&histogram[0], 0xFFFFFFFFu); + atomicStore(&histogram[1], 0u); + } + + for (var i: u32 = localIdx; i < clampedCount; i += RADIX_WG_SIZE) { + let entryIdx = tileEntries[tStart + i]; + sA[i] = depthBuffer[entryIdx]; + } + + workgroupBarrier(); + + // Phase 2: Min/max reduction via atomics + for (var i: u32 = localIdx; i < clampedCount; i += RADIX_WG_SIZE) { + atomicMin(&histogram[0], sA[i]); + atomicMax(&histogram[1], sA[i]); + } + + workgroupBarrier(); + + let depthMin = bitcast(atomicLoad(&histogram[0])); + let depthMax = bitcast(atomicLoad(&histogram[1])); + + let logMin = log(max(depthMin, 1e-6)); + let logRange = log(max(depthMax, 1e-6)) - logMin; + let invLogRange = select(DEPTH_LEVELS / logRange, 0.0, logRange < 1e-10); + + // Phase 3: Repack to (depth20 << 11 | localIndex11) + for (var i: u32 = localIdx; i < clampedCount; i += RADIX_WG_SIZE) { + let depth = bitcast(sA[i]); + let logDepth = log(max(depth, 1e-6)); + let depth20 = min(u32((logDepth - logMin) * invLogRange + 0.5), u32(DEPTH_LEVELS)); + sA[i] = (depth20 << INDEX_BITS) | i; + } + + workgroupBarrier(); + + // Phase 4: 5-pass radix sort (4 bits per pass, bits 11..30). + // After 5 passes (odd count), result is in sB. + for (var p: u32 = 0u; p < NUM_PASSES; p++) { + let bitOffset = INDEX_BITS + p * RADIX_BITS; + let even = (p % 2u == 0u); + + // 4a: Zero histogram + if (localIdx < NUM_BUCKETS) { + atomicStore(&histogram[localIdx], 0u); + } + workgroupBarrier(); + + // 4b: Build histogram + for (var i: u32 = localIdx; i < clampedCount; i += RADIX_WG_SIZE) { + var v: u32; + if (even) { v = sA[i]; } else { v = sB[i]; } + atomicAdd(&histogram[(v >> bitOffset) & BUCKET_MASK], 1u); + } + workgroupBarrier(); + + // 4c: Exclusive prefix sum (single thread — only 16 entries) + if (localIdx == 0u) { + var sum: u32 = 0u; + for (var d: u32 = 0u; d < NUM_BUCKETS; d++) { + let c = atomicLoad(&histogram[d]); + atomicStore(&histogram[d], sum); + sum += c; + } + } + workgroupBarrier(); + + // 4d: Stable scatter — 1 element per thread, using subgroupBallot. + // Ballot .x-only: correct only for sgSize <= 32 (NVIDIA, Intel, ARM Mali). + let numSteps = (clampedCount + RADIX_WG_SIZE - 1u) / RADIX_WG_SIZE; + + for (var step: u32 = 0u; step < numSteps; step++) { + let idx = step * RADIX_WG_SIZE + localIdx; + let valid = idx < clampedCount; + + var val: u32 = 0u; + var digit: u32 = 0u; + if (valid) { + if (even) { val = sA[idx]; } else { val = sB[idx]; } + digit = (val >> bitOffset) & BUCKET_MASK; + } + + var intraRank: u32 = 0u; + + for (var d: u32 = 0u; d < NUM_BUCKETS; d++) { + let ballot = subgroupBallot(valid && digit == d); + let cnt = countOneBits(ballot.x); + + if (valid && digit == d) { + intraRank = countOneBits(ballot.x & sgInvMask); + } + + if (sgInvId == d) { + warpCounts[sgId * NUM_BUCKETS + d] = cnt; + } + } + + workgroupBarrier(); + + if (valid) { + var rank: u32 = atomicLoad(&histogram[digit]); + for (var w: u32 = 0u; w < sgId; w++) { + rank += warpCounts[w * NUM_BUCKETS + digit]; + } + rank += intraRank; + if (even) { sB[rank] = val; } else { sA[rank] = val; } + } + + workgroupBarrier(); + + if (localIdx < NUM_BUCKETS) { + var total: u32 = 0u; + for (var w: u32 = 0u; w < numSgs; w++) { + total += warpCounts[w * NUM_BUCKETS + localIdx]; + } + atomicAdd(&histogram[localIdx], total); + } + + workgroupBarrier(); + } + } + + // After 5 passes (odd count), sorted data is in sB. + + // Phase 5: Extract local indices and write sorted global entries back. + for (var i: u32 = localIdx; i < clampedCount; i += RADIX_WG_SIZE) { + let localIndex = sB[i] & INDEX_MASK; + sA[i] = tileEntries[tStart + localIndex]; + } + + workgroupBarrier(); + + for (var i: u32 = localIdx; i < clampedCount; i += RADIX_WG_SIZE) { + tileEntries[tStart + i] = sA[i]; + } +} +`; diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-tile-radix-sort.js b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-tile-radix-sort.js new file mode 100644 index 00000000000..104354dba40 --- /dev/null +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-tile-radix-sort.js @@ -0,0 +1,33 @@ +// Per-tile radix sort for tiles with up to 1976 entries. +// Reads tile index from radixTileList and delegates to the shared radix sort logic. +// Requires subgroup support for stable scatter. +export const computeGsplatLocalTileRadixSortSource = /* wgsl */` + +#include "gsplatLocalRadixSortCS" + +@group(0) @binding(0) var tileEntries: array; +@group(0) @binding(1) var tileSplatCounts: array; +@group(0) @binding(2) var depthBuffer: array; +@group(0) @binding(3) var radixTileList: array; +@group(0) @binding(4) var tileListCounts: array; + +@compute @workgroup_size(256) +fn main( + @builtin(local_invocation_index) localIdx: u32, + @builtin(workgroup_id) wid: vec3u, + @builtin(num_workgroups) numWorkgroups: vec3u, + @builtin(subgroup_invocation_id) sgInvId: u32, + @builtin(subgroup_size) sgSize: u32 +) { + let workgroupIdx = wid.y * numWorkgroups.x + wid.x; + if (workgroupIdx >= tileListCounts[4]) { + return; + } + let tileIdx = radixTileList[workgroupIdx]; + let tStart = tileSplatCounts[tileIdx]; + let tEnd = tileSplatCounts[tileIdx + 1u]; + let count = tEnd - tStart; + + radixSortRange(localIdx, sgInvId, sgSize, tStart, count); +} +`;