Skip to content

Commit

Permalink
Add negative stride support to BLAS Level 1/2 functions
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Nov 5, 2021
1 parent d00d457 commit 501bc34
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 159 deletions.
162 changes: 65 additions & 97 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,18 @@ end


# Level 1
isdense(x) = x isa DenseArray
isdense(x::Base.FastContiguousSubArray) = isdense(parent(x))
isdense(x::Base.ReshapedArray) = isdense(parent(x))
isdense(x::Base.ReinterpretArray) = isdense(parent(x))
@inline function ptrst1(x::AbstractArray)
isdense(x) && return pointer(x), 1 # simpify runtime check when possibe
ndims(x) == 1 || strides(x) == Base.size_to_strides(strides(x, 1), size(x)...) ||
throw(ArgumentError("only support vector like inputs"))
st = stride(x, 1)
ptr = st >= 0 ? pointer(x) : pointer(x, lastindex(x))
ptr, st
end
## copy

"""
Expand Down Expand Up @@ -249,7 +261,10 @@ for (fname, elty) in ((:dscal_,:Float64),
DX
end

scal!(DA::$elty, DX::AbstractArray{$elty}) = scal!(length(DX),DA,DX,stride(DX,1))
scal!(DA::$elty, DX::AbstractArray{$elty}) = let (p, st) = ptrst1(DX)
GC.@preserve DX scal!(length(DX), DA, p, abs(st))
DX
end
end
end
scal(n, DA, DX, incx) = scal!(n, DA, copy(DX), incx)
Expand Down Expand Up @@ -353,75 +368,18 @@ for (fname, elty) in ((:cblas_zdotu_sub,:ComplexF64),
end
end

@inline function _dot_length_check(x,y)
n = length(x)
if n != length(y)
throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))"))
end
n
end

for (elty, f) in ((Float32, :dot), (Float64, :dot),
(ComplexF32, :dotc), (ComplexF64, :dotc),
(ComplexF32, :dotu), (ComplexF64, :dotu))
@eval begin
function $f(x::DenseArray{$elty}, y::DenseArray{$elty})
n = _dot_length_check(x,y)
$f(n, x, 1, y, 1)
end

function $f(x::StridedVector{$elty}, y::DenseArray{$elty})
n = _dot_length_check(x,y)
xstride = stride(x,1)
ystride = stride(y,1)
x_delta = xstride < 0 ? n : 1
GC.@preserve x $f(n,pointer(x,x_delta),xstride,y,ystride)
end

function $f(x::DenseArray{$elty}, y::StridedVector{$elty})
n = _dot_length_check(x,y)
xstride = stride(x,1)
ystride = stride(y,1)
y_delta = ystride < 0 ? n : 1
GC.@preserve y $f(n,x,xstride,pointer(y,y_delta),ystride)
end

function $f(x::StridedVector{$elty}, y::StridedVector{$elty})
n = _dot_length_check(x,y)
xstride = stride(x,1)
ystride = stride(y,1)
x_delta = xstride < 0 ? n : 1
y_delta = ystride < 0 ? n : 1
GC.@preserve x y $f(n,pointer(x,x_delta),xstride,pointer(y,y_delta),ystride)
function $f(x::AbstractVector{$elty}, y::AbstractVector{$elty})
n, m = length(x), length(y)
n == m || throw(DimensionMismatch("dot product arguments have lengths $n and $m"))
GC.@preserve x y $f(n, ptrst1(x)..., ptrst1(y)...)
end
end
end

function dot(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasReal
require_one_based_indexing(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
end
return dot(n, DX, stride(DX, 1), DY, stride(DY, 1))
end
function dotc(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex
require_one_based_indexing(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
end
return dotc(n, DX, stride(DX, 1), DY, stride(DY, 1))
end
function dotu(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex
require_one_based_indexing(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
end
return dotu(n, DX, stride(DX, 1), DY, stride(DY, 1))
end

## nrm2

"""
Expand Down Expand Up @@ -453,7 +411,10 @@ for (fname, elty, ret_type) in ((:dnrm2_,:Float64,:Float64),
end
end
end
nrm2(x::Union{AbstractVector,DenseArray}) = nrm2(length(x), x, stride1(x))
# openblas returns 0 for negative stride
nrm2(x::Union{AbstractArray}) = let (p, st) = ptrst1(x)
GC.@preserve x nrm2(length(x), p, abs(st))
end

## asum

Expand Down Expand Up @@ -490,8 +451,9 @@ for (fname, elty, ret_type) in ((:dasum_,:Float64,:Float64),
end
end
end
asum(x::Union{AbstractVector,DenseArray}) = asum(length(x), x, stride1(x))

asum(x::Union{AbstractArray}) = let (p, st) = ptrst1(x)
GC.@preserve x asum(length(x), p, abs(st))
end
## axpy

"""
Expand Down Expand Up @@ -538,7 +500,8 @@ function axpy!(alpha::Number, x::Union{DenseArray{T},StridedVector{T}}, y::Union
if length(x) != length(y)
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
end
return axpy!(length(x), convert(T,alpha), x, stride(x, 1), y, stride(y, 1))
GC.@preserve x y axpy!(length(x), convert(T,alpha), ptrst1(x)..., ptrst1(y)...)
y
end

function axpy!(alpha::Number, x::Array{T}, rx::Union{UnitRange{Ti},AbstractRange{Ti}},
Expand All @@ -555,9 +518,9 @@ function axpy!(alpha::Number, x::Array{T}, rx::Union{UnitRange{Ti},AbstractRange
GC.@preserve x y axpy!(
length(rx),
convert(T, alpha),
pointer(x) + (first(rx) - 1)*sizeof(T),
pointer(x, minimum(rx)),
step(rx),
pointer(y) + (first(ry) - 1)*sizeof(T),
pointer(y, minimum(ry)),
step(ry))

return y
Expand Down Expand Up @@ -609,7 +572,8 @@ function axpby!(alpha::Number, x::Union{DenseArray{T},AbstractVector{T}}, beta::
if length(x) != length(y)
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
end
return axpby!(length(x), convert(T, alpha), x, stride(x, 1), convert(T, beta), y, stride(y, 1))
GC.@preserve x y axpby!(length(x), convert(T, alpha), ptrst1(x)..., convert(T, beta), ptrst1(y)...)
y
end

## iamax
Expand Down Expand Up @@ -666,10 +630,7 @@ for (fname, elty) in ((:dgemv_,:Float64),
chkstride1(A)
lda = stride(A,2)
lda >= max(1, size(A,1)) || error("`stride(A,2)` must be at least `max(1, size(A,1))`")
sX = stride(X,1)
pX = pointer(X, sX > 0 ? firstindex(X) : lastindex(X))
sY = stride(Y,1)
pY = pointer(Y, sY > 0 ? firstindex(Y) : lastindex(Y))
(pX, sX), (pY, sY) = ptrst1(X), ptrst1(Y)
GC.@preserve X Y ccall((@blasfunc($fname), libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elty},
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Expand Down Expand Up @@ -750,14 +711,15 @@ for (fname, elty) in ((:dgbmv_,:Float64),
y::AbstractVector{$elty})
require_one_based_indexing(A, x, y)
chkstride1(A)
ccall((@blasfunc($fname), libblastrampoline), Cvoid,
(px, stx), (py, sty) = ptrst1(x), ptrst1(y)
GC.@preserve x y ccall((@blasfunc($fname), libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BlasInt},
Ptr{$elty}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
Ref{BlasInt}, Clong),
trans, m, size(A,2), kl,
ku, alpha, A, max(1,stride(A,2)),
x, stride(x,1), beta, y, stride(y,1), 1)
px, stx, beta, py, sty, 1)
y
end
function gbmv(trans::AbstractChar, m::Integer, kl::Integer, ku::Integer, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
Expand Down Expand Up @@ -810,13 +772,14 @@ for (fname, elty, lib) in ((:dsymv_,:Float64,libblastrampoline),
throw(DimensionMismatch("A has size $(size(A)), and y has length $(length(y))"))
end
chkstride1(A)
ccall((@blasfunc($fname), $lib), Cvoid,
(px, stx), (py, sty) = ptrst1(x), ptrst1(y)
GC.@preserve x y ccall((@blasfunc($fname), $lib), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ref{$elty},
Ptr{$elty}, Ref{BlasInt}, Clong),
uplo, n, alpha, A,
max(1,stride(A,2)), x, stride(x,1), beta,
y, stride(y,1), 1)
max(1,stride(A,2)), px, stx, beta,
py, sty, 1)
y
end
function symv(uplo::AbstractChar, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
Expand Down Expand Up @@ -872,15 +835,14 @@ for (fname, elty) in ((:zhemv_,:ComplexF64),
end
chkstride1(A)
lda = max(1, stride(A, 2))
incx = stride(x, 1)
incy = stride(y, 1)
ccall((@blasfunc($fname), libblastrampoline), Cvoid,
(px, stx), (py, sty) = ptrst1(x), ptrst1(y)
GC.@preserve x y ccall((@blasfunc($fname), libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ref{$elty},
Ptr{$elty}, Ref{BlasInt}, Clong),
uplo, n, α, A,
lda, x, incx, β,
y, incy, 1)
lda, px, stx, β,
py, sty, 1)
y
end
function hemv(uplo::AbstractChar, α::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
Expand Down Expand Up @@ -968,7 +930,8 @@ function hpmv!(uplo::AbstractChar,
if 2*length(AP) < N*(N + 1)
throw(DimensionMismatch("Packed Hermitian matrix A has size smaller than length(x) = $(N)."))
end
return hpmv!(uplo, N, convert(T, α), AP, x, stride(x, 1), convert(T, β), y, stride(y, 1))
GC.@preserve x y hpmv!(uplo, N, convert(T, α), AP, ptrst1(x)..., convert(T, β), ptrst1(y)...)
y
end

"""
Expand Down Expand Up @@ -1009,13 +972,14 @@ for (fname, elty) in ((:dsbmv_,:Float64),
function sbmv!(uplo::AbstractChar, k::Integer, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty}, beta::($elty), y::AbstractVector{$elty})
require_one_based_indexing(A, x, y)
chkstride1(A)
ccall((@blasfunc($fname), libblastrampoline), Cvoid,
(px, stx), (py, sty) = ptrst1(x), ptrst1(y)
GC.@preserve x y ccall((@blasfunc($fname), libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elty},
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Ref{$elty}, Ptr{$elty}, Ref{BlasInt}, Clong),
uplo, size(A,2), k, alpha,
A, max(1,stride(A,2)), x, stride(x,1),
beta, y, stride(y,1), 1)
A, max(1,stride(A,2)), px, stx,
beta, py, sty, 1)
y
end
function sbmv(uplo::AbstractChar, k::Integer, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
Expand Down Expand Up @@ -1118,7 +1082,8 @@ function spmv!(uplo::AbstractChar,
if 2*length(AP) < N*(N + 1)
throw(DimensionMismatch("Packed symmetric matrix A has size smaller than length(x) = $(N)."))
end
return spmv!(uplo, N, convert(T, α), AP, x, stride(x, 1), convert(T, β), y, stride(y, 1))
GC.@preserve x y spmv!(uplo, N, convert(T, α), AP, ptrst1(x)..., convert(T, β), ptrst1(y)...)
y
end

"""
Expand Down Expand Up @@ -1159,13 +1124,14 @@ for (fname, elty) in ((:zhbmv_,:ComplexF64),
function hbmv!(uplo::AbstractChar, k::Integer, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty}, beta::($elty), y::AbstractVector{$elty})
require_one_based_indexing(A, x, y)
chkstride1(A)
ccall((@blasfunc($fname), libblastrampoline), Cvoid,
(px, stx), (py, sty) = ptrst1(x), ptrst1(y)
GC.@preserve x y ccall((@blasfunc($fname), libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elty},
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Ref{$elty}, Ptr{$elty}, Ref{BlasInt}, Clong),
uplo, size(A,2), k, alpha,
A, max(1,stride(A,2)), x, stride(x,1),
beta, y, stride(y,1), 1)
A, max(1,stride(A,2)), px, stx,
beta, py, sty, 1)
y
end
function hbmv(uplo::AbstractChar, k::Integer, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
Expand Down Expand Up @@ -1219,12 +1185,13 @@ for (fname, elty) in ((:dtrmv_,:Float64),
throw(DimensionMismatch("A has size ($n,$n), x has length $(length(x))"))
end
chkstride1(A)
ccall((@blasfunc($fname), libblastrampoline), Cvoid,
px, stx = ptrst1(x)
GC.@preserve x ccall((@blasfunc($fname), libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{BlasInt},
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Clong, Clong, Clong),
uplo, trans, diag, n,
A, max(1,stride(A,2)), x, max(1,stride(x, 1)), 1, 1, 1)
A, max(1,stride(A,2)), px, stx, 1, 1, 1)
x
end
function trmv(uplo::AbstractChar, trans::AbstractChar, diag::AbstractChar, A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
Expand Down Expand Up @@ -1274,12 +1241,13 @@ for (fname, elty) in ((:dtrsv_,:Float64),
throw(DimensionMismatch("size of A is $n != length(x) = $(length(x))"))
end
chkstride1(A)
ccall((@blasfunc($fname), libblastrampoline), Cvoid,
px, stx = ptrst1(x)
GC.@preserve x ccall((@blasfunc($fname), libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{BlasInt},
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Clong, Clong, Clong),
uplo, trans, diag, n,
A, max(1,stride(A,2)), x, stride(x, 1), 1, 1, 1)
A, max(1,stride(A,2)), px, stx, 1, 1, 1)
x
end
function trsv(uplo::AbstractChar, trans::AbstractChar, diag::AbstractChar, A::AbstractMatrix{$elty}, x::AbstractVector{$elty})
Expand Down Expand Up @@ -1993,9 +1961,9 @@ function copyto!(dest::Array{T}, rdest::Union{UnitRange{Ti},AbstractRange{Ti}},
end
GC.@preserve src dest BLAS.blascopy!(
length(rsrc),
pointer(src) + (first(rsrc) - 1) * sizeof(T),
pointer(src, minimum(rsrc)),
step(rsrc),
pointer(dest) + (first(rdest) - 1) * sizeof(T),
pointer(dest, minimum(rdest)),
step(rdest))

return dest
Expand Down
Loading

0 comments on commit 501bc34

Please sign in to comment.