From bcedda19264e24d9b9c3cbb0916e47fb987f39a4 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 14 Feb 2023 19:50:50 +0530 Subject: [PATCH] Rowsupport in banded axpy methods (#316) * short-circuit in axpy * rowsupport in banded_dense_axpy --- src/generic/broadcast.jl | 25 +++++++++++++++++-------- test/test_broadcasting.jl | 15 +++++++++++++++ 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/src/generic/broadcast.jl b/src/generic/broadcast.jl index 047ed47e..d43f6d98 100644 --- a/src/generic/broadcast.jl +++ b/src/generic/broadcast.jl @@ -951,27 +951,36 @@ _banded_axpy!(a::Number, X::AbstractMatrix, Y::AbstractMatrix, notbandedX, notba Xl, Xu = bandwidths(X) Yl, Yu = bandwidths(Y) + if -Xl > Xu #= no bands in X =# + return Y + end + @boundscheck if Xl > Yl - # test that all entries are zero in extra bands - for j=1:size(X,2),k=max(1,j+Yl+1):min(j+Xl,n) + # test that all entries are zero in extra bands below the diagonal + for j=rowsupport(X),k=max(1,j+Yl+1):min(j+Xl,n) if inbands_getindex(X, k, j) ≠ 0 - throw(BandError(X, (k,j))) + throw(BandError(Y, (k,j))) end end end @boundscheck if Xu > Yu - # test that all entries are zero in extra bands - for j=1:size(X,2),k=max(1,j-Xu):min(j-Yu-1,n) + # test that all entries are zero in extra bands above the diagonal + for j=rowsupport(X),k=max(1,j-Xu):min(j-Yu-1,n) if inbands_getindex(X, k, j) ≠ 0 - throw(BandError(X, (k,j))) + throw(BandError(Y, (k,j))) end end end + if -Yl > Yu #= no bands in Y =# + return Y + end + + # only fill overlapping bands l = min(Xl,Yl) u = min(Xu,Yu) - @inbounds for j=1:m,k=max(1,j-u):min(n,j+l) + @inbounds for j=rowsupport(X), k=max(1,j-u):min(n,j+l) inbands_setindex!(Y, a*inbands_getindex(X,k,j) + inbands_getindex(Y,k,j) ,k, j) end Y @@ -981,7 +990,7 @@ function banded_dense_axpy!(a::Number, X::AbstractMatrix, Y::AbstractMatrix) if size(X) != size(Y) throw(DimensionMismatch("+")) end - @inbounds for j=1:size(X,2),k=colrange(X,j) + @inbounds for j=rowsupport(X), k=colrange(X,j) Y[k,j] += a*inbands_getindex(X,k,j) end Y diff --git a/test/test_broadcasting.jl b/test/test_broadcasting.jl index df6b37ea..9d826da2 100644 --- a/test/test_broadcasting.jl +++ b/test/test_broadcasting.jl @@ -306,6 +306,21 @@ import BandedMatrices: BandedStyle, BandedRows @test bandwidths(2A+B) == bandwidths(2.0.*A .+ B) == (2,2) B .= 2.0 .* A .+ B @test B == C + + @testset "trivial cases" begin + B = brand(2,4,-1,0) # no bands in B + B2 = brand(2,4,0,-1) # no bands in B2 + C = brand(size(B)...,1,1) + D = copy(C) + axpy!(0.1, B, C) # no bands in src + @test C == D + @test_throws BandError axpy!(0.1, C, B) + @test_throws BandError axpy!(0.1, C, B2) + D = copy(B) + C .= 0 + axpy!(0.1, C, B) # no bands in dest, but src is zero + @test B == D + end end @testset "gbmv!" begin