Skip to content

Commit

Permalink
Fix rendering.
Browse files Browse the repository at this point in the history
  • Loading branch information
schwa committed Dec 9, 2024
1 parent 6ae8e6f commit a1c5598
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ namespace GaussianSplatAntimatter15RenderShaders {
uint instance_id[[instance_id]],
constant SplatX *splats [[buffer(2)]],
constant IndexedDistance *indexedDistances [[buffer(3)]],
constant float4x4 &viewMatrix [[buffer(4)]],
constant float4x4 &projectionMatrix [[buffer(5)]],
constant float2 &focal [[buffer(6)]]
// constant float4x4 &modelMatrix [[buffer(4)]],
constant float4x4 &viewMatrix [[buffer(5)]],
constant float4x4 &projectionMatrix [[buffer(6)]],
constant float2 &focal [[buffer(7)]],
constant float2 &viewport [[buffer(8)]]
) {
VertexOut out;
auto splatIndex = indexedDistances[instance_id].index;
Expand Down Expand Up @@ -90,7 +92,13 @@ namespace GaussianSplatAntimatter15RenderShaders {
out.relativePosition = in.position.xy;

float2 vCenter = pos2d.xy / pos2d.w;
out.position = float4(vCenter + in.position.x * majorAxis / pos2d.w + in.position.y * minorAxis / pos2d.w, 0.0, 1.0);
out.position = float4(vCenter + in.position.x * majorAxis / viewport + in.position.y * minorAxis / viewport, 0.0, 1.0);

// gl_Position = vec4(
// vCenter
// + position.x * majorAxis / viewport
// + position.y * minorAxis / viewport, 0.0, 1.0);
//

return out;
}
Expand All @@ -101,7 +109,6 @@ namespace GaussianSplatAntimatter15RenderShaders {
half4 fragmentMain(
FragmentIn in [[stage_in]]
) {
// return half4(1, 0, 0, 1);
float A = -dot(in.relativePosition, in.relativePosition);
if (A < -4.0) {
discard_fragment();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import RenderKit
import simd
import SIMDSupport
import SwiftUI
import Constraints3D

public struct GaussianSplatAntimatter15RenderView: View {

Expand All @@ -14,6 +15,9 @@ public struct GaussianSplatAntimatter15RenderView: View {
@State
private var pass: GaussianSplatAntimatter15RenderPass?

@State
private var cameraTransform: Transform = .identity

@MainActor
public init() {
print(MemoryLayout<SplatX>.size, MemoryLayout<SplatX>.stride)
Expand All @@ -29,22 +33,23 @@ public struct GaussianSplatAntimatter15RenderView: View {
ZStack {
if let pass {
RenderView(pass: pass)
.modifier(NewBallControllerViewModifier(constraint: .init(radius: 5), transform: $cameraTransform))
}
}
.task {
// try! load(url: url)

let splatB = SplatB(position: [0, 0, 0], scale: [1, 1, 1], color: [255, 0, 255, 255], rotation: [1, 0, 0, 0])
let splatX = convertSplatBToSplatX(splat: splatB)
print(splatX)
// let splat = SplatD(position: [0, 0, 0], scale: [1, 0.5, 0.25], color: [1, 0, 1, 1], rotation: .init(angle: .zero, axis: [0, 0, 0]))

let splatB = SplatB(position: [0, 0, 0], scale: [1, 0.5, 0.25], color: [255, 0, 255, 255], rotation: [255, 128, 128, 128])
let splatX = convertSplatBToSplatX(splat: splatB)
try! load(splats: [splatX])
}
}

func load(splats: [SplatX]) throws {
let device = MTLCreateSystemDefaultDevice()!
let cameraMatrix = simd_float4x4(translate: [0, 0, 200])
let cameraMatrix = simd_float4x4(translate: [0, 0, 5])
let indexedDistances = splats.enumerated().map { IndexedDistance(index: UInt32($0.offset), distance: $0.element.position.distance(to: cameraMatrix.translation)) }.sorted(by: \.distance)

let splatsBuffer = try device.makeBuffer(bytesOf: splats, options: [])
Expand All @@ -54,7 +59,7 @@ public struct GaussianSplatAntimatter15RenderView: View {

func load(url: URL) throws {
let device = MTLCreateSystemDefaultDevice()!
let cameraMatrix = simd_float4x4(translate: [0, 0, 100])
let cameraMatrix = simd_float4x4(translate: [0, 0, 10])
let (splatCount, splats, indexedDistances) = try Data(contentsOf: url).withUnsafeBytes { (bytes: UnsafeRawBufferPointer) -> (Int, MTLBuffer, MTLBuffer) in
try bytes.withMemoryRebound(to: SplatX.self) { (splats: UnsafeBufferPointer<SplatX>) -> (Int, MTLBuffer, MTLBuffer) in
let indexedDistances = splats.enumerated().map { IndexedDistance(index: UInt32($0.offset), distance: $0.element.position.distance(to: cameraMatrix.translation)) }.sorted(by: \.distance)
Expand Down Expand Up @@ -89,6 +94,7 @@ struct GaussianSplatAntimatter15RenderPass: RenderPassProtocol {
var viewMatrix: Int = -1
var projectionMatrix: Int = -1
var focal: Int = -1
var viewport: Int = -1
}

var id: PassID
Expand Down Expand Up @@ -138,35 +144,31 @@ struct GaussianSplatAntimatter15RenderPass: RenderPassProtocol {
}

func render(commandBuffer: any MTLCommandBuffer, renderPassDescriptor: MTLRenderPassDescriptor, info: RenderKit.PassInfo, state: State) throws {

let perspectiveProjection = PerspectiveProjection(verticalAngleOfView: .degrees(90), zClip: 0.001 ... 250)
let projectionMatrix = perspectiveProjection.projectionMatrix(for: info.drawableSize)





try commandBuffer.withRenderCommandEncoder(descriptor: renderPassDescriptor, label: "\(type(of: self))", useDebugGroup: true) { commandEncoder in
if info.configuration.depthStencilPixelFormat != .invalid {
commandEncoder.setDepthStencilState(state.depthStencilState)
}
commandEncoder.setRenderPipelineState(state.renderPipelineState)
let drawableSize = try renderPassDescriptor.colorAttachments[0].size
let focalSize = drawableSize * projectionMatrix.diagonal.xy / 2
print(focalSize)
commandEncoder.withDebugGroup("VertexShader") {
// Tristrip quad
let vertices: [SIMD2<Float>] = [
[-1, -1],
[-1, 1],
[1, -1],
[1, 1]
[-2, -2],
[-2, 2],
[2, -2],
[2, 2]
]
commandEncoder.setVertexBytes(of: vertices, index: 0)
commandEncoder.setVertexBuffer(splats(), offset: 0, index: state.vertexBindings.splats)
commandEncoder.setVertexBuffer(indexedDistances(), offset: 0, index: state.vertexBindings.indexedDistances)
commandEncoder.setVertexBytes(of: viewMatrix, index: state.vertexBindings.viewMatrix)
commandEncoder.setVertexBytes(of: projectionMatrix, index: state.vertexBindings.projectionMatrix)
commandEncoder.setVertexBytes(of: focalSize, index: state.vertexBindings.focal)
commandEncoder.setVertexBytes(of: drawableSize, index: state.vertexBindings.viewport)
}
// commandEncoder.draw(state.quadMesh, instanceCount: splats.splats.count)
commandEncoder.drawPrimitives(type: .triangleStrip, vertexStart: 0, vertexCount: 4, instanceCount: splatCount)
Expand All @@ -176,6 +178,12 @@ struct GaussianSplatAntimatter15RenderPass: RenderPassProtocol {
}
}

public extension SplatX {
init(_ splat: SplatB) {
self = convertSplatBToSplatX(splat: splat)
}
}

// Conversion function
public func convertSplatBToSplatX(splat: SplatB) -> SplatX {
// Extract position
Expand Down Expand Up @@ -230,7 +238,5 @@ public func convertSplatBToSplatX(splat: SplatB) -> SplatX {
u3: u3,
color: color
)

print("SplatX: \(splatX)")
return splatX
}

0 comments on commit a1c5598

Please sign in to comment.