Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class GaussianSplattingMaterialDefines extends MaterialDefines {
public IS_COMPOUND = false;
/** Defines the maximum number of parts (computed from engine caps at runtime) */
public MAX_PART_COUNT = GaussianSplattingMaxPartCount;
/** Defines whether SOG raw-texture in-shader dequantization is enabled */
public USE_SOG = false;
/** Defines whether SOG v2 (codebook) dequantization is enabled */
public USE_SOG_V2 = false;

/**
* Constructor of the defines.
Expand Down Expand Up @@ -184,6 +188,10 @@ export class GaussianSplattingMaterial extends PushMaterial {
"shTexture3",
"shTexture4",
"partIndicesTexture",
"sogQuatsTexture",
"sogShNCentroidsTexture",
"sogShNLabelsTexture",
"sogCodebookTexture",
];
protected static _UniformBuffers = ["Scene", "Mesh"];
protected static _VoxelUniforms = [
Expand Down Expand Up @@ -216,6 +224,16 @@ export class GaussianSplattingMaterial extends PushMaterial {
"depthValues",
"partWorld",
"partVisibility",
"sogMeansMin",
"sogMeansMax",
"sogScalesMin",
"sogScalesMax",
"sogSh0Min",
"sogSh0Max",
"sogShnMin",
"sogShnMax",
"sogShCoeffCount",
"sogShCentroidsWidth",
Comment thread
CedricGuillemet marked this conversation as resolved.
Outdated
];
private _sourceMesh: GaussianSplattingMesh | null = null;
/**
Expand Down Expand Up @@ -304,6 +322,8 @@ export class GaussianSplattingMaterial extends PushMaterial {

defines["IS_COMPOUND"] = gsMesh.isCompound;
defines["MAX_PART_COUNT"] = GetGaussianSplattingMaxPartCount(engine);
defines["USE_SOG"] = gsMesh.useSog;
defines["USE_SOG_V2"] = gsMesh.useSog && gsMesh.sogParams?.version === 2;
Comment thread
CedricGuillemet marked this conversation as resolved.

// Compensation
const splatMaterial = gsMesh.material as GaussianSplattingMaterial;
Expand Down Expand Up @@ -462,7 +482,9 @@ export class GaussianSplattingMaterial extends PushMaterial {
effect.setTexture("centersTexture", gsMesh.centersTexture);
effect.setTexture("colorsTexture", gsMesh.colorsTexture);

if (gsMesh.shTextures) {
if (gsMesh.useSog) {
GaussianSplattingMaterial._BindSogUniforms(gsMesh, effect);
} else if (gsMesh.shTextures) {
for (let i = 0; i < gsMesh.shTextures.length; i++) {
effect.setTexture(`shTexture${i}`, gsMesh.shTextures[i]);
}
Expand All @@ -472,6 +494,39 @@ export class GaussianSplattingMaterial extends PushMaterial {
gsMesh.bindExtraEffectUniforms(effect);
}
}

/**
* Bind SOG dequantization uniforms + raw textures.
* @internal
*/
protected static _BindSogUniforms(gsMesh: GaussianSplattingMesh, effect: Effect): void {
const p = gsMesh.sogParams;
if (!p) {
return;
}
effect.setTexture("sogQuatsTexture", gsMesh.rotationsATexture);
if (gsMesh.shTextures && gsMesh.shTextures.length >= 2) {
effect.setTexture("sogShNCentroidsTexture", gsMesh.shTextures[0]);
effect.setTexture("sogShNLabelsTexture", gsMesh.shTextures[1]);
}
if (p.codebookTexture) {
effect.setTexture("sogCodebookTexture", p.codebookTexture);
}
effect.setFloat3("sogMeansMin", p.meansMin[0], p.meansMin[1], p.meansMin[2]);
effect.setFloat3("sogMeansMax", p.meansMax[0], p.meansMax[1], p.meansMax[2]);
if (p.scalesMin && p.scalesMax) {
effect.setFloat3("sogScalesMin", p.scalesMin[0], p.scalesMin[1], p.scalesMin[2]);
effect.setFloat3("sogScalesMax", p.scalesMax[0], p.scalesMax[1], p.scalesMax[2]);
}
if (p.sh0Min && p.sh0Max) {
effect.setFloat4("sogSh0Min", p.sh0Min[0], p.sh0Min[1], p.sh0Min[2], p.sh0Min[3]);
effect.setFloat4("sogSh0Max", p.sh0Max[0], p.sh0Max[1], p.sh0Max[2], p.sh0Max[3]);
}
effect.setFloat("sogShnMin", p.shnMin ?? 0);
effect.setFloat("sogShnMax", p.shnMax ?? 0);
effect.setFloat("sogShCoeffCount", p.shCoeffCount ?? 0);
effect.setFloat("sogShCentroidsWidth", p.shCentroidsWidth ?? 0);
}
/**
* Binds the submesh to this material by preparing the effect and shader to draw
* @param world defines the world transformation matrix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ export class GaussianSplattingMeshBase extends Mesh {

private _delayedTextureUpdate: Nullable<IDelayedTextureUpdate> = null;
protected _useRGBACovariants = false;
protected _useSog = false;
protected _sogParams: any = null;
private _material: Nullable<Material> = null;

private _tmpCovariances = [0, 0, 0, 0, 0, 0];
Expand Down Expand Up @@ -632,6 +634,105 @@ export class GaussianSplattingMeshBase extends Mesh {
return this._shTextures;
}

/**
* True when this mesh holds raw SOG webp textures (dequantized in-shader) rather than the
* pre-decoded covariance/center/color textures produced by the standard splat loader.
*/
public get useSog(): boolean {
return this._useSog;
}

/**
* SOG dequantization parameters paired with the raw textures.
* Set by the splat loader when `useSogTextures: true`. Null otherwise.
*/
public get sogParams(): any {
return this._sogParams;
}

/**
* Install a set of raw SOG webp textures and bind the mesh to the in-shader dequantization path.
* @param pack SOG texture pack produced by ParseSogMetaAsTextures.
* @internal
*/
public setSogTextureData(pack: any): void {
this._useSog = true;
this._sogParams = pack;
Comment thread
CedricGuillemet marked this conversation as resolved.
this._vertexCount = pack.splatCount;
this._shDegree = pack.shDegree ?? 0;
Comment thread
CedricGuillemet marked this conversation as resolved.

// Stride-4 (xyz + 1) — required by the depth-sort worker and the centers texture path.
this._splatPositions = pack.positions;

// Reuse existing texture slots for SOG textures (the shader, under USE_SOG, samples them as RGBA8).
this._covariancesATexture?.dispose();
this._covariancesBTexture?.dispose();
this._centersTexture?.dispose();
this._colorsTexture?.dispose();
this._rotationsATexture?.dispose();
if (this._shTextures) {
for (const t of this._shTextures) {
t.dispose();
}
}

this._centersTexture = pack.meansTextureL;
this._covariancesATexture = pack.meansTextureU;
this._covariancesBTexture = pack.scalesTexture;
this._rotationsATexture = pack.quatsTexture;
this._colorsTexture = pack.sh0Texture;

Comment thread
CedricGuillemet marked this conversation as resolved.
Outdated
const shTextures: BaseTexture[] = [];
if (pack.shCentroidsTexture) {
shTextures.push(pack.shCentroidsTexture);
}
if (pack.shLabelsTexture) {
shTextures.push(pack.shLabelsTexture);
}
this._shTextures = shTextures.length ? shTextures : null;

// Force pipeline rebuild so the USE_SOG define and extra samplers are picked up.
this._material?.resetDrawCache();

const size = pack.meansTextureL.getSize();
this._textureSize.x = size.width;
this._textureSize.y = size.height;

this._updateSplatIndexBuffer(this._vertexCount);
this._instantiateWorker();

// Compute bounds from the CPU-decoded positions (stride-4) so the mesh is not frustum-culled.
const positions = pack.positions as Float32Array;
const minimum = new Vector3(Number.POSITIVE_INFINITY, Number.POSITIVE_INFINITY, Number.POSITIVE_INFINITY);
const maximum = new Vector3(Number.NEGATIVE_INFINITY, Number.NEGATIVE_INFINITY, Number.NEGATIVE_INFINITY);
for (let i = 0; i < this._vertexCount; i++) {
const x = positions[i * 4 + 0];
const y = positions[i * 4 + 1];
const z = positions[i * 4 + 2];
if (x < minimum.x) {
minimum.x = x;
}
if (y < minimum.y) {
minimum.y = y;
}
if (z < minimum.z) {
minimum.z = z;
}
if (x > maximum.x) {
maximum.x = x;
}
if (y > maximum.y) {
maximum.y = y;
}
if (z > maximum.z) {
maximum.z = z;
}
}
this.getBoundingInfo().reConstruct(minimum, maximum, this.getWorldMatrix());
this.setEnabled(true);
this._sortIsDirty = true;
}

/**
* Gets the kernel size
* Documentation and mathematical explanations here:
Expand Down
127 changes: 123 additions & 4 deletions packages/dev/core/src/Shaders/ShadersInclude/gaussianSplatting.fx
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ struct Splat {
vec4 rotationB;
vec4 rotationScale;
#endif
#ifdef USE_SOG
float splatIndex;
#endif
};

float getSplatIndex(int localIndex)
Expand Down Expand Up @@ -80,25 +83,98 @@ Splat readSplat(float splatIndex)
{
Splat splat;
vec2 splatUV = getDataUV(splatIndex, dataTextureSize);
#ifdef USE_SOG
// --- SOG raw-texture path. All samplers are RGBA8 normalized with nearest filtering.
ivec2 sogUVi = ivec2(int(splatUV.x * dataTextureSize.x), int(splatUV.y * dataTextureSize.y));
vec4 mL = texelFetch(centersTexture, sogUVi, 0); // means_l
vec4 mU = texelFetch(covariancesATexture, sogUVi, 0); // means_u
vec4 sRaw = texelFetch(covariancesBTexture, sogUVi, 0); // scales (3 bytes valid)
vec4 qRaw = texelFetch(sogQuatsTexture, sogUVi, 0); // quaternion 3+mode
vec4 c0 = texelFetch(colorsTexture, sogUVi, 0); // sh0

// Reconstruct position: q = (u<<8)|l; n = lerp(min,max,q/65535); pos = sign(n)*(exp(|n|)-1)
vec3 q16 = (mU.xyz * 256.0 + mL.xyz) * (255.0 / 65535.0);
vec3 nPos = mix(sogMeansMin, sogMeansMax, q16);
vec3 center = sign(nPos) * (exp(abs(nPos)) - vec3(1.0));
splat.center = vec4(center, 1.0);

// Reconstruct scale (v1: lerp+exp ; v2: codebook lookup)
#ifdef USE_SOG_V2
// codebook layout: [scales 0..255 | sh0 256..511 | shN 512..767], width=768
vec3 sIdx = floor(sRaw.xyz * 255.0 + 0.5);
vec3 splatScale;
splatScale.x = exp(texelFetch(sogCodebookTexture, ivec2(int(sIdx.x), 0), 0).r);
splatScale.y = exp(texelFetch(sogCodebookTexture, ivec2(int(sIdx.y), 0), 0).r);
splatScale.z = exp(texelFetch(sogCodebookTexture, ivec2(int(sIdx.z), 0), 0).r);
#else
vec3 splatScale = exp(mix(sogScalesMin, sogScalesMax, sRaw.xyz));
#endif

// Reconstruct quaternion (largest-omitted, mode in alpha as 252+omitted-index)
const float invSqrt2 = 0.70710678118;
vec3 qabc = (qRaw.xyz - vec3(0.5)) * 2.0 * invSqrt2;
int qMode = int(qRaw.w * 255.0 + 0.5) - 252;
float qd = sqrt(max(0.0, 1.0 - dot(qabc, qabc)));
vec4 quat;
if (qMode == 0) quat = vec4(qd, qabc.x, qabc.y, qabc.z);
else if (qMode == 1) quat = vec4(qabc.x, qd, qabc.y, qabc.z);
else if (qMode == 2) quat = vec4(qabc.x, qabc.y, qd, qabc.z);
else quat = vec4(qabc.x, qabc.y, qabc.z, qd);

// Build rotation matrix from quaternion (w,x,y,z)
float qw = quat.x, qx = quat.y, qy = quat.z, qz = quat.w;
mat3 R = mat3(
1.0 - 2.0*(qy*qy + qz*qz), 2.0*(qx*qy + qw*qz), 2.0*(qx*qz - qw*qy),
2.0*(qx*qy - qw*qz), 1.0 - 2.0*(qx*qx + qz*qz), 2.0*(qy*qz + qw*qx),
2.0*(qx*qz + qw*qy), 2.0*(qy*qz - qw*qx), 1.0 - 2.0*(qx*qx + qy*qy)
);

// Covariance = R * diag(2s)^2 * R^T to match the CPU path (which scales by 2x before squaring).
mat3 S2 = mat3(4.0*splatScale.x*splatScale.x, 0.0, 0.0,
0.0, 4.0*splatScale.y*splatScale.y, 0.0,
0.0, 0.0, 4.0*splatScale.z*splatScale.z);
mat3 Sigma = R * S2 * transpose(R);
splat.covA = vec4(Sigma[0][0], Sigma[0][1], Sigma[0][2], Sigma[1][1]);
splat.covB = vec4(Sigma[1][2], Sigma[2][2], 0.0, 0.0);

// Color (sh0): RGB -> 0.5 + c*SH_C0 ; alpha -> sigmoid(c) (v1) or codebook[a] (v2)
const float SH_C0 = 0.28209479177387814;
#ifdef USE_SOG_V2
vec3 c3;
c3.x = texelFetch(sogCodebookTexture, ivec2(256 + int(c0.x * 255.0 + 0.5), 0), 0).r;
c3.y = texelFetch(sogCodebookTexture, ivec2(256 + int(c0.y * 255.0 + 0.5), 0), 0).r;
c3.z = texelFetch(sogCodebookTexture, ivec2(256 + int(c0.z * 255.0 + 0.5), 0), 0).r;
vec3 colRgb = vec3(0.5) + c3 * SH_C0;
float colA = c0.w; // already 0..1
#else
vec4 cLerp = mix(sogSh0Min, sogSh0Max, c0);
vec3 colRgb = vec3(0.5) + cLerp.xyz * SH_C0;
float colA = 1.0 / (1.0 + exp(-cLerp.w));
#endif
splat.color = vec4(colRgb, colA);
splat.splatIndex = splatIndex;
#else
splat.center = texture2D(centersTexture, splatUV);
splat.color = texture2D(colorsTexture, splatUV);
#if !defined(IS_FOR_VOXELIZATION)
splat.covA = texture2D(covariancesATexture, splatUV) * splat.center.w;
splat.covB = texture2D(covariancesBTexture, splatUV) * splat.center.w;
#endif
#endif

#if SH_DEGREE > 0 || IS_COMPOUND
ivec2 splatUVint = getDataUVint(splatIndex, dataTextureSize);
#endif
#if SH_DEGREE > 0
#if SH_DEGREE > 0 && !defined(USE_SOG)
splat.sh0 = texelFetch(shTexture0, splatUVint, 0);
#endif
#if SH_DEGREE > 1
#if SH_DEGREE > 1 && !defined(USE_SOG)
splat.sh1 = texelFetch(shTexture1, splatUVint, 0);
#endif
#if SH_DEGREE > 2
#if SH_DEGREE > 2 && !defined(USE_SOG)
splat.sh2 = texelFetch(shTexture2, splatUVint, 0);
#endif
#if SH_DEGREE > 3
#if SH_DEGREE > 3 && !defined(USE_SOG)
splat.sh3 = texelFetch(shTexture3, splatUVint, 0);
splat.sh4 = texelFetch(shTexture4, splatUVint, 0);
#endif
Expand Down Expand Up @@ -205,6 +281,48 @@ vec4 decompose(uint value)
return components * vec4(2./255.) - vec4(1.);
}

#ifdef USE_SOG
vec3 computeSH(Splat splat, vec3 dir)
{
#if SH_DEGREE > 0
vec3 sh[25];
sh[0] = vec3(0., 0., 0.);

// Read 16-bit label for this splat from the labels texture (LSB in r, MSB in g).
ivec2 labelSize = textureSize(sogShNLabelsTexture, 0);
int idx = int(splat.splatIndex + 0.5);
int lx = idx - (idx / labelSize.x) * labelSize.x;
int ly = idx / labelSize.x;
vec4 labelRaw = texelFetch(sogShNLabelsTexture, ivec2(lx, ly), 0);
int n = int(labelRaw.r * 255.0 + 0.5) + int(labelRaw.g * 255.0 + 0.5) * 256;

int coeffs = int(sogShCoeffCount + 0.5);
int u = (n - (n / 64) * 64) * coeffs;
int v = n / 64;

for (int k = 0; k < 24; k++) {
if (k >= coeffs) break;
vec4 centroidRaw = texelFetch(sogShNCentroidsTexture, ivec2(u + k, v), 0);
vec3 shCoeff;
#ifdef USE_SOG_V2
int rIdx = int(centroidRaw.r * 255.0 + 0.5);
int gIdx = int(centroidRaw.g * 255.0 + 0.5);
int bIdx = int(centroidRaw.b * 255.0 + 0.5);
shCoeff.r = texelFetch(sogCodebookTexture, ivec2(512 + rIdx, 0), 0).r;
shCoeff.g = texelFetch(sogCodebookTexture, ivec2(512 + gIdx, 0), 0).r;
shCoeff.b = texelFetch(sogCodebookTexture, ivec2(512 + bIdx, 0), 0).r;
#else
shCoeff = mix(vec3(sogShnMin), vec3(sogShnMax), centroidRaw.rgb);
#endif
sh[k + 1] = shCoeff;
}

return computeColorFromSHDegree(dir, sh, 1., 1., 1., 1.);
#else
return vec3(0., 0., 0.);
#endif
}
#else
vec3 computeSHWeighted(Splat splat, vec3 dir, float _so1, float _so2, float _so3, float _so4)
{
vec3 sh[25];
Expand Down Expand Up @@ -292,6 +410,7 @@ vec3 computeSH(Splat splat, vec3 dir)
#endif
return computeSHWeighted(splat, dir, _w1, _w2, _w3, _w4);
}
#endif
#else
vec3 computeSH(Splat splat, vec3 dir)
{
Expand Down
Loading
Loading