Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/platform/graphics/graphics-device.js
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,15 @@ class GraphicsDevice extends EventHandler {
*/
supportsSubgroupId = false;

/**
* Maximum subgroup (warp/wavefront) size supported by the device. Zero if subgroups are
* not supported. Used internally to gate algorithms that assume a specific subgroup size.
*
* @type {number}
* @ignore
*/
maxSubgroupSize = 0;

/**
* Currently active render target.
*
Expand Down
10 changes: 10 additions & 0 deletions src/platform/graphics/webgpu/webgpu-graphics-device.js
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,15 @@ class WebgpuGraphicsDevice extends GraphicsDevice {
this.initCapsDefines();
}

initCapsDefines() {
super.initCapsDefines();

const vendor = this.gpuAdapter?.info?.vendor;
if (vendor) {
this.capsDefines.set(`VENDOR_${vendor.toUpperCase()}`, '');
}
}

async initWebGpu(glslangUrl, twgslUrl) {

if (!window.navigator.gpu) {
Expand Down Expand Up @@ -333,6 +342,7 @@ class WebgpuGraphicsDevice extends GraphicsDevice {
this.supportsTextureFormatTier1 ||= this.supportsTextureFormatTier2;
this.supportsPrimitiveIndex = requireFeature('primitive-index');
this.supportsSubgroups = requireFeature('subgroups');
this.maxSubgroupSize = this.supportsSubgroups ? (this.gpuAdapter?.limits?.maxSubgroupSize ?? 0) : 0;
Debug.log(`WEBGPU features [${bare ? 'bare' : 'full'}]: ${requiredFeatures.join(', ') || 'none'}`);

// copy all adapter limits to the requiredLimits object (skipped for bare mode to use spec defaults)
Expand Down
84 changes: 74 additions & 10 deletions src/scene/gsplat-unified/gsplat-compute-local-renderer.js
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import { computeGsplatLocalBucketSortSource } from '../shader-lib/wgsl/chunks/gs
import { computeGsplatLocalChunkSortSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-chunk-sort.js';
import { computeGsplatLocalCopySource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-copy.js';
import { computeGsplatLocalBitonicSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-bitonic.js';
import { computeGsplatLocalRadixSortSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-radix-sort.js';
import { computeGsplatLocalTileRadixSortSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-local-tile-radix-sort.js';
import { computeGsplatCommonSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-common.js';
import { computeGsplatTileIntersectSource } from '../shader-lib/wgsl/chunks/gsplat/compute-gsplat-tile-intersect.js';
import { GSplatTileComposite } from './gsplat-tile-composite.js';
Expand Down Expand Up @@ -314,6 +316,15 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
/** @type {BindGroupFormat} */
_chunkSortBindGroupFormat;

/** @type {boolean} */
_useRadixSort = false;

/** @type {Shader|undefined} */
_radixSortShader;

/** @type {BindGroupFormat|undefined} */
_radixSortBindGroupFormat;

/**
* @param {GraphicsDevice} device - The graphics device.
* @param {GraphNode} node - The graph node.
Expand Down Expand Up @@ -772,7 +783,8 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
set._totalChunksBuffer.clear();
set._chunkSortIndirectBuffer.clear();

const indirectSlot = device.getIndirectDispatchSlot(3);
const numIndirectSlots = this._useRadixSort ? 4 : 3;
const indirectSlot = device.getIndirectDispatchSlot(numIndirectSlots);
const drawSlot = device.getIndirectDrawSlot(1);

set.classifyCompute.setParameter('tileSplatCounts', set._tileSplatCountsBuffer);
Expand All @@ -783,6 +795,9 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
set.classifyCompute.setParameter('indirectDispatchArgs', device.indirectDispatchBuffer);
set.classifyCompute.setParameter('largeTileOverflowBases', set._largeTileOverflowBasesBuffer);
set.classifyCompute.setParameter('indirectDrawArgs', device.indirectDrawBuffer);
if (this._useRadixSort) {
set.classifyCompute.setParameter('radixTileList', set._radixTileListBuffer);
}
set.classifyCompute.setParameter('numTiles', numTiles);
set.classifyCompute.setParameter('dispatchSlotOffset', indirectSlot * 3);
set.classifyCompute.setParameter('bufferCapacity', maxEntries);
Expand Down Expand Up @@ -816,15 +831,27 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
set.copyCompute.setupDispatch(1, 1, 1);
device.computeDispatch([set.copyCompute], pickMode ? 'GSplatPickCopy' : 'GSplatLocalCopy');

// --- Pass 4a: Small tile sort ---
// --- Pass 4a-radix: Radix tile sort (≤1976 entries, when enabled) ---
if (this._useRadixSort) {
set.radixSortCompute.setParameter('tileEntries', this._tileEntriesBuffer);
set.radixSortCompute.setParameter('tileSplatCounts', set._tileSplatCountsBuffer);
set.radixSortCompute.setParameter('depthBuffer', this._depthBuffer);
set.radixSortCompute.setParameter('radixTileList', set._radixTileListBuffer);
set.radixSortCompute.setParameter('tileListCounts', set._tileListCountsBuffer);

set.radixSortCompute.setupIndirectDispatch(indirectSlot + 3);
device.computeDispatch([set.radixSortCompute], pickMode ? 'GSplatPickTileRadixSort' : 'GSplatLocalTileRadixSort');
}

// --- Pass 4a-bitonic: Bitonic tile sort ---
set.sortCompute.setParameter('tileEntries', this._tileEntriesBuffer);
set.sortCompute.setParameter('tileSplatCounts', set._tileSplatCountsBuffer);
set.sortCompute.setParameter('depthBuffer', this._depthBuffer);
set.sortCompute.setParameter('smallTileList', set._smallTileListBuffer);
set.sortCompute.setParameter('tileListCounts', set._tileListCountsBuffer);

set.sortCompute.setupIndirectDispatch(indirectSlot);
device.computeDispatch([set.sortCompute], pickMode ? 'GSplatPickTileSort' : 'GSplatLocalTileSort');
device.computeDispatch([set.sortCompute], pickMode ? 'GSplatPickTileBitonicSort' : 'GSplatLocalTileBitonicSort');

// --- Pass 4c: Chunk sort ---
set.chunkSortCompute.setParameter('tileEntries', this._tileEntriesBuffer);
Expand All @@ -834,7 +861,7 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
set.chunkSortCompute.setParameter('maxChunks', numTiles * MAX_CHUNKS_PER_TILE);

set.chunkSortCompute.setupIndirectDispatch(0, set._chunkSortIndirectBuffer);
device.computeDispatch([set.chunkSortCompute], pickMode ? 'GSplatPickChunkSort' : 'GSplatLocalChunkSort');
device.computeDispatch([set.chunkSortCompute], pickMode ? 'GSplatPickChunkBitonicSort' : 'GSplatLocalChunkBitonicSort');

// --- Pass 5: Rasterize ---
// Select the shader variant based on pick mode and depth availability. Depth testing
Expand Down Expand Up @@ -1082,6 +1109,13 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
});

// --- Classify ---
// Radix sort requires subgroups and uses a subgroupBallot .x-only optimization that
// is only correct when sgSize <= 32. Currently restricted to NVIDIA only:
// - NVIDIA: sgSize always 32, benchmarked faster than bitonic for larger tiles
// - Apple: sgSize 32 but benchmarks show a performance regression vs bitonic
// - AMD: sgSize can be 64 (wave64), .x-only ballot would produce incorrect results
// - Intel/Qualcomm/others: untested, excluded for safety
this._useRadixSort = device.supportsSubgroups && device.capsDefines.has('VENDOR_NVIDIA');
{
const ubf = new UniformBufferFormat(device, [
new UniformFormat('numTiles', UNIFORMTYPE_UINT),
Expand All @@ -1090,7 +1124,7 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
new UniformFormat('maxWorkgroupsPerDim', UNIFORMTYPE_UINT),
new UniformFormat('drawSlot', UNIFORMTYPE_UINT)
]);
this._classifyBindGroupFormat = new BindGroupFormat(device, [
const classifyBindings = [
new BindStorageBufferFormat('tileSplatCounts', SHADERSTAGE_COMPUTE, true),
new BindStorageBufferFormat('smallTileList', SHADERSTAGE_COMPUTE),
new BindStorageBufferFormat('largeTileList', SHADERSTAGE_COMPUTE),
Expand All @@ -1100,11 +1134,18 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
new BindStorageBufferFormat('largeTileOverflowBases', SHADERSTAGE_COMPUTE),
new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE),
new BindStorageBufferFormat('indirectDrawArgs', SHADERSTAGE_COMPUTE)
]);
];
if (this._useRadixSort) {
classifyBindings.push(new BindStorageBufferFormat('radixTileList', SHADERSTAGE_COMPUTE));
}
this._classifyBindGroupFormat = new BindGroupFormat(device, classifyBindings);

const cdefines = this._useRadixSort ? new Map([['USE_RADIX_SORT', '']]) : undefined;
this._classifyShader = new Shader(device, {
name: 'GSplatLocalClassify',
shaderLanguage: SHADERLANGUAGE_WGSL,
cshader: computeGsplatLocalClassifySource,
cdefines: cdefines,
computeBindGroupFormat: this._classifyBindGroupFormat,
computeUniformBufferFormats: { uniforms: ubf }
});
Expand All @@ -1119,13 +1160,31 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
new BindStorageBufferFormat('tileListCounts', SHADERSTAGE_COMPUTE, true)
]);
this._sortShader = new Shader(device, {
name: 'GSplatLocalTileSort',
name: 'GSplatLocalTileBitonicSort',
shaderLanguage: SHADERLANGUAGE_WGSL,
cshader: computeGsplatLocalTileSortSource,
cincludes: this._createBitonicIncludes(),
computeBindGroupFormat: this._sortBindGroupFormat
});

// --- Radix Sort (for tiles with ≤1976 entries, NVIDIA only) ---
if (this._useRadixSort) {
this._radixSortBindGroupFormat = new BindGroupFormat(device, [
new BindStorageBufferFormat('tileEntries', SHADERSTAGE_COMPUTE),
new BindStorageBufferFormat('tileSplatCounts', SHADERSTAGE_COMPUTE, true),
new BindStorageBufferFormat('depthBuffer', SHADERSTAGE_COMPUTE, true),
new BindStorageBufferFormat('radixTileList', SHADERSTAGE_COMPUTE, true),
new BindStorageBufferFormat('tileListCounts', SHADERSTAGE_COMPUTE, true)
]);
this._radixSortShader = new Shader(device, {
name: 'GSplatLocalTileRadixSort',
shaderLanguage: SHADERLANGUAGE_WGSL,
cshader: computeGsplatLocalTileRadixSortSource,
cincludes: new Map([['gsplatLocalRadixSortCS', computeGsplatLocalRadixSortSource]]),
computeBindGroupFormat: this._radixSortBindGroupFormat
});
}

// --- BucketSort ---
{
const ubf = new UniformBufferFormat(device, [
Expand Down Expand Up @@ -1185,7 +1244,7 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE)
]);
this._chunkSortShader = new Shader(device, {
name: 'GSplatLocalChunkSort',
name: 'GSplatLocalChunkBitonicSort',
shaderLanguage: SHADERLANGUAGE_WGSL,
cshader: computeGsplatLocalChunkSortSource,
cincludes: this._createBitonicIncludes(),
Expand Down Expand Up @@ -1312,7 +1371,12 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
set.classifyCompute = new Compute(device, this._classifyShader, pickMode ? 'GSplatPickClassify' : 'GSplatLocalClassify');

// Sort: shared shader
set.sortCompute = new Compute(device, this._sortShader, pickMode ? 'GSplatPickTileSort' : 'GSplatLocalTileSort');
set.sortCompute = new Compute(device, this._sortShader, pickMode ? 'GSplatPickTileBitonicSort' : 'GSplatLocalTileBitonicSort');

// RadixSort: shared shader (when enabled)
if (this._useRadixSort) {
set.radixSortCompute = new Compute(device, this._radixSortShader, pickMode ? 'GSplatPickTileRadixSort' : 'GSplatLocalTileRadixSort');
}

// BucketSort: shared shader
set.bucketSortCompute = new Compute(device, this._bucketSortShader, pickMode ? 'GSplatPickBucketSort' : 'GSplatLocalBucketSort');
Expand All @@ -1321,7 +1385,7 @@ class GSplatComputeLocalRenderer extends GSplatRenderer {
set.copyCompute = new Compute(device, this._copyShader, pickMode ? 'GSplatPickCopy' : 'GSplatLocalCopy');

// ChunkSort: shared shader
set.chunkSortCompute = new Compute(device, this._chunkSortShader, pickMode ? 'GSplatPickChunkSort' : 'GSplatLocalChunkSort');
set.chunkSortCompute = new Compute(device, this._chunkSortShader, pickMode ? 'GSplatPickChunkBitonicSort' : 'GSplatLocalChunkBitonicSort');

return set;
}
Expand Down
11 changes: 10 additions & 1 deletion src/scene/gsplat-unified/gsplat-local-dispatch-set.js
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class GSplatLocalDispatchSet {
/** @type {Compute} */
chunkSortCompute;

/** @type {Compute|null} */
radixSortCompute = null;

/** @type {Map<string, {shader: Shader, bindGroupFormat: BindGroupFormat, compute: Compute}>} */
_rasterizeVariants = new Map();

Expand All @@ -104,6 +107,9 @@ class GSplatLocalDispatchSet {
/** @type {StorageBuffer|null} */
_largeTileOverflowBasesBuffer = null;

/** @type {StorageBuffer|null} */
_radixTileListBuffer = null;

/** @type {StorageBuffer|null} */
_rasterizeTileListBuffer = null;

Expand Down Expand Up @@ -203,6 +209,7 @@ class GSplatLocalDispatchSet {
this._smallTileListBuffer?.destroy();
this._largeTileListBuffer?.destroy();
this._largeTileOverflowBasesBuffer?.destroy();
this._radixTileListBuffer?.destroy();
this._rasterizeTileListBuffer?.destroy();
this._tileListCountsBuffer?.destroy();
this._chunkRangesBuffer?.destroy();
Expand All @@ -214,8 +221,9 @@ class GSplatLocalDispatchSet {
this._smallTileListBuffer = new StorageBuffer(this.device, numTiles * 4);
this._largeTileListBuffer = new StorageBuffer(this.device, numTiles * 4);
this._largeTileOverflowBasesBuffer = new StorageBuffer(this.device, numTiles * 4);
this._radixTileListBuffer = new StorageBuffer(this.device, numTiles * 4);
this._rasterizeTileListBuffer = new StorageBuffer(this.device, numTiles * 4);
this._tileListCountsBuffer = new StorageBuffer(this.device, 4 * 4, BUFFERUSAGE_COPY_DST | BUFFERUSAGE_COPY_SRC);
this._tileListCountsBuffer = new StorageBuffer(this.device, 5 * 4, BUFFERUSAGE_COPY_DST | BUFFERUSAGE_COPY_SRC);

const maxChunks = numTiles * MAX_CHUNKS_PER_TILE;
this._chunkRangesBuffer = new StorageBuffer(this.device, maxChunks * 8);
Expand Down Expand Up @@ -397,6 +405,7 @@ class GSplatLocalDispatchSet {
this._smallTileListBuffer?.destroy();
this._largeTileListBuffer?.destroy();
this._largeTileOverflowBasesBuffer?.destroy();
this._radixTileListBuffer?.destroy();
this._rasterizeTileListBuffer?.destroy();
this._tileListCountsBuffer?.destroy();
this._chunkRangesBuffer?.destroy();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
// Tile classification: scans prefix-summed tile counts, builds small/large/rasterize
// Tile classification: scans prefix-summed tile counts, builds radix/bitonic/large/rasterize
// tile lists, writes indirect dispatch args for subsequent passes, and writes indirect
// draw args for the tile-based composite.
// For large tiles (>4096 entries), assigns compact overflow scratch offsets within
// the shared tileEntries buffer (overflow region starts at totalEntries).
// When USE_RADIX_SORT is defined: three sort tiers — radix (≤1976), bitonic (1977..4096),
// large (>4096 via bucket+chunk). Otherwise: two tiers — bitonic (≤4096), large (>4096).
// For large tiles, assigns compact overflow scratch offsets within the shared tileEntries
// buffer (overflow region starts at totalEntries).
// Single workgroup (256 threads) — each thread processes ceil(numTiles/256) tiles.
//
// tileListCounts layout:
// [0] = bitonic tile count (smallTileList)
// [1] = large tile count (largeTileList)
// [2] = rasterize tile count (rasterizeTileList)
// [3] = large tile overflow entries claimed
// [4] = radix tile count (radixTileList) — only used when USE_RADIX_SORT is defined

import indirectCoreCS from '../common/comp/indirect-core.js';
import dispatchCoreCS from '../common/comp/dispatch-core.js';
Expand All @@ -13,7 +22,10 @@ export const computeGsplatLocalClassifySource = /* wgsl */`
${indirectCoreCS}
${dispatchCoreCS}

const MAX_TILE_ENTRIES: u32 = 4096u;
#ifdef USE_RADIX_SORT
const RADIX_MAX_ENTRIES: u32 = 1976u;
#endif
const BITONIC_MAX_ENTRIES: u32 = 4096u;
const CLASSIFY_WORKGROUP: u32 = 256u;

@group(0) @binding(0) var<storage, read> tileSplatCounts: array<u32>;
Expand All @@ -24,6 +36,9 @@ const CLASSIFY_WORKGROUP: u32 = 256u;
@group(0) @binding(5) var<storage, read_write> indirectDispatchArgs: array<u32>;
@group(0) @binding(6) var<storage, read_write> largeTileOverflowBases: array<u32>;
@group(0) @binding(8) var<storage, read_write> indirectDrawArgs: array<DrawIndirectArgs>;
#ifdef USE_RADIX_SORT
@group(0) @binding(9) var<storage, read_write> radixTileList: array<u32>;
#endif

struct Uniforms {
numTiles: u32,
Expand Down Expand Up @@ -52,7 +67,14 @@ fn main(@builtin(local_invocation_index) localIdx: u32) {
let rIdx = atomicAdd(&tileListCounts[2], 1u);
rasterizeTileList[rIdx] = i;

if (count <= MAX_TILE_ENTRIES) {
#ifdef USE_RADIX_SORT
if (count <= RADIX_MAX_ENTRIES) {
let rxIdx = atomicAdd(&tileListCounts[4], 1u);
radixTileList[rxIdx] = i;
} else if (count <= BITONIC_MAX_ENTRIES) {
#else
if (count <= BITONIC_MAX_ENTRIES) {
#endif
let sIdx = atomicAdd(&tileListCounts[0], 1u);
smallTileList[sIdx] = i;
} else {
Expand All @@ -68,20 +90,20 @@ fn main(@builtin(local_invocation_index) localIdx: u32) {

workgroupBarrier();

// Thread 0 writes indirect dispatch args for passes 4a (small sort), 4b (bucket), 5 (rasterize).
// Thread 0 writes indirect dispatch args for sort and rasterize passes.
// Uses balanced 2D dispatch to stay within maxComputeWorkgroupsPerDimension with minimal waste:
// y = ceil(count / maxDim), x = ceil(count / y). Waste is at most y-1 workgroups (typically 0-1).
if (localIdx == 0u) {
let smallCount = atomicLoad(&tileListCounts[0]);
let bitonicCount = atomicLoad(&tileListCounts[0]);
let largeCount = atomicLoad(&tileListCounts[1]);
let rasterizeCount = atomicLoad(&tileListCounts[2]);
let off = uniforms.dispatchSlotOffset;
let maxDim = uniforms.maxWorkgroupsPerDim;

// Slot 0: small tile sort — 1 workgroup per tile
let smallDim = calcDispatch2D(smallCount, maxDim);
indirectDispatchArgs[off + 0u] = smallDim.x;
indirectDispatchArgs[off + 1u] = smallDim.y;
// Slot 0: bitonic tile sort
let bitonicDim = calcDispatch2D(bitonicCount, maxDim);
indirectDispatchArgs[off + 0u] = bitonicDim.x;
indirectDispatchArgs[off + 1u] = bitonicDim.y;
indirectDispatchArgs[off + 2u] = 1u;

// Slot 1: bucket pre-sort — 1 workgroup per large tile
Expand All @@ -96,6 +118,15 @@ fn main(@builtin(local_invocation_index) localIdx: u32) {
indirectDispatchArgs[off + 7u] = rasterDim.y;
indirectDispatchArgs[off + 8u] = 1u;

#ifdef USE_RADIX_SORT
// Slot 3: radix tile sort (≤1976 entries)
let radixCount = atomicLoad(&tileListCounts[4]);
let radixDim = calcDispatch2D(radixCount, maxDim);
indirectDispatchArgs[off + 9u] = radixDim.x;
indirectDispatchArgs[off + 10u] = radixDim.y;
indirectDispatchArgs[off + 11u] = 1u;
#endif

// Indirect draw args for tile-based composite: 6 vertices per tile quad
indirectDrawArgs[uniforms.drawSlot] = DrawIndirectArgs(rasterizeCount * 6u, 1u, 0u, 0u, 0u);
}
Expand Down
Loading