Skip to content

Commit

Permalink
Merge pull request #2052 from CliMA/ck/refactor
Browse files Browse the repository at this point in the history
Refactor and use less internals
  • Loading branch information
charleskawczynski authored Oct 21, 2024
2 parents 2e0c495 + e220729 commit 67dd505
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 23 deletions.
4 changes: 2 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ steps:
key: unit_data_copyto
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/unit_copyto.jl"

- label: "Unit: getindex_field"
key: unit_data_getindex_field
- label: "Unit: cartesian_field_index"
key: unit_data_cartesian_field_index
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/unit_cartesian_field_index.jl"

- label: "Unit: mapreduce"
Expand Down
12 changes: 12 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,18 @@ device_dispatch(x::MArray) = ToCPU()
@inline singleton(@nospecialize(::IH1JH2)) = IH1JH2Singleton()
@inline singleton(@nospecialize(::IV1JH2)) = IV1JH2Singleton()

@inline singleton(::Type{IJKFVH}) = IJKFVHSingleton()
@inline singleton(::Type{IJFH}) = IJFHSingleton()
@inline singleton(::Type{IFH}) = IFHSingleton()
@inline singleton(::Type{DataF}) = DataFSingleton()
@inline singleton(::Type{IJF}) = IJFSingleton()
@inline singleton(::Type{IF}) = IFSingleton()
@inline singleton(::Type{VF}) = VFSingleton()
@inline singleton(::Type{VIJFH}) = VIJFHSingleton()
@inline singleton(::Type{VIFH}) = VIFHSingleton()
@inline singleton(::Type{IH1JH2}) = IH1JH2Singleton()
@inline singleton(::Type{IV1JH2}) = IV1JH2Singleton()


include("copyto.jl")
include("fused_copyto.jl")
Expand Down
8 changes: 5 additions & 3 deletions src/InputOutput/readers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -467,16 +467,18 @@ function read_field(reader::HDF5Reader, name::AbstractString)
end
topology = Spaces.topology(space)
ArrayType = ClimaComms.array_type(topology)
data_layout = attrs(obj)["data_layout"]
DataLayout = _scan_data_layout(data_layout)
h_dim = DataLayouts.h_dim(DataLayouts.singleton(DataLayout))
if topology isa Topologies.Topology2D
nd = ndims(obj)
localidx = ntuple(d -> d < nd ? (:) : topology.local_elem_gidx, nd)
localidx =
ntuple(d -> d == h_dim ? topology.local_elem_gidx : (:), nd)
data = ArrayType(obj[localidx...])
else
data = ArrayType(read(obj))
end
data_layout = attrs(obj)["data_layout"]
Nij = size(data, findfirst("I", data_layout)[1])
DataLayout = _scan_data_layout(data_layout)
# For when `Nh` is added back to the type space
# Nhd = Nh_dim(data_layout)
# Nht = Nhd == -1 ? () : (size(data, Nhd),)
Expand Down
6 changes: 4 additions & 2 deletions src/InputOutput/writers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ end

# write fields
function write!(writer::HDF5Writer, field::Fields.Field, name::AbstractString)
values = Fields.field_values(field)
space = axes(field)
staggering = Spaces.staggering(space)
grid = Spaces.grid(space)
Expand All @@ -434,8 +435,9 @@ function write!(writer::HDF5Writer, field::Fields.Field, name::AbstractString)
if topology isa Topologies.Topology2D &&
!(writer.context isa ClimaComms.SingletonCommsContext)
nelems = Topologies.nelems(topology)
dims = ntuple(d -> d == nd ? nelems : size(array, d), nd)
localidx = ntuple(d -> d < nd ? (:) : topology.local_elem_gidx, nd)
h_dim = DataLayouts.h_dim(DataLayouts.singleton(values))
dims = ntuple(d -> d == h_dim ? nelems : size(array, d), nd)
localidx = ntuple(d -> d == h_dim ? topology.local_elem_gidx : (:), nd)
dataset = create_dataset(
writer.file,
"fields/$name",
Expand Down
6 changes: 3 additions & 3 deletions src/Topologies/dss_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,17 +269,17 @@ function create_ghost_buffer(
if data isa DataLayouts.IJFH
send_data = DataLayouts.IJFH{S, Nij}(typeof(parent(data)), Nhsend)
recv_data = DataLayouts.IJFH{S, Nij}(typeof(parent(data)), Nhrec)
k = stride(parent(send_data), 4)
else
Nv, _, _, Nf, _ = DataLayouts.farray_size(data)
Nv = DataLayouts.nlevels(data)
Nf = DataLayouts.ncomponents(data)
send_data = DataLayouts.VIJFH{S, Nv, Nij}(
similar(parent(data), (Nv, Nij, Nij, Nf, Nhsend)),
)
recv_data = DataLayouts.VIJFH{S, Nv, Nij}(
similar(parent(data), (Nv, Nij, Nij, Nf, Nhrec)),
)
k = stride(parent(send_data), 5)
end
k = stride(parent(send_data), DataLayouts.h_dim(data))

graph_context = ClimaComms.graph_context(
topology.context,
Expand Down
15 changes: 9 additions & 6 deletions test/Fields/unit_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,15 @@ end
Nh = n1 * n2
space = spectral_space_2D(n1 = n1, n2 = n2, Nij = Nij)

nt_field = Fields.Field(
IJFH{NamedTuple{(:a, :b), Tuple{Float64, Float64}}, Nij}(
ones(Nij, Nij, 2, Nh),
),
space,
)
S = NamedTuple{(:a, :b), Tuple{Float64, Float64}}
context = ClimaComms.context(space)
device = ClimaComms.device(context)
ArrayType = ClimaComms.array_type(device)
FT = Spaces.undertype(space)
data = IJFH{S}(ArrayType{FT}, ones; Nij, Nh)

nt_field = Fields.Field(data, space)

nt_sum = sum(nt_field)
@test nt_sum isa NamedTuple{(:a, :b), Tuple{Float64, Float64}}
@test nt_sum.a 8.0 * 10.0 rtol = 10eps()
Expand Down
2 changes: 1 addition & 1 deletion test/InputOutput/hybrid3dcubedsphere.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
ᶜlocal_geometry = Fields.local_geometry_field(center_space)
ᶠlocal_geometry = Fields.local_geometry_field(face_space)

Y = Fields.FieldVector(c = ᶜlocal_geometry, f = ᶠlocal_geometry)
Y = Fields.FieldVector(; c = ᶜlocal_geometry, f = ᶠlocal_geometry)

# write field vector to hdf5 file
writer = InputOutput.HDF5Writer(filename, comms_ctx)
Expand Down
2 changes: 1 addition & 1 deletion test/InputOutput/hybrid3dcubedsphere_topography.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ end
ᶜlocal_geometry = Fields.local_geometry_field(center_space)
ᶠlocal_geometry = Fields.local_geometry_field(face_space)

Y = Fields.FieldVector(c = ᶜlocal_geometry, f = ᶠlocal_geometry)
Y = Fields.FieldVector(; c = ᶜlocal_geometry, f = ᶠlocal_geometry)

# write field vector to hdf5 file
writer = InputOutput.HDF5Writer(filename, comms_ctx)
Expand Down
5 changes: 0 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ UnitTest("Spectral elem - sphere diffusion" ,"Operators/spectralelement/s
UnitTest("Spectral elem - sphere diffusion vec" ,"Operators/spectralelement/sphere_diffusion_vec.jl"),
UnitTest("Spectral elem - sphere hyperdiff" ,"Operators/spectralelement/unit_sphere_hyperdiffusion.jl"),
UnitTest("Spectral elem - sphere hyperdiff vec" ,"Operators/spectralelement/unit_sphere_hyperdiffusion_vec.jl"),
# UnitTest("Spectral elem - sphere hyperdiff vec" ,"Operators/spectralelement/sphere_geometry_distributed.jl"), # MPI-only
UnitTest("FD ops - column" ,"Operators/finitedifference/unit_column.jl"),
UnitTest("FD ops - opt" ,"Operators/finitedifference/opt.jl"),
UnitTest("FD ops - wfact" ,"Operators/finitedifference/wfact.jl"),
UnitTest("FD ops - linsolve" ,"Operators/finitedifference/linsolve.jl"),
# UnitTest("FD ops - examples" ,"Operators/finitedifference/opt_examples.jl"), # only opt tests? (check coverage)
UnitTest("Hybrid - 2D" ,"Operators/hybrid/unit_2d.jl"),
UnitTest("Hybrid - 3D" ,"Operators/hybrid/unit_3d.jl"),
UnitTest("Hybrid - dss opt" ,"Operators/hybrid/dss_opt.jl"),
Expand Down Expand Up @@ -89,15 +87,12 @@ UnitTest("MatrixFields - non-scalar broadcasting (3)" ,"MatrixFields/matrix_fiel
UnitTest("MatrixFields - non-scalar broadcasting (4)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_4.jl"),
UnitTest("MatrixFields - non-scalar broadcasting (5)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_5.jl"),
UnitTest("MatrixFields - flat spaces" ,"MatrixFields/flat_spaces.jl"),

# UnitTest("MatrixFields - matrix field broadcast" ,"MatrixFields/matrix_field_broadcasting.jl"), # too long
# UnitTest("MatrixFields - operator matrices" ,"MatrixFields/operator_matrices.jl"), # too long
# UnitTest("MatrixFields - field matrix solvers" ,"MatrixFields/field_matrix_solvers.jl"), # too long
UnitTest("Hypsography - 2d" ,"Hypsography/2d.jl"),
UnitTest("Hypsography - 3d sphere" ,"Hypsography/3dsphere.jl"),
UnitTest("Remapping" ,"Operators/remapping.jl"),
UnitTest("Limiter" ,"Limiters/limiter.jl"),
# UnitTest("Limiter" ,"Limiters/distributed/dlimiter.jl"), # requires MPI
UnitTest("InputOutput - hdf5" ,"InputOutput/hdf5.jl"),
UnitTest("InputOutput - spectralelement2d" ,"InputOutput/spectralelement2d.jl"),
UnitTest("InputOutput - hybrid2dbox" ,"InputOutput/hybrid2dbox.jl"),
Expand Down

0 comments on commit 67dd505

Please sign in to comment.