Skip to content

Commit

Permalink
Remove arithtype in matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Aug 25, 2016
1 parent e59aa5d commit 0157d0c
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

# matmul.jl: Everything to do with dense matrix multiplication

arithtype(T) = T
arithtype(::Type{Bool}) = Int

matprod(x, y) = x*y + x*y

# multiply by diagonal matrix as vector
Expand Down Expand Up @@ -78,11 +75,11 @@ At_mul_B{T<:BlasComplex}(x::StridedVector{T}, y::StridedVector{T}) = [BLAS.dotu(

# Matrix-vector multiplication
function (*){T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(matprod, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_B!(similar(x, TS, size(A,1)), A, convert(AbstractVector{TS}, x))
end
function (*){T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(matprod, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_B!(similar(x,TS,size(A,1)),A,x)
end
(*)(A::AbstractVector, B::AbstractMatrix) = reshape(A,length(A),1)*B
Expand All @@ -101,22 +98,22 @@ end
A_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'N', A, x)

function At_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(matprod, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
At_mul_B!(similar(x,TS,size(A,2)), A, convert(AbstractVector{TS}, x))
end
function At_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(matprod, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
At_mul_B!(similar(x,TS,size(A,2)), A, x)
end
At_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = gemv!(y, 'T', A, x)
At_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'T', A, x)

function Ac_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(matprod, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
Ac_mul_B!(similar(x,TS,size(A,2)),A,convert(AbstractVector{TS},x))
end
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(matprod, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
Ac_mul_B!(similar(x,TS,size(A,2)), A, x)
end

Expand All @@ -127,7 +124,7 @@ Ac_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_m
# Matrix-matrix multiplication

function (*){T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(matprod, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_B!(similar(B, TS, (size(A,1), size(B,2))), A, B)
end
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
Expand All @@ -144,14 +141,14 @@ end
A_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'N', A, B)

function At_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(matprod, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
At_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
end
At_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? syrk_wrapper!(C, 'T', A) : gemm_wrapper!(C, 'T', 'N', A, B)
At_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'T', 'N', A, B)

function A_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(matprod, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_Bt!(similar(B, TS, (size(A,1), size(B,1))), A, B)
end
A_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? syrk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'T', A, B)
Expand All @@ -168,7 +165,7 @@ end
A_mul_Bt!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'T', A, B)

function At_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractVecOrMat{S})
TS = promote_op(matprod, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
At_mul_Bt!(similar(B, TS, (size(A,2), size(B,1))), A, B)
end
At_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'T', 'T', A, B)
Expand All @@ -177,7 +174,7 @@ At_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generi
Ac_mul_B{T<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{T}) = At_mul_B(A, B)
Ac_mul_B!{T<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = At_mul_B!(C, A, B)
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(matprod, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
Ac_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
end
Ac_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? herk_wrapper!(C,'C',A) : gemm_wrapper!(C,'C', 'N', A, B)
Expand All @@ -186,14 +183,14 @@ Ac_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic
A_mul_Bc{T<:BlasFloat,S<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{S}) = A_mul_Bt(A, B)
A_mul_Bc!{T<:BlasFloat,S<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{S}) = A_mul_Bt!(C, A, B)
function A_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(matprod, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_Bc!(similar(B,TS,(size(A,1),size(B,1))),A,B)
end
A_mul_Bc!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = is(A,B) ? herk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'C', A, B)
A_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'C', A, B)

Ac_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S}) =
Ac_mul_Bc!(similar(B, promote_op(matprod, arithtype(T), arithtype(S)), (size(A,2), size(B,1))), A, B)
Ac_mul_Bc!(similar(B, promote_op(matprod, T, S), (size(A,2), size(B,1))), A, B)
Ac_mul_Bc!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'C', 'C', A, B)
Ac_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'C', A, B)
Ac_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'T', A, B)
Expand Down Expand Up @@ -426,7 +423,7 @@ end
function generic_matmatmul{T,S}(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
C = similar(B, promote_op(matprod, arithtype(T), arithtype(S)), mA, nB)
C = similar(B, promote_op(matprod, T, S), mA, nB)
generic_matmatmul!(C, tA, tB, A, B)
end

Expand Down

0 comments on commit 0157d0c

Please sign in to comment.