From e7309a5a2b55048efe3084312e8302575300bbb6 Mon Sep 17 00:00:00 2001 From: Chris Foster Date: Tue, 11 Apr 2017 22:48:31 +1000 Subject: [PATCH] Reinstate StaticMatrix * AbstractVector -> StaticVector --- src/matrix_multiply.jl | 20 ++++++++++++++++++-- test/matrix_multiply.jl | 6 +++--- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/matrix_multiply.jl b/src/matrix_multiply.jl index bc9c3fb3..6d9ae996 100644 --- a/src/matrix_multiply.jl +++ b/src/matrix_multiply.jl @@ -31,6 +31,7 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero # Manage dispatch of * and A_mul_B! # TODO RowVector? (Inner product?) +@inline *(A::StaticMatrix, B::AbstractVector) = _A_mul_B(Size(A), A, B) @inline *(A::StaticMatrix, B::StaticVector) = _A_mul_B(Size(A), Size(B), A, B) @inline *(A::StaticMatrix, B::StaticMatrix) = _A_mul_B(Size(A), Size(B), A, B) @inline *(A::StaticVector, B::StaticMatrix) = *(reshape(A, Size(Size(A)[1], 1)), B) @@ -45,6 +46,23 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero # Implementations +@generated function _A_mul_B(::Size{sa}, a::StaticMatrix{Ta}, b::AbstractVector{Tb}) where {sa, Ta, Tb} + if sa[2] != 0 + exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(sub2ind(sa, k, j))]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]] + else + exprs = [:(zero(T)) for k = 1:sa[1]] + end + + return quote + @_inline_meta + if length(b) != sa[2] + throw(DimensionMismatch("Tried to multiply arrays of size $sa and $(size(b))")) + end + T = promote_matprod(Ta, Tb) + @inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...))) + end +end + @generated function _A_mul_B(::Size{sa}, ::Size{sb}, a::StaticMatrix{Ta}, b::StaticVector{Tb}) where {sa, sb, Ta, Tb} if sb[1] != sa[2] throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb")) @@ -63,8 +81,6 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero end end -# TODO: I removed StaticMatrix * AbstractVector. Reinstate? - # outer product @generated function _A_mul_B(::Size{sa}, ::Size{sb}, a::StaticVector{Ta}, b::RowVector{Tb, <:StaticVector}) where {sa, sb, Ta, Tb} newsize = (sa[1], sb[2]) diff --git a/test/matrix_multiply.jl b/test/matrix_multiply.jl index 18e6eb14..d1ee29d7 100644 --- a/test/matrix_multiply.jl +++ b/test/matrix_multiply.jl @@ -16,7 +16,7 @@ @test @inferred(v*v') === @SMatrix [1 2; 2 4] v3 = [1, 2] - @test_broken m*v3 === @SVector [5, 11] + @test m*v3 === @SVector [5, 11] m2 = @MMatrix [1 2; 3 4] v4 = @MVector [1, 2] @@ -32,11 +32,11 @@ m5 = @SMatrix [1.0 2.0; 3.0 4.0] v7 = [1.0, 2.0] - @test_broken (m5*v7)::SVector ≈ @SVector [5.0, 11.0] + @test (m5*v7)::SVector ≈ @SVector [5.0, 11.0] m6 = @SMatrix Float32[1.0 2.0; 3.0 4.0] v8 = Float64[1.0, 2.0] - @test_broken (m6*v8)::SVector{2,Float64} ≈ @SVector [5.0, 11.0] + @test (m6*v8)::SVector{2,Float64} ≈ @SVector [5.0, 11.0] end