Skip to content

Commit

Permalink
Use recursion to fix inference failure
Browse files Browse the repository at this point in the history
wip
  • Loading branch information
charleskawczynski committed Apr 15, 2024
1 parent d352572 commit 4f18ef6
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 12 deletions.
2 changes: 0 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,6 @@ steps:
- label: "Unit: matrix field broadcasting (CPU)"
key: unit_matrix_field_broadcasting_cpu_scalar_14
command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/matrix_fields_broadcasting/test_scalar_14.jl"
soft_fail: true

- label: "Unit: matrix field broadcasting (CPU)"
key: unit_matrix_field_broadcasting_cpu_scalar_15
Expand Down Expand Up @@ -819,7 +818,6 @@ steps:
- label: "Unit: matrix field broadcasting (GPU)"
key: unit_matrix_field_broadcasting_gpu_scalar_14
command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/matrix_fields_broadcasting/test_scalar_14.jl"
soft_fail: true
agents:
slurm_gpus: 1
slurm_mem: 10GB
Expand Down
55 changes: 45 additions & 10 deletions src/MatrixFields/matrix_multiplication.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,34 @@ boundary_modified_ud(_, ud, column_space, i) = ud
boundary_modified_ud(::BottomRightMatrixCorner, ud, column_space, i) =
min(Operators.right_idx(column_space) - i, ud)

@inline matrix_rows(ld1, ud1, args) =
_matrix_rows(ntuple(i -> i + ld1 - 1, ud1 - ld1 + 1), args)
Base.@propagate_inbounds function _matrix_rows(
d::Union{Int, PlusHalf{Int}},
args,
)
(
bc,
boundary_modified_ld1,
boundary_modified_ud1,
space,
matrix2,
loc,
idx,
hidx,
) = args
if isnothing(bc) || boundary_modified_ld1 <= d <= boundary_modified_ud1
@inbounds Operators.getidx(space, matrix2, loc, idx + d, hidx)
else
zero(eltype(matrix2)) # This row is outside the matrix.
end
end
Base.@propagate_inbounds _matrix_rows(tup::Tuple, args) =
(_matrix_rows(first(tup), args), _matrix_rows(Base.tail(tup), args)...)
Base.@propagate_inbounds _matrix_rows(tup::Tuple{<:Any}, args) =
(_matrix_rows(first(tup), args),)
@inline _matrix_rows(tup::Tuple{}, args) = ()

# TODO: Use @propagate_inbounds here, and remove @inbounds from this function.
# As of Julia 1.8, doing this increases compilation time by more than an order
# of magnitude, and it also makes type inference fail for some complicated
Expand Down Expand Up @@ -341,16 +369,20 @@ function multiply_matrix_at_index(loc, space, idx, hidx, matrix1, arg, bc)
# of as a map from boundary_modified_ld1 to boundary_modified_ud1. For
# simplicity, use zero padding for rows that are outside the matrix.
# Wrap the rows in a BandMatrixRow so that they can be easily indexed.
matrix2_rows = map((ld1:ud1...,)) do d
# TODO: Use @propagate_inbounds_meta instead of @inline_meta.
Base.@_inline_meta
if isnothing(bc) ||
boundary_modified_ld1 <= d <= boundary_modified_ud1
@inbounds Operators.getidx(space, matrix2, loc, idx + d, hidx)
else
zero(eltype(matrix2)) # This row is outside the matrix.
end
end
matrix2_rows = matrix_rows(
ld1,
ud1,
(
bc,
boundary_modified_ld1,
boundary_modified_ud1,
space,
matrix2,
loc,
idx,
hidx,
),
)
matrix2_rows_wrapper = BandMatrixRow{ld1}(matrix2_rows...)

# Precompute the zero value to avoid inference issues caused by passing
Expand Down Expand Up @@ -443,4 +475,7 @@ if hasfield(Method, :recursion_relation)
for m in methods(multiply_matrix_at_index)
m.recursion_relation = dont_limit
end
for m in methods(matrix_rows)
m.recursion_relation = dont_limit
end
end

0 comments on commit 4f18ef6

Please sign in to comment.