diff --git a/ext/cuda/data_layouts.jl b/ext/cuda/data_layouts.jl index 2fede7993e..0c1c00013f 100644 --- a/ext/cuda/data_layouts.jl +++ b/ext/cuda/data_layouts.jl @@ -195,9 +195,16 @@ Base.@propagate_inbounds function rcopyto_at!( v, ) dest, bc = pair.first, pair.second - if v <= size(dest, 4) + if 1 ≤ v <= size(dest, 4) + dest[I] = isascalar(bc) ? bc[] : bc[I] + end + return nothing +end +Base.@propagate_inbounds function rcopyto_at!(pair::Pair{<:DataF, <:Any}, I, v) + dest, bc = pair.first, pair.second + if 1 ≤ v <= size(dest, 4) bcI = isascalar(bc) ? bc[] : bc[I] - dest[I] = bcI + dest[] = bcI end return nothing end diff --git a/src/DataLayouts/broadcast.jl b/src/DataLayouts/broadcast.jl index 6c6e10d66c..e6add579d0 100644 --- a/src/DataLayouts/broadcast.jl +++ b/src/DataLayouts/broadcast.jl @@ -611,8 +611,21 @@ function Base.copyto!( fmbc::FusedMultiBroadcast{T}, ) where {N, T <: NTuple{N, Pair{<:AbstractData, <:Any}}} dest1 = first(fmbc.pairs).first + fmb_inst = FusedMultiBroadcast( + map(fmbc.pairs) do pair + bc = pair.second + bc′ = if isascalar(bc) + Base.Broadcast.instantiate( + Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()), + ) + else + bc + end + Pair(pair.first, bc′) + end, + ) # check_fused_broadcast_axes(fmbc) # we should already have checked the axes - fused_copyto!(fmbc, dest1, ClimaComms.device(dest1)) + fused_copyto!(fmb_inst, dest1, ClimaComms.device(dest1)) end function fused_copyto!( diff --git a/src/Fields/broadcast.jl b/src/Fields/broadcast.jl index c5ef39dfa2..b5c3f0cfcd 100644 --- a/src/Fields/broadcast.jl +++ b/src/Fields/broadcast.jl @@ -157,14 +157,7 @@ function Base.copyto!( fmb_data = FusedMultiBroadcast( map(fmbc.pairs) do pair bc = Base.Broadcast.instantiate(todata(pair.second)) - bc′ = if isascalar(bc) - Base.Broadcast.instantiate( - Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()), - ) - else - bc - end - Pair(field_values(pair.first), bc′) + Pair(field_values(pair.first), bc) end, ) check_mismatched_spaces(fmbc) diff --git a/test/Fields/field_multi_broadcast_fusion.jl b/test/Fields/field_multi_broadcast_fusion.jl index 830bb955b2..ea601a9a95 100644 --- a/test/Fields/field_multi_broadcast_fusion.jl +++ b/test/Fields/field_multi_broadcast_fusion.jl @@ -1,5 +1,6 @@ #= julia --check-bounds=yes --project=test +julia -g2 --check-bounds=yes --project=test julia --project=test using Revise; include(joinpath("test", "Fields", "field_multi_broadcast_fusion.jl")) =# @@ -39,6 +40,8 @@ if !(@isdefined(TU)) import .TestUtilities as TU end +@show ClimaComms.device() + function CenterExtrudedFiniteDifferenceSpaceLineHSpace( ::Type{FT}; zelem = 10, @@ -338,7 +341,13 @@ end FT = Float64 device = ClimaComms.device() ArrayType = device isa ClimaComms.CUDADevice ? CuArray : Array - VF_data() = VF{FT}(ArrayType(ones(FT, 3, 2))) + colspace = TU.ColumnCenterFiniteDifferenceSpace( + FT; + zelem = 3, + context = ClimaComms.context(device), + ) + VF_data() = Fields.Field(FT, colspace) + X = Fields.FieldVector(; x1 = VF_data(), x2 = VF_data(), x3 = VF_data()) Y = Fields.FieldVector(; y1 = VF_data(), y2 = VF_data(), y3 = VF_data()) test_kernel!(; fused!, unfused!, X, Y)