Skip to content

Commit

Permalink
GS WebGPU + NME ui fixes (#15778)
Browse files Browse the repository at this point in the history
* GS WebGPU

* debugging wgsl

* nme fixes

* missing imports

* wgsl imports

* clean up shaders, coroutine batch size

* more missing imports

* webgpu + nme

* removed empty line
  • Loading branch information
CedricGuillemet authored Nov 8, 2024
1 parent ae594bf commit 1598912
Show file tree
Hide file tree
Showing 15 changed files with 236 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import { Camera } from "core/Cameras/camera";

import "../../Shaders/gaussianSplatting.fragment";
import "../../Shaders/gaussianSplatting.vertex";
import "../../ShadersWGSL/gaussianSplatting.fragment";
import "../../ShadersWGSL/gaussianSplatting.vertex";
import {
BindFogParameters,
BindLogDepth,
Expand All @@ -24,6 +26,7 @@ import {
PrepareDefinesForMisc,
PrepareUniformsAndSamplersList,
} from "../materialHelper.functions";
import { ShaderLanguage } from "../shaderLanguage";

/**
* @internal
Expand Down Expand Up @@ -160,6 +163,15 @@ export class GaussianSplattingMaterial extends PushMaterial {
defines: join,
onCompiled: this.onCompiled,
onError: this.onError,
indexParameters: {},
shaderLanguage: this._shaderLanguage,
extraInitializationsAsync: async () => {
if (this._shaderLanguage === ShaderLanguage.WGSL) {
await Promise.all([import("../../ShadersWGSL/gaussianSplatting.fragment"), import("../../ShadersWGSL/gaussianSplatting.vertex")]);
} else {
await Promise.all([import("../../Shaders/gaussianSplatting.fragment"), import("../../Shaders/gaussianSplatting.vertex")]);
}
},
},
engine
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import type { NodeMaterialBuildState } from "../../nodeMaterialBuildState";
import { NodeMaterialBlockTargets } from "../../Enums/nodeMaterialBlockTargets";
import type { NodeMaterialConnectionPoint } from "../../nodeMaterialBlockConnectionPoint";
import { RegisterClass } from "../../../../Misc/typeStore";
import { ShaderLanguage } from "core/Materials/shaderLanguage";

/**
* Block used for the Gaussian Splatting Fragment part
Expand Down Expand Up @@ -70,7 +71,12 @@ export class GaussianBlock extends NodeMaterialBlock {
const color = this.splatColor;
const output = this._outputs[0];

state.compilationString += `${state._declareOutput(output)} = gaussianColor(${color.associatedVariableName});\n`;
if (state.shaderLanguage === ShaderLanguage.WGSL) {
state.compilationString += `${state._declareOutput(output)} = gaussianColor(${color.associatedVariableName}, input.vPosition);\n`;
} else {
state.compilationString += `${state._declareOutput(output)} = gaussianColor(${color.associatedVariableName});\n`;
}

return this;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { NodeMaterialBlockTargets } from "../../Enums/nodeMaterialBlockTargets";
import type { NodeMaterialConnectionPoint } from "../../nodeMaterialBlockConnectionPoint";
import { RegisterClass } from "../../../../Misc/typeStore";
import { VertexBuffer } from "core/Meshes/buffer";
import { ShaderLanguage } from "core/Materials/shaderLanguage";

/**
* Block used for the Gaussian Splatting
Expand Down Expand Up @@ -109,13 +110,19 @@ export class GaussianSplattingBlock extends NodeMaterialBlock {
const projection = this.projection;
const output = this.splatVertex;

let splatScaleParameter = "vec2(1.,1.)";
const addF = state.fSuffix;
let splatScaleParameter = `vec2${addF}(1.,1.)`;
if (splatScale.isConnected) {
splatScaleParameter = splatScale.associatedVariableName;
}

state.compilationString += `${state._declareOutput(output)} = gaussianSplatting(position, ${splatPosition.associatedVariableName}, ${splatScaleParameter}, covA, covB, ${world.associatedVariableName}, ${view.associatedVariableName}, ${projection.associatedVariableName});\n`;

let input = "position";
let uniforms = "";
if (state.shaderLanguage === ShaderLanguage.WGSL) {
input = "input.position";
uniforms = ", uniforms.focal, uniforms.invViewport";
}
state.compilationString += `${state._declareOutput(output)} = gaussianSplatting(${input}, ${splatPosition.associatedVariableName}, ${splatScaleParameter}, covA, covB, ${world.associatedVariableName}, ${view.associatedVariableName}, ${projection.associatedVariableName}${uniforms});\n`;
return this;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
export * from "./gaussianSplattingBlock";
export * from "./splatReaderBlock";
export * from "./gaussianBlock";

// Gaussian
export * from "../../../../ShadersWGSL/ShadersInclude/gaussianSplattingVertexDeclaration";
export * from "../../../../Shaders/ShadersInclude/gaussianSplattingVertexDeclaration";
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { GaussianSplattingMaterial } from "core/Materials/GaussianSplatting/gaus
import type { Mesh } from "core/Meshes/mesh";
import type { Effect } from "core/Materials/effect";
import type { NodeMaterial } from "../../nodeMaterial";
import { ShaderLanguage } from "core/Materials/shaderLanguage";

/**
* Block used for Reading components of the Gaussian Splatting
Expand Down Expand Up @@ -105,10 +106,16 @@ export class SplatReaderBlock extends NodeMaterialBlock {
const splatColor = this.splatColor;

const splatVariablename = state._getFreeVariableName("splat");
state.compilationString += `Splat ${splatVariablename} = readSplat(${splatIndex.associatedVariableName});\n`;

state.compilationString += "vec3 covA = splat.covA.xyz; vec3 covB = vec3(splat.covA.w, splat.covB.xy);\n";
state.compilationString += "vPosition = position;";
if (state.shaderLanguage === ShaderLanguage.WGSL) {
state.compilationString += `var ${splatVariablename}: Splat = readSplat(${splatIndex.associatedVariableName}, uniforms.dataTextureSize);\n`;
state.compilationString += `var covA: vec3f = splat.covA.xyz; var covB: vec3f = vec3f(splat.covA.w, splat.covB.xy);\n`;
state.compilationString += "vertexOutputs.vPosition = input.position;\n";
} else {
state.compilationString += `Splat ${splatVariablename} = readSplat(${splatIndex.associatedVariableName});\n`;
state.compilationString += `vec3 covA = splat.covA.xyz; vec3 covB = vec3(splat.covA.w, splat.covB.xy);\n`;
state.compilationString += "vPosition = position;\n";
}
state.compilationString += `${state._declareOutput(splatPosition)} = ${splatVariablename}.center.xyz;\n`;
state.compilationString += `${state._declareOutput(splatColor)} = ${splatVariablename}.color;\n`;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ export class GaussianSplattingMesh extends Mesh {

private static _RowOutputLength = 3 * 4 + 3 * 4 + 4 + 4; // Vector3 position, Vector3 scale, 1 u8 quaternion, 1 color with alpha
private static _SH_C0 = 0.28209479177387814;
// batch size between 2 yield calls. This value is a tradeoff between updates overhead and framerate hiccups
// This step is faster the PLY conversion. So batch size can be bigger
private static _SplatBatchSize = 327680;
// batch size between 2 yield calls during the PLY to splat conversion.
private static _PlyConversionBatchSize = 32768;

/**
* Set the number of batch (a batch is 16384 splats) after which a display update is performed
Expand Down Expand Up @@ -781,7 +786,7 @@ export class GaussianSplattingMesh extends Mesh {

for (let i = 0; i < header.vertexCount; i++) {
GaussianSplattingMesh._GetSplat(header, i, compressedChunks, offset);
if (i % 30000 === 0 && useCoroutine) {
if (i % GaussianSplattingMesh._PlyConversionBatchSize === 0 && useCoroutine) {
yield;
}
}
Expand Down Expand Up @@ -1089,7 +1094,7 @@ export class GaussianSplattingMesh extends Mesh {
} else {
for (let i = 0; i < vertexCount; i++) {
this._makeSplat(i, i, fBuffer, uBuffer, covA, covB, colorArray, minimum, maximum);
if (isAsync && i % 327680 === 0) {
if (isAsync && i % GaussianSplattingMesh._SplatBatchSize === 0) {
yield;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
fn getDataUV(index: f32, dataTextureSize: vec2f) -> vec2<f32> {
let y: f32 = floor(index / dataTextureSize.x);
let x: f32 = index - y * dataTextureSize.x;
return vec2f((x + 0.5), (y + 0.5));
}

struct Splat {
center: vec4f,
color: vec4f,
covA: vec4f,
covB: vec4f,
};

fn readSplat(splatIndex: f32, dataTextureSize: vec2f) -> Splat {
var splat: Splat;
let splatUV = getDataUV(splatIndex, dataTextureSize);
let splatUVi32 = vec2<i32>(i32(splatUV.x), i32(splatUV.y));
splat.center = textureLoad(centersTexture, splatUVi32, 0);
splat.color = textureLoad(colorsTexture, splatUVi32, 0);
splat.covA = textureLoad(covariancesATexture, splatUVi32, 0) * splat.center.w;
splat.covB = textureLoad(covariancesBTexture, splatUVi32, 0) * splat.center.w;

return splat;
}

fn gaussianSplatting(
meshPos: vec2<f32>,
worldPos: vec3<f32>,
scale: vec2<f32>,
covA: vec3<f32>,
covB: vec3<f32>,
worldMatrix: mat4x4<f32>,
viewMatrix: mat4x4<f32>,
projectionMatrix: mat4x4<f32>,
focal: vec2f,
invViewport: vec2f
) -> vec4f {
let modelView = viewMatrix * worldMatrix;
let camspace = viewMatrix * vec4f(worldPos, 1.0);
let pos2d = projectionMatrix * camspace;

let bounds = 1.2 * pos2d.w;
if (pos2d.z < 0. || pos2d.x < -bounds || pos2d.x > bounds || pos2d.y < -bounds || pos2d.y > bounds) {
return vec4f(0.0, 0.0, 2.0, 1.0);
}

let Vrk = mat3x3<f32>(
covA.x, covA.y, covA.z,
covA.y, covB.x, covB.y,
covA.z, covB.y, covB.z
);

let J = mat3x3<f32>(
focal.x / camspace.z, 0.0, -(focal.x * camspace.x) / (camspace.z * camspace.z),
0.0, focal.y / camspace.z, -(focal.y * camspace.y) / (camspace.z * camspace.z),
0.0, 0.0, 0.0
);

let invy = mat3x3<f32>(
1.0, 0.0, 0.0,
0.0, -1.0, 0.0,
0.0, 0.0, 1.0
);

let T = invy * transpose(mat3x3<f32>(
modelView[0].xyz,
modelView[1].xyz,
modelView[2].xyz)) * J;
let cov2d = transpose(T) * Vrk * T;

let mid = (cov2d[0][0] + cov2d[1][1]) / 2.0;
let radius = length(vec2<f32>((cov2d[0][0] - cov2d[1][1]) / 2.0, cov2d[0][1]));
let lambda1 = mid + radius;
let lambda2 = mid - radius;

if (lambda2 < 0.0) {
return vec4f(0.0, 0.0, 2.0, 1.0);
}

let diagonalVector = normalize(vec2<f32>(cov2d[0][1], lambda1 - cov2d[0][0]));
let majorAxis = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector;
let minorAxis = min(sqrt(2.0 * lambda2), 1024.0) * vec2<f32>(diagonalVector.y, -diagonalVector.x);

let vCenter = vec2<f32>(pos2d.x, pos2d.y);
return vec4f(
vCenter + ((meshPos.x * majorAxis + meshPos.y * minorAxis) * invViewport * pos2d.w) * scale,
pos2d.z,
pos2d.w
);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
fn gaussianColor(inColor: vec4f, inPosition: vec2f) -> vec4f
{
var A : f32 = -dot(inPosition, inPosition);
if (A > -4.0)
{
var B: f32 = exp(A) * inColor.a;

#include<logDepthFragment>

var color: vec3f = inColor.rgb;

#ifdef FOG
#include<fogFragment>
#endif

return vec4f(color, B);
} else {
return vec4f(0.0);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include<sceneUboDeclaration>
#include<meshUboDeclaration>

attribute position: vec2f;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
attribute position: vec2f;
15 changes: 15 additions & 0 deletions packages/dev/core/src/ShadersWGSL/gaussianSplatting.fragment.fx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include<clipPlaneFragmentDeclaration>
#include<logDepthDeclaration>
#include<fogFragmentDeclaration>

varying vColor: vec4f;
varying vPosition: vec2f;

#include<gaussianSplattingFragmentDeclaration>

@fragment
fn main(input: FragmentInputs) -> FragmentOutputs {
#include<clipPlaneFragment>

fragmentOutputs.color = gaussianColor(input.vColor, input.vPosition);
}
45 changes: 45 additions & 0 deletions packages/dev/core/src/ShadersWGSL/gaussianSplatting.vertex.fx
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include<sceneUboDeclaration>
#include<meshUboDeclaration>

#include<clipPlaneVertexDeclaration>
#include<fogVertexDeclaration>
#include<logDepthDeclaration>

// Attributes
attribute splatIndex: f32;
attribute position: vec2f;

// Uniforms
uniform invViewport: vec2f;
uniform dataTextureSize: vec2f;
uniform focal: vec2f;

// textures
var covariancesATexture: texture_2d<f32>;
var covariancesBTexture: texture_2d<f32>;
var centersTexture: texture_2d<f32>;
var colorsTexture: texture_2d<f32>;

// Output
varying vColor: vec4f;
varying vPosition: vec2f;

#include<gaussianSplatting>

@vertex
fn main(input : VertexInputs) -> FragmentInputs {

var splat: Splat = readSplat(input.splatIndex, uniforms.dataTextureSize);
var covA: vec3f = splat.covA.xyz;
var covB: vec3f = vec3f(splat.covA.w, splat.covB.xy);

let worldPos: vec4f = mesh.world * vec4f(splat.center.xyz, 1.0);

vertexOutputs.vColor = splat.color;
vertexOutputs.vPosition = input.position;
vertexOutputs.position = gaussianSplatting(input.position, worldPos.xyz, vec2f(1.0, 1.0), covA, covB, mesh.world, scene.view, scene.projection, uniforms.focal, uniforms.invViewport);

#include<clipPlaneVertex>
#include<fogVertex>
#include<logDepthVertex>
}
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,9 @@ export class PreviewManager {
this._prepareScene();
});
break;
case PreviewType.Custom:
this._globalState.filesInput.loadFiles({ target: { files: this._globalState.listOfCustomPreviewFiles } });
return;
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions packages/tools/nodeEditor/src/components/preview/previewType.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ export enum PreviewType {
Explosion,
Fire,

Custom,

// Env
Room,

// Gaussian Splatting
Parrot,
BricksSkull,
Plants,

Custom,

// Env
Room,
}
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ export class PropertyTabComponent extends React.Component<IPropertyTabComponentP
case NodeMaterialModes.Particle:
this.props.globalState.previewType = PreviewType.Bubbles;
break;
case NodeMaterialModes.GaussianSplatting:
this.props.globalState.previewType = PreviewType.BricksSkull;
break;
}

this.props.globalState.listOfCustomPreviewFiles = [];
Expand Down

0 comments on commit 1598912

Please sign in to comment.