diff --git a/ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl b/ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl index 5e7d007e6..c9f78add1 100644 --- a/ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl +++ b/ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl @@ -35,12 +35,6 @@ Base.show(io::IO, x::AnyROCBatchedAdjOrTrans) = show(io, adapt(Array, x)) Base.display(x::AnyROCBatchedAdjOrTrans) = display(adapt(Array, x)) -function NNlib._batched_gemm!( - ::Type{<: ROCArray}, transA::Char, transB::Char, α, A, B, β, C, -) - AMDGPU.rocBLAS.gemm_batched!(transA, transB, α, A, B, β, C) -end - function nnlib_padding(dims) pd = NNlib.padding(dims) if !all(pd[1:2:end] .== pd[2:2:end]) @@ -52,6 +46,8 @@ function nnlib_padding(dims) pd[1:2:end] end +include("batched_mul.jl") + @static if AMDGPU.functional(:MIOpen) using AMDGPU.MIOpen diff --git a/ext/NNlibAMDGPUExt/batched_mul.jl b/ext/NNlibAMDGPUExt/batched_mul.jl new file mode 100644 index 000000000..191af87d1 --- /dev/null +++ b/ext/NNlibAMDGPUExt/batched_mul.jl @@ -0,0 +1,24 @@ +function _blas_at(x) + Base.stride(x, 1) == 1 && return x, 'N' + Base.stride(x, 2) == 1 && return batched_transpose(x), 'T' + throw(ArgumentError(""" + Unsupported array layout for batched mul. + - Size: $(size(x)) + - Strides: $(strides(x)) + """)) +end + +function NNlib._batched_mul!( + ::Type{AT}, C, A, B, α::Float16, β::Float16, +) where AT <: ROCArray{Float16} + blasA, transA = _blas_at(A) + blasB, transB = _blas_at(B) + NNlib._batched_gemm!(AT, transA, transB, α, blasA, blasB, β, C) + C +end + +function NNlib._batched_gemm!( + ::Type{<:ROCArray{T}}, transA::Char, transB::Char, α::T, A, B, β::T, C, +) where T <: Union{MIOPENFloat, Float64} + AMDGPU.rocBLAS.gemm_batched!(transA, transB, α, A, B, β, C) +end diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 9e6bef84b..a3b7efc74 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -223,7 +223,6 @@ _batched_mul!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray _batched_try_gemm!(DT, C, A, B, α, β) function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat} - alpha, beta = promote(α, β, zero(T)) alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β)