Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make use of Aliasing API for alias_A and alias_b #553

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.36.1"
version = "2.37.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -105,7 +105,7 @@ RecursiveArrayTools = "3.8"
RecursiveFactorization = "0.2.14"
Reexport = "1"
SafeTestsets = "0.1"
SciMLBase = "2.26.3"
SciMLBase = "2.58.0"
SciMLOperators = "0.3.7"
Setfield = "1"
SparseArrays = "1.10"
Expand Down
46 changes: 42 additions & 4 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ __init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltyp

function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
args...;
alias_A = default_alias_A(alg, prob.A, prob.b),
alias_b = default_alias_b(alg, prob.A, prob.b),
abstol = default_tol(real(eltype(prob.b))),
reltol = default_tol(real(eltype(prob.b))),
maxiters::Int = length(prob.b),
Expand All @@ -149,10 +147,50 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
Pr = nothing,
assumptions = OperatorAssumptions(issquare(prob.A)),
sensealg = LinearSolveAdjoint(),
alias = LinearAliasSpecifier(),
kwargs...)
(;A, b, u0, p) = prob

A = if alias_A || A isa SMatrix
has_A = haskey(kwargs,:alias_A)
has_b = haskey(kwargs,:alias_b)

if has_A || has_b
aliases = LinearAliasSpecifier()
if has_A
Base.depwarn("alias_A keyword argument is deprecated, to set `alias_A`,
please use a LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_A = true))", :alias_A)
SciMLBase.@reset aliases.alias_A = values(kwargs).alias_A
else
SciMLBase.@reset aliases.alias_A = default_alias_A(alg, prob.A, prob.b)
end

if has_b
Base.depwarn("alias_b keyword argument is deprecated, to set `alias_b`,
please use an LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_b = true))", :alias_b)
SciMLBase.@reset aliases.alias_b = values(kwargs).alias_b
else
SciMLBase.@reset aliases.alias_b = default_alias_b(alg, prob.A, prob.b)
end

aliases
else
# If alias isa Bool, all fields of ODEAliases set to alias
if alias isa Bool
aliases = LinearAliasSpecifier(alias = alias)
elseif alias isa LinearAliasSpecifier || isnothing(alias)
aliases = alias
end

if isnothing(aliases.alias_A)
SciMLBase.@reset aliases.alias_A = default_alias_A(alg,prob.A,prob.b)
end
if isnothing(aliases.alias_b)
SciMLBase.@reset aliases.alias_b = default_alias_b(alg,prob.A,prob.b)
end
aliases
end

A = if aliases.alias_A || A isa SMatrix
A
elseif A isa Array
copy(A)
Expand All @@ -164,7 +202,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,

b = if b isa SparseArrays.AbstractSparseArray && !(A isa Diagonal)
Array(b) # the solution to a linear solve will always be dense!
elseif alias_b || b isa SVector
elseif aliases.alias_b || b isa SVector
b
elseif b isa Array
copy(b)
Expand Down
Loading