Skip to content

Commit

Permalink
Implement sum (#446)
Browse files Browse the repository at this point in the history
* add sum without dims

* add sum; dims=1

* support for dims = 2 and error handling

* fix for empty matrices and added unit tests

* style

* make improvements

* add test_sum.jl to runtests.jl

* fix method dispatch issue in a way that mimics Base.sum

* update unit tests, reduce memory allocation, improve style

* update unit tests, add sum!, move to AbstractBandedMatrix.jl

* revert tests to \approx

* make improvements in AbstractBandedMatrix.jl

* test special cases of sum!

* add some tests and avoid CI failure

---------

Co-authored-by: Sheehan Olver <[email protected]>
  • Loading branch information
max-vassili3v and dlfivefifty authored Jul 16, 2024
1 parent 47c15ab commit ea616cc
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/BandedMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Base: axes, axes1, getproperty, getindex, setindex!, *, +, -, ==, <, <=,
>=, /, \, adjoint, transpose, showerror, convert, size, view,
unsafe_indices, first, last, size, length, unsafe_length, step, to_indices,
to_index, show, fill!, similar, copy, promote_rule, real, imag,
copyto!, Array
copyto!, Array, sum, sum!

using Base.Broadcast: AbstractArrayStyle, DefaultArrayStyle, Broadcasted
import Base.Broadcast: BroadcastStyle, broadcasted
Expand Down Expand Up @@ -99,4 +99,5 @@ end
include("precompile.jl")



end #module
1 change: 1 addition & 0 deletions src/banded/BandedMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1007,3 +1007,4 @@ function resize(A::BandedSubBandedMatrix, n::Integer, m::Integer)
l,u = bandwidths(A)
_BandedMatrix(reshape(resize!(vec(copy(bandeddata(A))), (l+u+1)*m), l+u+1, m), n, l,u)
end

65 changes: 65 additions & 0 deletions src/generic/AbstractBandedMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,68 @@ if VERSION >= v"1.9"
copy(A::Adjoint{T,<:AbstractBandedMatrix}) where T = copy(parent(A))'
copy(A::Transpose{T,<:AbstractBandedMatrix}) where T = transpose(copy(parent(A)))
end

function sum!(ret::AbstractArray, A::AbstractBandedMatrix)
#Behaves similarly to Base.sum!
fill!(ret, zero(eltype(ret)))
n,m = size(A)
s = size(ret)
l = length(s)
#Check for singleton dimension and perform respective sum
if s[1] == 1 && (l == 1 || s[2]==1)
for j = 1:m, i = colrange(A, j)
ret .+= A[i, j]
end
elseif s[1] == n && (l == 1 || s[2]==1)
for i = 1:n, j = rowrange(A, i)
ret[i, 1] += A[i, j]
end
elseif s[1] == 1 && s[2] == m
for j = 1:m, i = colrange(A, j)
ret[1, j] += A[i, j]
end
elseif s[1] == n && s[2] == m
copyto!(ret,A)
else
throw(DimensionMismatch("reduction on matrix of size ($n, $m) with output size $s"))
end
#return the value to mimic Base.sum!
ret
end

function sum(A::AbstractBandedMatrix; dims=:)
if dims isa Colon
l, u = bandwidths(A)
ret = zero(eltype(A))
if l + u < 0
return ret
end
n, m = size(A)
for j = 1:m, i = colrange(A, j)
ret += A[i, j]
end
ret
elseif dims > 2
A
elseif dims == 2
l, u = bandwidths(A)
n, m = size(A)
ret = zeros(eltype(A), n, 1)
if l + u < 0
return ret
end
sum!(ret, A)
ret
elseif dims == 1
l, u = bandwidths(A)
n, m = size(A)
ret = zeros(eltype(A), 1, m)
if l + u < 0
return ret
end
sum!(ret, A)
ret
else
throw(ArgumentError("dimension must be ≥ 1, got $dims"))
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ include("test_symbanded.jl")
include("test_tribanded.jl")
include("test_interface.jl")
include("test_miscs.jl")
include("test_sum.jl")
2 changes: 1 addition & 1 deletion test/test_broadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Random.seed!(0)
@test identity.(A) isa BandedMatrix
@test bandwidths(identity.(A)) == bandwidths(A)

@test (z -> exp(z)-1).(A) == (z -> exp(z)-1).(Matrix(A))
@test (z -> exp(z)-1).(A) (z -> exp(z)-1).(Matrix(A)) # for some reason == is breaking on Mac CI
@test (z -> exp(z)-1).(A) isa BandedMatrix
@test bandwidths((z -> exp(z)-1).(A)) == bandwidths(A)

Expand Down
34 changes: 34 additions & 0 deletions test/test_sum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
module TestSum

using Test, BandedMatrices, Random

Random.seed!(0)
r = brand(rand(1:10_000),rand(1:10_000),rand(-20:100),rand(-20:100))
empty_r = brand(rand(1:1_000),rand(1:1_000),rand(1:100),rand(-200:-101))
n,m = size(empty_r)
matr = Matrix(r)
@testset "sum" begin
@test sum(empty_r) == 0
@test sum(empty_r; dims = 2) == zeros(n,1)
@test sum(empty_r; dims = 1) == zeros(1,m)

@test sum(r) sum(matr) rtol = 1e-10
@test sum(r; dims=2) sum(matr; dims=2) rtol = 1e-10
@test sum(r; dims=1) sum(matr; dims=1) rtol = 1e-10
@test sum(r; dims=3) == r
@test_throws ArgumentError sum(r; dims=0)

v = [1.0]
sum!(v, r)
@test v == sum!(v, Matrix(r))
n2, m2 = size(r)
v = ones(n2)
@test sum!(v, r) == sum!(v, Matrix(r))
V = zeros(1,m2)
@test sum!(V, r) === V sum!(zeros(1,m2), Matrix(r))
V = zeros(n2,m2)
@test sum!(V, r) === V == r
@test_throws DimensionMismatch sum!(zeros(Float64, n2 + 1, m2 + 1), r)
end

end

0 comments on commit ea616cc

Please sign in to comment.