Skip to content

Commit

Permalink
Fix matmul for array elements in OneElMat * StridedMat
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Feb 3, 2024
1 parent c88f035 commit 0d76191
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,11 @@ function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractMatrix, alpha,
Aval = A.val
if iszero(beta)
C .= Ref(zero(eltype(C)))
y .= Aval .* view(B, nzcol, :) .* alpha
y .= Ref(Aval) .* view(B, nzcol, :) .* alpha
else
view(C, 1:nzrow-1, :) .*= beta
view(C, nzrow+1:size(C,1), :) .*= beta
y .= Aval .* view(B, nzcol, :) .* alpha .+ y .* beta
y .= Ref(Aval) .* view(B, nzcol, :) .* alpha .+ y .* beta
end
C
end
Expand Down
94 changes: 81 additions & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2270,22 +2270,90 @@ end
end
@testset "array elements" begin
A = [SMatrix{2,3}(1:6)*(i+j) for i in 1:3, j in 1:2]
B = OneElement(SMatrix{3,2}(1:6), (size(A,2),2), (size(A,2),4))
C = [SMatrix{2,2}(1:4) for i in 1:size(A,1), j in 1:size(B,2)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
B = OneElement(SMatrix{3,2}(1:6), (size(A,2),), (size(A,2),))
C = [SMatrix{2,2}(1:4) for i in 1:size(A,1)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
@testset "StridedMatrix * OneElementMatrix" begin
B = OneElement(SMatrix{3,2}(1:6), (size(A,2),2), (size(A,2),4))
C = [SMatrix{2,2}(1:4) for i in axes(A,1), j in axes(B,2)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
end
@testset "StridedMatrix * OneElementVector" begin
B = OneElement(SMatrix{3,2}(1:6), (size(A,2),), (size(A,2),))
C = [SMatrix{2,2}(1:4) for i in axes(A,1)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
end

A = OneElement(SMatrix{3,2}(1:6), (3,2), (5,4))
@testset "OneElementMatrix * StridedMatrix" begin
B = [SMatrix{2,3}(1:6)*(i+j) for i in axes(A,2), j in 1:2]
C = [SMatrix{3,3}(1:9) for i in axes(A,1), j in axes(B,2)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
end
@testset "OneElementMatrix * StridedVector" begin
B = [SMatrix{2,3}(1:6)*i for i in axes(A,2)]
C = [SMatrix{3,3}(1:9) for i in axes(A,1)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
end
@testset "OneElementMatrix * OneElementMatrix" begin
B = OneElement(SMatrix{2,3}(1:6), (2,4), (size(A,2), 3))
C = [SMatrix{3,3}(1:9) for i in axes(A,1), j in axes(B,2)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
end
@testset "OneElementMatrix * OneElementVector" begin
B = OneElement(SMatrix{2,3}(1:6), 2, size(A,2))
C = [SMatrix{3,3}(1:9) for i in axes(A,1)]
@test mul!(copy(C), A, B) == A * B
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
end
end
@testset "non-commutative" begin
A = OneElement(quat(rand(4)...), (2,3), (3,4))
for (B,C) in (
# OneElementMatrix * OneElementVector
(OneElement(quat(rand(4)...), 3, size(A,2)),
[quat(rand(4)...) for i in axes(A,1)]),

# OneElementMatrix * OneElementMatrix
(OneElement(quat(rand(4)...), (3,2), (size(A,2), 4)),
[quat(rand(4)...) for i in axes(A,1), j in 1:4]),
)
@test mul!(copy(C), A, B) A * B
α, β = quat(0,0,1,0), quat(1,0,1,0)
@test mul!(copy(C), A, B, α, β) mul!(copy(C), A, Array(B), α, β) A * B * α + C * β
end

A = [quat(rand(4)...)*(i+j) for i in 1:2, j in 1:3]
B = OneElement(quat(rand(4)...), 1, size(A,2))
C = [quat(rand(4)...) for i in axes(A,1)]
@test mul!(copy(C), A, B) A * B
α, β = quat(0,0,1,0), quat(1,0,1,0)
@test mul!(copy(C), A, B, α, β) mul!(copy(C), A, Array(B), α, β) A * B * α + C * β
for (B,C) in (
# StridedMatrix * OneElementVector
(OneElement(quat(rand(4)...), 1, size(A,2)),
[quat(rand(4)...) for i in axes(A,1)]),

# StridedMatrix * OneElementMatrix
(OneElement(quat(rand(4)...), (2,2), (size(A,2), 4)),
[quat(rand(4)...) for i in axes(A,1), j in 1:4]),
)
@test mul!(copy(C), A, B) A * B
α, β = quat(0,0,1,0), quat(1,0,1,0)
@test mul!(copy(C), A, B, α, β) mul!(copy(C), A, Array(B), α, β) A * B * α + C * β
end

A = OneElement(quat(rand(4)...), (2,2), (3, 4))
for (B,C) in (
# OneElementMatrix * StridedMatrix
([quat(rand(4)...) for i in axes(A,2), j in 1:3],
[quat(rand(4)...) for i in axes(A,1), j in 1:3]),

# OneElementMatrix * StridedVector
([quat(rand(4)...) for i in axes(A,2)],
[quat(rand(4)...) for i in axes(A,1)]),
)
@test mul!(copy(C), A, B) A * B
α, β = quat(0,0,1,0), quat(1,0,1,0)
@test mul!(copy(C), A, B, α, β) mul!(copy(C), A, Array(B), α, β) A * B * α + C * β
end
end
end

Expand Down

0 comments on commit 0d76191

Please sign in to comment.