Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sum #446

Merged
merged 14 commits into from
Jul 16, 2024
3 changes: 2 additions & 1 deletion src/BandedMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Base: axes, axes1, getproperty, getindex, setindex!, *, +, -, ==, <, <=,
>=, /, \, adjoint, transpose, showerror, convert, size, view,
unsafe_indices, first, last, size, length, unsafe_length, step, to_indices,
to_index, show, fill!, similar, copy, promote_rule, real, imag,
copyto!, Array
copyto!, Array, sum

using Base.Broadcast: AbstractArrayStyle, DefaultArrayStyle, Broadcasted
import Base.Broadcast: BroadcastStyle, broadcasted
Expand Down Expand Up @@ -99,4 +99,5 @@ end
include("precompile.jl")



end #module
56 changes: 56 additions & 0 deletions src/banded/BandedMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1007,3 +1007,59 @@
l,u = bandwidths(A)
_BandedMatrix(reshape(resize!(vec(copy(bandeddata(A))), (l+u+1)*m), l+u+1, m), n, l,u)
end

function sum(A::BandedMatrix; dims=:)
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
if(dims == :)
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
l, u = bandwidths(A)
ret = zero(eltype(A))
if(l + u < 0)
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
return ret

Check warning on line 1016 in src/banded/BandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/banded/BandedMatrix.jl#L1016

Added line #L1016 was not covered by tests
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
end
n, m = size(A)
#Only get nonempty bands
lower, upper = min(n-1, l), min(m-1, u)
for i = -lower:upper
ret += sum(A[band(i)])
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
end
ret
elseif(dims > 2)
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
A
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
elseif(dims == 2)
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
l, u = bandwidths(A)
n, m = size(A)
ret = zeros(eltype(A), (n, 1))
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
if(l + u < 0)
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
return ret

Check warning on line 1032 in src/banded/BandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/banded/BandedMatrix.jl#L1032

Added line #L1032 was not covered by tests
end
lower, upper = min(n-1, l), min(m-1, u)
for i = -lower:upper
b = A[band(i)]
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
if(i <= 0)
ret[1-i:length(b)-i, 1] += b
else
ret[1:length(b), 1] += b
end
end
ret
elseif(dims == 1)
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
l, u = bandwidths(A)
n, m = size(A)
ret = zeros(eltype(A), (1, m))
if(l + u < 0)
return ret

Check warning on line 1049 in src/banded/BandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/banded/BandedMatrix.jl#L1049

Added line #L1049 was not covered by tests
end
lower, upper = min(n-1, l), min(m-1, u)
for i=-lower:upper
b = A[band(i)]
if(i <= 0)
ret[1, 1:length(b)] += b
else
ret[1, i+1:i+length(b)] += b
end
end
ret
else
throw(ArgumentError("dimension must be ≥ 1, got $dims"))
end
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ include("test_symbanded.jl")
include("test_tribanded.jl")
include("test_interface.jl")
include("test_miscs.jl")
include("test_sum.jl")
15 changes: 15 additions & 0 deletions test/test_sum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module TestSum

using Test, BandedMatrices, Random

r = brand(Float64,rand(1:10_000),rand(1:10_000),rand(-20:100),rand(-20:100))
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
matr = Matrix(r)
@testset "sum" begin
@test sum(r) ≈ sum(matr) atol = 1e-10
@test sum(r; dims=2) ≈ sum(matr; dims=2) atol = 1e-10
@test sum(r; dims=1) ≈ sum(matr; dims=1) atol = 1e-10
@test sum(r; dims=3) == r
@test_throws ArgumentError sum(r; dims=0)
end

end
Loading