Skip to content

Commit

Permalink
Begin to clean up CPUSorter.
Browse files Browse the repository at this point in the history
  • Loading branch information
schwa committed Oct 14, 2024
1 parent 48c021f commit e28f9a4
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions Sources/GaussianSplatSupport/CPUSorter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@ import os
import simd
import SIMDSupport

internal actor CPUSorter <Splat> where Splat: SplatProtocol {
internal static func sort(device: MTLDevice, splats: TypedMTLBuffer<Splat>, camera: simd_float4x4, model: simd_float4x4) throws -> SplatIndices {
var indexedDistances = try device.makeTypedBuffer(data: [IndexedDistance](repeating: .init(), count: splats.count)).labelled("Splats-IndexDistances-0")
var temporaryIndexedDistances = [IndexedDistance](repeating: .init(), count: splats.count)
sort(splats: splats, indexedDistances: &indexedDistances, temporaryIndexedDistances: &temporaryIndexedDistances, camera: camera, model: model)
return .init(state: .init(camera: camera, model: model, count: splats.count), indices: indexedDistances)
}
protocol CPUSorterProtocol {
associatedtype Splat: SplatProtocol
static func sort(device: MTLDevice, splats: TypedMTLBuffer<Splat>, camera: simd_float4x4, model: simd_float4x4) throws -> SplatIndices
}

internal actor CPUSorter <Splat> where Splat: SplatProtocol {
private static func sort(splats: TypedMTLBuffer<Splat>, indexedDistances: inout TypedMTLBuffer<IndexedDistance>, temporaryIndexedDistances: inout [IndexedDistance], camera: simd_float4x4, model: simd_float4x4) {
guard splats.count > 1 else {
return
}
let start = getMachTime()
releaseAssert(splats.count <= indexedDistances.capacity, "Too few indexed distances \(indexedDistances.count) for \(splats.capacity) splats.")
releaseAssert(splats.count <= temporaryIndexedDistances.count, "Too few temporary indexed distances \(temporaryIndexedDistances.count) for \(splats.count) splats.")
indexedDistances.withUnsafeMutableBufferPointer { indexedDistances in
Expand All @@ -43,8 +40,6 @@ internal actor CPUSorter <Splat> where Splat: SplatProtocol {
}
}
indexedDistances.count = splats.count
let end = getMachTime()
// print("XYZZY: \(Measurement(value: end - start, unit: UnitDuration.seconds).converted(to: UnitDuration.milliseconds))")
}

private var device: MTLDevice
Expand Down Expand Up @@ -99,6 +94,19 @@ internal actor CPUSorter <Splat> where Splat: SplatProtocol {
}
}

// MARK: -

internal extension CPUSorter {
static func sort(device: MTLDevice, splats: TypedMTLBuffer<Splat>, camera: simd_float4x4, model: simd_float4x4) throws -> SplatIndices {
var indexedDistances = try device.makeTypedBuffer(data: [IndexedDistance](repeating: .init(), count: splats.count)).labelled("Splats-IndexDistances-0")
var temporaryIndexedDistances = [IndexedDistance](repeating: .init(), count: splats.count)
sort(splats: splats, indexedDistances: &indexedDistances, temporaryIndexedDistances: &temporaryIndexedDistances, camera: camera, model: model)
return .init(state: .init(camera: camera, model: model, count: splats.count), indices: indexedDistances)
}
}

// MARK: -

extension OSAllocatedUnfairLock where State == Int {
func postIncrement() -> State {
withLock { state in
Expand Down

0 comments on commit e28f9a4

Please sign in to comment.