From c88f035fc4e3a260de487eb1b0b23ddf1ace16eb Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 22 Dec 2023 20:21:15 +0530 Subject: [PATCH 1/2] Fix OneElement multiplication with array elements --- src/oneelement.jl | 67 ++++++++++++++++++++--------------------------- test/runtests.jl | 19 ++++++++++++++ 2 files changed, 47 insertions(+), 39 deletions(-) diff --git a/src/oneelement.jl b/src/oneelement.jl index f1e2fd8e..5e8c5d76 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -145,12 +145,12 @@ function mul!(C::AbstractVector, A::OneElementMatrix, B::OneElementVector, alpha end @inline function __mul!(y, A::AbstractMatrix, x::OneElement, alpha, beta) - αx = alpha * x.val + xα = Ref(x.val * alpha) ind1 = x.ind[1] if iszero(beta) - y .= αx .* view(A, :, ind1) + y .= view(A, :, ind1) .* xα else - y .= αx .* view(A, :, ind1) .+ beta .* y + y .= view(A, :, ind1) .* xα .+ y .* beta end return y end @@ -171,13 +171,14 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::OneElementMatrix, alpha, mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta) return C end + nzrow, nzcol = B.ind if iszero(beta) - C .= zero(eltype(C)) + C .= Ref(zero(eltype(C))) else - view(C, :, 1:B.ind[2]-1) .*= beta - view(C, :, B.ind[2]+1:size(C,2)) .*= beta + view(C, :, 1:nzcol-1) .*= beta + view(C, :, nzcol+1:size(C,2)) .*= beta end - y = view(C, :, B.ind[2]) + y = view(C, :, nzcol) __mul!(y, A, B, alpha, beta) C end @@ -187,17 +188,14 @@ function _mul!(C::AbstractMatrix, A::Diagonal, B::OneElementMatrix, alpha, beta) mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta) return C end - if iszero(beta) - C .= zero(eltype(C)) - else - view(C, :, 1:B.ind[2]-1) .*= beta - view(C, :, B.ind[2]+1:size(C,2)) .*= beta - end - ABα = A * B * alpha nzrow, nzcol = B.ind + ABα = A * B * alpha if iszero(beta) - C[B.ind...] = ABα[B.ind...] + C .= Ref(zero(eltype(C))) + C[nzrow, nzcol] = ABα[nzrow, nzcol] else + view(C, :, 1:nzcol-1) .*= beta + view(C, :, nzcol+1:size(C,2)) .*= beta y = view(C, :, nzcol) y .= view(ABα, :, nzcol) .+ y .* beta end @@ -210,19 +208,16 @@ function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractMatrix, alpha, mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta) return C end - if iszero(beta) - C .= zero(eltype(C)) - else - view(C, 1:A.ind[1]-1, :) .*= beta - view(C, A.ind[1]+1:size(C,1), :) .*= beta - end - y = view(C, A.ind[1], :) - ind2 = A.ind[2] + nzrow, nzcol = A.ind + y = view(C, nzrow, :) Aval = A.val if iszero(beta) - y .= Aval .* view(B, ind2, :) .* alpha + C .= Ref(zero(eltype(C))) + y .= Aval .* view(B, nzcol, :) .* alpha else - y .= Aval .* view(B, ind2, :) .* alpha .+ y .* beta + view(C, 1:nzrow-1, :) .*= beta + view(C, nzrow+1:size(C,1), :) .*= beta + y .= Aval .* view(B, nzcol, :) .* alpha .+ y .* beta end C end @@ -232,17 +227,14 @@ function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::Diagonal, alpha, beta) mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta) return C end - if iszero(beta) - C .= zero(eltype(C)) - else - view(C, 1:A.ind[1]-1, :) .*= beta - view(C, A.ind[1]+1:size(C,1), :) .*= beta - end - ABα = A * B * alpha nzrow, nzcol = A.ind + ABα = A * B * alpha if iszero(beta) - C[A.ind...] = ABα[A.ind...] + C .= Ref(zero(eltype(C))) + C[nzrow, nzcol] = ABα[nzrow, nzcol] else + view(C, 1:nzrow-1, :) .*= beta + view(C, nzrow+1:size(C,1), :) .*= beta y = view(C, nzrow, :) y .= view(ABα, nzrow, :) .+ y .* beta end @@ -256,16 +248,13 @@ function _mul!(C::AbstractVector, A::OneElementMatrix, B::AbstractVector, alpha, return C end nzrow, nzcol = A.ind - if iszero(beta) - C .= zero(eltype(C)) - else - view(C, 1:nzrow-1) .*= beta - view(C, nzrow+1:size(C,1)) .*= beta - end Aval = A.val if iszero(beta) + C .= Ref(zero(eltype(C))) C[nzrow] = Aval * B[nzcol] * alpha else + view(C, 1:nzrow-1) .*= beta + view(C, nzrow+1:size(C,1)) .*= beta C[nzrow] = Aval * B[nzcol] * alpha + C[nzrow] * beta end C diff --git a/test/runtests.jl b/test/runtests.jl index 9d6d5378..6f91fe17 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2268,6 +2268,25 @@ end @test mul!(C, O, D, 2, 2) == 2 * O * D .+ 2 end end + @testset "array elements" begin + A = [SMatrix{2,3}(1:6)*(i+j) for i in 1:3, j in 1:2] + B = OneElement(SMatrix{3,2}(1:6), (size(A,2),2), (size(A,2),4)) + C = [SMatrix{2,2}(1:4) for i in 1:size(A,1), j in 1:size(B,2)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + B = OneElement(SMatrix{3,2}(1:6), (size(A,2),), (size(A,2),)) + C = [SMatrix{2,2}(1:4) for i in 1:size(A,1)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + end + @testset "non-commutative" begin + A = [quat(rand(4)...)*(i+j) for i in 1:2, j in 1:3] + B = OneElement(quat(rand(4)...), 1, size(A,2)) + C = [quat(rand(4)...) for i in axes(A,1)] + @test mul!(copy(C), A, B) ≈ A * B + α, β = quat(0,0,1,0), quat(1,0,1,0) + @test mul!(copy(C), A, B, α, β) ≈ mul!(copy(C), A, Array(B), α, β) ≈ A * B * α + C * β + end end @testset "multiplication/division by a number" begin From 0d7619192463fa720ad7f01999ba6450161d9e21 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 3 Feb 2024 11:15:44 +0530 Subject: [PATCH 2/2] Fix matmul for array elements in OneElMat * StridedMat --- src/oneelement.jl | 4 +- test/runtests.jl | 94 ++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 83 insertions(+), 15 deletions(-) diff --git a/src/oneelement.jl b/src/oneelement.jl index 5e8c5d76..28bc2c4b 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -213,11 +213,11 @@ function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractMatrix, alpha, Aval = A.val if iszero(beta) C .= Ref(zero(eltype(C))) - y .= Aval .* view(B, nzcol, :) .* alpha + y .= Ref(Aval) .* view(B, nzcol, :) .* alpha else view(C, 1:nzrow-1, :) .*= beta view(C, nzrow+1:size(C,1), :) .*= beta - y .= Aval .* view(B, nzcol, :) .* alpha .+ y .* beta + y .= Ref(Aval) .* view(B, nzcol, :) .* alpha .+ y .* beta end C end diff --git a/test/runtests.jl b/test/runtests.jl index 6f91fe17..0f6af632 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2270,22 +2270,90 @@ end end @testset "array elements" begin A = [SMatrix{2,3}(1:6)*(i+j) for i in 1:3, j in 1:2] - B = OneElement(SMatrix{3,2}(1:6), (size(A,2),2), (size(A,2),4)) - C = [SMatrix{2,2}(1:4) for i in 1:size(A,1), j in 1:size(B,2)] - @test mul!(copy(C), A, B) == A * B - @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C - B = OneElement(SMatrix{3,2}(1:6), (size(A,2),), (size(A,2),)) - C = [SMatrix{2,2}(1:4) for i in 1:size(A,1)] - @test mul!(copy(C), A, B) == A * B - @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + @testset "StridedMatrix * OneElementMatrix" begin + B = OneElement(SMatrix{3,2}(1:6), (size(A,2),2), (size(A,2),4)) + C = [SMatrix{2,2}(1:4) for i in axes(A,1), j in axes(B,2)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + end + @testset "StridedMatrix * OneElementVector" begin + B = OneElement(SMatrix{3,2}(1:6), (size(A,2),), (size(A,2),)) + C = [SMatrix{2,2}(1:4) for i in axes(A,1)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + end + + A = OneElement(SMatrix{3,2}(1:6), (3,2), (5,4)) + @testset "OneElementMatrix * StridedMatrix" begin + B = [SMatrix{2,3}(1:6)*(i+j) for i in axes(A,2), j in 1:2] + C = [SMatrix{3,3}(1:9) for i in axes(A,1), j in axes(B,2)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + end + @testset "OneElementMatrix * StridedVector" begin + B = [SMatrix{2,3}(1:6)*i for i in axes(A,2)] + C = [SMatrix{3,3}(1:9) for i in axes(A,1)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + end + @testset "OneElementMatrix * OneElementMatrix" begin + B = OneElement(SMatrix{2,3}(1:6), (2,4), (size(A,2), 3)) + C = [SMatrix{3,3}(1:9) for i in axes(A,1), j in axes(B,2)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + end + @testset "OneElementMatrix * OneElementVector" begin + B = OneElement(SMatrix{2,3}(1:6), 2, size(A,2)) + C = [SMatrix{3,3}(1:9) for i in axes(A,1)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + end end @testset "non-commutative" begin + A = OneElement(quat(rand(4)...), (2,3), (3,4)) + for (B,C) in ( + # OneElementMatrix * OneElementVector + (OneElement(quat(rand(4)...), 3, size(A,2)), + [quat(rand(4)...) for i in axes(A,1)]), + + # OneElementMatrix * OneElementMatrix + (OneElement(quat(rand(4)...), (3,2), (size(A,2), 4)), + [quat(rand(4)...) for i in axes(A,1), j in 1:4]), + ) + @test mul!(copy(C), A, B) ≈ A * B + α, β = quat(0,0,1,0), quat(1,0,1,0) + @test mul!(copy(C), A, B, α, β) ≈ mul!(copy(C), A, Array(B), α, β) ≈ A * B * α + C * β + end + A = [quat(rand(4)...)*(i+j) for i in 1:2, j in 1:3] - B = OneElement(quat(rand(4)...), 1, size(A,2)) - C = [quat(rand(4)...) for i in axes(A,1)] - @test mul!(copy(C), A, B) ≈ A * B - α, β = quat(0,0,1,0), quat(1,0,1,0) - @test mul!(copy(C), A, B, α, β) ≈ mul!(copy(C), A, Array(B), α, β) ≈ A * B * α + C * β + for (B,C) in ( + # StridedMatrix * OneElementVector + (OneElement(quat(rand(4)...), 1, size(A,2)), + [quat(rand(4)...) for i in axes(A,1)]), + + # StridedMatrix * OneElementMatrix + (OneElement(quat(rand(4)...), (2,2), (size(A,2), 4)), + [quat(rand(4)...) for i in axes(A,1), j in 1:4]), + ) + @test mul!(copy(C), A, B) ≈ A * B + α, β = quat(0,0,1,0), quat(1,0,1,0) + @test mul!(copy(C), A, B, α, β) ≈ mul!(copy(C), A, Array(B), α, β) ≈ A * B * α + C * β + end + + A = OneElement(quat(rand(4)...), (2,2), (3, 4)) + for (B,C) in ( + # OneElementMatrix * StridedMatrix + ([quat(rand(4)...) for i in axes(A,2), j in 1:3], + [quat(rand(4)...) for i in axes(A,1), j in 1:3]), + + # OneElementMatrix * StridedVector + ([quat(rand(4)...) for i in axes(A,2)], + [quat(rand(4)...) for i in axes(A,1)]), + ) + @test mul!(copy(C), A, B) ≈ A * B + α, β = quat(0,0,1,0), quat(1,0,1,0) + @test mul!(copy(C), A, B, α, β) ≈ mul!(copy(C), A, Array(B), α, β) ≈ A * B * α + C * β + end end end