diff --git a/test/Fields/inference_repro.jl b/test/Fields/inference_repro.jl index c8f9278bd2..bb47476ecf 100644 --- a/test/Fields/inference_repro.jl +++ b/test/Fields/inference_repro.jl @@ -4,15 +4,25 @@ import ClimaComms ClimaComms.@import_required_backends import ClimaCore: Fields, Domains, Geometry, Meshes, Spaces +macro ConstantValue(T) + quote + Base.broadcastable(x::$T) = x + Base.axes(x::$T) = () + Base.getindex(x::$T, i...) = x + Base.getindex(x::$T, i) = x + Base.ndims(x::Type{<:$T}) = 0 + end +end + struct LandParameters{FT} ρ_cloud_ice::FT end -Base.broadcastable(x::LandParameters) = tuple(x) +@ConstantValue LandParameters struct vanGenuchten{FT} α::FT end -Base.broadcastable(x::vanGenuchten) = tuple(x) +@ConstantValue vanGenuchten function phase_change_source( θ_l::FT, @@ -64,16 +74,5 @@ end using Test @testset "GPU inference failure" begin - if ClimaComms.device() isa ClimaComms.CUDADevice - @test_broken try - main(Float64) - true - catch e - @assert occursin("GPUCompiler.InvalidIRError", string(e)) - @assert occursin("dynamic function invocation", e.errors[1][1]) - false - end - else - main(Float64) - end + main(Float64) end