-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement sum #446
Implement sum #446
Changes from 10 commits
87c311f
b6530b9
bec277e
cd92d7f
0768d34
cb197f5
d0109be
8da60df
de15c53
0e0bc12
cfc895e
dc9a82d
393bf62
5ec4b3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -336,3 +336,66 @@ | |||||
copy(A::Adjoint{T,<:AbstractBandedMatrix}) where T = copy(parent(A))' | ||||||
copy(A::Transpose{T,<:AbstractBandedMatrix}) where T = transpose(copy(parent(A))) | ||||||
end | ||||||
|
||||||
function sum!(ret::AbstractArray, A::AbstractBandedMatrix) | ||||||
#Behaves similarly to Base.sum! | ||||||
ret .= 0 | ||||||
n,m = size(A) | ||||||
s = size(ret) | ||||||
l = length(s) | ||||||
#Check for singleton dimension and perform respective sum | ||||||
if s[1] == 1 && (l == 1 || s[2]==1) | ||||||
for j = 1:m, i = colrange(A, j) | ||||||
ret .+= A[i, j] | ||||||
end | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add tests fro this special case |
||||||
elseif s[1] == n && (l == 1 || s[2]==1) | ||||||
for i = 1:n, j = rowrange(A, i) | ||||||
ret[i, 1] += A[i, j] | ||||||
end | ||||||
elseif s[1] == 1 && s[2] == m | ||||||
for j = 1:m, i = colrange(A, j) | ||||||
ret[1, j] += A[i, j] | ||||||
end | ||||||
elseif s[1] == n && s[2] == m | ||||||
ret = A | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not changing
Suggested change
|
||||||
else | ||||||
throw(DimensionMismatch("reduction on matrix of size ($n, $m) with output size $s")) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add test using |
||||||
end | ||||||
end | ||||||
|
||||||
function sum(A::AbstractBandedMatrix; dims=:) | ||||||
if dims isa Colon | ||||||
l, u = bandwidths(A) | ||||||
ret = zero(eltype(A)) | ||||||
if l + u < 0 | ||||||
return ret | ||||||
end | ||||||
n, m = size(A) | ||||||
for j = 1:m, i = colrange(A, j) | ||||||
ret += A[i, j] | ||||||
end | ||||||
ret | ||||||
elseif dims > 2 | ||||||
A | ||||||
elseif dims == 2 | ||||||
l, u = bandwidths(A) | ||||||
n, m = size(A) | ||||||
ret = zeros(eltype(A), n, 1) | ||||||
if l + u < 0 | ||||||
return ret | ||||||
end | ||||||
sum!(ret, A) | ||||||
ret | ||||||
elseif dims == 1 | ||||||
l, u = bandwidths(A) | ||||||
n, m = size(A) | ||||||
ret = zeros(eltype(A), 1, m) | ||||||
if l + u < 0 | ||||||
return ret | ||||||
end | ||||||
sum!(ret, A) | ||||||
ret | ||||||
else | ||||||
throw(ArgumentError("dimension must be ≥ 1, got $dims")) | ||||||
end | ||||||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
module TestSum | ||
|
||
using Test, BandedMatrices, Random | ||
|
||
r = brand(Float64,rand(1:10_000),rand(1:10_000),rand(-20:100),rand(-20:100)) | ||
dlfivefifty marked this conversation as resolved.
Show resolved
Hide resolved
|
||
empty_r = brand(Float64,rand(1:1_000),rand(1:1_000),rand(1:100),rand(-200:-101)) | ||
n,m = size(empty_r) | ||
matr = Matrix(r) | ||
@testset "sum" begin | ||
@test sum(empty_r) == 0 | ||
@test sum(empty_r; dims = 2) == zeros(n,1) | ||
@test sum(empty_r; dims = 1) == zeros(1,m) | ||
@test sum(r) == sum(matr) | ||
@test sum(r; dims=2) == sum(matr; dims=2) | ||
@test sum(r; dims=1) == sum(matr; dims=1) | ||
@test sum(r; dims=3) == r | ||
@test_throws ArgumentError sum(r; dims=0) | ||
end | ||
|
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.