From 90a0043badea2fab2edf2262aaf12f9fc1be81d8 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 4 Sep 2023 18:48:24 -0400 Subject: [PATCH] Add `bias_act!` (#457) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * sometimes-in-place bias_act * update after dropout PR * add to docs * also fix two unrelated docstring which just told you what the function was called without explaining anything * tidy & un-comment * comment out 2nd path again * add Returns for 1.6 * upgrade tests * more tests * skip hardσ tests * Update test/bias_act.jl --- docs/src/reference.md | 2 + src/NNlib.jl | 3 ++ src/bias_act.jl | 107 +++++++++++++++++++++++++++++++++++++++ src/dropout.jl | 21 -------- src/utils.jl | 64 ++++++++++++++++++++++-- test/bias_act.jl | 114 ++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 7 files changed, 287 insertions(+), 25 deletions(-) create mode 100644 src/bias_act.jl create mode 100644 test/bias_act.jl diff --git a/docs/src/reference.md b/docs/src/reference.md index fe0ad2d8e..c01db6b24 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -75,6 +75,7 @@ pad_zeros ## Convolution `Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally. + `NNlib.conv` supports complex datatypes on CPU and CUDA devices. !!! AMDGPU MIOpen supports only cross-correlation (flipkernel=true). @@ -152,4 +153,5 @@ ctc_loss logsumexp NNlib.glu NNlib.within_gradient +bias_act! ``` diff --git a/src/NNlib.jl b/src/NNlib.jl index 8b0d3d5d5..8450a0261 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -72,6 +72,9 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, include("conv_bias_act.jl") export conv_bias_act, conv_bias_act! +include("bias_act.jl") +export bias_act! + include("fold.jl") include("ctc.jl") diff --git a/src/bias_act.jl b/src/bias_act.jl new file mode 100644 index 000000000..ef7fb29d9 --- /dev/null +++ b/src/bias_act.jl @@ -0,0 +1,107 @@ + +using NNlib: fast_act, tanh_fast +using ChainRulesCore + +const RCR = RuleConfig{>:HasReverseMode} + +# This just saves typing `only.(only.(` many times: +@inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x))) + +# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` +# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. +struct NotaNumber <: Real end + +""" + bias_act!(σ, x, b) + +This is equivalent to `x .= σ.(x .+ b)`, also replacing `sigmoid` & `tanh` +with `sigmoid_fast` & `tanh_fast`. +It will only overwrite `x` when `x isa StridedArray{<:AbstractFloat}`. + +When used within a gradient, it will overwrite only when `σ` has +a method of `derivatives_given_output` which does not need the input at all. +Such methods are defined by e.g. `@scalar_rule relu(x) Ω > 0` where the derivative +contains only `Ω` (the output) not `x`. + +!!! warning + This is not safe to use if `x` is still needed for the gradient + of some other function. Incorrect use will give silently wrong answers. + It is intended mainly for Flux layers, in which the previous operation is + known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer. +""" +bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) = + _fast_broadcast!(fast_act(σ, x)∘(+), x, b) # works around a SIMD bug + +function bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) + b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") + _fast_broadcast!(fast_act(σ, x), x) +end + +function bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) + b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") + x # pass-through +end + +function bias_act!(σ::Function, x::AbstractArray, b) + b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") + fast_act(σ, x).(x .+ b) # fallback +end + +function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B} + biasgrad = if eltype(B) !== Bool + # Summing over ndims(x)+1 is a trick to make b_dims type-stable + dims = ntuple(d -> size(b,d)==1 ? d : N+1, N) + _biasgrad(dx) = reshape(sum(dx; dims), size(b)) + else + Returns(NoTangent()) + end + + # Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + Ω = bias_act!(σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat} + function bias_act!_fastback(Δ) + # Tempting to overwrite x again, but only safe if you call pullback at most once, + # TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340 + # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592 + dx = only_derivative.(Ω, σ, NotaNumber()) .* unthunk(Δ) + return (NoTangent(), NoTangent(), dx, biasgrad(dx)) + end + return Ω, bias_act!_fastback + + # # Slower path: can't overwrite x, but can use derivatives_given_output + # # This case is WRONG and tests fail, but not sure why + # elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) + # Ω2 = fast_act(σ, x).(x) .+ b + # @show σ b + # function bias_act!_back2(Δ) + # dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ) + # return (NoTangent(), NoTangent(), dx, biasgrad(dx)) + # end + # return Ω2, bias_act!_back2 + + # Fallback path: let AD handle the broadcast + else + Ω3, back = rrule_via_ad(cfg, broadcast, fast_act(σ, x), bias_act!(identity, x, b)) + @inline function bias_act!_slowback(Δ) + _, _, dx = back(Δ) + return (NoTangent(), NoTangent(), dx, biasgrad(dx)) + end + return Ω3, bias_act!_slowback + end +end + +# Two easy cases with identity +function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::B) where {T,N,B} + dims = ntuple(d -> size(b,d)==1 ? d : N+1, N) + biasgrad(dx) = reshape(sum(dx; dims), size(b)) + function bias_act!_idback(Δ) + dx = unthunk(Δ) + return (NoTangent(), NoTangent(), dx, biasgrad(dx)) + end + return bias_act!(identity, x, b), bias_act!_idback +end +function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::Bool) where {T,N} + bias_act!_trivial(Δ) = (NoTangent(), NoTangent(), Δ, NoTangent()) + return x, bias_act!_trivial +end + diff --git a/src/dropout.jl b/src/dropout.jl index 86bcb6c6f..02673cf03 100644 --- a/src/dropout.jl +++ b/src/dropout.jl @@ -125,27 +125,6 @@ end # and seems about as fast, but needs a method `copy(::CUDA.RNG)` and careful checking. # https://github.com/FluxML/NNlib.jl/pull/454#issuecomment-1369357402 -""" - _fast_broadcast!(f, x, y, z...) - -This does `x .= f.(x, y, z...)`, but works around -an issue with broadcasting that prevents SIMD in such cases. -Can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed. - -Not intended for general use. Does not check sizes! -""" -function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function} - bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...)) - @simd ivdep for I in eachindex(bc) - @inbounds x[I] = bc[I] - end - return x -end -function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function} - # CUDA does not suffer from this bug - broadcast!(f, x, x, yz...) -end - """ _rng_from_array(x) diff --git a/src/utils.jl b/src/utils.jl index 6d16b1cb1..3d23e7383 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -53,15 +53,21 @@ ChainRulesCore.rrule(::typeof(within_gradient), x) = true, _ -> (NoTangent(), No """ safe_div(x, y) -Safely divide `x` by `y`. If `y` is zero, return `x` directly. +Returns `x/y` unless `y==0`, in which case it just returns `x`. +(Used internally by `scatter`.) """ safe_div(x, y) = ifelse(iszero(y), x, x/y) """ maximum_dims(dims) -Return the maximum value for each dimension. An array of dimensions `dims` is accepted. -The maximum of each dimension in the element is computed. +Given an array of `CartesianIndex{N}` or `NTuple{N,Int}`, +returns a tuple containing the maximum of all the 1st entries, +all the 2nd entries, and so on up to `N`. + +Given an array of integers, returns `(maximum(dims),)`. + +(These arguments are what [`scatter`](@ref NNlib.scatter) understands.) """ maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), ) maximum_dims(dims::AbstractArray{NTuple{N, T}}) where {N,T} = ntuple(i -> maximum(x->x[i], dims), N) @@ -105,4 +111,54 @@ function reverse_indices(idx::AbstractArray{<:Any,N}) where N return reverse_indices!(rev, idx) end -unsqueeze(x) = reshape(x, 1, size(x)...) +unsqueeze(x) = reshape(x, 1, size(x)...) + + +""" + _fast_broadcast!(f, x, y, z...) + +This does `x .= f.(x, y, z...)`, but works around +an issue with broadcasting that prevents SIMD in such cases. +Can perhaps be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed. + +Has an `rrule` to avoid mutation within derivatives. + +!!! warning + Not intended for general use. + Uses `@inbounds` but does not check sizes! + Assumes that `f` has no derivative! +""" +function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function} + bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...)) + @simd ivdep for I in eachindex(bc) + @inbounds x[I] = bc[I] + end + return x +end +function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function} + # CUDA does not suffer from this bug + broadcast!(f, x, x, yz...) +end + +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!), f::F, x::AbstractArray, ys...) where {F<:Function} + rrule_via_ad(cfg, broadcast, f, x, ys...) +end + +# Could get this from Compat.jl instead +# https://github.com/JuliaLang/julia/pull/39794 +if VERSION < v"1.7.0-DEV.793" + struct Returns{V} <: Function + value::V + Returns{V}(value) where {V} = new{V}(value) + Returns(value) = new{Core.Typeof(value)}(value) + end + + (obj::Returns)(args...; kw...) = obj.value + function Base.show(io::IO, obj::Returns) + show(io, typeof(obj)) + print(io, "(") + show(io, obj.value) + print(io, ")") + end +end + diff --git a/test/bias_act.jl b/test/bias_act.jl new file mode 100644 index 000000000..5d1b316d9 --- /dev/null +++ b/test/bias_act.jl @@ -0,0 +1,114 @@ +using NNlib, Zygote, ChainRulesCore, Test +using Zygote: ForwardDiff + +ACTIVATION_FUNCTIONS = + [@eval($a) for a in NNlib.ACTIVATIONS] + +@testset "bias_act!" begin + x = randn(3,4) + b = randn(3) + @test @inferred(bias_act!(identity, x, false)) === x # pass-through + @test @inferred(bias_act!(identity, copy(x), b)) ≈ (x .+ b) + @test @inferred(bias_act!(relu, copy(x), b)) ≈ relu.(x .+ b) + @test @inferred(bias_act!(tanh, copy(x), b)) ≈ tanh.(x .+ b) + @test @inferred(bias_act!(tanh, copy(x), false)) ≈ tanh.(x) + + # Check that it does overwrite: + x32 = rand(Float32, 3, 4); x32copy = copy(x32) + @test @inferred(bias_act!(cbrt, x32, b)) ≈ cbrt.(x32copy .+ b) + @test x32 ≈ cbrt.(x32copy .+ b) + + x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias + @test @inferred(bias_act!(tanh, x32, false)) ≈ tanh.(x32copy) + @test x32 ≈ tanh.(x32copy) + + x32 = rand(Float32, 3, 4); x32copy = copy(x32) # now check gradient rule + y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, b) + @test y ≈ x32 ≈ relu.(x32copy .+ b) + + x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias + y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, false) + @test y ≈ x32 ≈ relu.(x32copy) + + # Check that it doesn't try to overwrite non-float arrays: + xint = rand(-3:3, 3, 4) + bint = rand(-2:2, 3) + @test bias_act!(identity, copy(xint), bint) ≈ xint .+ bint + @test bias_act!(tanh, copy(xint), bint) ≈ tanh.(xint .+ bint) + @test bias_act!(tanh, copy(xint), false) ≈ tanh.(xint) + + # Reject bias===true so that Bool means one thing: + @test_throws Exception bias_act!(identity, rand(3), true) + @test_throws Exception bias_act!(cbrt, rand(3), true) + @test_throws Exception bias_act!(cbrt, rand(1:3, 3), true) + + @testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt], + ACTIVATION_FUNCTIONS, + [x->x, x -> 1/(x^2+2), x -> leakyrelu(x, 0.33)]) + # Only some of these go the fast path, `cbrt` is an example of a function NNlib knows nothing about. + fun == rrelu && continue # this one is randomised! + fun == hardσ && continue # this one has heisenbugs, not solved by discontinuity-avoidance code below + + @test bias_act!(fun, copy(x), b) ≈ fun.(x .+ b) + @test bias_act!(fun, copy(x), false) ≈ fun.(x) + + gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x) + gxplus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .+ eps()) + gxminus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .- eps()) + if !(gx ≈ gxplus ≈ gxminus) + @warn "skipping gradient tests due to discontinuity" fun x b + continue + end + @test gx ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1] + + gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x) + gx2plus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps()) + gx2minus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps()) + if !(gx2 ≈ gx2plus ≈ gx2minus) + @warn "skipping gradient tests due to discontinuity" fun x + continue + end + @test gx2 ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)[1] + + gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b) + @test gb ≈ Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)[1] + + @test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), false) == (nothing,) + @test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b .> 0) == (nothing,) + end + + @testset "gradient for fast_broadcast!" begin + # Gradient definition is just to disable mutation inside 2nd order AD + gx = ForwardDiff.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x) + @test gx ≈ Zygote.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x)[1] + + # relu should take the fast path + g2 = ForwardDiff.gradient(x) do x + sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1]) + end + @test_skip gx ≈ Zygote.gradient(x) do x # Here global variable b causes an error + sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1]) + end + # Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)). + # [5] (::typeof(∂(accum_global)))(Δ::Nothing) + @test g2 ≈ Zygote.gradient(x, b) do x, b + sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(relu, copy(x), b)), x, b)[1]) + end[1] + + g3 = ForwardDiff.gradient(x) do x + sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1]) + end + @test g3 ≈ Zygote.gradient(x, b) do x, b + sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1]) + end[1] + + # Anon function sure to take the generic path + g4 = ForwardDiff.gradient(x) do x + sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1]) + end + @test g4 ≈ Zygote.gradient(x, b) do x, b + sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1]) + end[1] + end +end + diff --git a/test/runtests.jl b/test/runtests.jl index 31db3a84f..ece02b0ed 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -127,6 +127,7 @@ end @testset "Activation Functions" begin include("activations.jl") + include("bias_act.jl") end @testset "Attention" begin