From a1c55984be38dcab90ea75499179dcfdc08252e7 Mon Sep 17 00:00:00 2001 From: Jonathan Wight Date: Mon, 9 Dec 2024 11:15:40 -0800 Subject: [PATCH] Fix rendering. --- .../GaussianSplatAntimatter15RenderShaders.h | 17 +++++--- .../GaussianSplatAntimatter15RenderView.swift | 40 +++++++++++-------- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/Sources/GaussianSplatShaders/include/GaussianSplatAntimatter15RenderShaders.h b/Sources/GaussianSplatShaders/include/GaussianSplatAntimatter15RenderShaders.h index ffe02f4..21a8acf 100644 --- a/Sources/GaussianSplatShaders/include/GaussianSplatAntimatter15RenderShaders.h +++ b/Sources/GaussianSplatShaders/include/GaussianSplatAntimatter15RenderShaders.h @@ -43,9 +43,11 @@ namespace GaussianSplatAntimatter15RenderShaders { uint instance_id[[instance_id]], constant SplatX *splats [[buffer(2)]], constant IndexedDistance *indexedDistances [[buffer(3)]], - constant float4x4 &viewMatrix [[buffer(4)]], - constant float4x4 &projectionMatrix [[buffer(5)]], - constant float2 &focal [[buffer(6)]] +// constant float4x4 &modelMatrix [[buffer(4)]], + constant float4x4 &viewMatrix [[buffer(5)]], + constant float4x4 &projectionMatrix [[buffer(6)]], + constant float2 &focal [[buffer(7)]], + constant float2 &viewport [[buffer(8)]] ) { VertexOut out; auto splatIndex = indexedDistances[instance_id].index; @@ -90,7 +92,13 @@ namespace GaussianSplatAntimatter15RenderShaders { out.relativePosition = in.position.xy; float2 vCenter = pos2d.xy / pos2d.w; - out.position = float4(vCenter + in.position.x * majorAxis / pos2d.w + in.position.y * minorAxis / pos2d.w, 0.0, 1.0); + out.position = float4(vCenter + in.position.x * majorAxis / viewport + in.position.y * minorAxis / viewport, 0.0, 1.0); + +// gl_Position = vec4( +// vCenter +// + position.x * majorAxis / viewport +// + position.y * minorAxis / viewport, 0.0, 1.0); +// return out; } @@ -101,7 +109,6 @@ namespace GaussianSplatAntimatter15RenderShaders { half4 fragmentMain( FragmentIn in [[stage_in]] ) { -// return half4(1, 0, 0, 1); float A = -dot(in.relativePosition, in.relativePosition); if (A < -4.0) { discard_fragment(); diff --git a/Sources/GaussianSplatSupport/GaussianSplatAntimatter15RenderView.swift b/Sources/GaussianSplatSupport/GaussianSplatAntimatter15RenderView.swift index 4b13e7d..ebc0543 100644 --- a/Sources/GaussianSplatSupport/GaussianSplatAntimatter15RenderView.swift +++ b/Sources/GaussianSplatSupport/GaussianSplatAntimatter15RenderView.swift @@ -6,6 +6,7 @@ import RenderKit import simd import SIMDSupport import SwiftUI +import Constraints3D public struct GaussianSplatAntimatter15RenderView: View { @@ -14,6 +15,9 @@ public struct GaussianSplatAntimatter15RenderView: View { @State private var pass: GaussianSplatAntimatter15RenderPass? + @State + private var cameraTransform: Transform = .identity + @MainActor public init() { print(MemoryLayout.size, MemoryLayout.stride) @@ -29,22 +33,23 @@ public struct GaussianSplatAntimatter15RenderView: View { ZStack { if let pass { RenderView(pass: pass) + .modifier(NewBallControllerViewModifier(constraint: .init(radius: 5), transform: $cameraTransform)) } } .task { // try! load(url: url) - let splatB = SplatB(position: [0, 0, 0], scale: [1, 1, 1], color: [255, 0, 255, 255], rotation: [1, 0, 0, 0]) - let splatX = convertSplatBToSplatX(splat: splatB) - print(splatX) + // let splat = SplatD(position: [0, 0, 0], scale: [1, 0.5, 0.25], color: [1, 0, 1, 1], rotation: .init(angle: .zero, axis: [0, 0, 0])) + let splatB = SplatB(position: [0, 0, 0], scale: [1, 0.5, 0.25], color: [255, 0, 255, 255], rotation: [255, 128, 128, 128]) + let splatX = convertSplatBToSplatX(splat: splatB) try! load(splats: [splatX]) } } func load(splats: [SplatX]) throws { let device = MTLCreateSystemDefaultDevice()! - let cameraMatrix = simd_float4x4(translate: [0, 0, 200]) + let cameraMatrix = simd_float4x4(translate: [0, 0, 5]) let indexedDistances = splats.enumerated().map { IndexedDistance(index: UInt32($0.offset), distance: $0.element.position.distance(to: cameraMatrix.translation)) }.sorted(by: \.distance) let splatsBuffer = try device.makeBuffer(bytesOf: splats, options: []) @@ -54,7 +59,7 @@ public struct GaussianSplatAntimatter15RenderView: View { func load(url: URL) throws { let device = MTLCreateSystemDefaultDevice()! - let cameraMatrix = simd_float4x4(translate: [0, 0, 100]) + let cameraMatrix = simd_float4x4(translate: [0, 0, 10]) let (splatCount, splats, indexedDistances) = try Data(contentsOf: url).withUnsafeBytes { (bytes: UnsafeRawBufferPointer) -> (Int, MTLBuffer, MTLBuffer) in try bytes.withMemoryRebound(to: SplatX.self) { (splats: UnsafeBufferPointer) -> (Int, MTLBuffer, MTLBuffer) in let indexedDistances = splats.enumerated().map { IndexedDistance(index: UInt32($0.offset), distance: $0.element.position.distance(to: cameraMatrix.translation)) }.sorted(by: \.distance) @@ -89,6 +94,7 @@ struct GaussianSplatAntimatter15RenderPass: RenderPassProtocol { var viewMatrix: Int = -1 var projectionMatrix: Int = -1 var focal: Int = -1 + var viewport: Int = -1 } var id: PassID @@ -138,14 +144,8 @@ struct GaussianSplatAntimatter15RenderPass: RenderPassProtocol { } func render(commandBuffer: any MTLCommandBuffer, renderPassDescriptor: MTLRenderPassDescriptor, info: RenderKit.PassInfo, state: State) throws { - let perspectiveProjection = PerspectiveProjection(verticalAngleOfView: .degrees(90), zClip: 0.001 ... 250) let projectionMatrix = perspectiveProjection.projectionMatrix(for: info.drawableSize) - - - - - try commandBuffer.withRenderCommandEncoder(descriptor: renderPassDescriptor, label: "\(type(of: self))", useDebugGroup: true) { commandEncoder in if info.configuration.depthStencilPixelFormat != .invalid { commandEncoder.setDepthStencilState(state.depthStencilState) @@ -153,13 +153,14 @@ struct GaussianSplatAntimatter15RenderPass: RenderPassProtocol { commandEncoder.setRenderPipelineState(state.renderPipelineState) let drawableSize = try renderPassDescriptor.colorAttachments[0].size let focalSize = drawableSize * projectionMatrix.diagonal.xy / 2 + print(focalSize) commandEncoder.withDebugGroup("VertexShader") { // Tristrip quad let vertices: [SIMD2] = [ - [-1, -1], - [-1, 1], - [1, -1], - [1, 1] + [-2, -2], + [-2, 2], + [2, -2], + [2, 2] ] commandEncoder.setVertexBytes(of: vertices, index: 0) commandEncoder.setVertexBuffer(splats(), offset: 0, index: state.vertexBindings.splats) @@ -167,6 +168,7 @@ struct GaussianSplatAntimatter15RenderPass: RenderPassProtocol { commandEncoder.setVertexBytes(of: viewMatrix, index: state.vertexBindings.viewMatrix) commandEncoder.setVertexBytes(of: projectionMatrix, index: state.vertexBindings.projectionMatrix) commandEncoder.setVertexBytes(of: focalSize, index: state.vertexBindings.focal) + commandEncoder.setVertexBytes(of: drawableSize, index: state.vertexBindings.viewport) } // commandEncoder.draw(state.quadMesh, instanceCount: splats.splats.count) commandEncoder.drawPrimitives(type: .triangleStrip, vertexStart: 0, vertexCount: 4, instanceCount: splatCount) @@ -176,6 +178,12 @@ struct GaussianSplatAntimatter15RenderPass: RenderPassProtocol { } } +public extension SplatX { + init(_ splat: SplatB) { + self = convertSplatBToSplatX(splat: splat) + } +} + // Conversion function public func convertSplatBToSplatX(splat: SplatB) -> SplatX { // Extract position @@ -230,7 +238,5 @@ public func convertSplatBToSplatX(splat: SplatB) -> SplatX { u3: u3, color: color ) - - print("SplatX: \(splatX)") return splatX }