Skip to content

Commit

Permalink
Update ROCSparse for Julia v1.10 (#613)
Browse files Browse the repository at this point in the history
* [rocSPARSE] Update the interface for sparse products

* Fix test for rocsparse/interfaces.jl

* Fix again the tests for rocsparse/interfaces.jl
  • Loading branch information
amontoison authored Nov 22, 2024
1 parent e61e088 commit 78c1036
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 83 deletions.
88 changes: 46 additions & 42 deletions src/sparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,48 +22,52 @@ function mm_wrapper(
mm!(transa, transb, alpha, A, B, beta, C, 'O')
end

tag_wrappers = (
(identity, identity),
(T -> :(HermOrSym{T, <:$T}), A -> :(parent($A))))

op_wrappers = (
(identity, T -> 'N', identity),
(T -> :(Transpose{<:T, <:$T}), T -> 'T', A -> :(parent($A))),
(T -> :(Adjoint{<:T, <:$T}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A))))

for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(taga(:(ROCSparseMatrix{T})))

@eval begin
function LinearAlgebra.mul!(
C::ROCVector{T}, A::$TypeA, B::DenseROCVector{T},
alpha::Number, beta::Number,
) where T <: Union{Float16, ComplexF16, BlasFloat}
mv_wrapper($transa(T), alpha, $(untaga(unwrapa(:A))), B, beta, C)
end

function LinearAlgebra.mul!(
C::ROCVector{Complex{T}}, A::$TypeA, B::DenseROCVector{Complex{T}},
alpha::Number, beta::Number,
) where T <: Union{Float16, BlasFloat}
mv_wrapper($transa(T), alpha, $(untaga(unwrapa(:A))), B, beta, C)
end
end

for (tagb, untagb) in tag_wrappers, (wrapb, transb, unwrapb) in op_wrappers
TypeB = wrapb(tagb(:(DenseROCMatrix{T})))

@eval begin
function LinearAlgebra.mul!(
C::ROCMatrix{T}, A::$TypeA, B::$TypeB,
alpha::Number, beta::Number,
) where T <: Union{Float16, ComplexF16, BlasFloat}
mm_wrapper(
$transa(T), $transb(T), alpha,
$(untaga(unwrapa(:A))), $(untagb(unwrapb(:B))), beta, C)
end
end
end
# legacy methods with final MulAddMul argument
LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::DenseROCVector{T}, _add::MulAddMul) where T <: BlasFloat =
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::ROCSparseVector{T}, _add::MulAddMul) where T <: BlasFloat =
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::ROCSparseMatrix{T}, B::DenseROCMatrix{T}, _add::MulAddMul) where T <: BlasFloat =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)

function LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::DenseROCVector{T}, alpha::Number, beta::Number) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
mv_wrapper(tA, alpha, A, B, beta, C)
end

function LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::ROCSparseVector{T}, alpha::Number, beta::Number) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
mv_wrapper(tA, alpha, A, ROCVector{T}(B), beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::ROCSparseMatrix{T}, B::DenseROCMatrix{T}, alpha::Number, beta::Number) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
mm_wrapper(tA, tB, alpha, A, B, beta, C)
end

# legacy methods with final MulAddMul argument
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSC{T}, _add::MulAddMul) where T <: BlasFloat =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSR{T}, _add::MulAddMul) where T <: BlasFloat =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCOO{T}, _add::MulAddMul) where T <: BlasFloat =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)

function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSC{T}, alpha::Number, beta::Number) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
mm!(tA, tB, alpha, A, B, beta, C, 'O')
end
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSR{T}, alpha::Number, beta::Number) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
mm!(tA, tB, alpha, A, B, beta, C, 'O')
end
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCOO{T}, alpha::Number, beta::Number) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
mm!(tA, tB, alpha, A, B, beta, C, 'O')
end

Base.:(+)(A::ROCSparseMatrixCSR, B::ROCSparseMatrixCSR) = geam(one(eltype(A)), A, one(eltype(A)), B, 'O')
Expand Down
58 changes: 17 additions & 41 deletions test/rocsparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,47 +79,23 @@
LinearAlgebra.mul!(dc, f(dA), db, alpha, beta)
@test c collect(dc)

A = A + transpose(A)
dA = ROCSparseMatrixCSR(A)

@assert issymmetric(A)
LinearAlgebra.mul!(c, f(Symmetric(A)), b, alpha, beta)
LinearAlgebra.mul!(dc, f(Symmetric(dA)), db, alpha, beta)
@test c collect(dc)
end

@testset "$f(A)*b Complex{$elty}*$elty" for elty in (
Float32, Float64,
), f in (
identity, transpose, adjoint,
)
n = 10
alpha = rand()
beta = rand()
A = sprand(Complex{elty}, n, n, rand())
b = rand(Complex{elty}, n)
c = rand(Complex{elty}, n)
alpha = beta = 1.0
c = zeros(Complex{elty}, n)

dA = ROCSparseMatrixCSR(A)
db = ROCArray(b)
dc = ROCArray(c)

# test with empty inputs
@test Array(dA * AMDGPU.zeros(Complex{elty}, n, 0)) == zeros(Complex{elty}, n, 0)

LinearAlgebra.mul!(c, f(A), b, alpha, beta)
LinearAlgebra.mul!(dc, f(dA), db, alpha, beta)
@test c collect(dc)

A = A + transpose(A)
dA = ROCSparseMatrixCSR(A)

@assert issymmetric(A)
LinearAlgebra.mul!(c, f(Symmetric(A)), b, alpha, beta)
LinearAlgebra.mul!(dc, f(Symmetric(dA)), db, alpha, beta)
@test c collect(dc)
if f in (identity, transpose)
A = A + transpose(A)
dA = ROCSparseMatrixCSR(A)

@assert issymmetric(A)
LinearAlgebra.mul!(c, f(Symmetric(A)), b, alpha, beta)
LinearAlgebra.mul!(dc, f(Symmetric(dA)), db, alpha, beta)
@test c collect(dc)
else
A = A + adjoint(A)
dA = ROCSparseMatrixCSR(A)

@assert ishermitian(A)
LinearAlgebra.mul!(c, f(Hermitian(A)), b, alpha, beta)
LinearAlgebra.mul!(dc, f(Hermitian(dA)), db, alpha, beta)
@test c collect(dc)
end
end

@testset "$f(A)*$h(B) $elty" for elty in (
Expand Down

0 comments on commit 78c1036

Please sign in to comment.