diff --git a/ext/LuxLibTrackerAMDGPUExt.jl b/ext/LuxLibTrackerAMDGPUExt.jl index 803b70fd..a3ecd174 100644 --- a/ext/LuxLibTrackerAMDGPUExt.jl +++ b/ext/LuxLibTrackerAMDGPUExt.jl @@ -1,6 +1,7 @@ module LuxLibTrackerAMDGPUExt using AMDGPU: AMDGPU +using LuxLib: LuxLib using NNlib: NNlib, ConvDims, PoolDims using Tracker: Tracker, TrackedArray diff --git a/src/LuxLib.jl b/src/LuxLib.jl index e962279e..776a2f5d 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -9,7 +9,7 @@ using PrecompileTools: @recompile_invalidations using FastClosures: @closure using GPUArraysCore: GPUArraysCore using KernelAbstractions: KernelAbstractions, @Const, @index, @kernel - using LinearAlgebra: LinearAlgebra, mul! + using LinearAlgebra: LinearAlgebra, BLAS, mul! using LuxCore: LuxCore using Markdown: @doc_str using NNlib: NNlib diff --git a/src/api/conv.jl b/src/api/conv.jl index 1c80afdd..c292be15 100644 --- a/src/api/conv.jl +++ b/src/api/conv.jl @@ -54,13 +54,13 @@ end @inline function fused_conv_bias_activation( σ::F, weight::AbstractArray, ::Val{false}, x::AbstractArray, ::Val{false}, b::Union{Nothing, AbstractArray}, ::Val{false}, cdims::ConvDims) where {F} - return __fused_conv_bias_activation_impl(σ, weight, x, b, cdims) + return _fused_conv_bias_activation_impl(σ, weight, x, b, cdims) end @inline function fused_conv_bias_activation( σ::F, weight::AbstractArray, ::Val, x::AbstractArray, ::Val, b::Union{Nothing, AbstractArray}, ::Val, cdims::ConvDims) where {F} - return __generic_conv_bias_activation(σ, weight, x, b, cdims) + return _generic_conv_bias_activation(σ, weight, x, b, cdims) end # SubArray Inputs: copy a subarray to make it contiguous in memory @@ -81,13 +81,13 @@ end @inline function fused_conv_bias_activation( σ::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, b::AbstractArray{bT, N}, cdims::ConvDims) where {F, wT, xT, bT, N} - return __generic_conv_bias_activation(σ, weight, x, b, cdims) + return _generic_conv_bias_activation(σ, weight, x, b, cdims) end @inline function fused_conv_bias_activation( σ::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, b::Nothing, cdims::ConvDims) where {F, wT, xT, N} - return __generic_conv_bias_activation(σ, weight, x, b, cdims) + return _generic_conv_bias_activation(σ, weight, x, b, cdims) end # Mixed Precision GPU Inputs diff --git a/src/impl/fused_conv.jl b/src/impl/fused_conv.jl index b159b651..5243e416 100644 --- a/src/impl/fused_conv.jl +++ b/src/impl/fused_conv.jl @@ -1,3 +1,27 @@ +@inline function _generic_conv_bias_activation( + act::F, weight::AbstractArray, args...) where {F} + old_threads = __maybe_reduce_BLAS_threads(weight) + ret = __generic_conv_bias_activation(act, weight, args...) + __reset_BLAS_threads(old_threads) + return ret +end + +for aType in (AbstractArray, GPUArraysCore.AnyGPUArray) + @eval begin + @inline function __generic_conv_bias_activation( + act::F, weight::$(aType){T, N}, x::$(aType){T, N}, + bias::$(aType){T, N}, cdims::ConvDims) where {T, N, F} + return __apply_bias_activation(act, conv(x, weight, cdims), bias) + end + + @inline function __generic_conv_bias_activation( + act::F, weight::$(aType){T, N}, x::$(aType){T, N}, + bias::Nothing, cdims::ConvDims) where {T, N, F} + return __apply_bias_activation(act, conv(x, weight, cdims), bias) + end + end +end + @inline function __generic_conv_bias_activation( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::AbstractArray{bT, N}, cdims::ConvDims) where {wT, xT, bT, N, F} @@ -15,8 +39,8 @@ end x::GPUArraysCore.AnyGPUArray{xT, N}, bias::GPUArraysCore.AnyGPUArray{bT, N}, cdims::ConvDims) where {wT, xT, bT, N, F} T = promote_type(wT, xT) - return __apply_bias_activation( - act, conv(_oftype_array(T, x), _oftype_array(T, weight), cdims), bias) + return __generic_conv_bias_activation( + act, _oftype_array(T, weight), _oftype_array(T, x), _oftype_array(T, bias), cdims) end @inline function __generic_conv_bias_activation( @@ -24,14 +48,21 @@ end x::GPUArraysCore.AnyGPUArray{xT, N}, bias::Nothing, cdims::ConvDims) where {wT, xT, N, F} T = promote_type(wT, xT) - return __apply_bias_activation( - act, conv(_oftype_array(T, x), _oftype_array(T, weight), cdims), bias) + return __generic_conv_bias_activation( + act, _oftype_array(T, weight), _oftype_array(T, x), bias, cdims) +end + +@inline function _fused_conv_bias_activation_impl( + act::F, weight::AbstractArray, args...) where {F} + old_threads = __maybe_reduce_BLAS_threads(weight) + ret = __fused_conv_bias_activation_impl(act, weight, args...) + __reset_BLAS_threads(old_threads) + return ret end # This implementation is different from `conv_bias_act` in that it defines the proper rrules # and fuses operations into a single kernel if it is possible. Unfortunately there are # certain configurations where CUDNN allows caching intermediates, but we don't do that rn. - @inline function __fused_conv_bias_activation_impl( act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F} @@ -92,11 +123,13 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, y = __apply_bias_activation!!(act, y, bias, Val(false)) end ∇__fused_conv_bias_activation_impl_no_cached = @closure Δ -> begin + old_threads = __maybe_reduce_BLAS_threads(weight) Δ = CRC.unthunk(NNlib.colmajor(Δ)) ∂y = act === identity ? Δ : __activation_gradient(Δ, y, act, NotaNumber()) ∂b = __added_bias_gradient(bias, ∂y) ∂x = NNlib.∇conv_data(∂y, weight, cdims) ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + __reset_BLAS_threads(old_threads) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() end return y, ∇__fused_conv_bias_activation_impl_no_cached @@ -108,11 +141,13 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) z, y = __apply_bias_activation!!(act, y, bias, Val(true)) ∇__fused_conv_bias_activation_impl_cached_crc = @closure Δ -> begin + old_threads = __maybe_reduce_BLAS_threads(weight) Δ = CRC.unthunk(NNlib.colmajor(Δ)) ∂y = __activation_gradient(Δ, z, act, y) ∂b = __added_bias_gradient(bias, ∂y) ∂x = NNlib.∇conv_data(∂y, weight, cdims) ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + __reset_BLAS_threads(old_threads) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() end return z, ∇__fused_conv_bias_activation_impl_cached_crc @@ -120,10 +155,12 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, z, pb_f = CRC.rrule_via_ad(cfg, __apply_bias_activation, act, y, bias) ∇__fused_conv_bias_activation_impl_cached = @closure Δ -> begin + old_threads = __maybe_reduce_BLAS_threads(weight) Δ = NNlib.colmajor(Δ) _, _, ∂y, ∂b = pb_f(Δ) ∂x = NNlib.∇conv_data(∂y, weight, cdims) ∂w = NNlib.∇conv_filter(x, ∂y, cdims) + __reset_BLAS_threads(old_threads) return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent() end diff --git a/src/utils.jl b/src/utils.jl index bc219fd5..e823327f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -180,3 +180,22 @@ end end return @. Δ * only_derivative(out, act, x) end + +# Reduce BLAS threads if we are going to use a native Julia implementation +@inline function __maybe_reduce_BLAS_threads(x::AbstractArray)::Int + if ArrayInterface.fast_scalar_indexing(x) + old_threads = BLAS.get_num_threads() + BLAS.set_num_threads(1) + return old_threads + end + return -1 +end + +CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray) + +@inline function __reset_BLAS_threads(old_threads::Int) + old_threads ≥ 1 && BLAS.set_num_threads(old_threads) + return nothing +end + +CRC.@non_differentiable __reset_BLAS_threads(::Int) diff --git a/test/conv_tests.jl b/test/conv_tests.jl index c695ec69..b2d9495c 100644 --- a/test/conv_tests.jl +++ b/test/conv_tests.jl @@ -16,62 +16,68 @@ return _expand(Val(2 * N), pad) end + anonact = x -> gelu(x) + @testset "$mode" for (mode, aType, on_gpu) in MODES # These are not all possible combinations but rather a representative set to keep # CI timings under check # Most of the actual tests happen upstream in Lux - @testset "$(Tw) x $(Tx)" for (Tw, Tx) in [ - (Float16, Float16), (Float32, Float16), (Float32, Float32), - (Float32, Float64), (Float64, Float64)] - for hasbias in (true, false), - activation in (identity, tanh, tanh_fast, sigmoid, - sigmoid_fast, relu, gelu, x -> gelu(x)), - (kernel, padding, stride, groups) in ( - ((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), - ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2)) + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for (Tw, Tx) in [ + (Float16, Float16), (Float32, Float16), (Float32, Float32), + (Float32, Float64), (Float64, Float64)], + hasbias in (true, false), + activation in ( + identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact), + (kernel, padding, stride, groups) in ( + ((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), + ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2)) - weight = _convfilter(Tw, kernel, 4 => 8; groups) |> aType - x = __generate_fixed_array( - Tx, ntuple(Returns(3), length(kernel))..., 4, 2) |> aType - bias = hasbias ? - aType(__generate_fixed_array( - Tx, ntuple(Returns(1), length(kernel))..., 8, 1)) : nothing + weight = _convfilter(Tw, kernel, 4 => 8; groups) |> aType + x = __generate_fixed_array(Tx, ntuple(Returns(3), length(kernel))..., 4, 2) |> + aType + bias = hasbias ? + aType(__generate_fixed_array( + Tx, ntuple(Returns(1), length(kernel))..., 8, 1)) : nothing - cdims = DenseConvDims( - x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), - dilation=1, groups) + cdims = DenseConvDims( + x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride), + dilation=1, groups) - y = fused_conv_bias_activation(activation, weight, x, bias, cdims) + y = fused_conv_bias_activation(activation, weight, x, bias, cdims) - y_generic = LuxLib.__generic_conv_bias_activation( - activation, weight, x, bias, cdims) + y_generic = LuxLib.__generic_conv_bias_activation( + activation, weight, x, bias, cdims) - @test y ≈ y_generic - @test eltype(y) == promote_type(Tw, Tx) + fp16 = Tx == Float16 || Tw == Float16 + atol = fp16 ? 1.0f-1 : 1.0f-3 + rtol = fp16 ? 1.0f-1 : 1.0f-3 + # Operation reordering has an effect on the accuracy of the results + @test y≈y_generic atol=atol rtol=rtol + @test eltype(y) == promote_type(Tw, Tx) - @inferred fused_conv_bias_activation(activation, weight, x, bias, cdims) - @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) + @inferred fused_conv_bias_activation(activation, weight, x, bias, cdims) + @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) - # FIXME: GPU compilation of the gradients for mixed precision seems broken - Tw !== Tx && on_gpu && continue + # FIXME: GPU compilation of the gradients for mixed precision seems broken + Tw !== Tx && on_gpu && continue - __f = (σ, w, x, b, cdims) -> sum( - abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) + __f = (σ, w, x, b, cdims) -> sum( + abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) - if mode != "AMDGPU" + if mode != "AMDGPU" && activation !== anonact + @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) + else + try @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) - else - try - @inferred Zygote.gradient(__f, activation, weight, x, bias, cdims) - @test true - catch - @test_broken false - end + @test true + catch + @test_broken false end - - fp16 = Tx == Float16 || Tw == Float16 - atol = fp16 ? 1.0f-1 : 1.0f-3 - rtol = fp16 ? 1.0f-1 : 1.0f-3 + end + if mode === "AMDGPU" + @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_tracker=true skip_finite_differences=$(Tx != + Tw) + else # FiniteDiffencing doesn't work great for MP because of how LuxTestUtils is # implemented. @eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_finite_differences=$(Tx != diff --git a/test/groupnorm_tests.jl b/test/groupnorm_tests.jl index 8cd39d74..72f5f6df 100644 --- a/test/groupnorm_tests.jl +++ b/test/groupnorm_tests.jl @@ -74,7 +74,7 @@ end end end -@testitem "Group Normalization Generic Fallback" tags=[:nworkers, :normalization] setup=[ +@testitem "Group Normalization Generic Fallback" tags=[:singleworker, :normalization] setup=[ SharedTestSetup, GroupNormSetup] begin @testset "$mode" for (mode, aType, on_gpu) in MODES @testset "eltype $T, size $sz, ngroups $groups, $act" for T in ( diff --git a/test/runtests.jl b/test/runtests.jl index ad617f06..477c60da 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,13 +4,12 @@ const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") @info "Running tests for group: $LUXLIB_TEST_GROUP" if LUXLIB_TEST_GROUP == "all" - # Instance Normalization Tests causes stalling on CUDA CI ReTestItems.runtests(@__DIR__; nworkers=0, tags=[:singleworker]) ReTestItems.runtests(@__DIR__; tags=[:nworkers]) else tag = Symbol(LUXLIB_TEST_GROUP) - # Instance Normalization Tests causes stalling on CUDA CI + ReTestItems.runtests(@__DIR__; nworkers=0, tags=[:singleworker, tag]) ReTestItems.runtests(@__DIR__; tags=[:nworkers, tag])