Skip to content

Commit

Permalink
Add missing matrix multiplication methods involving OneElement (#347)
Browse files Browse the repository at this point in the history
* Add missing matrix multiplication methods involving OneElement

* multiplications with Diagonal

* Add suggested comment to `__muloneel!`

Co-authored-by: Frames White <[email protected]>

---------

Co-authored-by: Frames White <[email protected]>
  • Loading branch information
jishnub and oxinabox authored Feb 2, 2024
1 parent cff6c44 commit 99278fa
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 16 deletions.
177 changes: 166 additions & 11 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,58 @@ function *(A::OneElementMatrix, B::AbstractFillVector)
OneElement(val, A.ind[1], size(A,1))
end

@inline function __mulonel!(y, A, x, alpha, beta)
# Special matrix types

function *(A::OneElementMatrix, D::Diagonal)
check_matmul_sizes(A, D)
nzcol = A.ind[2]
val = if nzcol in axes(D,1)
A.val * D[nzcol, nzcol]
else
A.val * zero(eltype(D))
end
OneElement(val, A.ind, size(A))
end
function *(D::Diagonal, A::OneElementMatrix)
check_matmul_sizes(D, A)
nzrow = A.ind[1]
val = if nzrow in axes(D,2)
D[nzrow, nzrow] * A.val
else
zero(eltype(D)) * A.val
end
OneElement(val, A.ind, size(A))
end

# Inplace multiplication

# We use this for out overloads for _mul! for OneElement because its more efficient
# due to how efficient 2 arg mul is when one or more of the args are OneElement
function __mulonel!(C, A, B, alpha, beta)
ABα = A * B * alpha
if iszero(beta)
C .= ABα
else
C .= ABα .+ C .* beta
end
return C
end
# These methods remove the ambituity in _mul!. This isn't strictly necessary, but this makes Aqua happy.
function _mul!(C::AbstractVector, A::OneElementMatrix, B::OneElementVector, alpha, beta)
__mulonel!(C, A, B, alpha, beta)
end
function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::OneElementMatrix, alpha, beta)
__mulonel!(C, A, B, alpha, beta)
end

function mul!(C::AbstractMatrix, A::OneElementMatrix, B::OneElementMatrix, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end
function mul!(C::AbstractVector, A::OneElementMatrix, B::OneElementVector, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end

@inline function __mul!(y, A::AbstractMatrix, x::OneElement, alpha, beta)
αx = alpha * x.val
ind1 = x.ind[1]
if iszero(beta)
Expand All @@ -104,19 +155,19 @@ end
return y
end

function _mulonel!(y, A, x::OneElementVector, alpha::Number, beta::Number)
function _mul!(y::AbstractVector, A::AbstractMatrix, x::OneElementVector, alpha, beta)
check_matmul_sizes(y, A, x)
if x.ind[1] axes(x,1) # in this case x is all zeros
if iszero(getindex_value(x))
mul!(y, A, Zeros{eltype(x)}(axes(x)), alpha, beta)
return y
end
__mulonel!(y, A, x, alpha, beta)
__mul!(y, A, x, alpha, beta)
y
end

function _mulonel!(C, A, B::OneElementMatrix, alpha::Number, beta::Number)
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::OneElementMatrix, alpha, beta)
check_matmul_sizes(C, A, B)
if B.ind[1] axes(B,1) || B.ind[2] axes(B,2) # in this case x is all zeros
if iszero(getindex_value(B))
mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta)
return C
end
Expand All @@ -127,24 +178,128 @@ function _mulonel!(C, A, B::OneElementMatrix, alpha::Number, beta::Number)
view(C, :, B.ind[2]+1:size(C,2)) .*= beta
end
y = view(C, :, B.ind[2])
__mulonel!(y, A, B, alpha, beta)
__mul!(y, A, B, alpha, beta)
C
end
function _mul!(C::AbstractMatrix, A::Diagonal, B::OneElementMatrix, alpha, beta)
check_matmul_sizes(C, A, B)
if iszero(getindex_value(B))
mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta)
return C
end
if iszero(beta)
C .= zero(eltype(C))
else
view(C, :, 1:B.ind[2]-1) .*= beta
view(C, :, B.ind[2]+1:size(C,2)) .*= beta
end
ABα = A * B * alpha
nzrow, nzcol = B.ind
if iszero(beta)
C[B.ind...] = ABα[B.ind...]
else
y = view(C, :, nzcol)
y .= view(ABα, :, nzcol) .+ y .* beta
end
C
end

function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractMatrix, alpha, beta)
check_matmul_sizes(C, A, B)
if iszero(getindex_value(A))
mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta)
return C
end
if iszero(beta)
C .= zero(eltype(C))
else
view(C, 1:A.ind[1]-1, :) .*= beta
view(C, A.ind[1]+1:size(C,1), :) .*= beta
end
y = view(C, A.ind[1], :)
ind2 = A.ind[2]
Aval = A.val
if iszero(beta)
y .= Aval .* view(B, ind2, :) .* alpha
else
y .= Aval .* view(B, ind2, :) .* alpha .+ y .* beta
end
C
end
function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::Diagonal, alpha, beta)
check_matmul_sizes(C, A, B)
if iszero(getindex_value(A))
mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta)
return C
end
if iszero(beta)
C .= zero(eltype(C))
else
view(C, 1:A.ind[1]-1, :) .*= beta
view(C, A.ind[1]+1:size(C,1), :) .*= beta
end
ABα = A * B * alpha
nzrow, nzcol = A.ind
if iszero(beta)
C[A.ind...] = ABα[A.ind...]
else
y = view(C, nzrow, :)
y .= view(ABα, nzrow, :) .+ y .* beta
end
C
end

function _mul!(C::AbstractVector, A::OneElementMatrix, B::AbstractVector, alpha, beta)
check_matmul_sizes(C, A, B)
if iszero(getindex_value(A))
mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta)
return C
end
nzrow, nzcol = A.ind
if iszero(beta)
C .= zero(eltype(C))
else
view(C, 1:nzrow-1) .*= beta
view(C, nzrow+1:size(C,1)) .*= beta
end
Aval = A.val
if iszero(beta)
C[nzrow] = Aval * B[nzcol] * alpha
else
C[nzrow] = Aval * B[nzcol] * alpha + C[nzrow] * beta
end
C
end

for MT in (:StridedMatrix, :(Transpose{<:Any, <:StridedMatrix}), :(Adjoint{<:Any, <:StridedMatrix}))
@eval function mul!(y::StridedVector, A::$MT, x::OneElementVector, alpha::Number, beta::Number)
_mulonel!(y, A, x, alpha, beta)
_mul!(y, A, x, alpha, beta)
end
end
for MT in (:StridedMatrix, :(Transpose{<:Any, <:StridedMatrix}), :(Adjoint{<:Any, <:StridedMatrix}),
:Diagonal)
@eval function mul!(C::StridedMatrix, A::$MT, B::OneElementMatrix, alpha::Number, beta::Number)
_mulonel!(C, A, B, alpha, beta)
_mul!(C, A, B, alpha, beta)
end
@eval function mul!(C::StridedMatrix, A::OneElementMatrix, B::$MT, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end
end
function mul!(C::StridedVector, A::OneElementMatrix, B::StridedVector, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end

function mul!(y::AbstractVector, A::AbstractFillMatrix, x::OneElementVector, alpha::Number, beta::Number)
_mulonel!(y, A, x, alpha, beta)
_mul!(y, A, x, alpha, beta)
end
function mul!(C::AbstractMatrix, A::AbstractFillMatrix, B::OneElementMatrix, alpha::Number, beta::Number)
_mulonel!(C, A, B, alpha, beta)
_mul!(C, A, B, alpha, beta)
end
function mul!(C::AbstractVector, A::OneElementMatrix, B::AbstractFillVector, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end
function mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractFillMatrix, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end

# adjoint/transpose
Expand Down
87 changes: 82 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2113,14 +2113,16 @@ end

@testset "matmul" begin
A = reshape(Float64[1:9;], 3, 3)
v = reshape(Float64[1:3;], 3)
testinds(w::AbstractArray) = testinds(size(w))
testinds(szw::Tuple{Int}) = (szw .- 1, szw .+ 1)
function testinds(szA::Tuple{Int,Int})
(szA .- 1, szA .+ (-1,0), szA .+ (0,-1), szA .+ 1, szA .+ (1,-1), szA .+ (-1,1))
end
function test_A_mul_OneElement(A, (w, w2))
@testset for ind in testinds(w)
x = OneElement(3, ind, size(w))
# test matvec if w is a vector, or matmat if w is a matrix
function test_mat_mul_OneElement(A, (w, w2), sz)
@testset for ind in testinds(sz)
x = OneElement(3, ind, sz)
xarr = Array(x)
Axarr = A * xarr
Aadjxarr = A' * xarr
Expand All @@ -2143,26 +2145,84 @@ end
@test mul!(w2, F, x, 1.0, 1.0) Array(F) * xarr .+ 1
end
end
function test_OneElementMatrix_mul_mat(A, (w, w2), sz)
@testset for ind in testinds(sz)
O = OneElement(3, ind, sz)
Oarr = Array(O)
OarrA = Oarr * A
OarrAadj = Oarr * A'

@test O * A OarrA
@test O * A' OarrAadj
@test O * transpose(A) Oarr * transpose(A)

@test mul!(w, O, A) OarrA
# check columnwise to ensure zero columns
@test all(((c1, c2),) -> c1 c2, zip(eachcol(w), eachcol(OarrA)))
@test mul!(w, O, A') OarrAadj
w .= 1
@test mul!(w, O, A, 1.0, 2.0) OarrA .+ 2
w .= 1
@test mul!(w, O, A', 1.0, 2.0) OarrAadj .+ 2

F = Fill(3, size(A))
w2 .= 1
@test mul!(w2, O, F, 1.0, 1.0) Oarr * Array(F) .+ 1
end
end
function test_OneElementMatrix_mul_vec(v, (w, w2), sz)
@testset for ind in testinds(sz)
O = OneElement(3, ind, sz)
Oarr = Array(O)
Oarrv = Oarr * v

@test O * v == Oarrv

@test mul!(w, O, v) == Oarrv
# check rowwise to ensure zero rows
@test all(((r1, r2),) -> r1 == r2, zip(eachrow(w), eachrow(Oarrv)))
w .= 1
@test mul!(w, O, v, 1.0, 2.0) == Oarrv .+ 2

F = Fill(3, size(v))
w2 .= 1
@test mul!(w2, O, F, 1.0, 1.0) == Oarr * Array(F) .+ 1
end
end
@testset "Matrix * OneElementVector" begin
w = zeros(size(A,1))
w2 = MVector{length(w)}(w)
test_A_mul_OneElement(A, (w, w2))
test_mat_mul_OneElement(A, (w, w2), size(w))
end
@testset "Matrix * OneElementMatrix" begin
C = zeros(size(A))
C2 = MMatrix{size(C)...}(C)
test_A_mul_OneElement(A, (C, C2))
test_mat_mul_OneElement(A, (C, C2), size(C))
end
@testset "OneElementMatrix * Vector" begin
w = zeros(size(v))
w2 = MVector{size(v)...}(v)
test_OneElementMatrix_mul_vec(v, (w, w2), size(A))
end
@testset "OneElementMatrix * Matrix" begin
C = zeros(size(A))
C2 = MMatrix{size(C)...}(C)
test_OneElementMatrix_mul_mat(A, (C, C2), size(A))
end
@testset "OneElementMatrix * OneElement" begin
@testset for ind in testinds(A)
O = OneElement(3, ind, size(A))
v = OneElement(4, ind[2], size(A,1))
@test O * v isa OneElement
@test O * v == Array(O) * Array(v)
@test mul!(ones(size(O,1)), O, v) == O * v
@test mul!(ones(size(O,1)), O, v, 2, 1) == 2 * O * v .+ 1

B = OneElement(4, ind, size(A))
@test O * B isa OneElement
@test O * B == Array(O) * Array(B)
@test mul!(ones(size(O,1), size(B,2)), O, B) == O * B
@test mul!(ones(size(O,1), size(B,2)), O, B, 2, 1) == 2 * O * B .+ 1
end

@test OneElement(3, (2,3), (5,4)) * OneElement(2, 2, 4) == Zeros(5)
Expand Down Expand Up @@ -2191,6 +2251,23 @@ end
B = Zeros(4)
@test A * B === Zeros(5)
end
@testset "Diagonal and OneElementMatrix" begin
for ind in ((2,3), (2,2), (10,10))
O = OneElement(3, ind, (4,3))
Oarr = Array(O)
C = zeros(size(O))
D = Diagonal(axes(O,1))
@test D * O == D * Oarr
@test mul!(C, D, O) == D * O
C .= 1
@test mul!(C, D, O, 2, 2) == 2 * D * O .+ 2
D = Diagonal(axes(O,2))
@test O * D == Oarr * D
@test mul!(C, O, D) == O * D
C .= 1
@test mul!(C, O, D, 2, 2) == 2 * O * D .+ 2
end
end
end

@testset "multiplication/division by a number" begin
Expand Down

0 comments on commit 99278fa

Please sign in to comment.