Skip to content

Commit

Permalink
Merge pull request #1982 from CliMA/ck/inference_repro2
Browse files Browse the repository at this point in the history
Add a broken inference test for field broadcasting
  • Loading branch information
charleskawczynski authored Sep 11, 2024
2 parents 42cd28b + 1f093c5 commit 4369a5a
Showing 1 changed file with 56 additions and 20 deletions.
76 changes: 56 additions & 20 deletions test/Fields/field_opt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#=
julia --project
using Revise; include(joinpath("test", "Fields", "field_opt.jl"))
=#
# These tests require running with `--check-bounds=[auto|no]`
using Test
using StaticArrays, IntervalSets
Expand Down Expand Up @@ -307,27 +311,28 @@ end
end

# https://github.com/CliMA/ClimaCore.jl/issues/1062
function toy_sphere(::Type{FT}) where {FT}
context = ClimaComms.context()
helem = npoly = 2
hdomain = Domains.SphereDomain(FT(1e7))
hmesh = Meshes.EquiangularCubedSphere(hdomain, helem)
htopology = Topologies.Topology2D(context, hmesh)
quad = Quadratures.GLL{npoly + 1}()
hspace = Spaces.SpectralElementSpace2D(htopology, quad)
vdomain = Domains.IntervalDomain(
Geometry.ZPoint{FT}(zero(FT)),
Geometry.ZPoint{FT}(FT(1e4));
boundary_names = (:bottom, :top),
)
vmesh = Meshes.IntervalMesh(vdomain, nelems = 4)
vtopology = Topologies.IntervalTopology(context, vmesh)
vspace = Spaces.CenterFiniteDifferenceSpace(vtopology)
center_space = Spaces.ExtrudedFiniteDifferenceSpace(hspace, vspace)
face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space)
return (center_space, face_space)
end

@testset "Allocations with copyto! on FieldVectors" begin
function toy_sphere(::Type{FT}) where {FT}
context = ClimaComms.context()
helem = npoly = 2
hdomain = Domains.SphereDomain(FT(1e7))
hmesh = Meshes.EquiangularCubedSphere(hdomain, helem)
htopology = Topologies.Topology2D(context, hmesh)
quad = Quadratures.GLL{npoly + 1}()
hspace = Spaces.SpectralElementSpace2D(htopology, quad)
vdomain = Domains.IntervalDomain(
Geometry.ZPoint{FT}(zero(FT)),
Geometry.ZPoint{FT}(FT(1e4));
boundary_names = (:bottom, :top),
)
vmesh = Meshes.IntervalMesh(vdomain, nelems = 4)
vtopology = Topologies.IntervalTopology(context, vmesh)
vspace = Spaces.CenterFiniteDifferenceSpace(vtopology)
center_space = Spaces.ExtrudedFiniteDifferenceSpace(hspace, vspace)
face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space)
return (center_space, face_space)
end
function field_vec(center_space, face_space)
Y = Fields.FieldVector(
c = map(Fields.coordinate_field(center_space)) do coord
Expand Down Expand Up @@ -357,4 +362,35 @@ end
palloc = @allocated foo!(obj)
@test palloc == 0
end

struct VarTimescaleAcnv{FT}
τ::FT
α::FT
end
Base.broadcastable(x::VarTimescaleAcnv) = tuple(x)
function conv_q_liq_to_q_rai(
(; τ, α)::VarTimescaleAcnv{FT},
q_liq::FT,
ρ::FT,
N_d::FT,
) where {FT}
return max(0, q_liq) / (1 * (N_d / 1e8)^1)
end
function ifelsekernel!(Sᵖ, ρ)
var = VarTimescaleAcnv(1.0, 2.0)
@. Sᵖ = ifelse(false, 1.0, conv_q_liq_to_q_rai(var, 2.0, ρ, 2.0))
return nothing
end

using JET
# https://github.com/CliMA/ClimaCore.jl/issues/1981
# TODO: improve the testset name once we better under
@testset "ifelse kernel" begin
(cspace, fspace) = toy_sphere(Float64)
ρ = Fields.Field(Float64, cspace)
S = Fields.Field(Float64, cspace)
ifelsekernel!(S, ρ)
@test_opt broken = true ifelsekernel!(S, ρ)
end

nothing

0 comments on commit 4369a5a

Please sign in to comment.