diff --git a/src/oneelement.jl b/src/oneelement.jl index f1e2fd8e..28bc2c4b 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -145,12 +145,12 @@ function mul!(C::AbstractVector, A::OneElementMatrix, B::OneElementVector, alpha end @inline function __mul!(y, A::AbstractMatrix, x::OneElement, alpha, beta) - αx = alpha * x.val + xα = Ref(x.val * alpha) ind1 = x.ind[1] if iszero(beta) - y .= αx .* view(A, :, ind1) + y .= view(A, :, ind1) .* xα else - y .= αx .* view(A, :, ind1) .+ beta .* y + y .= view(A, :, ind1) .* xα .+ y .* beta end return y end @@ -171,13 +171,14 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::OneElementMatrix, alpha, mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta) return C end + nzrow, nzcol = B.ind if iszero(beta) - C .= zero(eltype(C)) + C .= Ref(zero(eltype(C))) else - view(C, :, 1:B.ind[2]-1) .*= beta - view(C, :, B.ind[2]+1:size(C,2)) .*= beta + view(C, :, 1:nzcol-1) .*= beta + view(C, :, nzcol+1:size(C,2)) .*= beta end - y = view(C, :, B.ind[2]) + y = view(C, :, nzcol) __mul!(y, A, B, alpha, beta) C end @@ -187,17 +188,14 @@ function _mul!(C::AbstractMatrix, A::Diagonal, B::OneElementMatrix, alpha, beta) mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta) return C end - if iszero(beta) - C .= zero(eltype(C)) - else - view(C, :, 1:B.ind[2]-1) .*= beta - view(C, :, B.ind[2]+1:size(C,2)) .*= beta - end - ABα = A * B * alpha nzrow, nzcol = B.ind + ABα = A * B * alpha if iszero(beta) - C[B.ind...] = ABα[B.ind...] + C .= Ref(zero(eltype(C))) + C[nzrow, nzcol] = ABα[nzrow, nzcol] else + view(C, :, 1:nzcol-1) .*= beta + view(C, :, nzcol+1:size(C,2)) .*= beta y = view(C, :, nzcol) y .= view(ABα, :, nzcol) .+ y .* beta end @@ -210,19 +208,16 @@ function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractMatrix, alpha, mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta) return C end - if iszero(beta) - C .= zero(eltype(C)) - else - view(C, 1:A.ind[1]-1, :) .*= beta - view(C, A.ind[1]+1:size(C,1), :) .*= beta - end - y = view(C, A.ind[1], :) - ind2 = A.ind[2] + nzrow, nzcol = A.ind + y = view(C, nzrow, :) Aval = A.val if iszero(beta) - y .= Aval .* view(B, ind2, :) .* alpha + C .= Ref(zero(eltype(C))) + y .= Ref(Aval) .* view(B, nzcol, :) .* alpha else - y .= Aval .* view(B, ind2, :) .* alpha .+ y .* beta + view(C, 1:nzrow-1, :) .*= beta + view(C, nzrow+1:size(C,1), :) .*= beta + y .= Ref(Aval) .* view(B, nzcol, :) .* alpha .+ y .* beta end C end @@ -232,17 +227,14 @@ function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::Diagonal, alpha, beta) mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta) return C end - if iszero(beta) - C .= zero(eltype(C)) - else - view(C, 1:A.ind[1]-1, :) .*= beta - view(C, A.ind[1]+1:size(C,1), :) .*= beta - end - ABα = A * B * alpha nzrow, nzcol = A.ind + ABα = A * B * alpha if iszero(beta) - C[A.ind...] = ABα[A.ind...] + C .= Ref(zero(eltype(C))) + C[nzrow, nzcol] = ABα[nzrow, nzcol] else + view(C, 1:nzrow-1, :) .*= beta + view(C, nzrow+1:size(C,1), :) .*= beta y = view(C, nzrow, :) y .= view(ABα, nzrow, :) .+ y .* beta end @@ -256,16 +248,13 @@ function _mul!(C::AbstractVector, A::OneElementMatrix, B::AbstractVector, alpha, return C end nzrow, nzcol = A.ind - if iszero(beta) - C .= zero(eltype(C)) - else - view(C, 1:nzrow-1) .*= beta - view(C, nzrow+1:size(C,1)) .*= beta - end Aval = A.val if iszero(beta) + C .= Ref(zero(eltype(C))) C[nzrow] = Aval * B[nzcol] * alpha else + view(C, 1:nzrow-1) .*= beta + view(C, nzrow+1:size(C,1)) .*= beta C[nzrow] = Aval * B[nzcol] * alpha + C[nzrow] * beta end C diff --git a/test/runtests.jl b/test/runtests.jl index 9d6d5378..0f6af632 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2268,6 +2268,93 @@ end @test mul!(C, O, D, 2, 2) == 2 * O * D .+ 2 end end + @testset "array elements" begin + A = [SMatrix{2,3}(1:6)*(i+j) for i in 1:3, j in 1:2] + @testset "StridedMatrix * OneElementMatrix" begin + B = OneElement(SMatrix{3,2}(1:6), (size(A,2),2), (size(A,2),4)) + C = [SMatrix{2,2}(1:4) for i in axes(A,1), j in axes(B,2)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + end + @testset "StridedMatrix * OneElementVector" begin + B = OneElement(SMatrix{3,2}(1:6), (size(A,2),), (size(A,2),)) + C = [SMatrix{2,2}(1:4) for i in axes(A,1)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + end + + A = OneElement(SMatrix{3,2}(1:6), (3,2), (5,4)) + @testset "OneElementMatrix * StridedMatrix" begin + B = [SMatrix{2,3}(1:6)*(i+j) for i in axes(A,2), j in 1:2] + C = [SMatrix{3,3}(1:9) for i in axes(A,1), j in axes(B,2)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + end + @testset "OneElementMatrix * StridedVector" begin + B = [SMatrix{2,3}(1:6)*i for i in axes(A,2)] + C = [SMatrix{3,3}(1:9) for i in axes(A,1)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + end + @testset "OneElementMatrix * OneElementMatrix" begin + B = OneElement(SMatrix{2,3}(1:6), (2,4), (size(A,2), 3)) + C = [SMatrix{3,3}(1:9) for i in axes(A,1), j in axes(B,2)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + end + @testset "OneElementMatrix * OneElementVector" begin + B = OneElement(SMatrix{2,3}(1:6), 2, size(A,2)) + C = [SMatrix{3,3}(1:9) for i in axes(A,1)] + @test mul!(copy(C), A, B) == A * B + @test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C + end + end + @testset "non-commutative" begin + A = OneElement(quat(rand(4)...), (2,3), (3,4)) + for (B,C) in ( + # OneElementMatrix * OneElementVector + (OneElement(quat(rand(4)...), 3, size(A,2)), + [quat(rand(4)...) for i in axes(A,1)]), + + # OneElementMatrix * OneElementMatrix + (OneElement(quat(rand(4)...), (3,2), (size(A,2), 4)), + [quat(rand(4)...) for i in axes(A,1), j in 1:4]), + ) + @test mul!(copy(C), A, B) ≈ A * B + α, β = quat(0,0,1,0), quat(1,0,1,0) + @test mul!(copy(C), A, B, α, β) ≈ mul!(copy(C), A, Array(B), α, β) ≈ A * B * α + C * β + end + + A = [quat(rand(4)...)*(i+j) for i in 1:2, j in 1:3] + for (B,C) in ( + # StridedMatrix * OneElementVector + (OneElement(quat(rand(4)...), 1, size(A,2)), + [quat(rand(4)...) for i in axes(A,1)]), + + # StridedMatrix * OneElementMatrix + (OneElement(quat(rand(4)...), (2,2), (size(A,2), 4)), + [quat(rand(4)...) for i in axes(A,1), j in 1:4]), + ) + @test mul!(copy(C), A, B) ≈ A * B + α, β = quat(0,0,1,0), quat(1,0,1,0) + @test mul!(copy(C), A, B, α, β) ≈ mul!(copy(C), A, Array(B), α, β) ≈ A * B * α + C * β + end + + A = OneElement(quat(rand(4)...), (2,2), (3, 4)) + for (B,C) in ( + # OneElementMatrix * StridedMatrix + ([quat(rand(4)...) for i in axes(A,2), j in 1:3], + [quat(rand(4)...) for i in axes(A,1), j in 1:3]), + + # OneElementMatrix * StridedVector + ([quat(rand(4)...) for i in axes(A,2)], + [quat(rand(4)...) for i in axes(A,1)]), + ) + @test mul!(copy(C), A, B) ≈ A * B + α, β = quat(0,0,1,0), quat(1,0,1,0) + @test mul!(copy(C), A, B, α, β) ≈ mul!(copy(C), A, Array(B), α, β) ≈ A * B * α + C * β + end + end end @testset "multiplication/division by a number" begin