Skip to content

Commit

Permalink
Merge pull request #1554 from CliMA/sb/point-space-adapt
Browse files Browse the repository at this point in the history
Define Adapt for PointSpace
  • Loading branch information
Sbozzolo authored Nov 18, 2023
2 parents bcda120 + 6719dfa commit ee9badd
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ steps:
agents:
slurm_gpus: 1

- label: "Unit: cuda point spaces"
key: "point_space_cuda"
command: "julia --color=yes --check-bounds=yes --project=test test/Spaces/point_cuda.jl CUDA"
artifact_paths:
- output/point_cuda
agents:
slurm_gpus: 1

- label: "Unit: cuda dss 2-process test"
key: "gpu_ddss2_test"
command:
Expand Down
4 changes: 4 additions & 0 deletions src/Spaces/pointspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ function PointSpace(
return PointSpace(Adapt.adapt(ArrayType, local_geometry_data))
end


Adapt.adapt_structure(to, space::PointSpace) =
PointSpace(Adapt.adapt(to, space.local_geometry))

function PointSpace(
context::ClimaComms.AbstractCommsContext,
coord::Geometry.Abstract1DPoint{FT},
Expand Down
23 changes: 23 additions & 0 deletions test/Spaces/point_cuda.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import ClimaCore
import ClimaCore: Domains, Topologies, Meshes, Spaces, Geometry, column, Fields
import ClimaComms
using Test

compare(cpu, gpu) = all(parent(cpu) .≈ Array(parent(gpu)))
compare(cpu, gpu, sym) =
all(parent(getproperty(cpu, sym)) .≈ Array(parent(getproperty(gpu, sym))))

@testset "CuArray-backed point spaces" begin
cpu_context =
ClimaComms.SingletonCommsContext(ClimaComms.CPUSingleThreaded())
gpu_context = ClimaComms.SingletonCommsContext(ClimaComms.CUDADevice())

point = Geometry.ZPoint(1.0)

cpuspace = Spaces.PointSpace(cpu_context, point)
gpuspace = Spaces.PointSpace(gpu_context, point)

# Test that all geometries match with CPU version:
@test compare(cpuspace, gpuspace, :local_geometry)

end

0 comments on commit ee9badd

Please sign in to comment.