Skip to content

Commit

Permalink
Improve inference in _SpectralElementGrid2D
Browse files Browse the repository at this point in the history
charleskawczynski committed Jun 16, 2024
1 parent 9c72fc3 commit 7c0f864
Showing 2 changed files with 15 additions and 10 deletions.
23 changes: 14 additions & 9 deletions src/Grids/spectralelement.jl
Original file line number Diff line number Diff line change
@@ -171,6 +171,15 @@ function SpectralElementGrid2D(
end
end

function get_CoordType2D(topology)
domain = Topologies.domain(topology)
return if domain isa Domains.SphereDomain
FT = Domains.float_type(domain)
Geometry.LatLongPoint{FT} # Domains.coordinate_type(topology)
else
Topologies.coordinate_type(topology)
end
end

function _SpectralElementGrid2D(
topology::Topologies.Topology2D,
@@ -195,17 +204,13 @@ function _SpectralElementGrid2D(
# 1. allocate buffers externally
DA = ClimaComms.array_type(topology)
domain = Topologies.domain(topology)
if domain isa Domains.SphereDomain
CoordType3D = Topologies.coordinate_type(topology)
FT = Geometry.float_type(CoordType3D)
CoordType2D = Geometry.LatLongPoint{FT} # Domains.coordinate_type(topology)
global_geometry =
Geometry.SphericalGlobalGeometry(topology.mesh.domain.radius)
FT = Domains.float_type(domain)
global_geometry = if domain isa Domains.SphereDomain
Geometry.SphericalGlobalGeometry(topology.mesh.domain.radius)
else
CoordType2D = Topologies.coordinate_type(topology)
FT = Geometry.float_type(CoordType2D)
global_geometry = Geometry.CartesianGlobalGeometry()
Geometry.CartesianGlobalGeometry()
end
CoordType2D = get_CoordType2D(topology)
AIdx = Geometry.coordinate_axis(CoordType2D)
nlelems = Topologies.nlocalelems(topology)
ngelems = Topologies.nghostelems(topology)
2 changes: 1 addition & 1 deletion test/Spaces/opt_spaces.jl
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ end
space = TU.CenterExtrudedFiniteDifferenceSpace(Float32; context=ClimaComms.context())
result = JET.@report_opt Grids._SpectralElementGrid2D(Spaces.topology(space), Spaces.quadrature_style(space); enable_bubble=false)
n_found = length(JET.get_reports(result.analyzer, result.result))
n_allowed = 351
n_allowed = 189
@test n_found n_allowed
n_found < n_allowed && @info "Inference may have improved. (found, allowed) = ($n_found, $n_allowed)"
end

0 comments on commit 7c0f864

Please sign in to comment.