Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update rocBLAS wrapper #402

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions gen/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.9.0-DEV"
julia_version = "1.9.0-rc1"
manifest_format = "2.0"
project_hash = "3d1c3440c5f5d2f04d75370d10d97474962219fe"
project_hash = "26a57bd5ee5b16d5abc45095c0840c2c659a6a22"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
Expand All @@ -27,15 +27,15 @@ version = "0.4.2"

[[deps.Clang]]
deps = ["CEnum", "Clang_jll", "Downloads", "Pkg", "TOML"]
git-tree-sha1 = "9e605c9149e4a0182118f00c8d69ef76d59998ee"
git-tree-sha1 = "ac81f3ea7c53b20d64ad1609a0298d9731fbdcf8"
uuid = "40e3b903-d033-50b4-a0cc-940c62c95e31"
version = "0.16.6"
version = "0.17.3"

[[deps.Clang_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "TOML", "Zlib_jll", "libLLVM_jll"]
git-tree-sha1 = "c7c8938a36b2ab8e5eb9b6c937ba5049e1e666fa"
git-tree-sha1 = "b88c99c9093f9db49a40d0715ea0e3ae5bbd91f7"
uuid = "0ee61d77-7f21-5576-8119-9fcc46b10100"
version = "14.0.6+0"
version = "14.0.6+2"

[[deps.Dates]]
deps = ["Printf"]
Expand Down Expand Up @@ -101,9 +101,9 @@ version = "1.8.7+0"

[[deps.Libglvnd_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll", "Xorg_libXext_jll"]
git-tree-sha1 = "7739f837d6447403596a75d19ed01fd08d6f56bf"
git-tree-sha1 = "6f73d1dd803986947b2c750138528a999a6c7733"
uuid = "7e76a0d4-f3c7-5321-8279-8d96eeed0f29"
version = "1.3.0+3"
version = "1.6.0+0"

[[deps.Libgpg_error_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand All @@ -113,9 +113,9 @@ version = "1.42.0+0"

[[deps.Libiconv_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "42b62845d70a619f063a7da093d995ec8e15e778"
git-tree-sha1 = "c7cb1f5d892775ba13767a87c7ada0b980ea0a71"
uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531"
version = "1.16.1+1"
version = "1.16.1+2"

[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand All @@ -133,7 +133,7 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.28.0+0"
version = "2.28.2+0"

[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
Expand All @@ -152,7 +152,7 @@ version = "1.2.0"
[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.8.0"
version = "1.9.0"

[[deps.Preferences]]
deps = ["TOML"]
Expand Down Expand Up @@ -195,10 +195,10 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0"

[[deps.SQLite_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"]
git-tree-sha1 = "9d920c4ee8cd5684e23bf84f43ead45c0af796e7"
deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"]
git-tree-sha1 = "54d66b0f69f4578f4988fc08d579783fcdcd764f"
uuid = "76ed43ae-9a5d-5a62-8c75-30186b810ce8"
version = "3.39.4+0"
version = "3.41.0+0"

[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Expand All @@ -209,7 +209,7 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[deps.TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
version = "1.0.0"
version = "1.0.3"

[[deps.Tar]]
deps = ["ArgTools", "SHA"]
Expand All @@ -225,9 +225,9 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[deps.XML2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "Zlib_jll"]
git-tree-sha1 = "58443b63fb7e465a8a7210828c91c08b92132dff"
git-tree-sha1 = "93c41695bc1c08c46c5899f4fe06d6ead504bb73"
uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a"
version = "2.9.14+0"
version = "2.10.3+0"

[[deps.XSLT_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "Pkg", "XML2_jll", "Zlib_jll"]
Expand All @@ -237,9 +237,9 @@ version = "1.1.34+0"

[[deps.XZ_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "d1d49166bc58e698ab38804d7bde2aef43e4b594"
git-tree-sha1 = "7928d348322698fb93d5c14b184fdc176c8afc82"
uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800"
version = "5.2.7+0"
version = "5.2.9+0"

[[deps.Xorg_libX11_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libxcb_jll", "Xorg_xtrans_jll"]
Expand Down Expand Up @@ -314,9 +314,9 @@ version = "1.79.0+0"

[[deps.fts_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "78732b942383d2cb521df8a1a0814911144e663d"
git-tree-sha1 = "aa21810b841ae26d2fc7f780cb1596b4170a4c49"
uuid = "d65627f6-89bd-53e8-8ab5-8b75ff535eee"
version = "1.2.7+1"
version = "1.2.8+0"

[[deps.hsa_rocr_jll]]
deps = ["Artifacts", "Elfutils_jll", "JLLWrappers", "Libdl", "NUMA_jll", "Pkg", "ROCmDeviceLibs_jll", "XML2_jll", "Zlib_jll", "hsakmt_roct_jll"]
Expand All @@ -333,7 +333,7 @@ version = "5.2.3+0"
[[deps.libLLVM_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8f36deef-c2a5-5394-99ed-8e07531fb29a"
version = "14.0.6+0"
version = "14.0.6+2"

[[deps.libdrm_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libpciaccess_jll"]
Expand Down
1 change: 1 addition & 0 deletions gen/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31"
MIOpen_jll = "2409bb75-d5ef-542a-ac68-1cfd4c37dc24"
hsa_rocr_jll = "dd59ff1a-a01a-568d-8b29-0669330f116a"
rocBLAS_jll = "1ef8cab2-a151-54b4-a57f-5fbb4046a4ab"
18 changes: 18 additions & 0 deletions gen/rocblas/generator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using Clang.Generators
using rocBLAS_jll

include_dir = normpath(rocBLAS_jll.artifact_dir, "include")
rocblas_dir = joinpath(include_dir, "rocblas")
options = load_options("rocblas/rocblas-generator.toml")

args = get_default_args()
push!(args, "-I$include_dir")

headers = [
joinpath(rocblas_dir, header)
for header in readdir(rocblas_dir)
if endswith(header, ".h")
]

ctx = create_context(headers, args, options)
build!(ctx)
4 changes: 4 additions & 0 deletions gen/rocblas/rocblas-generator.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[general]
library_name = "librocblas"
output_file_path = "./librocblas.jl"
export_symbol_prefixes = []
32 changes: 14 additions & 18 deletions src/blas/error.jl
Original file line number Diff line number Diff line change
@@ -1,43 +1,39 @@
export ROCBLASError

struct ROCBLASError <: Exception
code::rocblas_status_t
code::rocblas_status
msg::AbstractString
end
Base.show(io::IO, err::ROCBLASError) = print(io, "ROCBLASError(code $(err.code), $(err.msg))")

function ROCBLASError(code::rocblas_status_t)
function ROCBLASError(code::rocblas_status)
msg = status_message(code)
return ROCBLASError(code, msg)
end

function status_message(status)
if status == ROCBLAS_STATUS_SUCCESS
if status == rocblas_status_success
return "the operation completed successfully"
elseif status == ROCBLAS_STATUS_INVALID_HANDLE
elseif status == rocblas_status_invalid_handle
return "handle not initialized, invalid or null"
elseif status == ROCBLAS_STATUS_NOT_IMPLEMENTED
elseif status == rocblas_status_not_implemented
return "this function is not implemented"
elseif status == ROCBLAS_STATUS_INVALID_POINTER
elseif status == rocblas_status_invalid_pointer
return "invalid pointer parameter"
elseif status == ROCBLAS_STATUS_INVALID_SIZE
elseif status == rocblas_status_invalid_size
return "invalid size parameter"
elseif status == ROCBLAS_STATUS_MEMORY_ERROR
elseif status == rocblas_status_memory_error
return "failed internal memory allocation, copy or dealloc"
elseif status == ROCBLAS_STATUS_INTERNAL_ERROR
elseif status == rocblas_status_internal_error
return "an internal operation failed"
else
return "unknown status"
return "unknown status: $status"
end
end

macro check(blas_func)
quote
local err::rocblas_status_t
err = $(esc(blas_func::Expr))
if err != ROCBLAS_STATUS_SUCCESS
throw(ROCBLASError(err))
end
err
function check(status::rocblas_status)
if status != rocblas_status_success
throw(ROCBLASError(status))
end
return status
end
45 changes: 10 additions & 35 deletions src/blas/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,8 @@ function gemv_wrapper!(y::ROCVector{T}, tA::Char, A::ROCMatrix{T}, x::ROCVector{
if mA != length(y)
throw(DimensionMismatch("first dimension of A, $mA, does not match length of y, $(length(y))"))
end
if mA == 0
return y
end
if nA == 0
return rmul!(y, 0)
end
mA == 0 && return y
nA == 0 && return rmul!(y, 0)
gemv!(tA, alpha, A, x, beta, y)
end

Expand Down Expand Up @@ -207,15 +203,9 @@ for (t, uploc, isunitc) in (
)
@eval begin
LinearAlgebra.lmul!(A::$t{T, ROCMatrix{T}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
trmm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B, B)
trmm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B)
LinearAlgebra.rmul!(A::ROCMatrix{T}, B::$t{T, ROCMatrix{T}}) where T <: ROCBLASFloat =
trmm!('R', $uploc, 'N', $isunitc, one(T), parent(B), A, A)

# Optimization to avoid copy.
LinearAlgebra.mul!(Y::ROCMatrix{T}, A::$t{T, ROCMatrix{T}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
trmm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B, Y)
LinearAlgebra.mul!(Y::ROCMatrix{T}, A::ROCMatrix{T}, B::$t{T, ROCMatrix{T}}) where T <: ROCBLASFloat =
trmm!('R', $uploc, 'N', $isunitc, one(T), parent(B), A, Y)
trmm!('R', $uploc, 'N', $isunitc, one(T), parent(B), A)

LinearAlgebra.ldiv!(A::$t{T, ROCMatrix{T}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
trsm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B)
Expand All @@ -234,33 +224,18 @@ for (t, uploc, isunitc) in (
@eval begin
# Multiplication.
LinearAlgebra.lmul!(A::$t{<: Any, <: Transpose{T, <: ROCMatrix{T}}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
trmm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B, B)
trmm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B)
LinearAlgebra.lmul!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
trmm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B, B)
trmm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B)
LinearAlgebra.lmul!(A::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}, B::ROCMatrix{T}) where T <: ROCBLASComplex =
trmm!('L', $uploc, 'C', $isunitc, one(T), parent(parent(A)), B, B)
trmm!('L', $uploc, 'C', $isunitc, one(T), parent(parent(A)), B)

LinearAlgebra.rmul!(A::ROCMatrix{T}, B::$t{<: Any, <: Transpose{T, <: ROCMatrix{T}}}) where T <: ROCBLASFloat =
trmm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A, A)
trmm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A)
LinearAlgebra.rmul!(A::ROCMatrix{T}, B::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}) where T <: ROCBLASFloat =
trmm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A, A)
trmm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A)
LinearAlgebra.rmul!(A::ROCMatrix{T}, B::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}) where T <: ROCBLASComplex =
trmm!('R', $uploc, 'C', $isunitc, one(T), parent(parent(B)), A, A)

# Optimization.
LinearAlgebra.mul!(Y::ROCMatrix{T}, A::$t{<: Any, <: Transpose{T, <: ROCMatrix{T}}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
trmm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B, Y)
LinearAlgebra.mul!(Y::ROCMatrix{T}, A::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
trmm!('L', $uploc, 'T', $isunitc, one(T), parent(parent(A)), B, Y)
LinearAlgebra.mul!(Y::ROCMatrix{T}, A::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}, B::ROCMatrix{T}) where T <: ROCBLASComplex =
trmm!('L', $uploc, 'C', $isunitc, one(T), parent(parent(A)), B, Y)

LinearAlgebra.mul!(Y::ROCMatrix{T}, A::ROCMatrix{T}, B::$t{<: Any, <: Transpose{T, <: ROCMatrix{T}}}) where T <: ROCBLASFloat =
trmm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A, Y)
LinearAlgebra.mul!(Y::ROCMatrix{T}, A::ROCMatrix{T}, B::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}) where T <: ROCBLASFloat =
trmm!('R', $uploc, 'T', $isunitc, one(T), parent(parent(B)), A, Y)
LinearAlgebra.mul!(Y::ROCMatrix{T}, A::ROCMatrix{T}, B::$t{<: Any, <: Adjoint{T, <: ROCMatrix{T}}}) where T <: ROCBLASComplex =
trmm!('R', $uploc, 'C', $isunitc, one(T), parent(parent(B)), A, Y)
trmm!('R', $uploc, 'C', $isunitc, one(T), parent(parent(B)), A)

# Left division.
LinearAlgebra.ldiv!(A::$t{<: Any, <: Transpose{T, <: ROCMatrix{T}}}, B::ROCMatrix{T}) where T <: ROCBLASFloat =
Expand Down
Loading