From 300c4a9c6e788ccb59fc6844371e6f48f5807083 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 8 Dec 2023 14:51:40 -0500 Subject: [PATCH] Fix NormalCholesky on GPU --- Project.toml | 2 +- src/factorization.jl | 4 ++-- test/gpu/cuda.jl | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index fce114b23..36749c026 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LinearSolve" uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" authors = ["SciML"] -version = "2.20.1" +version = "2.20.2" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/factorization.jl b/src/factorization.jl index e6476eb9d..2eb085767 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -896,7 +896,7 @@ default_alias_b(::NormalCholeskyFactorization, ::Any, ::Any) = true const PREALLOCATED_NORMALCHOLESKY = ArrayInterface.cholesky_instance(rand(1, 1), NoPivot()) function init_cacheval(alg::NormalCholeskyFactorization, - A::Union{AbstractSparseArray, + A::Union{AbstractSparseArray, GPUArraysCore.AbstractGPUArray, Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) @@ -921,7 +921,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization; A = cache.A A = convert(AbstractMatrix, A) if cache.isfresh - if A isa SparseMatrixCSC + if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray fact = cholesky(Symmetric((A)' * A, :L); check = false) else fact = cholesky(Symmetric((A)' * A, :L), alg.pivot; check = false) diff --git a/test/gpu/cuda.jl b/test/gpu/cuda.jl index d0aaf579f..75a181026 100644 --- a/test/gpu/cuda.jl +++ b/test/gpu/cuda.jl @@ -44,8 +44,8 @@ function test_interface(alg, prob1, prob2) return end -@testset "CudaOffloadFactorization" begin - test_interface(CudaOffloadFactorization(), prob1, prob2) +@testset "$alg" for alg in (CudaOffloadFactorization(), NormalCholeskyFactorization()) + test_interface(alg, prob1, prob2) end @testset "Simple GMRES: restart = $restart" for restart in (true, false)