Skip to content

Commit

Permalink
Add dispatch path for FP16 batched mul
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Apr 13, 2023
1 parent ee909e6 commit 8eebdcb
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
8 changes: 2 additions & 6 deletions ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@ Base.show(io::IO, x::AnyROCBatchedAdjOrTrans) = show(io, adapt(Array, x))

Base.display(x::AnyROCBatchedAdjOrTrans) = display(adapt(Array, x))

function NNlib._batched_gemm!(
::Type{<: ROCArray}, transA::Char, transB::Char, α, A, B, β, C,
)
AMDGPU.rocBLAS.gemm_batched!(transA, transB, α, A, B, β, C)
end

function nnlib_padding(dims)
pd = NNlib.padding(dims)
if !all(pd[1:2:end] .== pd[2:2:end])
Expand All @@ -52,6 +46,8 @@ function nnlib_padding(dims)
pd[1:2:end]
end

include("batched_mul.jl")

@static if AMDGPU.functional(:MIOpen)
using AMDGPU.MIOpen

Expand Down
24 changes: 24 additions & 0 deletions ext/NNlibAMDGPUExt/batched_mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
function _blas_at(x)
Base.stride(x, 1) == 1 && return x, 'N'
Base.stride(x, 2) == 1 && return batched_transpose(x), 'T'
throw(ArgumentError("""
Unsupported array layout for batched mul.
- Size: $(size(x))
- Strides: $(strides(x))
"""))
end

function NNlib._batched_mul!(
::Type{AT}, C, A, B, α::Float16, β::Float16,
) where AT <: ROCArray{Float16}
blasA, transA = _blas_at(A)
blasB, transB = _blas_at(B)
NNlib._batched_gemm!(AT, transA, transB, α, blasA, blasB, β, C)
C
end

function NNlib._batched_gemm!(
::Type{<:ROCArray{T}}, transA::Char, transB::Char, α::T, A, B, β::T, C,
) where T <: Union{MIOPENFloat, Float64}
AMDGPU.rocBLAS.gemm_batched!(transA, transB, α, A, B, β, C)
end
1 change: 0 additions & 1 deletion src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ _batched_mul!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray
_batched_try_gemm!(DT, C, A, B, α, β)

function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat}

alpha, beta = promote(α, β, zero(T))
alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β)

Expand Down

0 comments on commit 8eebdcb

Please sign in to comment.