From 4258a601d017dd0ae244c960bd8dc562eeabdcb7 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 26 Mar 2023 01:39:45 -0400 Subject: [PATCH] Add a preference system for turning on/off slow fallbacks This gives a good way to balance development vs usage. For development, you want to just error if you hit any slower path. But for users, code should just work. Thus the slower fallbacks were given a preference system for allowing error throwing, without forcing all users to have to always see errors on new types just for more optimizations. --- src/ArrayInterface.jl | 88 +++++++++++++++++++++++++------------------ 1 file changed, 52 insertions(+), 36 deletions(-) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 00dd753bc..175330104 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -1,5 +1,7 @@ module ArrayInterface +const SLOWFALLBACKS = @load_preference("slow_fallbacks", true) + using LinearAlgebra using SparseArrays using SuiteSparse @@ -282,7 +284,7 @@ function ismutable end ismutable(::Type{T}) -> Bool Query whether instances of type `T` are mutable or not, see -https://github.com/JuliaDiffEq/RecursiveArrayTools.jl/issues/19. +https://github.com/SciML/RecursiveArrayTools.jl/issues/19. """ ismutable(x) = ismutable(typeof(x)) function ismutable(::Type{T}) where {T <: AbstractArray} @@ -460,12 +462,15 @@ Returns the number. """ bunchkaufman_instance(a::Number) = a -""" -bunchkaufman_instance(a::Any) -> cholesky(a, check=false) +@static if SLOWFALLBACKS + """ + bunchkaufman_instance(a::Any) -> cholesky(a, check=false) -Returns the number. -""" -bunchkaufman_instance(a::Any) = bunchkaufman(a, check = false) + Slow fallback which gets the instance via factorization. Should get + specialized for new matrix types. + """ + bunchkaufman_instance(a::Any) = bunchkaufman(a, check = false) +end """ cholesky_instance(A, pivot = LinearAlgebra.RowMaximum()) -> cholesky_factorization_instance @@ -487,13 +492,15 @@ Returns the number. """ cholesky_instance(a::Number, pivot = LinearAlgebra.RowMaximum()) = a -""" -cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) -> cholesky(a, check=false) +@static if SLOWFALLBACKS + """ + cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) -> cholesky(a, check=false) -Slow fallback which gets the instance via factorization. Should get -specialized for new matrix types. -""" -cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) = cholesky(a, pivot, check = false) + Slow fallback which gets the instance via factorization. Should get + specialized for new matrix types. + """ + cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) = cholesky(a, pivot, check = false) +end """ ldlt_instance(A) -> ldlt_factorization_instance @@ -515,13 +522,15 @@ Returns the number. """ ldlt_instance(a::Number) = a -""" -ldlt_instance(a::Any) -> ldlt(a, check=false) +@static if SLOWFALLBACKS + """ + ldlt_instance(a::Any) -> ldlt(a, check=false) -Slow fallback which gets the instance via factorization. Should get -specialized for new matrix types. -""" -ldlt_instance(a::Any) = ldlt(a) + Slow fallback which gets the instance via factorization. Should get + specialized for new matrix types. + """ + ldlt_instance(a::Any) = ldlt(a) +end """ lu_instance(A) -> lu_factorization_instance @@ -558,13 +567,15 @@ Returns the number. """ lu_instance(a::Number) = a -""" - lu_instance(a::Any) -> lu(a, check=false) +@static if SLOWFALLBACKS + """ + lu_instance(a::Any) -> lu(a, check=false) -Slow fallback which gets the instance via factorization. Should get -specialized for new matrix types. -""" -lu_instance(a::Any) = lu(a, check = false) + Slow fallback which gets the instance via factorization. Should get + specialized for new matrix types. + """ + lu_instance(a::Any) = lu(a, check = false) +end """ qr_instance(A) -> qr_factorization_instance @@ -588,13 +599,15 @@ Returns the number. """ qr_instance(a::Number) = a -""" - qr_instance(a::Any) -> qr(a) +@static if SLOWFALLBACKS + """ + qr_instance(a::Any) -> qr(a) -Slow fallback which gets the instance via factorization. Should get -specialized for new matrix types. -""" -qr_instance(a::Any) = qr(a)# check = false) + Slow fallback which gets the instance via factorization. Should get + specialized for new matrix types. + """ + qr_instance(a::Any) = qr(a)# check = false) +end """ svd_instance(A) -> qr_factorization_instance @@ -613,13 +626,15 @@ Returns the number. """ svd_instance(a::Number) = a -""" - svd_instance(a::Any) -> svd(a) +@static if SLOWFALLBACKS + """ + svd_instance(a::Any) -> svd(a) -Slow fallback which gets the instance via factorization. Should get -specialized for new matrix types. -""" -svd_instance(a::Any) = svd(a) #check = false) + Slow fallback which gets the instance via factorization. Should get + specialized for new matrix types. + """ + svd_instance(a::Any) = svd(a) #check = false) +end """ safevec(v) @@ -1034,3 +1049,4 @@ import Requires end end # module +