Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for GPU type 3 transform #69

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ This is a full-featured Julia interface to [FINUFFT](https://github.com/flatiron

## Installation

FINUFFT.jl requires Julia v1.6 or later, and has been tested up to v1.10. From the Pkg REPL mode (hit `]` in REPL to enter), run
FINUFFT.jl requires Julia v1.6 or later, and has been tested up to v1.11. From the Pkg REPL mode (hit `]` in REPL to enter), run

```julia
add FINUFFT
Expand Down Expand Up @@ -105,7 +105,7 @@ see [examples/time2d1.jl](examples/time2d1.jl)

Finally, the more involved codes [test/test_nufft.jl](test/test_nufft.jl)
and [test/test_cuda.jl](test/test_cuda.jl)
tests `dtype=Float64` and `dtype=Float32` precisions for all supported transform types, and can be used as references.
test `dtype=Float64` and `dtype=Float32` precisions for all supported transform types, and can be used as references.
The outputs are tested there for mathematical correctness.
In the 1D type 1 it also tests a vectorized simple, a guru call and
a vectorized guru call.
Expand Down
27 changes: 25 additions & 2 deletions src/cufinufft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
status = true
if !CUDA.functional()
status = false
@warn "CUDA installation is not functional"

Check warning on line 21 in src/cufinufft.jl

View workflow job for this annotation

GitHub Actions / Documentation

CUDA installation is not functional
end
if !cufinufft_jll.is_available()
status = false
@warn "cuFINUFFT binary is not available on this platform"

Check warning on line 25 in src/cufinufft.jl

View workflow job for this annotation

GitHub Actions / Documentation

cuFINUFFT binary is not available on this platform
end
USE_CUDA[] = status
end
Expand Down Expand Up @@ -122,15 +122,14 @@

n_modes = ones(Int64,3)
if type==3
throw("Type 3 not implemented yet")
@assert ndims(n_modes_or_dim) == 0
dim = n_modes_or_dim
else
@assert length(n_modes_or_dim)<=3 && length(n_modes_or_dim)>=1
dim = length(n_modes_or_dim)
n_modes[1:dim] .= n_modes_or_dim
end

if dtype==Float64
tol = Float64(eps)
ret = ccall( (:cufinufft_makeplan, libcufinufft),
Expand Down Expand Up @@ -414,6 +413,30 @@
plan.plan_ptr,output,input
)
end
elseif type==3
nk = plan.nk
if ntrans==1
@assert size(output)==(nk,ntrans) || size(output)==(nk,)

Check warning on line 419 in src/cufinufft.jl

View check run for this annotation

Codecov / codecov/patch

src/cufinufft.jl#L416-L419

Added lines #L416 - L419 were not covered by tests
else
@assert size(output)==(nk,ntrans)

Check warning on line 421 in src/cufinufft.jl

View check run for this annotation

Codecov / codecov/patch

src/cufinufft.jl#L421

Added line #L421 was not covered by tests
end
if T==Float64
ret = ccall( (:cufinufft_execute, libcufinufft),

Check warning on line 424 in src/cufinufft.jl

View check run for this annotation

Codecov / codecov/patch

src/cufinufft.jl#L423-L424

Added lines #L423 - L424 were not covered by tests
Cint,
(cufinufft_plan_c,
CuRef{ComplexF64},
CuRef{ComplexF64}),
plan.plan_ptr,input,output
)
else
ret = ccall( (:cufinufftf_execute, libcufinufft),

Check warning on line 432 in src/cufinufft.jl

View check run for this annotation

Codecov / codecov/patch

src/cufinufft.jl#L432

Added line #L432 was not covered by tests
Cint,
(cufinufft_plan_c,
CuRef{ComplexF32},
CuRef{ComplexF32}),
plan.plan_ptr,input,output
)
end
else
ret = ERR_TYPE_NOTVALID
end
Expand Down
101 changes: 101 additions & 0 deletions src/cufinufft_simple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,35 @@
cufinufft_destroy!(plan)
end

"""
nufft1d3!(xj :: CuArray{Float64} or CuArray{Float32},
cj :: CuArray{ComplexF64} or CuArray{ComplexF32},
iflag :: Integer,
eps :: Real,
sk :: CuArray{Float64} or CuArray{Float32},
fk :: CuArray{ComplexF64} or CuArray{ComplexF32};
kwargs...
)

CUDA version.
"""
function nufft1d3!(xj :: CuArray{T},

Check warning on line 75 in src/cufinufft_simple.jl

View check run for this annotation

Codecov / codecov/patch

src/cufinufft_simple.jl#L75

Added line #L75 was not covered by tests
cj :: CuArray{Complex{T}},
iflag :: Integer,
eps :: Real,
sk :: CuArray{T},
fk :: CuArray{Complex{T}};
kwargs...) where T <: finufftReal
(nj, nk) = valid_setpts(3,1,xj,T[],T[],sk)
ntrans = valid_ntr(xj,cj)

Check warning on line 83 in src/cufinufft_simple.jl

View check run for this annotation

Codecov / codecov/patch

src/cufinufft_simple.jl#L82-L83

Added lines #L82 - L83 were not covered by tests

checkkwdtype(T; kwargs...)
plan = _cufinufft_makeplan(T,3,1,iflag,ntrans,eps;kwargs...)
cufinufft_setpts!(plan,xj,CuVector{T}(),CuVector{T}(),sk)
cufinufft_exec!(plan,cj,fk)
cufinufft_destroy!(plan)

Check warning on line 89 in src/cufinufft_simple.jl

View check run for this annotation

Codecov / codecov/patch

src/cufinufft_simple.jl#L85-L89

Added lines #L85 - L89 were not covered by tests
end

## 2D

"""
Expand Down Expand Up @@ -123,6 +152,40 @@
cufinufft_destroy!(plan)
end

"""
nufft2d3!(xj :: CuArray{Float64} or CuArray{Float32},
yj :: CuArray{Float64} or CuArray{Float32},
cj :: CuArray{ComplexF64} or CuArray{ComplexF32},
iflag :: Integer,
eps :: Real,
sk :: CuArray{Float64} or CuArray{Float32},
tk :: CuArray{Float64} or CuArray{Float32},
fk :: CuArray{ComplexF64} or CuArray{ComplexF32};
kwargs...
)

CUDA version.
"""
function nufft2d3!(xj :: CuArray{T},

Check warning on line 169 in src/cufinufft_simple.jl

View check run for this annotation

Codecov / codecov/patch

src/cufinufft_simple.jl#L169

Added line #L169 was not covered by tests
yj :: CuArray{T},
cj :: CuArray{Complex{T}},
iflag :: Integer,
eps :: Real,
sk :: CuArray{T},
tk :: CuArray{T},
fk :: CuArray{Complex{T}};
kwargs...) where T <: finufftReal
(nj, nk) = valid_setpts(3,2,xj,yj,T[],sk,tk)
ntrans = valid_ntr(xj,cj)

Check warning on line 179 in src/cufinufft_simple.jl

View check run for this annotation

Codecov / codecov/patch

src/cufinufft_simple.jl#L178-L179

Added lines #L178 - L179 were not covered by tests

checkkwdtype(T; kwargs...)
plan = _cufinufft_makeplan(T,3,2,iflag,ntrans,eps;kwargs...)
cufinufft_setpts!(plan,xj,yj,CuVector{T}(),sk,tk)
cufinufft_exec!(plan,cj,fk)
cufinufft_destroy!(plan)

Check warning on line 185 in src/cufinufft_simple.jl

View check run for this annotation

Codecov / codecov/patch

src/cufinufft_simple.jl#L181-L185

Added lines #L181 - L185 were not covered by tests
end


## 3D

"""
Expand Down Expand Up @@ -188,3 +251,41 @@
cufinufft_exec!(plan,fk,cj)
cufinufft_destroy!(plan)
end

"""
nufft3d3!(xj :: CuArray{Float64} or CuArray{Float32},
yj :: CuArray{Float64} or CuArray{Float32},
zj :: CuArray{Float64} or CuArray{Float32},
cj :: CuArray{ComplexF64} or CuArray{ComplexF32},
iflag :: Integer,
eps :: Real,
sk :: CuArray{Float64} or CuArray{Float32},
tk :: CuArray{Float64} or CuArray{Float32},
uk :: CuArray{Float64} or CuArray{Float32},
fk :: CuArray{ComplexF64} or CuArray{ComplexF32};
kwargs...
)

CUDA version.
"""
function nufft3d3!(xj :: CuArray{T},

Check warning on line 271 in src/cufinufft_simple.jl

View check run for this annotation

Codecov / codecov/patch

src/cufinufft_simple.jl#L271

Added line #L271 was not covered by tests
yj :: CuArray{T},
zj :: CuArray{T},
cj :: CuArray{Complex{T}},
iflag :: Integer,
eps :: Real,
sk :: CuArray{T},
tk :: CuArray{T},
uk :: CuArray{T},
fk :: CuArray{Complex{T}};
kwargs...) where T <: finufftReal
(nj, nk) = valid_setpts(3,3,xj,yj,zj,sk,tk,uk)
ntrans = valid_ntr(xj,cj)

Check warning on line 283 in src/cufinufft_simple.jl

View check run for this annotation

Codecov / codecov/patch

src/cufinufft_simple.jl#L282-L283

Added lines #L282 - L283 were not covered by tests

checkkwdtype(T; kwargs...)
plan = _cufinufft_makeplan(T,3,3,iflag,ntrans,eps;kwargs...)
cufinufft_setpts!(plan,xj,yj,zj,sk,tk,uk)
cufinufft_exec!(plan,cj,fk)
cufinufft_destroy!(plan)

Check warning on line 289 in src/cufinufft_simple.jl

View check run for this annotation

Codecov / codecov/patch

src/cufinufft_simple.jl#L285-L289

Added lines #L285 - L289 were not covered by tests
end

5 changes: 4 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ const nufft_c_opts = nufft_opts # for backward compatibility - remove?
gpu_stream :: Ptr{Cvoid}
modeord :: Cint # (type 1,2 only): 0 CMCL-style increasing mode order
# 1 FFT-style mode order
debug :: Cint # 0: no debug, 1: debug
end

Options struct passed to cuFINUFFT, see C documentation.
Expand Down Expand Up @@ -162,6 +163,8 @@ mutable struct cufinufft_opts
gpu_stream :: Ptr{Cvoid}

modeord :: Cint # (type 1,2 only): 0 CMCL-style increasing mode order
# 1 FFT-style mode order
# 1 FFT-style mode order

debug :: Cint # 0: no debug, 1: debug
cufinufft_opts() = new()
end
45 changes: 45 additions & 0 deletions test/test_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ function test_cuda(tol::Real, dtype::DataType)
y_d = CuArray(y)
z_d = CuArray(z)
c_d = CuArray(c)
s_d = CuArray(s)
t_d = CuArray(t)
u_d = CuArray(u)
F1D_d = CuArray(F1D)
F2D_d = CuArray(F2D)
F3D_d = CuArray(F3D)
Expand Down Expand Up @@ -128,6 +131,21 @@ function test_cuda(tol::Real, dtype::DataType)
relerr_1d2_guru = norm(vec(out2)-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_1d2_guru < errfac*tol
end

# 1D3
@testset "1D3" begin
out_d = CUDA.zeros(Complex{T},nk)
ref = zeros(Complex{T},nk)
for k=1:nk
for j=1:nj
ref[k] += c[j] * exp(1im*s[k]*x[j])
end
end
nufft1d3!(x_d,c_d,1,tol,s_d,out_d)
relerr_1d3 = norm(vec(Array(out_d))-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_1d3 < errfac*tol
end

end

## 2D
Expand Down Expand Up @@ -177,6 +195,20 @@ function test_cuda(tol::Real, dtype::DataType)
relerr_2d2_guru = norm(vec(out2)-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_2d2_guru < errfac*tol
end

@testset "2D3" begin
# 2D3
out_d = CUDA.zeros(Complex{T},nk)
ref = zeros(Complex{T},nk)
for k=1:nk
for j=1:nj
ref[k] += c[j] * exp(1im*(s[k]*x[j]+t[k]*y[j]))
end
end
nufft2d3!(x_d,y_d,c_d,1,tol,s_d,t_d,out_d)
relerr_2d3 = norm(vec(Array(out_d))-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_2d3 < errfac*tol
end
end

## 3D
Expand Down Expand Up @@ -231,6 +263,19 @@ function test_cuda(tol::Real, dtype::DataType)
@test relerr_3d2_guru < errfac*tol
end

@testset "3D3" begin
# 3D3
out_d = CUDA.zeros(Complex{T},nk)
ref = zeros(Complex{T},nk)
for k=1:nk
for j=1:nj
ref[k] += c[j] * exp(1im*(s[k]*x[j]+t[k]*y[j]+u[k]*z[j]))
end
end
nufft3d3!(x_d,y_d,z_d,c_d,1,tol,s_d,t_d,u_d,out_d)
relerr_3d3 = norm(vec(Array(out_d))-vec(ref), Inf) / norm(vec(ref), Inf)
@test relerr_3d3 < errfac*tol
end
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/test_nufft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ function test_nufft(tol::Real, dtype::DataType)
@test reldiff < errdifffac*tol
end

@testset "3D3" begin
@testset "2D3" begin
# 2D3
out = zeros(Complex{T},nk)
ref = zeros(Complex{T},nk)
Expand Down
Loading