Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
partial fix to linear solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 30, 2024
1 parent 68b226a commit de88368
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 23 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BatchedRoutinesCUDAExt = ["CUDA"]
BatchedRoutinesFiniteDiffExt = ["FiniteDiff"]
BatchedRoutinesForwardDiffExt = ["ForwardDiff"]
BatchedRoutinesLinearSolveExt = ["LinearSolve"]
BatchedRoutinesReverseDiffExt = ["ReverseDiff"]
BatchedRoutinesZygoteExt = ["Zygote"]

Expand Down
128 changes: 128 additions & 0 deletions ext/BatchedRoutinesLinearSolveExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
module BatchedRoutinesLinearSolveExt

using ArrayInterface: ArrayInterface
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, getdata
using LinearAlgebra: LinearAlgebra
using LinearSolve: LinearSolve

# Overload LinearProblem, else causing problems in the adjoint code
function LinearSolve.LinearProblem(op::UniformBlockDiagonalOperator, b, args...; kwargs...)
return LinearSolve.LinearProblem{true}(op, b, args...; kwargs...)
end

# Default Algorithm
function LinearSolve.defaultalg(
op::UniformBlockDiagonalOperator, b, assump::LinearSolve.OperatorAssumptions{Bool})
alg = if assump.issq
LinearSolve.DefaultAlgorithmChoice.LUFactorization
elseif assump.condition === LinearSolve.OperatorCondition.WellConditioned
LinearSolve.DefaultAlgorithmChoice.NormalCholeskyFactorization
elseif assump.condition === LinearSolve.OperatorCondition.IllConditioned
if LinearSolve.is_underdetermined(op)
LinearSolve.DefaultAlgorithmChoice.QRFactorizationPivoted
else
LinearSolve.DefaultAlgorithmChoice.QRFactorization
end
elseif assump.condition === LinearSolve.OperatorCondition.VeryIllConditioned
if LinearSolve.is_underdetermined(op)
LinearSolve.DefaultAlgorithmChoice.QRFactorizationPivoted
else
LinearSolve.DefaultAlgorithmChoice.QRFactorization
end
elseif assump.condition === LinearSolve.OperatorCondition.SuperIllConditioned
LinearSolve.DefaultAlgorithmChoice.SVDFactorization
else
error("Special factorization not handled in current default algorithm.")
end
return LinearSolve.DefaultLinearSolver(alg)
end

# GenericLUFactorization
function LinearSolve.init_cacheval(alg::LinearSolve.GenericLUFactorization,
A::UniformBlockDiagonalOperator, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions)
A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1))
return LinearAlgebra.generic_lufact!(A_, alg.pivot; check=false)
end

function LinearSolve.do_factorization(
alg::LinearSolve.GenericLUFactorization, A::UniformBlockDiagonalOperator, b, u)
return LinearAlgebra.generic_lufact!(A, alg.pivot; check=false)
end

# LUFactorization
function LinearSolve.init_cacheval(
alg::LinearSolve.LUFactorization, A::UniformBlockDiagonalOperator, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions)
A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1))
return LinearAlgebra.lu!(A_, alg.pivot; check=false)
end

function LinearSolve.do_factorization(
alg::LinearSolve.LUFactorization, A::UniformBlockDiagonalOperator, b, u)
return LinearAlgebra.lu!(A, alg.pivot; check=false)
end

# QRFactorization
function LinearSolve.init_cacheval(
alg::LinearSolve.QRFactorization, A::UniformBlockDiagonalOperator, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions)
A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1))
return LinearAlgebra.qr!(A_, alg.pivot)
end

function LinearSolve.do_factorization(
alg::LinearSolve.QRFactorization, A::UniformBlockDiagonalOperator, b, u)
alg.inplace && return LinearAlgebra.qr!(A, alg.pivot)
return LinearAlgebra.qr(A, alg.pivot)
end

# CholeskyFactorization
function LinearSolve.init_cacheval(alg::LinearSolve.CholeskyFactorization,
A::UniformBlockDiagonalOperator, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions)
A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1))
return ArrayInterface.cholesky_instance(A_, alg.pivot)
end

function LinearSolve.do_factorization(
alg::LinearSolve.CholeskyFactorization, A::UniformBlockDiagonalOperator, b, u)
return LinearAlgebra.cholesky!(A, alg.pivot; check=false)
end

# NormalCholeskyFactorization
function LinearSolve.init_cacheval(alg::LinearSolve.NormalCholeskyFactorization,
A::UniformBlockDiagonalOperator, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions)
A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1))
return ArrayInterface.cholesky_instance(A_, alg.pivot)
end

function LinearSolve.solve!(cache::LinearSolve.LinearCache{<:UniformBlockDiagonalOperator},
alg::LinearSolve.NormalCholeskyFactorization; kwargs...)
A = cache.A
if cache.isfresh
fact = LinearAlgebra.cholesky!(A' * A, alg.pivot; check=false)
cache.cacheval = fact
cache.isfresh = false
end
y = LinearAlgebra.ldiv!(
cache.u, LinearSolve.@get_cacheval(cache, :NormalCholeskyFactorization),
A' * cache.b)
return LinearSolve.SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

# SVDFactorization
function LinearSolve.init_cacheval(
alg::LinearSolve.SVDFactorization, A::UniformBlockDiagonalOperator, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions)
A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1))
return ArrayInterface.svd_instance(A_)
end

function LinearSolve.do_factorization(
alg::LinearSolve.SVDFactorization, A::UniformBlockDiagonalOperator, b, u)
return LinearAlgebra.svd!(A; alg.full, alg.alg)
end

end
18 changes: 7 additions & 11 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,20 @@ end
# batched_mul rrule
function CRC.rrule(::typeof(_batched_mul), A::AbstractArray{T1, 3},
B::AbstractArray{T2, 3}) where {T1, T2}
function ∇batched_mul(_Δ)
∇batched_mul = @closure-> begin
Δ = CRC.unthunk(_Δ)
∂A = CRC.@thunk begin
tmp = batched_mul(Δ, batched_adjoint(B))
size(A, 3) == 1 ? sum(tmp; dims=3) : tmp
end
∂B = CRC.@thunk begin
tmp = batched_mul(batched_adjoint(A), Δ)
size(B, 3) == 1 ? sum(tmp; dims=3) : tmp
end
tmpA = batched_mul(Δ, batched_adjoint(B))
∂A = size(A, 3) == 1 ? sum(tmpA; dims=3) : tmpA
tmpB = batched_mul(batched_adjoint(A), Δ)
∂B = size(B, 3) == 1 ? sum(tmpB; dims=3) : tmpB
return (NoTangent(), ∂A, ∂B)
end
return batched_mul(A, B), ∇batched_mul
end

# constructor
function CRC.rrule(::Type{<:UniformBlockDiagonalOperator}, data)
function ∇UniformBlockDiagonalOperator(Δ)
∇UniformBlockDiagonalOperator = @closure Δ -> begin
∂data = Δ isa UniformBlockDiagonalOperator ? getdata(Δ) :
isa NoTangent ? NoTangent() : Δ)
return (NoTangent(), ∂data)
Expand All @@ -113,7 +109,7 @@ end

function CRC.rrule(::typeof(getproperty), op::UniformBlockDiagonalOperator, x::Symbol)
@assert x === :data
∇getproperty(Δ) = (NoTangent(), UniformBlockDiagonalOperator(Δ))
∇getproperty = @closure Δ -> (NoTangent(), UniformBlockDiagonalOperator(Δ))
return op.data, ∇getproperty
end

Expand Down
17 changes: 9 additions & 8 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,19 @@ function Base.show(io::IO, mime::MIME"text/plain", F::GenericBatchedFactorizatio
show(io, mime, first(F.fact))
end

for fact in (:qr, :lu, :cholesky)
for fact in (:qr, :lu, :cholesky, :generic_lufact, :svd)
fact! = Symbol(fact, :!)
@eval begin
function LinearAlgebra.$(fact)(op::UniformBlockDiagonalOperator, args...; kwargs...)
if isdefined(LinearAlgebra, fact)
@eval function LinearAlgebra.$(fact)(
op::UniformBlockDiagonalOperator, args...; kwargs...)
return LinearAlgebra.$(fact!)(copy(op), args...; kwargs...)
end
end

function LinearAlgebra.$(fact!)(
op::UniformBlockDiagonalOperator, args...; kwargs...)
fact = map(Aᵢ -> LinearAlgebra.$(fact!)(Aᵢ, args...; kwargs...), batchview(op))
return GenericBatchedFactorization(LinearAlgebra.$(fact!), fact)
end
@eval function LinearAlgebra.$(fact!)(
op::UniformBlockDiagonalOperator, args...; kwargs...)
fact = map(Aᵢ -> LinearAlgebra.$(fact!)(Aᵢ, args...; kwargs...), batchview(op))
return GenericBatchedFactorization(LinearAlgebra.$(fact!), fact)
end
end

Expand Down
22 changes: 18 additions & 4 deletions test/integration_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testitem "LinearSolve" setup=[SharedTestSetup] begin
using FiniteDiff, LinearSolve, Zygote
using FiniteDiff, LinearAlgebra, LinearSolve, Zygote

rng = get_stable_rng(1001)

Expand All @@ -13,12 +13,26 @@
prob2 = LinearProblem(A2, b)

if dims[1] == dims[2]
solvers = [LUFactorization(), QRFactorization(), KrylovJL_GMRES()]
solvers = [LUFactorization(), QRFactorization(),
KrylovJL_GMRES(), svd_factorization(mode), nothing]
else
solvers = [QRFactorization(), KrylovJL_LSMR()]
solvers = [
QRFactorization(), KrylovJL_LSMR(), NormalCholeskyFactorization(),
QRFactorization(LinearAlgebra.ColumnNorm()),
svd_factorization(mode), nothing]
end

@testset "solver: $(solver)" for solver in solvers
@testset "solver: $(nameof(typeof(solver)))" for solver in solvers
# FIXME: SVD doesn't define ldiv on CUDA side
if mode == "CUDA"
@show solver, solver isa SVDFactorization
if solver isa SVDFactorization || (solver isa QRFactorization &&
solver.pivot isa LinearAlgebra.ColumnNorm)
# ColumnNorm is not implemented on CUDA
continue
end
end

x1 = solve(prob1, solver)
x2 = solve(prob2, solver)
@test x1.u x2.u
Expand Down
11 changes: 11 additions & 0 deletions test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,18 @@ end

get_stable_rng(seed=12345) = StableRNG(seed)

# SVD Helper till https://github.com/SciML/LinearSolve.jl/issues/488 is resolved
using LinearSolve: LinearSolve

function svd_factorization(mode)
mode == "CPU" && return LinearSolve.SVDFactorization()
mode == "CUDA" &&
return LinearSolve.SVDFactorization(true, CUDA.CUSOLVER.JacobiAlgorithm())
error("Unsupported mode: $mode")
end

export @jet, @test_gradients, check_approx
export GROUP, MODES, cpu_testing, cuda_testing, get_default_rng, get_stable_rng
export svd_factorization

end

0 comments on commit de88368

Please sign in to comment.