Skip to content

Commit

Permalink
fix 5-arg mul! for vectors of vectors (#47665)
Browse files Browse the repository at this point in the history
Co-authored-by: N5N3 <[email protected]>
(cherry picked from commit 902e8a7)
  • Loading branch information
ranocha authored and KristofferC committed Nov 28, 2022
1 parent 0865ae0 commit cfbb86a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ function generic_matvecmul!(C::AbstractVector{R}, tA, A::AbstractVecOrMat, B::Ab
end
for k = 1:mB
aoffs = (k-1)*Astride
b = _add(B[k], false)
b = _add(B[k])
for i = 1:mA
C[i] += A[aoffs + i] * b
end
Expand Down
52 changes: 52 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,58 @@ end
end
end

@testset "generic_matvecmul for vectors of vectors" begin
@testset "matrix of scalars" begin
u = [[1, 2], [3, 4]]
A = [1 2; 3 4]
v = [[0, 0], [0, 0]]
Au = [[7, 10], [15, 22]]
@test A * u == Au
mul!(v, A, u)
@test v == Au
mul!(v, A, u, 2, -1)
@test v == Au
end

@testset "matrix of matrices" begin
u = [[1, 2], [3, 4]]
A = Matrix{Matrix{Int}}(undef, 2, 2)
A[1, 1] = [1 2; 3 4]
A[1, 2] = [5 6; 7 8]
A[2, 1] = [9 10; 11 12]
A[2, 2] = [13 14; 15 16]
v = [[0, 0], [0, 0]]
Au = [[44, 64], [124, 144]]
@test A * u == Au
mul!(v, A, u)
@test v == Au
mul!(v, A, u, 2, -1)
@test v == Au
end
end

@testset "generic_matmatmul for matrices of vectors" begin
B = Matrix{Vector{Int}}(undef, 2, 2)
B[1, 1] = [1, 2]
B[2, 1] = [3, 4]
B[1, 2] = [5, 6]
B[2, 2] = [7, 8]
A = [1 2; 3 4]
C = Matrix{Vector{Int}}(undef, 2, 2)
AB = Matrix{Vector{Int}}(undef, 2, 2)
AB[1, 1] = [7, 10]
AB[2, 1] = [15, 22]
AB[1, 2] = [19, 22]
AB[2, 2] = [43, 50]
@test A * B == AB
mul!(C, A, B)
@test C == AB
mul!(C, A, B, 2, -1)
@test C == AB
LinearAlgebra._generic_matmatmul!(C, 'N', 'N', A, B, LinearAlgebra.MulAddMul(2, -1))
@test C == AB
end

@testset "fallbacks & such for BlasFloats" begin
AA = rand(Float64, 6, 6)
BB = rand(Float64, 6, 6)
Expand Down

0 comments on commit cfbb86a

Please sign in to comment.