Skip to content

Commit

Permalink
Rowsupport in banded axpy methods (#316)
Browse files Browse the repository at this point in the history
* short-circuit in axpy

* rowsupport in banded_dense_axpy
  • Loading branch information
jishnub authored Feb 14, 2023
1 parent acd0ab5 commit bcedda1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
25 changes: 17 additions & 8 deletions src/generic/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -951,27 +951,36 @@ _banded_axpy!(a::Number, X::AbstractMatrix, Y::AbstractMatrix, notbandedX, notba
Xl, Xu = bandwidths(X)
Yl, Yu = bandwidths(Y)

if -Xl > Xu #= no bands in X =#
return Y
end

@boundscheck if Xl > Yl
# test that all entries are zero in extra bands
for j=1:size(X,2),k=max(1,j+Yl+1):min(j+Xl,n)
# test that all entries are zero in extra bands below the diagonal
for j=rowsupport(X),k=max(1,j+Yl+1):min(j+Xl,n)
if inbands_getindex(X, k, j) 0
throw(BandError(X, (k,j)))
throw(BandError(Y, (k,j)))
end
end
end
@boundscheck if Xu > Yu
# test that all entries are zero in extra bands
for j=1:size(X,2),k=max(1,j-Xu):min(j-Yu-1,n)
# test that all entries are zero in extra bands above the diagonal
for j=rowsupport(X),k=max(1,j-Xu):min(j-Yu-1,n)
if inbands_getindex(X, k, j) 0
throw(BandError(X, (k,j)))
throw(BandError(Y, (k,j)))
end
end
end

if -Yl > Yu #= no bands in Y =#
return Y
end

# only fill overlapping bands
l = min(Xl,Yl)
u = min(Xu,Yu)

@inbounds for j=1:m,k=max(1,j-u):min(n,j+l)
@inbounds for j=rowsupport(X), k=max(1,j-u):min(n,j+l)
inbands_setindex!(Y, a*inbands_getindex(X,k,j) + inbands_getindex(Y,k,j) ,k, j)
end
Y
Expand All @@ -981,7 +990,7 @@ function banded_dense_axpy!(a::Number, X::AbstractMatrix, Y::AbstractMatrix)
if size(X) != size(Y)
throw(DimensionMismatch("+"))
end
@inbounds for j=1:size(X,2),k=colrange(X,j)
@inbounds for j=rowsupport(X), k=colrange(X,j)
Y[k,j] += a*inbands_getindex(X,k,j)
end
Y
Expand Down
15 changes: 15 additions & 0 deletions test/test_broadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,21 @@ import BandedMatrices: BandedStyle, BandedRows
@test bandwidths(2A+B) == bandwidths(2.0.*A .+ B) == (2,2)
B .= 2.0 .* A .+ B
@test B == C

@testset "trivial cases" begin
B = brand(2,4,-1,0) # no bands in B
B2 = brand(2,4,0,-1) # no bands in B2
C = brand(size(B)...,1,1)
D = copy(C)
axpy!(0.1, B, C) # no bands in src
@test C == D
@test_throws BandError axpy!(0.1, C, B)
@test_throws BandError axpy!(0.1, C, B2)
D = copy(B)
C .= 0
axpy!(0.1, C, B) # no bands in dest, but src is zero
@test B == D
end
end

@testset "gbmv!" begin
Expand Down

0 comments on commit bcedda1

Please sign in to comment.