Skip to content

Commit

Permalink
Refactor CPU sorter.
Browse files Browse the repository at this point in the history
  • Loading branch information
schwa committed Oct 14, 2024
1 parent e621e6a commit ad237a6
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 77 deletions.
67 changes: 0 additions & 67 deletions Sources/GaussianSplatSupport/CPUSorter.swift

This file was deleted.

4 changes: 2 additions & 2 deletions Sources/GaussianSplatSupport/GaussianSplatViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public class GaussianSplatViewModel <Splat> where Splat: SplatProtocol {
@ObservationIgnored
private var logger: Logger?

var cpuSorter: CPUSorter<Splat>?
var cpuSorter: AsyncSortManager<Splat>?
var cpuSorterTask: Task<Void, Never>?

// TODO: bang and try!
Expand Down Expand Up @@ -116,7 +116,7 @@ public class GaussianSplatViewModel <Splat> where Splat: SplatProtocol {
loadProgress.completedUnitCount = Int64(splatCloud.count)
loadProgress.totalUnitCount = Int64(splatCloud.count)

let cpuSorter = try CPUSorter<Splat>(device: device, splatCloud: splatCloud, capacity: splatCloud.capacity)
let cpuSorter = try AsyncSortManager<Splat>(device: device, splatCloud: splatCloud, capacity: splatCloud.capacity)
let cpuSorterTask = Task {
for await splatIndices in await cpuSorter.sortedIndicesChannel().buffer(policy: .bufferingLatest(1)) {
Traces.shared.trace(name: "Sorted Splats")
Expand Down
60 changes: 60 additions & 0 deletions Sources/GaussianSplatSupport/Sorting/AsyncSortManager.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import AsyncAlgorithms
import BaseSupport
import GaussianSplatShaders
@preconcurrency import Metal
import MetalSupport
import os
import simd
import SIMDSupport

internal actor AsyncSortManager <Splat> where Splat: SplatProtocol {
private var splatCloud: SplatCloud<Splat>
private var _sortRequestChannel: AsyncChannel<SortState> = .init()
private var _sortedIndicesChannel: AsyncChannel<SplatIndices> = .init()
private var logger: Logger? = Logger()
private var sorter: CPUSplatRadixSorter<Splat>

internal init(device: MTLDevice, splatCloud: SplatCloud<Splat>, capacity: Int) throws {
self.sorter = .init(device: device, capacity: capacity)
self.splatCloud = splatCloud
Task(priority: .high) {
do {
try await self.sort()
}
catch is CancellationError {
}
catch {
await logger?.log("Failed to sort splats: \(error)")
}
}
}

internal func sortedIndicesChannel() -> AsyncChannel<SplatIndices> {
_sortedIndicesChannel
}

nonisolated
internal func requestSort(camera: simd_float4x4, model: simd_float4x4, count: Int) {
Task {
await _sortRequestChannel.send(.init(camera: camera, model: model, count: count))
}
}

internal func sort() async throws {
// swiftlint:disable:next empty_count
for await state in _sortRequestChannel.removeDuplicates() where state.count > 0 {
let currentIndexedDistances = try sorter.sort(splats: splatCloud.splats, camera: state.camera, model: state.model)
await self._sortedIndicesChannel.send(.init(state: state, indices: currentIndexedDistances))
}
}
}

// MARK: -

internal extension AsyncSortManager {
static func sort(device: MTLDevice, splats: TypedMTLBuffer<Splat>, camera: simd_float4x4, model: simd_float4x4) throws -> SplatIndices {
let sorter = CPUSplatRadixSorter<Splat>(device: device, capacity: splats.capacity)
let indices = try sorter.sort(splats: splats, camera: camera, model: model)
return .init(state: .init(camera: camera, model: model, count: splats.count), indices: indices)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,20 @@ import SIMDSupport

internal class CPUSplatRadixSorter <Splat> where Splat: SplatProtocol {
private var device: MTLDevice
private var splatCloud: SplatCloud<Splat>
private var temporaryIndexedDistances: [IndexedDistance]
private var capacity: Int
private var logger: Logger? = Logger()

internal init(device: MTLDevice, splatCloud: SplatCloud<Splat>, capacity: Int) throws {
internal init(device: MTLDevice, capacity: Int) {
self.device = device
self.capacity = capacity
releaseAssert(capacity > 0, "You shouldn't be creating a CPUSorter with a capacity of zero.")
self.splatCloud = splatCloud
releaseAssert(capacity > 0, "You shouldn't be creating a sorter with a capacity of zero.")
temporaryIndexedDistances = .init(repeating: .init(), count: capacity)
}

internal func sort(camera: simd_float4x4, model: simd_float4x4) async throws -> TypedMTLBuffer<IndexedDistance> {
internal func sort(splats: TypedMTLBuffer<Splat>, camera: simd_float4x4, model: simd_float4x4) throws -> TypedMTLBuffer<IndexedDistance> {
var currentIndexedDistances = try device.makeTypedBuffer(element: IndexedDistance.self, capacity: capacity).labelled("Splats-IndexDistances")
cpuRadixSort(splats: splatCloud.splats, indexedDistances: &currentIndexedDistances, temporaryIndexedDistances: &temporaryIndexedDistances, camera: camera, model: model)
cpuRadixSort(splats: splats, indexedDistances: &currentIndexedDistances, temporaryIndexedDistances: &temporaryIndexedDistances, camera: camera, model: model)
return currentIndexedDistances
}
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/GaussianSplatSupport/SplatCloud.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public final class SplatCloud <Splat>: Equatable, @unchecked Sendable where Spla

public init(device: MTLDevice, splats: TypedMTLBuffer<Splat>) throws {
self.splats = splats
self.indexedDistances = try CPUSorter.sort(device: device, splats: splats, camera: .identity, model: .identity)
self.indexedDistances = try AsyncSortManager.sort(device: device, splats: splats, camera: .identity, model: .identity)
}

public convenience init(device: MTLDevice, capacity: Int) throws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public struct GaussianSplatLobbyView: View {
let image = skyboxGradient.image(size: .init(width: 1024, height: 1024))

guard var cgImage = ImageRenderer(content: image).cgImage else {
fatalError()
fatalError("Could not render image.")
}
let bitmapInfo: CGBitmapInfo
if cgImage.byteOrderInfo == .order32Little {
Expand Down

0 comments on commit ad237a6

Please sign in to comment.