Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inference using promote_op in multiply_matrix_at_index #1683

Merged
merged 3 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,6 @@ steps:
- label: "Unit: bidiag matrix row example (GPU)"
key: gpu_compat_bidiag_matrix_row
command: "julia --color=yes --project=test test/MatrixFields/gpu_compat_bidiag_matrix_row.jl"
soft_fail: true
agents:
slurm_gpus: 1

Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ ClimaCore.jl Release Notes
main
-------

v0.13.4
-------

- ![][badge-🐛bugfix] We fixed some fieldvector broadcasting on Julia 1.9. PR [#1658](https://github.com/CliMA/ClimaCore.jl/pull/1658).
- ![][badge-🚀performance] We fixed an inference failure with matrix field broadcasting. PR [#1683](https://github.com/CliMA/ClimaCore.jl/pull/1683).

v0.13.3
-------
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ClimaCore"
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
authors = ["CliMA Contributors <[email protected]>"]
version = "0.13.3"
version = "0.13.4"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
23 changes: 2 additions & 21 deletions src/MatrixFields/matrix_multiplication.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,29 +270,11 @@ function Operators.right_interior_idx(
end
end

pick_inferred_type(
::Type{Union{}},
::Type{Y},
) where {Y <: Geometry.LocalGeometry} = Y
pick_inferred_type(
::Type{X},
::Type{Union{}},
) where {X <: Geometry.LocalGeometry} = X
pick_inferred_type(::Type{T}, ::Type{T}) where {T <: Geometry.LocalGeometry} = T
pick_inferred_type(::Type{Union{}}, ::Type{Union{}}) =
error("Both LGs are not inferred")
pick_inferred_type(::Type{X}, ::Type{Y}) where {X, Y} =
error("LGs do not match: X=$X, Y=$Y")

function Operators.return_eltype(
::MultiplyColumnwiseBandMatrixField,
matrix1,
arg,
)
# LG1 = local_geometry_type(typeof(axes(matrix1)))
# LG2 = local_geometry_type(typeof(axes(arg)))
# LG = pick_inferred_type(LG1, LG2)
# return Operators.return_eltype(op, matrix1, arg, LG)
eltype(matrix1) <: BandMatrixRow || error(
"The first argument of ⋅ must have elements of type BandMatrixRow, but \
the given argument has elements of type $(eltype(matrix1))",
Expand Down Expand Up @@ -365,9 +347,8 @@ boundary_modified_ud(::BottomRightMatrixCorner, ud, column_space, i) =
# matrix field broadcast expressions to take roughly 3 or 4 times longer to
# evaluate, but this is less significant than the decrease in compilation time.
function multiply_matrix_at_index(loc, space, idx, hidx, matrix1, arg, bc)
# lg = Geometry.LocalGeometry(space, idx, hidx)
# prod_type = Operators.return_eltype(⋅, matrix1, arg, typeof(lg))
prod_type = Operators.return_eltype(⋅, matrix1, arg)
lg = Geometry.LocalGeometry(space, idx, hidx)
prod_type = Operators.return_eltype(⋅, matrix1, arg, typeof(lg))

column_space1 = column_axes(matrix1, space)
ld1, ud1 = outer_diagonals(eltype(matrix1))
Expand Down
2 changes: 1 addition & 1 deletion test/MatrixFields/operator_matrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ end
ᶠuvw,
ᶜadvect(ᶠuvw, ᶜwinterp(ᶠscalar, ᶠrbias(ᶜlbias(ᶠinterp(ᶜnested))))),
)),
time_ratio_limit = 20, # This case's ref function is fast on Buildkite.
time_ratio_limit = 25, # This case's ref function is fast on Buildkite.
test_broken_with_cuda = true, # TODO: Fix this.
)

Expand Down
Loading