From 518b07a9318a4a37886a987ae3419ebf92940626 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 16 Oct 2023 11:43:58 +0200 Subject: [PATCH] Move SparseArrays support to an extension --- Project.toml | 11 ++- ext/ChainRulesCoreSparseArraysExt.jl | 103 +++++++++++++++++++++++++++ src/ChainRulesCore.jl | 6 +- src/accumulation.jl | 3 - src/projection.jl | 98 ------------------------- 5 files changed, 117 insertions(+), 104 deletions(-) create mode 100644 ext/ChainRulesCoreSparseArraysExt.jl diff --git a/Project.toml b/Project.toml index 93ac93294..ded9488fc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,18 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.17.0" +version = "1.18.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +[weakdeps] +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[extensions] +ChainRulesCoreSparseArraysExt = "SparseArrays" + [compat] BenchmarkTools = "0.5" Compat = "2, 3, 4" @@ -19,8 +25,9 @@ julia = "1.6" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "StaticArrays"] +test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "SparseArrays", "StaticArrays"] diff --git a/ext/ChainRulesCoreSparseArraysExt.jl b/ext/ChainRulesCoreSparseArraysExt.jl new file mode 100644 index 000000000..e4714f4c1 --- /dev/null +++ b/ext/ChainRulesCoreSparseArraysExt.jl @@ -0,0 +1,103 @@ +module ChainRulesCoreSparseArraysExt + +using ChainRulesCore +using ChainRulesCore: project_type, _projection_mismatch +using SparseArrays: SparseVector, SparseMatrixCSC, nzrange, rowvals + +ChainRulesCore.is_inplaceable_destination(::SparseVector) = true +ChainRulesCore.is_inplaceable_destination(::SparseMatrixCSC) = true + +# Word from on high is that we should regard all un-stored values of sparse arrays as +# structural zeros. Thus ProjectTo needs to store nzind, and get only those. +# This implementation very naiive, can probably be made more efficient. + +function ChainRulesCore.ProjectTo(x::SparseVector{T}) where {T<:Number} + return ProjectTo{SparseVector}(; + element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x) + ) +end +function (project::ProjectTo{SparseVector})(dx::AbstractArray) + dy = if axes(dx) == project.axes + dx + else + if size(dx, 1) != length(project.axes[1]) + throw(_projection_mismatch(project.axes, size(dx))) + end + reshape(dx, project.axes) + end + nzval = map(i -> project.element(dy[i]), project.nzind) + return SparseVector(length(dx), project.nzind, nzval) +end +function (project::ProjectTo{SparseVector})(dx::SparseVector) + if size(dx) != map(length, project.axes) + throw(_projection_mismatch(project.axes, size(dx))) + end + # When sparsity pattern is unchanged, all the time is in checking this, + # perhaps some simple hash/checksum might be good enough? + samepattern = project.nzind == dx.nzind + # samepattern = length(project.nzind) == length(dx.nzind) + if eltype(dx) <: project_type(project.element) && samepattern + return dx + elseif samepattern + nzval = map(project.element, dx.nzval) + SparseVector(length(dx), dx.nzind, nzval) + else + nzind = project.nzind + # Or should we intersect? Can this exploit sorting? + # nzind = intersect(project.nzind, dx.nzind) + nzval = map(i -> project.element(dx[i]), nzind) + return SparseVector(length(dx), nzind, nzval) + end +end + +function ChainRulesCore.ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number} + return ProjectTo{SparseMatrixCSC}(; + element=ProjectTo(zero(T)), + axes=axes(x), + rowval=rowvals(x), + nzranges=nzrange.(Ref(x), axes(x, 2)), + colptr=x.colptr, + ) +end +# You need not really store nzranges, you can get them from colptr -- TODO +# nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1) +function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) + dy = if axes(dx) == project.axes + dx + else + if size(dx) != (length(project.axes[1]), length(project.axes[2])) + throw(_projection_mismatch(project.axes, size(dx))) + end + reshape(dx, project.axes) + end + nzval = Vector{project_type(project.element)}(undef, length(project.rowval)) + k = 0 + for col in project.axes[2] + for i in project.nzranges[col] + row = project.rowval[i] + val = dy[row, col] + nzval[k += 1] = project.element(val) + end + end + m, n = map(length, project.axes) + return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval) +end + +function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC) + if size(dx) != map(length, project.axes) + throw(_projection_mismatch(project.axes, size(dx))) + end + samepattern = dx.colptr == project.colptr && dx.rowval == project.rowval + # samepattern = length(dx.colptr) == length(project.colptr) && dx.colptr[end] == project.colptr[end] + if eltype(dx) <: project_type(project.element) && samepattern + return dx + elseif samepattern + nzval = map(project.element, dx.nzval) + m, n = size(dx) + return SparseMatrixCSC(m, n, dx.colptr, dx.rowval, nzval) + else + invoke(project, Tuple{AbstractArray}, dx) + end +end + +end # module diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index b81ab4fba..94e8242b1 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,7 +2,6 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using Base.Meta using LinearAlgebra -using SparseArrays: SparseVector, SparseMatrixCSC using Compat: hasfield, hasproperty export frule, rrule # core function @@ -36,4 +35,9 @@ include("ignore_derivatives.jl") include("deprecated.jl") +# SparseArrays support on Julia < 1.9 +if !isdefined(Base, :get_extension) + include("../ext/ChainRulesCoreSparseArraysExt.jl") +end + end # module diff --git a/src/accumulation.jl b/src/accumulation.jl index dc4ccd3bf..26d0fbb27 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -56,9 +56,6 @@ is_inplaceable_destination(::Any) = false is_inplaceable_destination(::Array) = true is_inplaceable_destination(:: Array{<:Integer}) = false -is_inplaceable_destination(::SparseVector) = true -is_inplaceable_destination(::SparseMatrixCSC) = true - function is_inplaceable_destination(x::SubArray) alpha = is_inplaceable_destination(parent(x)) beta = x.indices isa Tuple{Vararg{Union{Integer, Base.Slice, UnitRange}}} diff --git a/src/projection.jl b/src/projection.jl index 811802536..e4ed4d8dc 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -515,101 +515,3 @@ function (project::ProjectTo{Tridiagonal})(dx::AbstractArray) return Tridiagonal(dy) end # Note that backing(::Tridiagonal) doesn't work, https://github.com/JuliaDiff/ChainRulesCore.jl/issues/392 - -##### -##### `SparseArrays` -##### - -using SparseArrays -# Word from on high is that we should regard all un-stored values of sparse arrays as -# structural zeros. Thus ProjectTo needs to store nzind, and get only those. -# This implementation very naiive, can probably be made more efficient. - -function ProjectTo(x::SparseVector{T}) where {T<:Number} - return ProjectTo{SparseVector}(; - element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x) - ) -end -function (project::ProjectTo{SparseVector})(dx::AbstractArray) - dy = if axes(dx) == project.axes - dx - else - if size(dx, 1) != length(project.axes[1]) - throw(_projection_mismatch(project.axes, size(dx))) - end - reshape(dx, project.axes) - end - nzval = map(i -> project.element(dy[i]), project.nzind) - return SparseVector(length(dx), project.nzind, nzval) -end -function (project::ProjectTo{SparseVector})(dx::SparseVector) - if size(dx) != map(length, project.axes) - throw(_projection_mismatch(project.axes, size(dx))) - end - # When sparsity pattern is unchanged, all the time is in checking this, - # perhaps some simple hash/checksum might be good enough? - samepattern = project.nzind == dx.nzind - # samepattern = length(project.nzind) == length(dx.nzind) - if eltype(dx) <: project_type(project.element) && samepattern - return dx - elseif samepattern - nzval = map(project.element, dx.nzval) - SparseVector(length(dx), dx.nzind, nzval) - else - nzind = project.nzind - # Or should we intersect? Can this exploit sorting? - # nzind = intersect(project.nzind, dx.nzind) - nzval = map(i -> project.element(dx[i]), nzind) - return SparseVector(length(dx), nzind, nzval) - end -end - -function ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number} - return ProjectTo{SparseMatrixCSC}(; - element=ProjectTo(zero(T)), - axes=axes(x), - rowval=rowvals(x), - nzranges=nzrange.(Ref(x), axes(x, 2)), - colptr=x.colptr, - ) -end -# You need not really store nzranges, you can get them from colptr -- TODO -# nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1) -function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) - dy = if axes(dx) == project.axes - dx - else - if size(dx) != (length(project.axes[1]), length(project.axes[2])) - throw(_projection_mismatch(project.axes, size(dx))) - end - reshape(dx, project.axes) - end - nzval = Vector{project_type(project.element)}(undef, length(project.rowval)) - k = 0 - for col in project.axes[2] - for i in project.nzranges[col] - row = project.rowval[i] - val = dy[row, col] - nzval[k += 1] = project.element(val) - end - end - m, n = map(length, project.axes) - return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval) -end - -function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC) - if size(dx) != map(length, project.axes) - throw(_projection_mismatch(project.axes, size(dx))) - end - samepattern = dx.colptr == project.colptr && dx.rowval == project.rowval - # samepattern = length(dx.colptr) == length(project.colptr) && dx.colptr[end] == project.colptr[end] - if eltype(dx) <: project_type(project.element) && samepattern - return dx - elseif samepattern - nzval = map(project.element, dx.nzval) - m, n = size(dx) - return SparseMatrixCSC(m, n, dx.colptr, dx.rowval, nzval) - else - invoke(project, Tuple{AbstractArray}, dx) - end -end