Skip to content

Commit

Permalink
Update to KA 0.9 & remove runtime dispatches (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Jun 1, 2023
1 parent 123f638 commit 5195701
Show file tree
Hide file tree
Showing 29 changed files with 340 additions and 372 deletions.
10 changes: 2 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ version = "0.1.0"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
Expand All @@ -20,17 +18,13 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"
ROCKernels = "7eb9e9f0-4bd3-4c4c-8bef-26bd9629d9b9"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AMDGPU = "0.4.8"
CUDA = "4.0.1"
CUDAKernels = "0.4.7"
KernelAbstractions = "0.8.6"
ROCKernels = "0.3.2"
AMDGPU = "0.4"
KernelAbstractions = "0.9"
Zygote = "0.6.55"
38 changes: 25 additions & 13 deletions src/Nerf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using ImageCore
using ImageTransformations
using JSON
using KernelAbstractions
using KernelAbstractions: @atomic
using KernelAbstractions: @atomic, unsafe_free!
using LinearAlgebra
using Preferences
using Quaternions
Expand All @@ -17,6 +17,8 @@ using StaticArrays
using Statistics
using Zygote

# TODO rand on device

include("kautils.jl")

struct Ray
Expand Down Expand Up @@ -82,25 +84,33 @@ include("models/basic.jl")
include("marching_cubes/marching_cubes.jl")
include("marching_tetrahedra/marching_tetrahedra.jl")

@info "[Nerf.jl] Backend: $BACKEND"
@info "[Nerf.jl] Device: $DEVICE"
function sync_free!(Backend, args...)
unsafe_free!.(args)
end

@info "[Nerf.jl] Backend: $BACKEND_NAME"
@info "[Nerf.jl] Device: $Backend"

# TODO
# - use Flux for models
# - non-allocating renderer (except NN part)
# - get rid of sync_free

function main()
dev = DEVICE
config_file = joinpath(pkgdir(Nerf), "data", "raccoon_sofa2", "transforms.json")
dataset = Dataset(dev; config_file)
dataset = Dataset(Backend; config_file)

model = BasicModel(BasicField(dev))
trainer = Trainer(model, dataset)
model = BasicModel(BasicField(Backend))
trainer = Trainer(model, dataset; n_rays=512)

camera = Camera(MMatrix{3, 4, Float32}(I), dataset.intrinsics)
renderer = Renderer(dev, camera, trainer.bbox, trainer.cone)
renderer = Renderer(Backend, camera, trainer.bbox, trainer.cone)

for i in 1:20_000
loss = step!(trainer)
@show i, loss

i % 1000 == 0 || continue
i % 250 == 0 || continue

pose_idx = clamp(round(Int, rand() * length(dataset)), 1, length(dataset))
set_projection!(camera, get_pose(dataset, pose_idx)...)
Expand Down Expand Up @@ -132,9 +142,9 @@ end

function benchmark()
config_file = joinpath(pkgdir(Nerf), "data", "raccoon_sofa2", "transforms.json")
dataset = Dataset(DEVICE; config_file)
model = BasicModel(BasicField(DEVICE))
trainer = Trainer(model, dataset)
dataset = Dataset(Backend; config_file)
model = BasicModel(BasicField(Backend))
trainer = Trainer(model, dataset; n_rays=512)

# GC.enable_logging(true)

Expand All @@ -143,9 +153,11 @@ function benchmark()
@time trainer_benchmark(trainer, 10)
@time trainer_benchmark(trainer, 1000)

return nothing

camera = Camera(MMatrix{3, 4, Float32}(I), dataset.intrinsics)
set_projection!(camera, get_pose(dataset, 1)...)
renderer = Renderer(DEVICE, camera, trainer.bbox, trainer.cone)
renderer = Renderer(Backend, camera, trainer.bbox, trainer.cone)

Core.println("Renderer benchmark")

Expand Down
62 changes: 32 additions & 30 deletions src/acceleration/occupancy.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
include("morton.jl")
include("indexing.jl")

mutable struct OccupancyGrid{D <: AbstractArray{Float32, 4}, B <: AbstractVector{UInt8}}
mutable struct OccupancyGrid{
D <: AbstractArray{Float32, 4},
B <: AbstractVector{UInt8},
}
density::D
binary::B
mean_density::Float32
end

function OccupancyGrid(dev; n_levels::Int, resolution::Int = 128)
function OccupancyGrid(Backend; n_levels::Int, resolution::Int = 128)
dim_size = ntuple(i -> resolution, Val{3}())
density = zeros(dev, Float32, (dim_size..., n_levels))
binary = zeros(dev, UInt8, length(density) ÷ 8)
density = KernelAbstractions.zeros(Backend, Float32, (dim_size..., n_levels))
binary = KernelAbstractions.zeros(Backend, UInt8, length(density) ÷ 8)
OccupancyGrid(density, binary, 0f0)
end

Expand All @@ -27,7 +30,7 @@ end
get_voxel_diameter(UInt32(get_resolution(oc)), level)
end

get_device(::OccupancyGrid{D, B}) where {D, B} = device_from_type(D)
KernelAbstractions.get_backend(og::OccupancyGrid) = get_backend(og.density)

@inline get_resolution(oc::OccupancyGrid) = size(oc.density, 1)

Expand Down Expand Up @@ -70,61 +73,60 @@ function update!(

step ÷= update_frequency

dev = get_device(oc)
points = similar(dev, SVector{3, Float32}, (n_samples,))
indices = similar(dev, UInt32, (n_samples,))
Backend = get_backend(oc)
points = allocate(Backend, SVector{3, Float32}, (n_samples,))
indices = allocate(Backend, UInt32, (n_samples,))

gp_kernel = generate_points!(dev)
wait(gp_kernel(
gp_kernel = generate_points!(Backend)
gp_kernel(
points, indices, rng_state, density, bbox,
-0.01f0, UInt32(step); ndrange=n_uniform))
-0.01f0, UInt32(step); ndrange=n_uniform)
rng_state = advance(rng_state)
if n_non_uniform > 0
offset = (n_uniform + 1):n_samples
wait(gp_kernel(
gp_kernel(
@view(points[offset]), @view(indices[offset]), rng_state,
density, bbox, threshold, UInt32(step); ndrange=n_non_uniform))
density, bbox, threshold, UInt32(step); ndrange=n_non_uniform)
end
rng_state = advance(rng_state)

raw_points = reshape(reinterpret(Float32, points), 3, :)
log_densities = density_eval_fn(raw_points)
unsafe_free!(points)
sync_free!(Backend, points)

tmp_density = zeros(dev, Float32, size(oc.density))
wait(distribute_density!(dev)(
tmp_density = KernelAbstractions.zeros(Backend, Float32, size(oc.density))
distribute_density!(Backend)(
reinterpret(UInt32, tmp_density), log_densities,
indices, cone.min_stepsize; ndrange=length(indices)))
unsafe_free!(indices)
unsafe_free!(log_densities)
indices, cone.min_stepsize; ndrange=length(indices))
sync_free!(Backend, indices, log_densities)

wait(ema_update!(dev)(
oc.density, tmp_density, decay; ndrange=length(oc.density)))
unsafe_free!(tmp_density)
ema_update!(Backend)(
oc.density, tmp_density, decay; ndrange=length(oc.density))
sync_free!(Backend, tmp_density)

update_binary!(oc; threshold)
return rng_state
end

function update_binary!(oc::OccupancyGrid; threshold::Float32 = 0.01f0)
dev = get_device(oc)
Backend = get_backend(oc)

oc.mean_density = mean(x -> max(0f0, x), @view(oc.density[:, :, :, 1]))
threshold = min(threshold, oc.mean_density)
wait(distribute_to_binary!(dev)(
oc.binary, oc.density, threshold; ndrange=length(oc.binary)))
distribute_to_binary!(Backend)(
oc.binary, oc.density, threshold; ndrange=length(oc.binary))

binary_level_length = offset_binary(oc, 1)
binary_resolution = UInt32(size(oc.density, 1) ÷ 8)
ndrange = binary_level_length ÷ 8
n_levels = size(oc.density, 4)

bmp_kernel = binary_max_pool!(dev)
bmp_kernel = binary_max_pool!(Backend)
for l in 1:(n_levels - 1)
s, m, e = binary_level_length .* ((l - 1), l, (l + 1))
prev_level = @view(oc.binary[(s + 1):m])
curr_level = @view(oc.binary[(m + 1):e])
wait(bmp_kernel(curr_level, prev_level, binary_resolution; ndrange))
bmp_kernel(curr_level, prev_level, binary_resolution; ndrange)
end
end

Expand Down Expand Up @@ -220,11 +222,11 @@ end
function mark_invisible_regions!(
oc::OccupancyGrid; intrinsics, rotations, translations,
)
dev = get_device(oc)
Backend = get_backend(oc)
res_scale = 0.5f0 .* intrinsics.resolution ./ intrinsics.focal
wait(_mark_invisible_regions!(dev)(
_mark_invisible_regions!(Backend)(
oc.density, rotations, translations, res_scale;
ndrange=length(oc.density)))
ndrange=length(oc.density))
end

@kernel function _mark_invisible_regions!(
Expand Down
26 changes: 12 additions & 14 deletions src/data/dataset.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
include("images.jl")
include("intrinsics.jl")

struct Dataset{D, R, T}
images::D
struct Dataset{
I <: Images,
R <: AbstractVector{SMatrix{3, 3, Float32, 9}},
T <: AbstractVector{SVector{3, Float32}},
C <: CameraIntrinsics,
}
images::I
rotations::R
translations::T
intrinsics::CameraIntrinsics
intrinsics::C

frame_filenames::Vector{String}
rotations_host::Vector{SMatrix{3, 3, Float32, 9}}
Expand All @@ -17,7 +22,7 @@ struct Dataset{D, R, T}
end

function Dataset(
dev; config_file::String,
backend; config_file::String,
scale::Float32 = 0.33f0,
offset = SVector{3, Float32}(0.5f0, 0.5f0, 0.5f0),
)
Expand All @@ -32,7 +37,6 @@ function Dataset(
has_metadata = (
"w" in keys(config) && "fl_x" in keys(config) &&
"cx" in keys(config) && "k1" in keys(config))
is_synthetic = !has_metadata

# HACK:
# config file for synthetic NeRF datasets does not specify file extension,
Expand Down Expand Up @@ -82,15 +86,9 @@ function Dataset(
intrinsics = CameraIntrinsics(width(images), height(images), fov)
end

if dev isa CPU
device_rotations = rotations
device_translations = translations
device_images = images
else
device_rotations = to_device(dev, rotations)
device_translations = to_device(dev, translations)
device_images = adapt(type_from_device(dev), images)
end
device_rotations = adapt(backend, rotations)
device_translations = adapt(backend, translations)
device_images = adapt(backend, images)
Dataset(
device_images, device_rotations, device_translations, intrinsics,
frame_filenames, rotations, translations, scale, offset, bbox_scale)
Expand Down
1 change: 1 addition & 0 deletions src/data/images.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ end
@inline function sample(i::Images, xy::SVector{2, Float32}, image_idx::UInt32)
width::UInt32, height::UInt32 = size(i.data, 2), size(i.data, 3)
pixel = to_pixel(xy, width, height)
# TODO inbounds
SVector{3, Float32}(
i.data[1, pixel[1], pixel[2], image_idx],
i.data[2, pixel[1], pixel[2], image_idx],
Expand Down
Loading

0 comments on commit 5195701

Please sign in to comment.