Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
schwa committed Oct 24, 2024
1 parent 5783d04 commit 952026a
Show file tree
Hide file tree
Showing 24 changed files with 435 additions and 461 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ namespace GaussianSplatShaders {

VertexOut out;

auto indexedDistance = indexedDistances[instance_id];
auto splat = splats[indexedDistance.index];
const IndexedDistance indexedDistance = indexedDistances[instance_id];
const SplatC splat = splats[indexedDistance.index];
const float4 splatModelSpacePosition = float4(float3(splat.position), 1);
const float4 splatClipSpacePosition = uniforms.modelViewProjectionMatrix * splatModelSpacePosition;

Expand All @@ -80,7 +80,7 @@ namespace GaussianSplatShaders {
return out;
}

const float4 splatWorldSpacePosition = uniforms.modelViewMatrix * splatModelSpacePosition;
const float4 splatWorldSpacePosition = uniforms.modelViewMatrix * splatModelSpacePosition;
const float3 covPosition = splatWorldSpacePosition.xyz;
const Tuple2<float2> axes = decomposedCalcCovariance2D(covPosition, splat.cov_a, splat.cov_b, uniforms.modelViewMatrix, uniforms.focalSize, uniforms.limit);

Expand Down
66 changes: 66 additions & 0 deletions Sources/GaussianSplatSupport/GaussianSplatConfiguration.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import Metal
import PanoramaSupport
#if !targetEnvironment(simulator)
import MetalFX
#endif
import MetalKit
import os
import simd
import SwiftUI

public struct GaussianSplatConfiguration {
public enum SortMethod {
case gpuBitonic
case cpuRadix
}

public var debugMode: Bool
public var metalFXRate: Float
public var discardRate: Float
public var clearColor: MTLClearColor
public var skyboxTexture: MTLTexture?
public var verticalAngleOfView: Angle
public var sortMethod: SortMethod
public var renderSkybox: Bool = true
public var renderSplats: Bool = true

public init(debugMode: Bool = false, metalFXRate: Float = 2, discardRate: Float = 0.0, clearColor: MTLClearColor = .init(red: 0, green: 0, blue: 0, alpha: 1), skyboxTexture: MTLTexture? = nil, verticalAngleOfView: Angle = .degrees(90), sortMethod: SortMethod = .cpuRadix) {
self.debugMode = debugMode
self.metalFXRate = metalFXRate
self.discardRate = discardRate
self.clearColor = clearColor
self.skyboxTexture = skyboxTexture
self.verticalAngleOfView = verticalAngleOfView
self.sortMethod = sortMethod
}

@MainActor
public static func defaultSkyboxTexture(device: MTLDevice) -> MTLTexture? {
let gradient = LinearGradient(
stops: [
.init(color: .white, location: 0),
.init(color: .white, location: 0.4),
.init(color: Color(red: 135 / 255, green: 206 / 255, blue: 235 / 255), location: 0.5),
.init(color: Color(red: 135 / 255, green: 206 / 255, blue: 235 / 255), location: 1)
],
startPoint: .init(x: 0, y: 0),
endPoint: .init(x: 0, y: 1)
)

guard var cgImage = ImageRenderer(content: Rectangle().fill(gradient).frame(width: 1024, height: 1024)).cgImage else {
fatalError("Could not render image.")
}
let bitmapInfo: CGBitmapInfo
if cgImage.byteOrderInfo == .order32Little {
bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.premultipliedFirst.rawValue | CGBitmapInfo.byteOrder32Big.rawValue)
} else {
bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.premultipliedFirst.rawValue | CGBitmapInfo.byteOrder32Little.rawValue)
}
cgImage = cgImage.convert(bitmapInfo: bitmapInfo)!

let textureLoader = MTKTextureLoader(device: device)
let texture = try! textureLoader.newTexture(cgImage: cgImage, options: nil)
texture.label = "Skybox Gradient"
return texture
}
}
50 changes: 50 additions & 0 deletions Sources/GaussianSplatSupport/GaussianSplatSupport.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import BaseSupport
import CoreGraphics
import Foundation
import Metal
import os
Expand Down Expand Up @@ -32,3 +33,52 @@ internal func releaseAssert(_ condition: @autoclosure () -> Bool, _ message: @au
fatalError(message(), file: file, line: line)
}
}

public extension CGImage {
func convert(bitmapInfo: CGBitmapInfo) -> CGImage? {
let width = width
let height = height
let bitsPerComponent = 8
let bytesPerPixel = 4
let bytesPerRow = width * bytesPerPixel
let colorSpace = CGColorSpaceCreateDeviceRGB()
guard let context = CGContext(data: nil, width: width, height: height, bitsPerComponent: bitsPerComponent, bytesPerRow: bytesPerRow, space: colorSpace, bitmapInfo: bitmapInfo.rawValue) else {
return nil
}
context.draw(self, in: CGRect(x: 0, y: 0, width: width, height: height))
return context.makeImage()
}
}

internal func convertCGImageEndianness2(_ inputImage: CGImage) -> CGImage? {
let width = inputImage.width
let height = inputImage.height
let bitsPerComponent = 8
let bytesPerPixel = 4
let bytesPerRow = width * bytesPerPixel
let colorSpace = CGColorSpaceCreateDeviceRGB()

// Choose the appropriate bitmap info for the target endianness
let bitmapInfo: CGBitmapInfo
if inputImage.byteOrderInfo == .order32Little {
bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.premultipliedFirst.rawValue | CGBitmapInfo.byteOrder32Big.rawValue)
} else {
bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.premultipliedFirst.rawValue | CGBitmapInfo.byteOrder32Little.rawValue)
}

guard let context = CGContext(data: nil,
width: width,
height: height,
bitsPerComponent: bitsPerComponent,
bytesPerRow: bytesPerRow,
space: colorSpace,
bitmapInfo: bitmapInfo.rawValue) else {
return nil
}

// Draw the original image into the new context
context.draw(inputImage, in: CGRect(x: 0, y: 0, width: width, height: height))

// Create a new CGImage from the context
return context.makeImage()
}
62 changes: 14 additions & 48 deletions Sources/GaussianSplatSupport/GaussianSplatViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,6 @@ import Shapes3D
import simd
import SIMDSupport
import SwiftUI
import SwiftUISupport
import Traces

public struct GaussianSplatConfiguration {
public enum SortMethod {
case gpuBitonic
case cpuRadix
}

public var debugMode: Bool
public var metalFXRate: Float
public var discardRate: Float
public var gpuCounters: GPUCounters?
public var clearColor: MTLClearColor // TODO: make this a SwiftUI Color
public var skyboxTexture: MTLTexture?
public var verticalAngleOfView: Angle
public var sortMethod: SortMethod

public init(debugMode: Bool = false, metalFXRate: Float = 2, discardRate: Float = 0.0, gpuCounters: GPUCounters? = nil, clearColor: MTLClearColor = .init(red: 0, green: 0, blue: 0, alpha: 1), skyboxTexture: MTLTexture? = nil, verticalAngleOfView: Angle = .degrees(90), sortMethod: SortMethod = .gpuBitonic) {
self.debugMode = debugMode
self.metalFXRate = metalFXRate
self.discardRate = discardRate
self.gpuCounters = gpuCounters
self.clearColor = clearColor
self.skyboxTexture = skyboxTexture
self.verticalAngleOfView = verticalAngleOfView
self.sortMethod = sortMethod
}
}

// MARK: -

@Observable
@MainActor
Expand All @@ -62,8 +31,6 @@ public class GaussianSplatViewModel <Splat> where Splat: SplatProtocol {
}
}

public var splatResource: SplatResource

public var pass: GroupPass?

public var loadProgress = Progress()
Expand Down Expand Up @@ -100,9 +67,8 @@ public class GaussianSplatViewModel <Splat> where Splat: SplatProtocol {

// MARK: -

public init(device: MTLDevice, splatResource: SplatResource, splatCloud: SplatCloud<SplatC>, configuration: GaussianSplatConfiguration, logger: Logger? = nil) throws where Splat == SplatC {
public init(device: MTLDevice, splatCloud: SplatCloud<SplatC>, configuration: GaussianSplatConfiguration, logger: Logger? = nil) throws where Splat == SplatC {
self.device = device
self.splatResource = splatResource
self.configuration = configuration
self.logger = logger

Expand Down Expand Up @@ -164,49 +130,49 @@ public class GaussianSplatViewModel <Splat> where Splat: SplatProtocol {
enabled: sortEnabled,
splats: splats,
modelMatrix: simd_float3x3(truncating: splatsNode.transform.matrix),
cameraPosition: cameraNode.transform.translation
cameraPosition: cameraNode.transform.matrix.translation
)
GaussianSplatBitonicSortComputePass(
id: "SplatBitonicSort",
enabled: sortEnabled,
splats: splats
)
}
PanoramaShadingPass(id: "Panorama", scene: scene)
GaussianSplatRenderPass<Splat>(
id: "SplatRender",
enabled: true,
scene: scene,
discardRate: configuration.discardRate
)
}
GroupPass(id: "GaussianSplatRenderGroup-1", enabled: fullRedraw, renderPassDescriptor: offscreenRenderPassDescriptor1) {
GroupPass(id: "Panorama Render", enabled: configuration.renderSkybox && fullRedraw, renderPassDescriptor: offscreenRenderPassDescriptor1) {
PanoramaShadingPass(id: "Panorama", scene: scene)
}
GroupPass(id: "GaussianSplatRenderGroup-2", enabled: fullRedraw, renderPassDescriptor: offscreenRenderPassDescriptor2) {
GroupPass(id: "Splats Render", enabled: configuration.renderSplats && fullRedraw, renderPassDescriptor: offscreenRenderPassDescriptor2) {
GaussianSplatRenderPass<Splat>(
id: "SplatRender",
enabled: true,
scene: scene,
discardRate: configuration.discardRate
)
}

}
#if !targetEnvironment(simulator)
try SpatialUpscalingPass(id: "SpatialUpscalingPass", enabled: configuration.metalFXRate > 1 && fullRedraw, device: device, source: resources.downscaledTexture, destination: resources.outputTexture, colorProcessingMode: .perceptual)
let blitTexture = resources.outputTexture
#else
let blitTexture = resources.downscaledTexture
#endif
BlitTexturePass(id: "BlitTexturePass", source: resources.outputTexture, destination: nil)
BlitTexturePass(id: "BlitTexturePass", source: blitTexture, destination: nil)
}
}

public func drawableChanged(pixelFormat: MTLPixelFormat, size: SIMD2<Float>) throws {
print("###################", #function, pixelFormat, size)
try makeResources(pixelFormat: pixelFormat, size: size)
}

// MARK: -

private func makeResources(pixelFormat: MTLPixelFormat, size: SIMD2<Float>) throws {
#if !targetEnvironment(simulator)
let downscaledSize = SIMD2<Int>(ceil(size / configuration.metalFXRate))
#else
let downscaledSize = SIMD2<Int>(size)
#endif

let colorTextureDescriptor = MTLTextureDescriptor.texture2DDescriptor(pixelFormat: pixelFormat, width: downscaledSize.x, height: downscaledSize.y, mipmapped: false)
colorTextureDescriptor.storageMode = .private
Expand Down
1 change: 1 addition & 0 deletions Sources/GaussianSplatSupport/SplatCloud+Support.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public extension SplatCloud where Splat == SplatC {
convert_b_to_c(splats)
}
}

let splats = try device.makeTypedBuffer(data: splatArray, options: .storageModeShared).labelled("Splats")
try self.init(device: device, splats: splats)
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/GaussianSplatSupport/SplatResource.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import Foundation
import simd
import SwiftUI

public struct SplatResource: Hashable {
public struct UFOSpecifier: Hashable {
public var name: String
public var url: URL
public var bounds: ConeBounds
Expand Down
35 changes: 35 additions & 0 deletions Sources/GaussianSplatSupport/UFOView.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import Foundation
import MetalKit
import Observation
import PanoramaSupport
import simd
import SwiftUI

// swiftlint:disable force_unwrapping

@available(iOS 17, macOS 14, visionOS 1, *)
public struct UFOView: View {
@Environment(\.metalDevice)
private var device

@Environment(GaussianSplatViewModel<SplatC>.self)
private var viewModel

@State
private var bounds: ConeBounds

public init(bounds: ConeBounds) {
self.bounds = bounds
}

public var body: some View {
@Bindable
var viewModel = viewModel

return GaussianSplatRenderView<SplatC>()
#if os(iOS)
.ignoresSafeArea()
#endif
.modifier(CameraConeController(cameraCone: bounds, transform: $viewModel.scene.unsafeCurrentCameraNode.transform))
}
}
4 changes: 2 additions & 2 deletions Sources/RenderKit/Passes/BlitTexturePass.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ public struct BlitTexturePass: GeneralPassProtocol {
public var id: PassID
public var enabled: Bool = true

public var source: Box<MTLTexture>
public var destination: Box<MTLTexture>?
internal var source: Box<MTLTexture>
internal var destination: Box<MTLTexture>?

public init(id: PassID, enabled: Bool = true, source: MTLTexture, destination: MTLTexture?) {
self.id = id
Expand Down
3 changes: 2 additions & 1 deletion Sources/RenderKit/Passes/SpatialUpscalingPass.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ public struct SpatialUpscalingPass: GeneralPassProtocol {

public var id: PassID
public var enabled: Bool = true
public var spatialScaler: Box<MTLFXSpatialScaler>
internal var spatialScaler: Box<MTLFXSpatialScaler>

public init(id: PassID, enabled: Bool = true, device: MTLDevice, source: MTLTexture, destination: MTLTexture, colorProcessingMode: MTLFXSpatialScalerColorProcessingMode) throws {
self.id = id
self.enabled = enabled

// TODO: We are doing this in init() when it really should happen in setup() because we can't easily cause a new setup if texture size changes.
let spatialScalerDescriptor = MTLFXSpatialScalerDescriptor()
spatialScalerDescriptor.inputWidth = source.width
spatialScalerDescriptor.inputHeight = source.height
Expand Down
4 changes: 2 additions & 2 deletions Sources/RenderKit/Renderer/RenderErrorHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ public struct RenderErrorHandler: Sendable {
}
}

public struct RenderErrorHandlerKey: EnvironmentKey {
public static let defaultValue = RenderErrorHandler()
struct RenderErrorHandlerKey: EnvironmentKey {
static let defaultValue = RenderErrorHandler()
}

public extension EnvironmentValues {
Expand Down
5 changes: 1 addition & 4 deletions Sources/RenderKitSceneGraph/Passes/DebugRenderPass.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ public struct DebugRenderPass: RenderPassProtocol {
}

public func setup(device: MTLDevice, configuration: some MetalConfigurationProtocol) throws -> State {
guard let bundle = Bundle.main.bundle(forTarget: "RenderKitShaders") else {
throw BaseError.error(.missingResource)
}
let library = try device.makeDebugLibrary(bundle: bundle)
let library = try device.makeDebugLibrary(bundle: Bundle.main.bundle(forTarget: "RenderKitShaders").safelyUnwrap())
let renderPipelineDescriptor = MTLRenderPipelineDescriptor(configuration)
renderPipelineDescriptor.vertexFunction = library.makeFunction(name: "DebugVertexShader")
renderPipelineDescriptor.fragmentFunction = library.makeFunction(name: "DebugFragmentShader")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ public struct DiffuseShadingRenderPass: RenderPassProtocol {
}

public func setup(device: MTLDevice, configuration: some MetalConfigurationProtocol) throws -> State {
guard let bundle = Bundle.main.bundle(forTarget: "RenderKitShaders") else {
throw BaseError.error(.missingResource)
}
let library = try device.makeDebugLibrary(bundle: bundle)
let library = try device.makeDebugLibrary(bundle: Bundle.main.bundle(forTarget: "RenderKitShaders").safelyUnwrap())
let useFlatShading = false
let constantValues = MTLFunctionConstantValues(dictionary: [0: useFlatShading])
let renderPipelineDescriptor = MTLRenderPipelineDescriptor(configuration)
Expand Down
Loading

0 comments on commit 952026a

Please sign in to comment.