Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference failure in broadcast expression #1981

Closed
charleskawczynski opened this issue Sep 11, 2024 · 1 comment · Fixed by #1984
Closed

Inference failure in broadcast expression #1981

charleskawczynski opened this issue Sep 11, 2024 · 1 comment · Fixed by #1984
Assignees
Labels
bug Something isn't working performance

Comments

@charleskawczynski
Copy link
Member

Found in CliMA/ClimaAtmos.jl#3290:

ci/src/parameterized_tendencies/microphysics/microphysics_wrappers.jl:176
--
  | │││┌ materialize!(dest::ClimaCore.Fields.Field{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:911
  | ││││┌ materialize!(::ClimaCore.Fields.FieldStyle{…}, dest::ClimaCore.Fields.Field{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:914
  | │││││┌ copyto!(dest::ClimaCore.Fields.Field{…}, bc::Base.Broadcast.Broadcasted{…}) @ ClimaCore.Fields /central/scratch/esm/slurm-buildkite/climaatmos-ci/depot/default/packages/ClimaCore/HTRMg/src/Fields/broadcast.jl:149
  | ││││││┌ copyto!(dest::ClimaCore.DataLayouts.VIJFH{…}, bc::Base.Broadcast.Broadcasted{…}) @ ClimaCore.DataLayouts /central/scratch/esm/slurm-buildkite/climaatmos-ci/depot/default/packages/ClimaCore/HTRMg/src/DataLayouts/copyto.jl:5
  | │││││││┌ copyto!(dest::ClimaCore.DataLayouts.VIJFH{…}, bc::Base.Broadcast.Broadcasted{…}, ::ClimaCore.DataLayouts.ToCPU) @ ClimaCore.DataLayouts /central/scratch/esm/slurm-buildkite/climaatmos-ci/depot/default/packages/ClimaCore/HTRMg/src/DataLayouts/copyto.jl:148
  | ││││││││┌ copyto!(dest::ClimaCore.DataLayouts.VF{…}, bc::Base.Broadcast.Broadcasted{…}) @ ClimaCore.DataLayouts /central/scratch/esm/slurm-buildkite/climaatmos-ci/depot/default/packages/ClimaCore/HTRMg/src/DataLayouts/copyto.jl:5
  | │││││││││┌ copyto!(dest::ClimaCore.DataLayouts.VF{…}, bc::Base.Broadcast.Broadcasted{…}, ::ClimaCore.DataLayouts.ToCPU) @ ClimaCore.DataLayouts /central/scratch/esm/slurm-buildkite/climaatmos-ci/depot/default/packages/ClimaCore/HTRMg/src/DataLayouts/copyto.jl:120
  | ││││││││││┌ getindex(bc::Base.Broadcast.Broadcasted{…}, I::CartesianIndex{…}) @ Base.Broadcast ./broadcast.jl:635
  | │││││││││││┌ checkbounds(bc::Base.Broadcast.Broadcasted{…}, I::CartesianIndex{…}) @ Base.Broadcast ./broadcast.jl:647
  | ││││││││││││┌ axes(bc::Base.Broadcast.Broadcasted{ClimaCore.DataLayouts.VFStyle{…}, Nothing, typeof(ifelse), Tuple{…}}) @ Base.Broadcast ./broadcast.jl:234
  | │││││││││││││┌ _axes(bc::Base.Broadcast.Broadcasted{ClimaCore.DataLayouts.VFStyle{…}, Nothing, typeof(ifelse), Tuple{…}}, ::Nothing) @ Base.Broadcast ./broadcast.jl:236
  | ││││││││││││││┌ combine_axes(::Base.Broadcast.Broadcasted{…}, ::Base.Broadcast.Broadcasted{…}, ::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:523
  | │││││││││││││││┌ combine_axes(A::Base.Broadcast.Broadcasted{…}, B::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:524
  | ││││││││││││││││┌ axes(bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:234
  | │││││││││││││││││┌ _axes(bc::Base.Broadcast.Broadcasted{…}, ::Nothing) @ Base.Broadcast ./broadcast.jl:236
  | ││││││││││││││││││ failed to optimize due to recursion: Base.Broadcast._axes(::Base.Broadcast.Broadcasted{…}, ::Nothing)
  | │││││││││││││││││└────────────────────
  | ││││││││││││││││┌ axes(bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:234
  | │││││││││││││││││ failed to optimize due to recursion: axes(::Base.Broadcast.Broadcasted{…})
  | ││││││││││││││││└────────────────────
  | │││││││││││││││┌ combine_axes(A::Base.Broadcast.Broadcasted{…}, B::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:524

This points to https://github.com/CliMA/ClimaAtmos.jl/blob/a22b643fc6a22dccd5a8b8d17d12222b114333eb/src/parameterized_tendencies/microphysics/microphysics_wrappers.jl#L176-L180. And it looks like we see the same issue with https://github.com/CliMA/ClimaAtmos.jl/blob/a22b643fc6a22dccd5a8b8d17d12222b114333eb/src/parameterized_tendencies/microphysics/microphysics_wrappers.jl#L254-L264.

I think that this will be a big performance hit since it's inside getindex on the broadcasted object, so this is pretty important to fix.

@charleskawczynski charleskawczynski added bug Something isn't working performance labels Sep 11, 2024
@charleskawczynski charleskawczynski self-assigned this Sep 11, 2024
@charleskawczynski
Copy link
Member Author

charleskawczynski commented Sep 11, 2024

Here is a reproducer:

using Test
using StaticArrays, IntervalSets
import ClimaCore
import ClimaComms
import ClimaCore.Utilities: PlusHalf, half
import ClimaCore.DataLayouts: IJFH
import ClimaCore:
    Fields,
    slab,
    Domains,
    Topologies,
    Meshes,
    Operators,
    Spaces,
    Geometry,
    Quadratures

using FastBroadcast
using LinearAlgebra: norm
using Statistics: mean
using ForwardDiff

include(
    joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"),
)
import .TestUtilities as TU
function toy_sphere(::Type{FT}) where {FT}
    context = ClimaComms.context()
    helem = npoly = 2
    hdomain = Domains.SphereDomain(FT(1e7))
    hmesh = Meshes.EquiangularCubedSphere(hdomain, helem)
    htopology = Topologies.Topology2D(context, hmesh)
    quad = Quadratures.GLL{npoly + 1}()
    hspace = Spaces.SpectralElementSpace2D(htopology, quad)
    vdomain = Domains.IntervalDomain(
        Geometry.ZPoint{FT}(zero(FT)),
        Geometry.ZPoint{FT}(FT(1e4));
        boundary_names = (:bottom, :top),
    )
    vmesh = Meshes.IntervalMesh(vdomain, nelems = 4)
    vtopology = Topologies.IntervalTopology(context, vmesh)
    vspace = Spaces.CenterFiniteDifferenceSpace(vtopology)
    center_space = Spaces.ExtrudedFiniteDifferenceSpace(hspace, vspace)
    face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space)
    return (center_space, face_space)
end

struct VarTimescaleAcnv{FT}
    τ::FT
    α::FT
end
Base.broadcastable(x::VarTimescaleAcnv) = tuple(x)
function conv_q_liq_to_q_rai(
    ::VarTimescaleAcnv{FT},
    q_liq::FT,
    ρ::FT,
    N_d::FT,
) where {FT}
    return max(0, q_liq) / (1 * (N_d / 1e8)^1)
end
function ifelsekernel!(Sᵖ, ρ)
    var = VarTimescaleAcnv(1.0, 2.0)
    @. Sᵖ = ifelse(false,1.0, conv_q_liq_to_q_rai(var, 2.0, ρ, 2.0))
    return nothing
end

using JET
# https://github.com/CliMA/ClimaCore.jl/issues/1981
@testset "ifelse kernel" begin
    (cspace, fspace) = toy_sphere(Float64)
    ρ = Fields.Field(Float64, cspace)
    S = Fields.Field(Float64, cspace)
    ifelsekernel!(S, ρ)
    @test_opt ifelsekernel!(S, ρ)
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working performance
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant