Skip to content

Commit

Permalink
Merge pull request #436 from avik-pal/ap/defaults
Browse files Browse the repository at this point in the history
Support StaticArrays Properly
  • Loading branch information
ChrisRackauckas authored Dec 6, 2023
2 parents 9787717 + 21559ed commit 98a377d
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 10 deletions.
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.20.0"
version = "2.20.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand All @@ -26,6 +26,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
Expand All @@ -48,14 +49,14 @@ LinearSolveBandedMatricesExt = "BandedMatrices"
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveEnzymeExt = ["Enzyme", "EnzymeCore"]
LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"]
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
LinearSolveKrylovKitExt = "KrylovKit"
LinearSolveMetalExt = "Metal"
LinearSolvePardisoExt = "Pardiso"
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"]

[compat]
Aqua = "0.8"
Expand Down Expand Up @@ -101,6 +102,8 @@ SciMLOperators = "0.3"
Setfield = "1"
SparseArrays = "1.9"
Sparspak = "0.3.6"
StaticArraysCore = "1"
StaticArrays = "1"
Test = "1"
UnPack = "1"
julia = "1.9"
Expand All @@ -126,7 +129,8 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices"]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays"]
2 changes: 2 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ PrecompileTools.@recompile_invalidations begin
using Requires
import InteractiveUtils

import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix

using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing,
chkfinite, chkstride1,
Expand Down
7 changes: 7 additions & 0 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ function defaultalg(A, b, assump::OperatorAssumptions)
DefaultAlgorithmChoice.LUFactorization
end

# For static arrays GMRES allocates a lot. Use factorization
elseif A isa StaticArray
DefaultAlgorithmChoice.LUFactorization

# This catches the cases where a factorization overload could exist
# For example, BlockBandedMatrix
elseif A !== nothing && ArrayInterface.isstructured(A)
Expand All @@ -186,6 +190,9 @@ function defaultalg(A, b, assump::OperatorAssumptions)
end
elseif assump.condition === OperatorCondition.WellConditioned
DefaultAlgorithmChoice.NormalCholeskyFactorization
elseif A isa StaticArray
# Static Array doesn't have QR() \ b defined
DefaultAlgorithmChoice.SVDFactorization
elseif assump.condition === OperatorCondition.IllConditioned
if is_underdetermined(A)
# Underdetermined
Expand Down
42 changes: 36 additions & 6 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ end

_ldiv!(x, A, b) = ldiv!(x, A, b)

_ldiv!(x, A, b::SVector) = (x .= A \ b)
_ldiv!(::SVector, A, b::SVector) = (A \ b)
_ldiv!(::SVector, A, b) = (A \ b)

function _ldiv!(x::Vector, A::Factorization, b::Vector)
# workaround https://github.com/JuliaLang/julia/issues/43507
# Fallback if working with non-square matrices
Expand Down Expand Up @@ -74,6 +78,8 @@ function do_factorization(alg::LUFactorization, A, b, u)
if A isa AbstractSparseMatrixCSC
return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
elseif !ArrayInterface.can_setindex(typeof(A))
fact = lu(A, alg.pivot, check = false)
else
fact = lu!(A, alg.pivot, check = false)
end
Expand Down Expand Up @@ -136,10 +142,14 @@ end

function do_factorization(alg::QRFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
fact = qr!(A, alg.pivot)
if ArrayInterface.can_setindex(typeof(A))
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
fact = qr!(A, alg.pivot)
else
fact = qr(A) # CUDA.jl does not allow other args!
end
else
fact = qr(A) # CUDA.jl does not allow other args!
fact = qr(A, alg.pivot)
end
return fact
end
Expand Down Expand Up @@ -202,6 +212,16 @@ function do_factorization(alg::CholeskyFactorization, A, b, u)
return fact
end

function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions) where {S1, S2}
# StaticArrays doesn't have the pivot argument. Prevent generic fallback.
# CholeskyFactorization is part of DefaultLinearSolver, so it is possible that `A` is
# not Hermitian.
(!issquare(A) || !ishermitian(A)) && return nothing
cholesky(A)
end

function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
Expand Down Expand Up @@ -276,11 +296,15 @@ SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer())

function do_factorization(alg::SVDFactorization, A, b, u)
A = convert(AbstractMatrix, A)
fact = svd!(A; full = alg.full, alg = alg.alg)
if ArrayInterface.can_setindex(typeof(A))
fact = svd!(A; alg.full, alg.alg)
else
fact = svd(A; alg.full)
end
return fact
end

function init_cacheval(alg::SVDFactorization, A::Matrix, b, u, Pl, Pr,
function init_cacheval(alg::SVDFactorization, A::Union{Matrix, SMatrix}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ArrayInterface.svd_instance(convert(AbstractMatrix, A))
Expand Down Expand Up @@ -882,7 +906,8 @@ end
function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
A_ = convert(AbstractMatrix, A)
ArrayInterface.cholesky_instance(Symmetric((A)' * A, :L), alg.pivot)
end

function init_cacheval(alg::NormalCholeskyFactorization,
Expand Down Expand Up @@ -1128,6 +1153,11 @@ function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int,
end
end

function init_cacheval(::SparspakFactorization, ::StaticArray, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs...)
A = cache.A
if cache.isfresh
Expand Down
2 changes: 1 addition & 1 deletion test/default_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ solve(prob)
A = rand(4, 4)
b = rand(4)
prob = LinearProblem(A, b)
JET.@test_opt init(prob, nothing)
VERSION v"1.10-" && JET.@test_opt init(prob, nothing)
JET.@test_opt solve(prob, LUFactorization())
JET.@test_opt solve(prob, GenericLUFactorization())
@test_skip JET.@test_opt solve(prob, QRFactorization())
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ if GROUP == "All" || GROUP == "Core"
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
@time @safetestset "Traits" include("traits.jl")
@time @safetestset "BandedMatrices" include("banded.jl")
@time @safetestset "Static Arrays" include("static_arrays.jl")
end

if GROUP == "LinearSolveCUDA"
Expand Down
24 changes: 24 additions & 0 deletions test/static_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using LinearSolve, StaticArrays, LinearAlgebra

A = SMatrix{5, 5}(Hermitian(rand(5, 5) + I))
b = SVector{5}(rand(5))

for alg in (nothing, LUFactorization(), SVDFactorization(), CholeskyFactorization(),
KrylovJL_GMRES())
sol = solve(LinearProblem(A, b), alg)
@test norm(A * sol .- b) < 1e-10
end

A = SMatrix{7, 5}(rand(7, 5))
b = SVector{7}(rand(7))

for alg in (nothing, SVDFactorization(), KrylovJL_LSMR())
@test_nowarn solve(LinearProblem(A, b), alg)
end

A = SMatrix{5, 7}(rand(5, 7))
b = SVector{5}(rand(5))

for alg in (nothing, SVDFactorization(), KrylovJL_LSMR())
@test_nowarn solve(LinearProblem(A, b), alg)
end

0 comments on commit 98a377d

Please sign in to comment.