Skip to content

Commit

Permalink
Move SparseArrays support to an extension
Browse files Browse the repository at this point in the history
  • Loading branch information
David Widmann committed Oct 16, 2023
1 parent efc2f86 commit 518b07a
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 104 deletions.
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.17.0"
version = "1.18.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[extensions]
ChainRulesCoreSparseArraysExt = "SparseArrays"

[compat]
BenchmarkTools = "0.5"
Compat = "2, 3, 4"
Expand All @@ -19,8 +25,9 @@ julia = "1.6"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "StaticArrays"]
test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "SparseArrays", "StaticArrays"]
103 changes: 103 additions & 0 deletions ext/ChainRulesCoreSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
module ChainRulesCoreSparseArraysExt

using ChainRulesCore
using ChainRulesCore: project_type, _projection_mismatch
using SparseArrays: SparseVector, SparseMatrixCSC, nzrange, rowvals

ChainRulesCore.is_inplaceable_destination(::SparseVector) = true
ChainRulesCore.is_inplaceable_destination(::SparseMatrixCSC) = true

# Word from on high is that we should regard all un-stored values of sparse arrays as
# structural zeros. Thus ProjectTo needs to store nzind, and get only those.
# This implementation very naiive, can probably be made more efficient.

function ChainRulesCore.ProjectTo(x::SparseVector{T}) where {T<:Number}
return ProjectTo{SparseVector}(;
element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x)
)
end
function (project::ProjectTo{SparseVector})(dx::AbstractArray)
dy = if axes(dx) == project.axes
dx
else
if size(dx, 1) != length(project.axes[1])
throw(_projection_mismatch(project.axes, size(dx)))
end
reshape(dx, project.axes)
end
nzval = map(i -> project.element(dy[i]), project.nzind)
return SparseVector(length(dx), project.nzind, nzval)
end
function (project::ProjectTo{SparseVector})(dx::SparseVector)
if size(dx) != map(length, project.axes)
throw(_projection_mismatch(project.axes, size(dx)))

Check warning on line 33 in ext/ChainRulesCoreSparseArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/ChainRulesCoreSparseArraysExt.jl#L33

Added line #L33 was not covered by tests
end
# When sparsity pattern is unchanged, all the time is in checking this,
# perhaps some simple hash/checksum might be good enough?
samepattern = project.nzind == dx.nzind
# samepattern = length(project.nzind) == length(dx.nzind)
if eltype(dx) <: project_type(project.element) && samepattern
return dx
elseif samepattern
nzval = map(project.element, dx.nzval)
SparseVector(length(dx), dx.nzind, nzval)
else
nzind = project.nzind
# Or should we intersect? Can this exploit sorting?
# nzind = intersect(project.nzind, dx.nzind)
nzval = map(i -> project.element(dx[i]), nzind)
return SparseVector(length(dx), nzind, nzval)
end
end

function ChainRulesCore.ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number}
return ProjectTo{SparseMatrixCSC}(;
element=ProjectTo(zero(T)),
axes=axes(x),
rowval=rowvals(x),
nzranges=nzrange.(Ref(x), axes(x, 2)),
colptr=x.colptr,
)
end
# You need not really store nzranges, you can get them from colptr -- TODO
# nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1)
function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray)
dy = if axes(dx) == project.axes
dx
else
if size(dx) != (length(project.axes[1]), length(project.axes[2]))
throw(_projection_mismatch(project.axes, size(dx)))
end
reshape(dx, project.axes)
end
nzval = Vector{project_type(project.element)}(undef, length(project.rowval))
k = 0
for col in project.axes[2]
for i in project.nzranges[col]
row = project.rowval[i]
val = dy[row, col]
nzval[k += 1] = project.element(val)
end
end
m, n = map(length, project.axes)
return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval)
end

function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
if size(dx) != map(length, project.axes)
throw(_projection_mismatch(project.axes, size(dx)))

Check warning on line 88 in ext/ChainRulesCoreSparseArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/ChainRulesCoreSparseArraysExt.jl#L88

Added line #L88 was not covered by tests
end
samepattern = dx.colptr == project.colptr && dx.rowval == project.rowval
# samepattern = length(dx.colptr) == length(project.colptr) && dx.colptr[end] == project.colptr[end]
if eltype(dx) <: project_type(project.element) && samepattern
return dx
elseif samepattern
nzval = map(project.element, dx.nzval)
m, n = size(dx)
return SparseMatrixCSC(m, n, dx.colptr, dx.rowval, nzval)
else
invoke(project, Tuple{AbstractArray}, dx)
end
end

end # module
6 changes: 5 additions & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module ChainRulesCore
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
using Base.Meta
using LinearAlgebra
using SparseArrays: SparseVector, SparseMatrixCSC
using Compat: hasfield, hasproperty

export frule, rrule # core function
Expand Down Expand Up @@ -36,4 +35,9 @@ include("ignore_derivatives.jl")

include("deprecated.jl")

# SparseArrays support on Julia < 1.9
if !isdefined(Base, :get_extension)
include("../ext/ChainRulesCoreSparseArraysExt.jl")
end

end # module
3 changes: 0 additions & 3 deletions src/accumulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ is_inplaceable_destination(::Any) = false
is_inplaceable_destination(::Array) = true
is_inplaceable_destination(:: Array{<:Integer}) = false

is_inplaceable_destination(::SparseVector) = true
is_inplaceable_destination(::SparseMatrixCSC) = true

function is_inplaceable_destination(x::SubArray)
alpha = is_inplaceable_destination(parent(x))
beta = x.indices isa Tuple{Vararg{Union{Integer, Base.Slice, UnitRange}}}
Expand Down
98 changes: 0 additions & 98 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,101 +515,3 @@ function (project::ProjectTo{Tridiagonal})(dx::AbstractArray)
return Tridiagonal(dy)
end
# Note that backing(::Tridiagonal) doesn't work, https://github.com/JuliaDiff/ChainRulesCore.jl/issues/392

#####
##### `SparseArrays`
#####

using SparseArrays
# Word from on high is that we should regard all un-stored values of sparse arrays as
# structural zeros. Thus ProjectTo needs to store nzind, and get only those.
# This implementation very naiive, can probably be made more efficient.

function ProjectTo(x::SparseVector{T}) where {T<:Number}
return ProjectTo{SparseVector}(;
element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x)
)
end
function (project::ProjectTo{SparseVector})(dx::AbstractArray)
dy = if axes(dx) == project.axes
dx
else
if size(dx, 1) != length(project.axes[1])
throw(_projection_mismatch(project.axes, size(dx)))
end
reshape(dx, project.axes)
end
nzval = map(i -> project.element(dy[i]), project.nzind)
return SparseVector(length(dx), project.nzind, nzval)
end
function (project::ProjectTo{SparseVector})(dx::SparseVector)
if size(dx) != map(length, project.axes)
throw(_projection_mismatch(project.axes, size(dx)))
end
# When sparsity pattern is unchanged, all the time is in checking this,
# perhaps some simple hash/checksum might be good enough?
samepattern = project.nzind == dx.nzind
# samepattern = length(project.nzind) == length(dx.nzind)
if eltype(dx) <: project_type(project.element) && samepattern
return dx
elseif samepattern
nzval = map(project.element, dx.nzval)
SparseVector(length(dx), dx.nzind, nzval)
else
nzind = project.nzind
# Or should we intersect? Can this exploit sorting?
# nzind = intersect(project.nzind, dx.nzind)
nzval = map(i -> project.element(dx[i]), nzind)
return SparseVector(length(dx), nzind, nzval)
end
end

function ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number}
return ProjectTo{SparseMatrixCSC}(;
element=ProjectTo(zero(T)),
axes=axes(x),
rowval=rowvals(x),
nzranges=nzrange.(Ref(x), axes(x, 2)),
colptr=x.colptr,
)
end
# You need not really store nzranges, you can get them from colptr -- TODO
# nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1)
function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray)
dy = if axes(dx) == project.axes
dx
else
if size(dx) != (length(project.axes[1]), length(project.axes[2]))
throw(_projection_mismatch(project.axes, size(dx)))
end
reshape(dx, project.axes)
end
nzval = Vector{project_type(project.element)}(undef, length(project.rowval))
k = 0
for col in project.axes[2]
for i in project.nzranges[col]
row = project.rowval[i]
val = dy[row, col]
nzval[k += 1] = project.element(val)
end
end
m, n = map(length, project.axes)
return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval)
end

function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
if size(dx) != map(length, project.axes)
throw(_projection_mismatch(project.axes, size(dx)))
end
samepattern = dx.colptr == project.colptr && dx.rowval == project.rowval
# samepattern = length(dx.colptr) == length(project.colptr) && dx.colptr[end] == project.colptr[end]
if eltype(dx) <: project_type(project.element) && samepattern
return dx
elseif samepattern
nzval = map(project.element, dx.nzval)
m, n = size(dx)
return SparseMatrixCSC(m, n, dx.colptr, dx.rowval, nzval)
else
invoke(project, Tuple{AbstractArray}, dx)
end
end

0 comments on commit 518b07a

Please sign in to comment.