Skip to content

Commit

Permalink
Fix rocBLAS wrapper & refactor it
Browse files Browse the repository at this point in the history
Also use `queue` instead of `default_queue` to avoid triggering depwarn.
  • Loading branch information
pxl-th committed Mar 29, 2023
1 parent c43c447 commit 6414dc0
Show file tree
Hide file tree
Showing 8 changed files with 313 additions and 668 deletions.
162 changes: 99 additions & 63 deletions src/blas/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,44 @@ Base.argmax(xs::ROCBLASArray{<:ROCBLASReal}) = iamax(xs)
#
############

for (t, uploc, isunitc) in (
(:LowerTriangular, 'L', 'N'),
(:UnitLowerTriangular, 'L', 'U'),
(:UpperTriangular, 'U', 'N'),
(:UnitUpperTriangular, 'U', 'U'),
)
@eval begin
LinearAlgebra.lmul!(A::$t{T, ROCMatrix{T}}, B::ROCVector{T}) where T <: ROCBLASFloat =
trmv!($uploc, 'N', $isunitc, parent(A), B)
LinearAlgebra.ldiv!(A::$t{T, ROCMatrix{T}}, B::ROCVector{T}) where T <: ROCBLASFloat =
trsv!($uploc, 'N', $isunitc, parent(A), B)
end
end

# Adjoint/transpose - reversed uploc.
for (t, uploc, isunitc) in (
(:LowerTriangular, 'U', 'N'),
(:UnitLowerTriangular, 'U', 'U'),
(:UpperTriangular, 'L', 'N'),
(:UnitUpperTriangular, 'L', 'U'),
)
@eval begin
LinearAlgebra.lmul!(A::$t{<: Any, <: Transpose{T, <: ROCMatrix{T}}}, B::ROCVector{T}) where T <: ROCBLASFloat =
trmv!($uploc, 'T', $isunitc, parent(parent(A)), B)
LinearAlgebra.lmul!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}, B::ROCVector{T}) where T <: ROCBLASFloat =
trmv!($uploc, 'T', $isunitc, parent(parent(A)), B)
LinearAlgebra.lmul!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}, B::ROCVector{T}) where T <: ROCBLASComplex =
trmv!($uploc, 'C', $isunitc, parent(parent(A)), B)

LinearAlgebra.ldiv!(A::$t{<: Any, <: Transpose{T, <: ROCMatrix{T}}}, B::ROCVector{T}) where T <: ROCBLASFloat =
trsv!($uploc, 'T', $isunitc, parent(parent(A)), B)
LinearAlgebra.ldiv!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}, B::ROCVector{T}) where T <: ROCBLASFloat =
trsv!($uploc, 'T', $isunitc, parent(parent(A)), B)
LinearAlgebra.ldiv!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}, B::ROCVector{T}) where T <: ROCBLASComplex =
trsv!($uploc, 'C', $isunitc, parent(parent(A)), B)
end
end

#########
# GEMV
##########
Expand All @@ -76,12 +114,8 @@ function gemv_wrapper!(y::ROCVector{T}, tA::Char, A::ROCMatrix{T}, x::ROCVector{
if mA != length(y)
throw(DimensionMismatch("first dimension of A, $mA, does not match length of y, $(length(y))"))
end
if mA == 0
return y
end
if nA == 0
return rmul!(y, 0)
end
mA == 0 && return y
nA == 0 && return rmul!(y, 0)
gemv!(tA, alpha, A, x, beta, y)
end

Expand Down Expand Up @@ -161,60 +195,62 @@ LinearAlgebra.mul!(C::ROCMatrix{T}, adjA::LinearAlgebra.Adjoint{<:Any, <:ROCMatr
# TRSM
########

# ldiv!
## No transpose/adjoint
LinearAlgebra.ldiv!(A::UpperTriangular{T,ROCMatrix{T}}, B::ROCMatrix{T}) where T<:ROCBLASFloat =
rocBLAS.trsm!('L', 'U', 'N', 'N', one(T), parent(A), B)
LinearAlgebra.ldiv!(A::UnitUpperTriangular{T,ROCMatrix{T}}, B::ROCMatrix{T}) where T<:ROCBLASFloat =
rocBLAS.trsm!('L', 'U', 'N', 'U', one(T), parent(A), B)
LinearAlgebra.ldiv!(A::LowerTriangular{T,ROCMatrix{T}}, B::ROCMatrix{T}) where T<:ROCBLASFloat =
rocBLAS.trsm!('L', 'L', 'N', 'N', one(T), parent(A), B)
LinearAlgebra.ldiv!(A::UnitLowerTriangular{T,ROCMatrix{T}}, B::ROCMatrix{T}) where T<:ROCBLASFloat =
rocBLAS.trsm!('L', 'L', 'N', 'U', one(T), parent(A), B)
## Adjoint
LinearAlgebra.ldiv!(A::Adjoint{T,UpperTriangular{T,ROCMatrix{T}}}, B::ROCMatrix{T}) where T<:ROCBLASFloat =
rocBLAS.trsm!('L', 'U', 'C', 'N', one(T), parent(parent(A)), B)
LinearAlgebra.ldiv!(A::Adjoint{T,UnitUpperTriangular{T,ROCMatrix{T}}}, B::ROCMatrix{T}) where T<:ROCBLASFloat =
rocBLAS.trsm!('L', 'U', 'C', 'U', one(T), parent(parent(A)), B)
LinearAlgebra.ldiv!(A::Adjoint{T,LowerTriangular{T,ROCMatrix{T}}}, B::ROCMatrix{T}) where T<:ROCBLASFloat =
rocBLAS.trsm!('L', 'L', 'C', 'N', one(T), parent(parent(A)), B)
LinearAlgebra.ldiv!(A::Adjoint{T,UnitLowerTriangular{T,ROCMatrix{T}}}, B::ROCMatrix{T}) where T<:ROCBLASFloat =
rocBLAS.trsm!('L', 'L', 'C', 'U', one(T), parent(parent(A)), B)
## Transpose
LinearAlgebra.ldiv!(A::Transpose{T,UpperTriangular{T,ROCMatrix{T}}}, B::ROCMatrix{T}) where T<:ROCBLASFloat =
rocBLAS.trsm!('L', 'U', 'T', 'N', one(T), parent(parent(A)), B)
LinearAlgebra.ldiv!(A::Transpose{T,UnitUpperTriangular{T,ROCMatrix{T}}}, B::ROCMatrix{T}) where T<:ROCBLASFloat =
rocBLAS.trsm!('L', 'U', 'T', 'U', one(T), parent(parent(A)), B)
LinearAlgebra.ldiv!(A::Transpose{T,LowerTriangular{T,ROCMatrix{T}}}, B::ROCMatrix{T}) where T<:ROCBLASFloat =
rocBLAS.trsm!('L', 'L', 'T', 'N', one(T), parent(parent(A)), B)
LinearAlgebra.ldiv!(A::Transpose{T,UnitLowerTriangular{T,ROCMatrix{T}}}, B::ROCMatrix{T}) where T<:ROCBLASFloat =
rocBLAS.trsm!('L', 'L', 'T', 'U', one(T), parent(parent(A)), B)

# rdiv!
## No transpose/adjoint
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::UpperTriangular{T,ROCMatrix{T}}) where T<:ROCBLASFloat =
rocBLAS.trsm!('R', 'U', 'N', 'N', one(T), parent(B), A)
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::UnitUpperTriangular{T,ROCMatrix{T}}) where T<:ROCBLASFloat =
rocBLAS.trsm!('R', 'U', 'N', 'U', one(T), parent(B), A)
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::LowerTriangular{T,ROCMatrix{T}}) where T<:ROCBLASFloat =
rocBLAS.trsm!('R', 'L', 'N', 'N', one(T), parent(B), A)
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::UnitLowerTriangular{T,ROCMatrix{T}}) where T<:ROCBLASFloat =
rocBLAS.trsm!('R', 'L', 'N', 'U', one(T), parent(B), A)
## Adjoint
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::Adjoint{T,UpperTriangular{T,ROCMatrix{T}}}) where T<:ROCBLASFloat =
rocBLAS.trsm!('R', 'U', 'C', 'N', one(T), parent(parent(B)), A)
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::Adjoint{T,UnitUpperTriangular{T,ROCMatrix{T}}}) where T<:ROCBLASFloat =
rocBLAS.trsm!('R', 'U', 'C', 'U', one(T), parent(parent(B)), A)
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::Adjoint{T,LowerTriangular{T,ROCMatrix{T}}}) where T<:ROCBLASFloat =
rocBLAS.trsm!('R', 'L', 'C', 'N', one(T), parent(parent(B)), A)
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::Adjoint{T,UnitLowerTriangular{T,ROCMatrix{T}}}) where T<:ROCBLASFloat =
rocBLAS.trsm!('R', 'L', 'C', 'U', one(T), parent(parent(B)), A)
## Transpose
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::Transpose{T,UpperTriangular{T,ROCMatrix{T}}}) where T<:ROCBLASFloat =
rocBLAS.trsm!('R', 'U', 'T', 'N', one(T), parent(parent(B)), A)
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::Transpose{T,UnitUpperTriangular{T,ROCMatrix{T}}}) where T<:ROCBLASFloat =
rocBLAS.trsm!('R', 'U', 'T', 'U', one(T), parent(parent(B)), A)
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::Transpose{T,LowerTriangular{T,ROCMatrix{T}}}) where T<:ROCBLASFloat =
rocBLAS.trsm!('R', 'L', 'T', 'N', one(T), parent(parent(B)), A)
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::Transpose{T,UnitLowerTriangular{T,ROCMatrix{T}}}) where T<:ROCBLASFloat =
rocBLAS.trsm!('R', 'L', 'T', 'U', one(T), parent(parent(B)), A)
for (t, uploc, isunitc) in (
(:LowerTriangular, 'L', 'N'),
(:UnitLowerTriangular, 'L', 'U'),
(:UpperTriangular, 'U', 'N'),
(:UnitUpperTriangular, 'U', 'U'),
)
@eval begin
LinearAlgebra.lmul!(A::$t{T, ROCMatrix{T}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
trmm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B)
LinearAlgebra.rmul!(A::ROCMatrix{T}, B::$t{T, ROCMatrix{T}}) where T <: ROCBLASFloat =
trmm!('R', $uploc, 'N', $isunitc, one(T), parent(B), A)

LinearAlgebra.ldiv!(A::$t{T, ROCMatrix{T}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
trsm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B)
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::$t{T, ROCMatrix{T}}) where T <: ROCBLASFloat =
trsm!('R', $uploc, 'N', $isunitc, one(T), parent(B), A)
end
end

# Adjoint/transpose - reversed uploc.
for (t, uploc, isunitc) in (
(:LowerTriangular, 'U', 'N'),
(:UnitLowerTriangular, 'U', 'U'),
(:UpperTriangular, 'L', 'N'),
(:UnitUpperTriangular, 'L', 'U'),
)
@eval begin
# Multiplication.
LinearAlgebra.lmul!(A::$t{<: Any, <: Transpose{T, <: ROCMatrix{T}}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
trmm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B)
LinearAlgebra.lmul!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
trmm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B)
LinearAlgebra.lmul!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}, B::ROCMatrix{T}) where T <: ROCBLASComplex =
trmm!('L', $uploc, 'C', $isunitc, one(T), parent(parent(A)), B)

LinearAlgebra.rmul!(A::ROCMatrix{T}, B::$t{<: Any, <: Transpose{T, <: ROCMatrix{T}}}) where T <: ROCBLASFloat =
trmm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A)
LinearAlgebra.rmul!(A::ROCMatrix{T}, B::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}) where T <: ROCBLASFloat =
trmm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A)
LinearAlgebra.rmul!(A::ROCMatrix{T}, B::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}) where T <: ROCBLASComplex =
trmm!('R', $uploc, 'C', $isunitc, one(T), parent(parent(B)), A)

# Left division.
LinearAlgebra.ldiv!(A::$t{<: Any, <: Transpose{T, <: ROCMatrix{T}}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
trsm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B)
LinearAlgebra.ldiv!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
trsm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B)
LinearAlgebra.ldiv!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}, B::ROCMatrix{T}) where T <: ROCBLASComplex =
trsm!('L', $uploc, 'C', $isunitc, one(T), parent(parent(A)), B)

# Right division.
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::$t{<: Any, <: Transpose{T, <: ROCMatrix{T}}}) where T <: ROCBLASFloat =
trsm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A)
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}) where T <: ROCBLASFloat =
trsm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A)
LinearAlgebra.rdiv!(A::ROCMatrix{T}, B::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}) where T <: ROCBLASComplex =
trsm!('R', $uploc, 'C', $isunitc, one(T), parent(parent(B)), A)
end
end
16 changes: 3 additions & 13 deletions src/blas/librocblas.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using CEnum

const rocblas_handle = Ptr{Cvoid}

@cenum rocblas_status_::UInt32 begin
Expand Down Expand Up @@ -39,9 +37,7 @@ const rocblas_float = Cfloat

const rocblas_double = Cdouble

struct rocblas_half
data::UInt16
end
const rocblas_half = Float16

struct rocblas_int8x4
a::Int8
Expand All @@ -50,15 +46,9 @@ struct rocblas_int8x4
d::Int8
end

struct rocblas_float_complex
x::Cfloat
y::Cfloat
end
const rocblas_float_complex = ComplexF32

struct rocblas_double_complex
x::Cdouble
y::Cdouble
end
const rocblas_double_complex = ComplexF64

@cenum rocblas_operation_::UInt32 begin
rocblas_operation_none = 111
Expand Down
8 changes: 8 additions & 0 deletions src/blas/rocBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@ import AMDGPU: HIP
import .HIP: HIPContext, HIPStream, hipContext_t, hipStream_t, hipEvent_t

using LinearAlgebra
using CEnum

include("librocblas.jl")
include("error.jl")
include("wrappers.jl")
include("highlevel.jl")

function rocblas_get_version_string()
vec = zeros(UInt8, 64)
str = reinterpret(Cstring, pointer(vec))
rocblas_get_version_string(vec, 64) |> check
return unsafe_string(str)
end

function version()
VersionNumber(join(split(rocblas_get_version_string(), '.')[1:3], '.'))
end
Expand Down
Loading

0 comments on commit 6414dc0

Please sign in to comment.