Skip to content

Commit

Permalink
starting to implement subarray functionality in blas/lapack calls
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeXing committed Aug 20, 2011
1 parent d694ef6 commit 5d12f44
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 29 deletions.
2 changes: 1 addition & 1 deletion j/abstractarray.j
Original file line number Diff line number Diff line change
Expand Up @@ -1305,4 +1305,4 @@ end

summary{T,N}(s::SubArray{T,N}) =
strcat(dims2string(size(s)), " SubArray of ",
summary(s.parent))
summary(s.parent))
1 change: 1 addition & 0 deletions j/array.j
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
typealias Vector{T} Array{T,1}
typealias Matrix{T} Array{T,2}
typealias VecOrMat{T} Union(Vector{T}, Matrix{T})
typealias DenseMat{T} Union(Matrix{T},SubArray{T,2,Array{T}})

## Basic functions ##

Expand Down
29 changes: 18 additions & 11 deletions j/linalg_blas.j
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,22 @@ macro jl_blas_gemm_macro(fname, eltype)
quote

function jl_blas_gemm(transA, transB, m::Int, n::Int, k::Int,
alpha::($eltype), A::Array{$eltype}, lda::Int,
B::Array{$eltype}, ldb::Int,
beta::($eltype), C::Array{$eltype}, ldc::Int)
alpha::($eltype), A::DenseMat{$eltype}, lda::Int,
B::DenseMat{$eltype}, ldb::Int,
beta::($eltype), C::DenseMat{$eltype}, ldc::Int)
a = pointer(A)
b = pointer(B)
c = pointer(C)
ccall(dlsym(libBLAS, $fname),
Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32},
Ptr{$eltype}, Ptr{$eltype}, Ptr{Int32},
Ptr{$eltype}, Ptr{Int32},
Ptr{$eltype}, Ptr{$eltype}, Ptr{Int32}),
transA, transB, int32(m), int32(n), int32(k),
alpha, A, int32(lda),
B, int32(ldb),
beta, C, int32(ldc))
alpha, a, int32(lda),
b, int32(ldb),
beta, c, int32(ldc))
end

end
Expand All @@ -144,8 +147,8 @@ end
@jl_blas_gemm_macro :zgemm_ Complex128
@jl_blas_gemm_macro :cgemm_ Complex64

function (*){T<:Union(Float64,Float32,Complex128,Complex64)}(A::Matrix{T},
B::Matrix{T})
function (*){T<:Union(Float64,Float32,Complex128,Complex64)}(A::DenseMat{T},
B::DenseMat{T})
(mA, nA) = size(A)
(mB, nB) = size(B)

Expand All @@ -154,14 +157,18 @@ function (*){T<:Union(Float64,Float32,Complex128,Complex64)}(A::Matrix{T},
if mA == 2 && nA == 2 && nB == 2; return matmul2x2(A,B); end
if mA == 3 && nA == 3 && nB == 3; return matmul3x3(A,B); end

if stride(A, 1) != 1 || stride(B, 1) != 1
return invoke(*, (AbstractArray, AbstractArray), A, B)
end

# Result array does not need to be initialized as long as beta==0
C = Array(T, mA, nB)

jl_blas_gemm("N", "N", mA, nB, nA,
convert(T, 1.0), A, mA,
B, nA,
convert(T, 1.0), A, stride(A, 2),
B, stride(B, 2),
convert(T, 0.0), C, mA)
return C
end

# TODO: Use DGEMV for matvec.
# TODO: Use DGEMV for matvec.
44 changes: 27 additions & 17 deletions j/linalg_lapack.j
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ macro jl_lapack_potrf_macro(potrf, eltype)
# INTEGER INFO, LDA, N
# * .. Array Arguments ..
# DOUBLE PRECISION A( LDA, * )
function jl_lapack_potrf(uplo, n, A::AbstractMatrix{$eltype}, lda)
function jl_lapack_potrf(uplo, n, A::DenseMat{$eltype}, lda)
info = Array(Int32, 1)
a = pointer(A)
ccall(dlsym(libLAPACK, $potrf),
Expand All @@ -30,7 +30,8 @@ end
#does not check that input matrix is symmetric/hermitian
#(uses upper triangular half)
#Possible TODO: "economy mode"
function chol{T<:Union(Float32,Float64,Complex64,Complex128)}(A::Union(Matrix,SubArray{T,2}))
function chol{T<:Union(Float32,Float64,Complex64,Complex128)}(A::DenseMat{T})
if stride(A,1) != 1; error("chol: matrix columns must have contiguous elements"); end
n = int32(size(A, 1))
if isa(A, Matrix)
R = triu(A)
Expand Down Expand Up @@ -74,8 +75,9 @@ end

lu(A::Matrix) = lu(A, false)
lu{T}(A::SubArray{T,2}) = lu(A,false)
function lu{T<:Union(Float32,Float64,Complex64,Complex128)}(A::Union(Matrix,SubArray{T,2}),
function lu{T<:Union(Float32,Float64,Complex64,Complex128)}(A::DenseMat{T},
economy::Bool)
if stride(A,1) != 1; error("lu: matrix columns must have contiguous elements"); end
m, n = size(A)
LU = A
if !economy
Expand Down Expand Up @@ -115,7 +117,7 @@ macro jl_lapack_qr_macro(real_geqp3, complex_geqp3, orgqr, ungqr, eltype, celtyp
# * .. Array Arguments ..
# INTEGER JPVT( * )
# DOUBLE PRECISION A( LDA, * ), TAU( * ), WORK( * )
function jl_lapack_geqp3(m, n, A::Matrix{$eltype}, lda, jpvt, tau, work, lwork)
function jl_lapack_geqp3(m, n, A::DenseMat{$eltype}, lda, jpvt, tau, work, lwork)
info = Array(Int32, 1)
a = pointer(A)
ccall(dlsym(libLAPACK, $real_geqp3),
Expand All @@ -133,7 +135,7 @@ macro jl_lapack_qr_macro(real_geqp3, complex_geqp3, orgqr, ungqr, eltype, celtyp
# INTEGER JPVT( * )
# DOUBLE PRECISION RWORK( * )
# COMPLEX*16 A( LDA, * ), TAU( * ), WORK( * )
function jl_lapack_geqp3(m, n, A::AbstractMatrix{$celtype}, lda, jpvt, tau, work, lwork, rwork)
function jl_lapack_geqp3(m, n, A::DenseMat{$eltype}, lda, jpvt, tau, work, lwork, rwork)
info = Array(Int32, 1)
a = pointer(A)
ccall(dlsym(libLAPACK, $complex_geqp3),
Expand All @@ -149,7 +151,7 @@ macro jl_lapack_qr_macro(real_geqp3, complex_geqp3, orgqr, ungqr, eltype, celtyp
# INTEGER INFO, K, LDA, LWORK, M, N
# * .. Array Arguments ..
# DOUBLE PRECISION A( LDA, * ), TAU( * ), WORK( * )
function jl_lapack_orgqr(m, n, k, A::AbstractMatrix{$eltype}, lda, tau, work, lwork)
function jl_lapack_orgqr(m, n, k, A::DenseMat{$eltype}, lda, tau, work, lwork)
info = Array(Int32, 1)
a = pointer(A)
ccall(dlsym(libLAPACK, $orgqr),
Expand All @@ -165,7 +167,7 @@ macro jl_lapack_qr_macro(real_geqp3, complex_geqp3, orgqr, ungqr, eltype, celtyp
# INTEGER INFO, K, LDA, LWORK, M, N
#* .. Array Arguments ..
# COMPLEX*16 A( LDA, * ), TAU( * ), WORK( * )
function jl_lapack_ungqr(m, n, k, A::AbstractMatrix{$celtype}, lda, tau, work, lwork)
function jl_lapack_ungqr(m, n, k, A::DenseMat{$eltype}, lda, tau, work, lwork)
info = Array(Int32, 1)
a = pointer(A)
ccall(dlsym(libLAPACK, $ungqr),
Expand All @@ -183,7 +185,7 @@ end
@jl_lapack_qr_macro :sgeqp3_ :cgeqp3_ :sorgqr_ :cungqr_ Float32 Complex64

#possible TODO: economy mode?
function qr{T<:Union(Float32,Float64,Complex64,Complex128)}(A::Union(Matrix{T},SubArray{T,2}))
function qr{T<:Union(Float32,Float64,Complex64,Complex128)}(A::DenseMat{T})
m, n = size(A)
if isa(A, Matrix)
QR = copy(A)
Expand Down Expand Up @@ -521,13 +523,15 @@ macro jl_lapack_backslash_macro(gesv, posv, gels, trtrs, eltype)
# * .. Array Arguments ..
# INTEGER IPIV( * )
# DOUBLE PRECISION A( LDA, * ), B( LDB, * )
function jl_lapack_gesv(n, nrhs, A::Matrix{$eltype}, lda, ipiv, B, ldb)
function jl_lapack_gesv(n, nrhs, A::AbstractMatrix{$eltype}, lda, ipiv, B, ldb)
info = Array(Int32, 1)
a = pointer(A)
b = pointer(B)
ccall(dlsym(libLAPACK, $gesv),
Void,
(Ptr{Int32}, Ptr{Int32}, Ptr{$eltype}, Ptr{Int32}, Ptr{Int32},
Ptr{$eltype}, Ptr{Int32}, Ptr{Int32}),
int32(n), int32(nrhs), A, int32(lda), ipiv, B, int32(ldb), info)
int32(n), int32(nrhs), a, int32(lda), ipiv, b, int32(ldb), info)
return info[1]
end

Expand All @@ -537,28 +541,32 @@ macro jl_lapack_backslash_macro(gesv, posv, gels, trtrs, eltype)
# INTEGER INFO, LDA, LDB, N, NRHS
# .. Array Arguments ..
# DOUBLE PRECISION A( LDA, * ), B( LDB, * )
function jl_lapack_posv(uplo, n, nrhs, A::Matrix{$eltype}, lda, B, ldb)
function jl_lapack_posv(uplo, n, nrhs, A::AbstractMatrix{$eltype}, lda, B, ldb)
info = Array(Int32, 1)
a = pointer(A)
b = pointer(B)
ccall(dlsym(libLAPACK, $posv),
Void,
(Ptr{Uint8}, Ptr{Int32}, Ptr{Int32}, Ptr{$eltype}, Ptr{Int32},
Ptr{$eltype}, Ptr{Int32}, Ptr{Int32}),
uplo, int32(n), int32(nrhs), A, int32(lda), B, int32(ldb), info)
uplo, int32(n), int32(nrhs), a, int32(lda), b, int32(ldb), info)
return info[1]
end

# SUBROUTINE DGELS( TRANS, M, N, NRHS, A, LDA, B, LDB, WORK, LWORK, INFO)
# * .. Scalar Arguments ..
# CHARACTER TRANS
# INTEGER INFO, LDA, LDB, LWORK, M, N, NRHS
function jl_lapack_gels(trans, m, n, nrhs, A::Matrix{$eltype}, lda, B, ldb, work, lwork)
function jl_lapack_gels(trans, m, n, nrhs, A::AbstractMatrix{$eltype}, lda, B, ldb, work, lwork)
info = Array(Int32, 1)
a = pointer(A)
b = pointer(B)
ccall(dlsym(libLAPACK, $gels),
Void,
(Ptr{Uint8}, Ptr{Int32}, Ptr{Int32}, Ptr{Int32}, Ptr{$eltype}, Ptr{Int32},
Ptr{$eltype}, Ptr{Int32}, Ptr{$eltype}, Ptr{Int32}, Ptr{Int32}),
trans, int32(m), int32(n), int32(nrhs), A, int32(lda),
B, int32(ldb), work, int32(lwork), info)
trans, int32(m), int32(n), int32(nrhs), a, int32(lda),
b, int32(ldb), work, int32(lwork), info)
return info[1]
end

Expand All @@ -568,13 +576,15 @@ macro jl_lapack_backslash_macro(gesv, posv, gels, trtrs, eltype)
# INTEGER INFO, LDA, LDB, N, NRHS
# * .. Array Arguments ..
# DOUBLE PRECISION A( LDA, * ), B( LDB, * )
function jl_lapack_trtrs(uplo, trans, diag, n, nrhs, A::Matrix{$eltype}, lda, B, ldb)
function jl_lapack_trtrs(uplo, trans, diag, n, nrhs, A::AbstractMatrix{$eltype}, lda, B, ldb)
info = Array(Int32, 1)
a = pointer(A)
b = pointer(B)
ccall(dlsym(libLAPACK, $trtrs),
Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{Uint8}, Ptr{Int32}, Ptr{Int32},
Ptr{$eltype}, Ptr{Int32}, Ptr{$eltype}, Ptr{Int32}, Ptr{Int32}),
uplo, trans, diag, int32(n), int32(nrhs), A, int32(lda), B, int32(ldb), info)
uplo, trans, diag, int32(n), int32(nrhs), a, int32(lda), b, int32(ldb), info)
return info[1]
end

Expand Down

0 comments on commit 5d12f44

Please sign in to comment.