Skip to content

Commit

Permalink
Reinstate StaticMatrix * AbstractVector -> StaticVector
Browse files Browse the repository at this point in the history
  • Loading branch information
c42f committed Apr 11, 2017
1 parent 8d3ec32 commit e7309a5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
20 changes: 18 additions & 2 deletions src/matrix_multiply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero
# Manage dispatch of * and A_mul_B!
# TODO RowVector? (Inner product?)

@inline *(A::StaticMatrix, B::AbstractVector) = _A_mul_B(Size(A), A, B)
@inline *(A::StaticMatrix, B::StaticVector) = _A_mul_B(Size(A), Size(B), A, B)
@inline *(A::StaticMatrix, B::StaticMatrix) = _A_mul_B(Size(A), Size(B), A, B)
@inline *(A::StaticVector, B::StaticMatrix) = *(reshape(A, Size(Size(A)[1], 1)), B)
Expand All @@ -45,6 +46,23 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero

# Implementations

@generated function _A_mul_B(::Size{sa}, a::StaticMatrix{Ta}, b::AbstractVector{Tb}) where {sa, Ta, Tb}
if sa[2] != 0
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(sub2ind(sa, k, j))]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
else
exprs = [:(zero(T)) for k = 1:sa[1]]
end

return quote
@_inline_meta
if length(b) != sa[2]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $(size(b))"))
end
T = promote_matprod(Ta, Tb)
@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))
end
end

@generated function _A_mul_B(::Size{sa}, ::Size{sb}, a::StaticMatrix{Ta}, b::StaticVector{Tb}) where {sa, sb, Ta, Tb}
if sb[1] != sa[2]
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
Expand All @@ -63,8 +81,6 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero
end
end

# TODO: I removed StaticMatrix * AbstractVector. Reinstate?

# outer product
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, a::StaticVector{Ta}, b::RowVector{Tb, <:StaticVector}) where {sa, sb, Ta, Tb}
newsize = (sa[1], sb[2])
Expand Down
6 changes: 3 additions & 3 deletions test/matrix_multiply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@test @inferred(v*v') === @SMatrix [1 2; 2 4]

v3 = [1, 2]
@test_broken m*v3 === @SVector [5, 11]
@test m*v3 === @SVector [5, 11]

m2 = @MMatrix [1 2; 3 4]
v4 = @MVector [1, 2]
Expand All @@ -32,11 +32,11 @@

m5 = @SMatrix [1.0 2.0; 3.0 4.0]
v7 = [1.0, 2.0]
@test_broken (m5*v7)::SVector @SVector [5.0, 11.0]
@test (m5*v7)::SVector @SVector [5.0, 11.0]

m6 = @SMatrix Float32[1.0 2.0; 3.0 4.0]
v8 = Float64[1.0, 2.0]
@test_broken (m6*v8)::SVector{2,Float64} @SVector [5.0, 11.0]
@test (m6*v8)::SVector{2,Float64} @SVector [5.0, 11.0]

end

Expand Down

0 comments on commit e7309a5

Please sign in to comment.