Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed May 11, 2024
1 parent 7c595c7 commit 9b3e67b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 12 deletions.
11 changes: 9 additions & 2 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,16 @@ Base.@propagate_inbounds function rcopyto_at!(
v,
)
dest, bc = pair.first, pair.second
if v <= size(dest, 4)
if 1 v <= size(dest, 4)
dest[I] = isascalar(bc) ? bc[] : bc[I]
end
return nothing
end
Base.@propagate_inbounds function rcopyto_at!(pair::Pair{<:DataF, <:Any}, I, v)
dest, bc = pair.first, pair.second
if 1 v <= size(dest, 4)
bcI = isascalar(bc) ? bc[] : bc[I]
dest[I] = bcI
dest[] = bcI
end
return nothing
end
Expand Down
15 changes: 14 additions & 1 deletion src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -611,8 +611,21 @@ function Base.copyto!(
fmbc::FusedMultiBroadcast{T},
) where {N, T <: NTuple{N, Pair{<:AbstractData, <:Any}}}
dest1 = first(fmbc.pairs).first
fmb_inst = FusedMultiBroadcast(
map(fmbc.pairs) do pair
bc = pair.second
bc′ = if isascalar(bc)
Base.Broadcast.instantiate(
Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()),
)
else
bc
end
Pair(pair.first, bc′)
end,
)
# check_fused_broadcast_axes(fmbc) # we should already have checked the axes
fused_copyto!(fmbc, dest1, ClimaComms.device(dest1))
fused_copyto!(fmb_inst, dest1, ClimaComms.device(dest1))
end

function fused_copyto!(
Expand Down
9 changes: 1 addition & 8 deletions src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,7 @@ function Base.copyto!(
fmb_data = FusedMultiBroadcast(
map(fmbc.pairs) do pair
bc = Base.Broadcast.instantiate(todata(pair.second))
bc′ = if isascalar(bc)
Base.Broadcast.instantiate(
Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()),
)
else
bc
end
Pair(field_values(pair.first), bc′)
Pair(field_values(pair.first), bc)
end,
)
check_mismatched_spaces(fmbc)
Expand Down
11 changes: 10 additions & 1 deletion test/Fields/field_multi_broadcast_fusion.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#=
julia --check-bounds=yes --project=test
julia -g2 --check-bounds=yes --project=test
julia --project=test
using Revise; include(joinpath("test", "Fields", "field_multi_broadcast_fusion.jl"))
=#
Expand Down Expand Up @@ -39,6 +40,8 @@ if !(@isdefined(TU))
import .TestUtilities as TU
end

@show ClimaComms.device()

function CenterExtrudedFiniteDifferenceSpaceLineHSpace(
::Type{FT};
zelem = 10,
Expand Down Expand Up @@ -338,7 +341,13 @@ end
FT = Float64
device = ClimaComms.device()
ArrayType = device isa ClimaComms.CUDADevice ? CuArray : Array
VF_data() = VF{FT}(ArrayType(ones(FT, 3, 2)))
colspace = TU.ColumnCenterFiniteDifferenceSpace(
FT;
zelem = 3,
context = ClimaComms.context(device),
)
VF_data() = Fields.Field(FT, colspace)

X = Fields.FieldVector(; x1 = VF_data(), x2 = VF_data(), x3 = VF_data())
Y = Fields.FieldVector(; y1 = VF_data(), y2 = VF_data(), y3 = VF_data())
test_kernel!(; fused!, unfused!, X, Y)
Expand Down

0 comments on commit 9b3e67b

Please sign in to comment.