Skip to content

Commit

Permalink
Refactor CPU sorter.
Browse files Browse the repository at this point in the history
  • Loading branch information
schwa committed Oct 14, 2024
1 parent e28f9a4 commit cc03aa9
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 141 deletions.
137 changes: 0 additions & 137 deletions Sources/GaussianSplatSupport/CPUSorter.swift

This file was deleted.

11 changes: 11 additions & 0 deletions Sources/GaussianSplatSupport/GaussianSplatSupport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,14 @@ public struct TupleBuffered<Element> {

extension TupleBuffered: Sendable where Element: Sendable {
}

extension OSAllocatedUnfairLock where State == Int {
func postIncrement() -> State {
withLock { state in
defer {
state += 1
}
return state
}
}
}
4 changes: 2 additions & 2 deletions Sources/GaussianSplatSupport/GaussianSplatViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public class GaussianSplatViewModel <Splat> where Splat: SplatProtocol {
@ObservationIgnored
private var logger: Logger?

var cpuSorter: CPUSorter<Splat>?
var cpuSorter: AsyncSortManager<Splat>?
var cpuSorterTask: Task<Void, Never>?

// TODO: bang and try!
Expand Down Expand Up @@ -116,7 +116,7 @@ public class GaussianSplatViewModel <Splat> where Splat: SplatProtocol {
loadProgress.completedUnitCount = Int64(splatCloud.count)
loadProgress.totalUnitCount = Int64(splatCloud.count)

let cpuSorter = try CPUSorter<Splat>(device: device, splatCloud: splatCloud, capacity: splatCloud.capacity)
let cpuSorter = try AsyncSortManager<Splat>(device: device, splatCloud: splatCloud, capacity: splatCloud.capacity)
let cpuSorterTask = Task {
for await splatIndices in await cpuSorter.sortedIndicesChannel().buffer(policy: .bufferingLatest(1)) {
Traces.shared.trace(name: "Sorted Splats")
Expand Down
60 changes: 60 additions & 0 deletions Sources/GaussianSplatSupport/Sorting/AsyncSortManager.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import AsyncAlgorithms
import BaseSupport
import GaussianSplatShaders
@preconcurrency import Metal
import MetalSupport
import os
import simd
import SIMDSupport

internal actor AsyncSortManager <Splat> where Splat: SplatProtocol {
private var splatCloud: SplatCloud<Splat>
private var _sortRequestChannel: AsyncChannel<SortState> = .init()
private var _sortedIndicesChannel: AsyncChannel<SplatIndices> = .init()
private var logger: Logger? = Logger()
private var sorter: CPUSplatRadixSorter<Splat>

internal init(device: MTLDevice, splatCloud: SplatCloud<Splat>, capacity: Int) throws {
self.sorter = .init(device: device, capacity: capacity)
self.splatCloud = splatCloud
Task(priority: .high) {
do {
try await self.sort()
}
catch is CancellationError {
}
catch {
await logger?.log("Failed to sort splats: \(error)")
}
}
}

internal func sortedIndicesChannel() -> AsyncChannel<SplatIndices> {
_sortedIndicesChannel
}

nonisolated
internal func requestSort(camera: simd_float4x4, model: simd_float4x4, count: Int) {
Task {
await _sortRequestChannel.send(.init(camera: camera, model: model, count: count))
}
}

internal func sort() async throws {
// swiftlint:disable:next empty_count
for await state in _sortRequestChannel.removeDuplicates() where state.count > 0 {
let currentIndexedDistances = try sorter.sort(splats: splatCloud.splats, camera: state.camera, model: state.model)
await self._sortedIndicesChannel.send(.init(state: state, indices: currentIndexedDistances))
}
}
}

// MARK: -

internal extension AsyncSortManager {
static func sort(device: MTLDevice, splats: TypedMTLBuffer<Splat>, camera: simd_float4x4, model: simd_float4x4) throws -> SplatIndices {
let sorter = CPUSplatRadixSorter<Splat>(device: device, capacity: splats.capacity)
let indices = try sorter.sort(splats: splats, camera: camera, model: model)
return .init(state: .init(camera: camera, model: model, count: splats.count), indices: indices)
}
}
75 changes: 75 additions & 0 deletions Sources/GaussianSplatSupport/Sorting/CPUSplatRadixSorter.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import AsyncAlgorithms
import BaseSupport
import GaussianSplatShaders
@preconcurrency import Metal
import MetalSupport
import os
import simd
import SIMDSupport

internal class CPUSplatRadixSorter <Splat> where Splat: SplatProtocol {
private var device: MTLDevice
private var temporaryIndexedDistances: [IndexedDistance]
private var capacity: Int
private var logger: Logger? = Logger()

internal init(device: MTLDevice, capacity: Int) {
self.device = device
self.capacity = capacity
releaseAssert(capacity > 0, "You shouldn't be creating a sorter with a capacity of zero.")
temporaryIndexedDistances = .init(repeating: .init(), count: capacity)
}

internal func sort(splats: TypedMTLBuffer<Splat>, camera: simd_float4x4, model: simd_float4x4) throws -> TypedMTLBuffer<IndexedDistance> {
var currentIndexedDistances = try device.makeTypedBuffer(element: IndexedDistance.self, capacity: capacity).labelled("Splats-IndexDistances")
cpuRadixSort(splats: splats, indexedDistances: &currentIndexedDistances, temporaryIndexedDistances: &temporaryIndexedDistances, camera: camera, model: model)
return currentIndexedDistances
}
}

// MARK: -

private func cpuRadixSort<Splat>(splats: TypedMTLBuffer<Splat>, indexedDistances: inout TypedMTLBuffer<IndexedDistance>, temporaryIndexedDistances: inout [IndexedDistance], camera: simd_float4x4, model: simd_float4x4) where Splat: SplatProtocol {
guard splats.count > 1 else {
return
}
releaseAssert(splats.count <= indexedDistances.capacity, "Too few indexed distances \(indexedDistances.count) for \(splats.capacity) splats.")
releaseAssert(splats.count <= temporaryIndexedDistances.count, "Too few temporary indexed distances \(temporaryIndexedDistances.count) for \(splats.count) splats.")
indexedDistances.withUnsafeMutableBufferPointer { indexedDistances in
let indexedDistances = UnsafeMutableBufferPointer<IndexedDistance>(start: indexedDistances.baseAddress, count: splats.count)
// Compute distances.
let modelView = camera.inverse * model
releaseAssert(splats.count <= indexedDistances.count, "Cannot sort \(splats.count) splats into \(indexedDistances.count) indexed distances.")
splats.withUnsafeBufferPointer { splats in
for index in 0..<splats.count {
let position = modelView * SIMD4<Float>(splats[index].floatPosition, 1.0)
let distance = position.z
indexedDistances[index] = .init(index: UInt32(index), distance: distance)
}
}
temporaryIndexedDistances.withUnsafeMutableBufferPointer { temporaryIndexedDistances in
let temporaryIndexedDistances = UnsafeMutableBufferPointer<IndexedDistance>(start: temporaryIndexedDistances.baseAddress, count: splats.count)
releaseAssert(splats.count == indexedDistances.count, "Mismatch between splats \(splats.count) and indexed distances \(indexedDistances.count).")
releaseAssert(splats.count == temporaryIndexedDistances.count, "Mismatch between splats \(splats.count) and temporary indexed distances \(temporaryIndexedDistances.count).")
releaseAssert(temporaryIndexedDistances.count == indexedDistances.count, "Mismatch between temporary indexed distances \(temporaryIndexedDistances.count) and indexed distances \(indexedDistances.count).")
RadixSortCPU<IndexedDistance>().radixSort(input: indexedDistances, temp: temporaryIndexedDistances)
}
}
indexedDistances.count = splats.count
}

// MARK: -

extension IndexedDistance: RadixSortable {
func key(shift: Int) -> Int {
let bits = distance.bitPattern
let signMask: UInt32 = 0x80000000
let key: UInt32 = (bits & signMask != 0) ? ~bits : bits ^ signMask
return (Int(key) >> shift) & 0xFF
}
}

// MARK: -

extension IndexedDistance: @unchecked @retroactive Sendable {
}
2 changes: 1 addition & 1 deletion Sources/GaussianSplatSupport/SplatCloud.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public final class SplatCloud <Splat>: Equatable, @unchecked Sendable where Spla

public init(device: MTLDevice, splats: TypedMTLBuffer<Splat>) throws {
self.splats = splats
self.indexedDistances = try CPUSorter.sort(device: device, splats: splats, camera: .identity, model: .identity)
self.indexedDistances = try AsyncSortManager.sort(device: device, splats: splats, camera: .identity, model: .identity)
}

public convenience init(device: MTLDevice, capacity: Int) throws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public struct GaussianSplatLobbyView: View {
let image = skyboxGradient.image(size: .init(width: 1024, height: 1024))

guard var cgImage = ImageRenderer(content: image).cgImage else {
fatalError()
fatalError("Could not render image.")
}
let bitmapInfo: CGBitmapInfo
if cgImage.byteOrderInfo == .order32Little {
Expand Down

0 comments on commit cc03aa9

Please sign in to comment.