Skip to content

Commit

Permalink
[rocSPARSE] Update the interface for sparse products
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Oct 3, 2024
1 parent a674fdf commit 68dcb2e
Showing 1 changed file with 46 additions and 42 deletions.
88 changes: 46 additions & 42 deletions src/sparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,48 +22,52 @@ function mm_wrapper(
mm!(transa, transb, alpha, A, B, beta, C, 'O')
end

tag_wrappers = (
(identity, identity),
(T -> :(HermOrSym{T, <:$T}), A -> :(parent($A))))

op_wrappers = (
(identity, T -> 'N', identity),
(T -> :(Transpose{<:T, <:$T}), T -> 'T', A -> :(parent($A))),
(T -> :(Adjoint{<:T, <:$T}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A))))

for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(taga(:(ROCSparseMatrix{T})))

@eval begin
function LinearAlgebra.mul!(
C::ROCVector{T}, A::$TypeA, B::DenseROCVector{T},
alpha::Number, beta::Number,
) where T <: Union{Float16, ComplexF16, BlasFloat}
mv_wrapper($transa(T), alpha, $(untaga(unwrapa(:A))), B, beta, C)
end

function LinearAlgebra.mul!(
C::ROCVector{Complex{T}}, A::$TypeA, B::DenseROCVector{Complex{T}},
alpha::Number, beta::Number,
) where T <: Union{Float16, BlasFloat}
mv_wrapper($transa(T), alpha, $(untaga(unwrapa(:A))), B, beta, C)
end
end

for (tagb, untagb) in tag_wrappers, (wrapb, transb, unwrapb) in op_wrappers
TypeB = wrapb(tagb(:(DenseROCMatrix{T})))

@eval begin
function LinearAlgebra.mul!(
C::ROCMatrix{T}, A::$TypeA, B::$TypeB,
alpha::Number, beta::Number,
) where T <: Union{Float16, ComplexF16, BlasFloat}
mm_wrapper(
$transa(T), $transb(T), alpha,
$(untaga(unwrapa(:A))), $(untagb(unwrapb(:B))), beta, C)
end
end
end
# legacy methods with final MulAddMul argument
LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::DenseROCVector{T}, _add::MulAddMul) where T <: BlasFloat =
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::ROCSparseVector{T}, _add::MulAddMul) where T <: BlasFloat =
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::ROCSparseMatrix{T}, B::DenseROCMatrix{T}, _add::MulAddMul) where T <: BlasFloat =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)

function LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::DenseROCVector{T}, alpha::Number, beta::Number) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
mv_wrapper(tA, alpha, A, B, beta, C)
end

function LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::ROCSparseVector{T}, alpha::Number, beta::Number) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
mv_wrapper(tA, alpha, A, ROCVector{T}(B), beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::ROCSparseMatrix{T}, B::DenseROCMatrix{T}, alpha::Number, beta::Number) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
mm_wrapper(tA, tB, alpha, A, B, beta, C)
end

# legacy methods with final MulAddMul argument
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSC{T}, _add::MulAddMul) where T <: BlasFloat =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSR{T}, _add::MulAddMul) where T <: BlasFloat =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCOO{T}, _add::MulAddMul) where T <: BlasFloat =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)

function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSC{T}, alpha::Number, beta::Number) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
mm!(tA, tB, alpha, A, B, beta, C, 'O')
end
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSR{T}, alpha::Number, beta::Number) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
mm!(tA, tB, alpha, A, B, beta, C, 'O')
end
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCOO{T}, alpha::Number, beta::Number) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
mm!(tA, tB, alpha, A, B, beta, C, 'O')
end

Base.:(+)(A::ROCSparseMatrixCSR, B::ROCSparseMatrixCSR) = geam(one(eltype(A)), A, one(eltype(A)), B, 'O')
Expand Down

0 comments on commit 68dcb2e

Please sign in to comment.