Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
reduce BLAS threads for scalar indexing compatible convolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 24, 2024
1 parent c445922 commit 6a08a48
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 54 deletions.
1 change: 1 addition & 0 deletions ext/LuxLibTrackerAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module LuxLibTrackerAMDGPUExt

using AMDGPU: AMDGPU
using LuxLib: LuxLib
using NNlib: NNlib, ConvDims, PoolDims
using Tracker: Tracker, TrackedArray

Expand Down
2 changes: 1 addition & 1 deletion src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using PrecompileTools: @recompile_invalidations
using FastClosures: @closure
using GPUArraysCore: GPUArraysCore
using KernelAbstractions: KernelAbstractions, @Const, @index, @kernel
using LinearAlgebra: LinearAlgebra, mul!
using LinearAlgebra: LinearAlgebra, BLAS, mul!
using LuxCore: LuxCore
using Markdown: @doc_str
using NNlib: NNlib
Expand Down
8 changes: 4 additions & 4 deletions src/api/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ end
@inline function fused_conv_bias_activation(
σ::F, weight::AbstractArray, ::Val{false}, x::AbstractArray, ::Val{false},
b::Union{Nothing, AbstractArray}, ::Val{false}, cdims::ConvDims) where {F}
return __fused_conv_bias_activation_impl(σ, weight, x, b, cdims)
return _fused_conv_bias_activation_impl(σ, weight, x, b, cdims)
end

@inline function fused_conv_bias_activation(
σ::F, weight::AbstractArray, ::Val, x::AbstractArray, ::Val,
b::Union{Nothing, AbstractArray}, ::Val, cdims::ConvDims) where {F}
return __generic_conv_bias_activation(σ, weight, x, b, cdims)
return _generic_conv_bias_activation(σ, weight, x, b, cdims)
end

# SubArray Inputs: copy a subarray to make it contiguous in memory
Expand All @@ -81,13 +81,13 @@ end
@inline function fused_conv_bias_activation(
σ::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N},
b::AbstractArray{bT, N}, cdims::ConvDims) where {F, wT, xT, bT, N}
return __generic_conv_bias_activation(σ, weight, x, b, cdims)
return _generic_conv_bias_activation(σ, weight, x, b, cdims)
end

@inline function fused_conv_bias_activation(
σ::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N},
b::Nothing, cdims::ConvDims) where {F, wT, xT, N}
return __generic_conv_bias_activation(σ, weight, x, b, cdims)
return _generic_conv_bias_activation(σ, weight, x, b, cdims)
end

# Mixed Precision GPU Inputs
Expand Down
47 changes: 42 additions & 5 deletions src/impl/fused_conv.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,27 @@
@inline function _generic_conv_bias_activation(
act::F, weight::AbstractArray, args...) where {F}
old_threads = __maybe_reduce_BLAS_threads(weight)
ret = __generic_conv_bias_activation(act, weight, args...)
__reset_BLAS_threads(old_threads)
return ret
end

for aType in (AbstractArray, GPUArraysCore.AnyGPUArray)
@eval begin
@inline function __generic_conv_bias_activation(
act::F, weight::$(aType){T, N}, x::$(aType){T, N},
bias::$(aType){T, N}, cdims::ConvDims) where {T, N, F}
return __apply_bias_activation(act, conv(x, weight, cdims), bias)
end

@inline function __generic_conv_bias_activation(
act::F, weight::$(aType){T, N}, x::$(aType){T, N},
bias::Nothing, cdims::ConvDims) where {T, N, F}
return __apply_bias_activation(act, conv(x, weight, cdims), bias)
end
end
end

@inline function __generic_conv_bias_activation(
act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N},
bias::AbstractArray{bT, N}, cdims::ConvDims) where {wT, xT, bT, N, F}
Expand All @@ -15,23 +39,30 @@ end
x::GPUArraysCore.AnyGPUArray{xT, N}, bias::GPUArraysCore.AnyGPUArray{bT, N},
cdims::ConvDims) where {wT, xT, bT, N, F}
T = promote_type(wT, xT)
return __apply_bias_activation(
act, conv(_oftype_array(T, x), _oftype_array(T, weight), cdims), bias)
return __generic_conv_bias_activation(
act, _oftype_array(T, weight), _oftype_array(T, x), _oftype_array(T, bias), cdims)
end

@inline function __generic_conv_bias_activation(
act::F, weight::GPUArraysCore.AnyGPUArray{wT, N},
x::GPUArraysCore.AnyGPUArray{xT, N}, bias::Nothing,
cdims::ConvDims) where {wT, xT, N, F}
T = promote_type(wT, xT)
return __apply_bias_activation(
act, conv(_oftype_array(T, x), _oftype_array(T, weight), cdims), bias)
return __generic_conv_bias_activation(
act, _oftype_array(T, weight), _oftype_array(T, x), bias, cdims)
end

@inline function _fused_conv_bias_activation_impl(
act::F, weight::AbstractArray, args...) where {F}
old_threads = __maybe_reduce_BLAS_threads(weight)
ret = __fused_conv_bias_activation_impl(act, weight, args...)
__reset_BLAS_threads(old_threads)
return ret
end

# This implementation is different from `conv_bias_act` in that it defines the proper rrules
# and fuses operations into a single kernel if it is possible. Unfortunately there are
# certain configurations where CUDNN allows caching intermediates, but we don't do that rn.

@inline function __fused_conv_bias_activation_impl(
act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N},
bias::Union{Nothing, AbstractArray}, cdims::ConvDims) where {wT, xT, N, F}
Expand Down Expand Up @@ -92,11 +123,13 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
y = __apply_bias_activation!!(act, y, bias, Val(false))
end
∇__fused_conv_bias_activation_impl_no_cached = @closure Δ -> begin
old_threads = __maybe_reduce_BLAS_threads(weight)
Δ = CRC.unthunk(NNlib.colmajor(Δ))
∂y = act === identity ? Δ : __activation_gradient(Δ, y, act, NotaNumber())
∂b = __added_bias_gradient(bias, ∂y)
∂x = NNlib.∇conv_data(∂y, weight, cdims)
∂w = NNlib.∇conv_filter(x, ∂y, cdims)
__reset_BLAS_threads(old_threads)
return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent()
end
return y, ∇__fused_conv_bias_activation_impl_no_cached
Expand All @@ -108,22 +141,26 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
z, y = __apply_bias_activation!!(act, y, bias, Val(true))
∇__fused_conv_bias_activation_impl_cached_crc = @closure Δ -> begin
old_threads = __maybe_reduce_BLAS_threads(weight)
Δ = CRC.unthunk(NNlib.colmajor(Δ))
∂y = __activation_gradient(Δ, z, act, y)
∂b = __added_bias_gradient(bias, ∂y)
∂x = NNlib.∇conv_data(∂y, weight, cdims)
∂w = NNlib.∇conv_filter(x, ∂y, cdims)
__reset_BLAS_threads(old_threads)
return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent()
end
return z, ∇__fused_conv_bias_activation_impl_cached_crc
end

z, pb_f = CRC.rrule_via_ad(cfg, __apply_bias_activation, act, y, bias)
∇__fused_conv_bias_activation_impl_cached = @closure Δ -> begin
old_threads = __maybe_reduce_BLAS_threads(weight)
Δ = NNlib.colmajor(Δ)
_, _, ∂y, ∂b = pb_f(Δ)
∂x = NNlib.∇conv_data(∂y, weight, cdims)
∂w = NNlib.∇conv_filter(x, ∂y, cdims)
__reset_BLAS_threads(old_threads)
return CRC.NoTangent(), CRC.NoTangent(), ∂w, ∂x, ∂b, CRC.NoTangent()
end

Expand Down
19 changes: 19 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,22 @@ end
end
return @. Δ * only_derivative(out, act, x)
end

# Reduce BLAS threads if we are going to use a native Julia implementation
@inline function __maybe_reduce_BLAS_threads(x::AbstractArray)::Int
if ArrayInterface.fast_scalar_indexing(x)
old_threads = BLAS.get_num_threads()
BLAS.set_num_threads(1)
return old_threads
end
return -1
end

CRC.@non_differentiable __maybe_reduce_BLAS_threads(::AbstractArray)

@inline function __reset_BLAS_threads(old_threads::Int)
old_threads 1 && BLAS.set_num_threads(old_threads)
return nothing
end

CRC.@non_differentiable __reset_BLAS_threads(::Int)
88 changes: 47 additions & 41 deletions test/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,62 +16,68 @@
return _expand(Val(2 * N), pad)
end

anonact = x -> gelu(x)

@testset "$mode" for (mode, aType, on_gpu) in MODES
# These are not all possible combinations but rather a representative set to keep
# CI timings under check
# Most of the actual tests happen upstream in Lux
@testset "$(Tw) x $(Tx)" for (Tw, Tx) in [
(Float16, Float16), (Float32, Float16), (Float32, Float32),
(Float32, Float64), (Float64, Float64)]
for hasbias in (true, false),
activation in (identity, tanh, tanh_fast, sigmoid,
sigmoid_fast, relu, gelu, x -> gelu(x)),
(kernel, padding, stride, groups) in (
((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1),
((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2))
@testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for (Tw, Tx) in [
(Float16, Float16), (Float32, Float16), (Float32, Float32),
(Float32, Float64), (Float64, Float64)],
hasbias in (true, false),
activation in (
identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact),
(kernel, padding, stride, groups) in (
((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1),
((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2))

weight = _convfilter(Tw, kernel, 4 => 8; groups) |> aType
x = __generate_fixed_array(
Tx, ntuple(Returns(3), length(kernel))..., 4, 2) |> aType
bias = hasbias ?
aType(__generate_fixed_array(
Tx, ntuple(Returns(1), length(kernel))..., 8, 1)) : nothing
weight = _convfilter(Tw, kernel, 4 => 8; groups) |> aType
x = __generate_fixed_array(Tx, ntuple(Returns(3), length(kernel))..., 4, 2) |>
aType
bias = hasbias ?
aType(__generate_fixed_array(
Tx, ntuple(Returns(1), length(kernel))..., 8, 1)) : nothing

cdims = DenseConvDims(
x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride),
dilation=1, groups)
cdims = DenseConvDims(
x, weight; stride, padding=_calc_padding(padding, kernel, 1, stride),
dilation=1, groups)

y = fused_conv_bias_activation(activation, weight, x, bias, cdims)
y = fused_conv_bias_activation(activation, weight, x, bias, cdims)

y_generic = LuxLib.__generic_conv_bias_activation(
activation, weight, x, bias, cdims)
y_generic = LuxLib.__generic_conv_bias_activation(
activation, weight, x, bias, cdims)

@test y y_generic
@test eltype(y) == promote_type(Tw, Tx)
fp16 = Tx == Float16 || Tw == Float16
atol = fp16 ? 1.0f-1 : 1.0f-3
rtol = fp16 ? 1.0f-1 : 1.0f-3
# Operation reordering has an effect on the accuracy of the results
@test yy_generic atol=atol rtol=rtol
@test eltype(y) == promote_type(Tw, Tx)

@inferred fused_conv_bias_activation(activation, weight, x, bias, cdims)
@jet fused_conv_bias_activation(activation, weight, x, bias, cdims)
@inferred fused_conv_bias_activation(activation, weight, x, bias, cdims)
@jet fused_conv_bias_activation(activation, weight, x, bias, cdims)

# FIXME: GPU compilation of the gradients for mixed precision seems broken
Tw !== Tx && on_gpu && continue
# FIXME: GPU compilation of the gradients for mixed precision seems broken
Tw !== Tx && on_gpu && continue

__f = (σ, w, x, b, cdims) -> sum(
abs2, fused_conv_bias_activation(σ, w, x, b, cdims))
__f = (σ, w, x, b, cdims) -> sum(
abs2, fused_conv_bias_activation(σ, w, x, b, cdims))

if mode != "AMDGPU"
if mode != "AMDGPU" && activation !== anonact
@inferred Zygote.gradient(__f, activation, weight, x, bias, cdims)
else
try
@inferred Zygote.gradient(__f, activation, weight, x, bias, cdims)
else
try
@inferred Zygote.gradient(__f, activation, weight, x, bias, cdims)
@test true
catch
@test_broken false
end
@test true
catch
@test_broken false
end

fp16 = Tx == Float16 || Tw == Float16
atol = fp16 ? 1.0f-1 : 1.0f-3
rtol = fp16 ? 1.0f-1 : 1.0f-3
end
if mode === "AMDGPU"
@eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_tracker=true skip_finite_differences=$(Tx !=
Tw)
else
# FiniteDiffencing doesn't work great for MP because of how LuxTestUtils is
# implemented.
@eval @test_gradients $__f $activation $weight $x $bias $cdims gpu_testing=$on_gpu soft_fail=$fp16 atol=$atol rtol=$rtol skip_finite_differences=$(Tx !=
Expand Down
2 changes: 1 addition & 1 deletion test/groupnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ end
end
end

@testitem "Group Normalization Generic Fallback" tags=[:nworkers, :normalization] setup=[
@testitem "Group Normalization Generic Fallback" tags=[:singleworker, :normalization] setup=[
SharedTestSetup, GroupNormSetup] begin
@testset "$mode" for (mode, aType, on_gpu) in MODES
@testset "eltype $T, size $sz, ngroups $groups, $act" for T in (
Expand Down
3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@ const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all")
@info "Running tests for group: $LUXLIB_TEST_GROUP"

if LUXLIB_TEST_GROUP == "all"
# Instance Normalization Tests causes stalling on CUDA CI
ReTestItems.runtests(@__DIR__; nworkers=0, tags=[:singleworker])

ReTestItems.runtests(@__DIR__; tags=[:nworkers])
else
tag = Symbol(LUXLIB_TEST_GROUP)
# Instance Normalization Tests causes stalling on CUDA CI

ReTestItems.runtests(@__DIR__; nworkers=0, tags=[:singleworker, tag])

ReTestItems.runtests(@__DIR__; tags=[:nworkers, tag])
Expand Down

0 comments on commit 6a08a48

Please sign in to comment.