Skip to content

Commit

Permalink
Add bias_act! (#457)
Browse files Browse the repository at this point in the history
* sometimes-in-place bias_act

* update after dropout PR

* add to docs

* also fix two unrelated docstring which just told you what the function was called without explaining anything

* tidy & un-comment

* comment out 2nd path again

* add Returns for 1.6

* upgrade tests

* more tests

* skip hardσ tests

* Update test/bias_act.jl
  • Loading branch information
mcabbott authored Sep 4, 2023
1 parent 2b548b6 commit 90a0043
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 25 deletions.
2 changes: 2 additions & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pad_zeros
## Convolution

`Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally.

`NNlib.conv` supports complex datatypes on CPU and CUDA devices.

!!! AMDGPU MIOpen supports only cross-correlation (flipkernel=true).
Expand Down Expand Up @@ -152,4 +153,5 @@ ctc_loss
logsumexp
NNlib.glu
NNlib.within_gradient
bias_act!
```
3 changes: 3 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
include("conv_bias_act.jl")
export conv_bias_act, conv_bias_act!

include("bias_act.jl")
export bias_act!

include("fold.jl")

include("ctc.jl")
Expand Down
107 changes: 107 additions & 0 deletions src/bias_act.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@

using NNlib: fast_act, tanh_fast
using ChainRulesCore

const RCR = RuleConfig{>:HasReverseMode}

# This just saves typing `only.(only.(` many times:
@inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x)))

# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)`
# is independent of `x`, as `_return_type` says `Union{}` when calling is an error.
struct NotaNumber <: Real end

"""
bias_act!(σ, x, b)
This is equivalent to `x .= σ.(x .+ b)`, also replacing `sigmoid` & `tanh`
with `sigmoid_fast` & `tanh_fast`.
It will only overwrite `x` when `x isa StridedArray{<:AbstractFloat}`.
When used within a gradient, it will overwrite only when `σ` has
a method of `derivatives_given_output` which does not need the input at all.
Such methods are defined by e.g. `@scalar_rule relu(x) Ω > 0` where the derivative
contains only `Ω` (the output) not `x`.
!!! warning
This is not safe to use if `x` is still needed for the gradient
of some other function. Incorrect use will give silently wrong answers.
It is intended mainly for Flux layers, in which the previous operation is
known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer.
"""
bias_act!::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) =
_fast_broadcast!(fast_act(σ, x)(+), x, b) # works around a SIMD bug

function bias_act!::Function, x::StridedArray{<:AbstractFloat}, b::Bool)
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
_fast_broadcast!(fast_act(σ, x), x)
end

function bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool)
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
x # pass-through
end

function bias_act!::Function, x::AbstractArray, b)
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
fast_act(σ, x).(x .+ b) # fallback
end

function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B}
biasgrad = if eltype(B) !== Bool
# Summing over ndims(x)+1 is a trick to make b_dims type-stable
dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)
_biasgrad(dx) = reshape(sum(dx; dims), size(b))
else
Returns(NoTangent())
end

# Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ
if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber}))
Ω = bias_act!(σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat}
function bias_act!_fastback(Δ)
# Tempting to overwrite x again, but only safe if you call pullback at most once,
# TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592
dx = only_derivative.(Ω, σ, NotaNumber()) .* unthunk(Δ)
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
end
return Ω, bias_act!_fastback

# # Slower path: can't overwrite x, but can use derivatives_given_output
# # This case is WRONG and tests fail, but not sure why
# elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
# Ω2 = fast_act(σ, x).(x) .+ b
# @show σ b
# function bias_act!_back2(Δ)
# dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ)
# return (NoTangent(), NoTangent(), dx, biasgrad(dx))
# end
# return Ω2, bias_act!_back2

# Fallback path: let AD handle the broadcast
else
Ω3, back = rrule_via_ad(cfg, broadcast, fast_act(σ, x), bias_act!(identity, x, b))
@inline function bias_act!_slowback(Δ)
_, _, dx = back(Δ)
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
end
return Ω3, bias_act!_slowback
end
end

# Two easy cases with identity
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::B) where {T,N,B}
dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)
biasgrad(dx) = reshape(sum(dx; dims), size(b))
function bias_act!_idback(Δ)
dx = unthunk(Δ)
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
end
return bias_act!(identity, x, b), bias_act!_idback
end
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::Bool) where {T,N}
bias_act!_trivial(Δ) = (NoTangent(), NoTangent(), Δ, NoTangent())
return x, bias_act!_trivial
end

21 changes: 0 additions & 21 deletions src/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,27 +125,6 @@ end
# and seems about as fast, but needs a method `copy(::CUDA.RNG)` and careful checking.
# https://github.com/FluxML/NNlib.jl/pull/454#issuecomment-1369357402

"""
_fast_broadcast!(f, x, y, z...)
This does `x .= f.(x, y, z...)`, but works around
an issue with broadcasting that prevents SIMD in such cases.
Can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.
Not intended for general use. Does not check sizes!
"""
function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function}
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))
@simd ivdep for I in eachindex(bc)
@inbounds x[I] = bc[I]
end
return x
end
function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function}
# CUDA does not suffer from this bug
broadcast!(f, x, x, yz...)
end


"""
_rng_from_array(x)
Expand Down
64 changes: 60 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,21 @@ ChainRulesCore.rrule(::typeof(within_gradient), x) = true, _ -> (NoTangent(), No
"""
safe_div(x, y)
Safely divide `x` by `y`. If `y` is zero, return `x` directly.
Returns `x/y` unless `y==0`, in which case it just returns `x`.
(Used internally by `scatter`.)
"""
safe_div(x, y) = ifelse(iszero(y), x, x/y)

"""
maximum_dims(dims)
Return the maximum value for each dimension. An array of dimensions `dims` is accepted.
The maximum of each dimension in the element is computed.
Given an array of `CartesianIndex{N}` or `NTuple{N,Int}`,
returns a tuple containing the maximum of all the 1st entries,
all the 2nd entries, and so on up to `N`.
Given an array of integers, returns `(maximum(dims),)`.
(These arguments are what [`scatter`](@ref NNlib.scatter) understands.)
"""
maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), )
maximum_dims(dims::AbstractArray{NTuple{N, T}}) where {N,T} = ntuple(i -> maximum(x->x[i], dims), N)
Expand Down Expand Up @@ -105,4 +111,54 @@ function reverse_indices(idx::AbstractArray{<:Any,N}) where N
return reverse_indices!(rev, idx)
end

unsqueeze(x) = reshape(x, 1, size(x)...)
unsqueeze(x) = reshape(x, 1, size(x)...)


"""
_fast_broadcast!(f, x, y, z...)
This does `x .= f.(x, y, z...)`, but works around
an issue with broadcasting that prevents SIMD in such cases.
Can perhaps be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.
Has an `rrule` to avoid mutation within derivatives.
!!! warning
Not intended for general use.
Uses `@inbounds` but does not check sizes!
Assumes that `f` has no derivative!
"""
function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function}
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))
@simd ivdep for I in eachindex(bc)
@inbounds x[I] = bc[I]
end
return x
end
function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function}
# CUDA does not suffer from this bug
broadcast!(f, x, x, yz...)
end

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!), f::F, x::AbstractArray, ys...) where {F<:Function}
rrule_via_ad(cfg, broadcast, f, x, ys...)
end

# Could get this from Compat.jl instead
# https://github.com/JuliaLang/julia/pull/39794
if VERSION < v"1.7.0-DEV.793"
struct Returns{V} <: Function
value::V
Returns{V}(value) where {V} = new{V}(value)
Returns(value) = new{Core.Typeof(value)}(value)
end

(obj::Returns)(args...; kw...) = obj.value
function Base.show(io::IO, obj::Returns)
show(io, typeof(obj))
print(io, "(")
show(io, obj.value)
print(io, ")")
end
end

114 changes: 114 additions & 0 deletions test/bias_act.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using NNlib, Zygote, ChainRulesCore, Test
using Zygote: ForwardDiff

ACTIVATION_FUNCTIONS =
[@eval($a) for a in NNlib.ACTIVATIONS]

@testset "bias_act!" begin
x = randn(3,4)
b = randn(3)
@test @inferred(bias_act!(identity, x, false)) === x # pass-through
@test @inferred(bias_act!(identity, copy(x), b)) (x .+ b)
@test @inferred(bias_act!(relu, copy(x), b)) relu.(x .+ b)
@test @inferred(bias_act!(tanh, copy(x), b)) tanh.(x .+ b)
@test @inferred(bias_act!(tanh, copy(x), false)) tanh.(x)

# Check that it does overwrite:
x32 = rand(Float32, 3, 4); x32copy = copy(x32)
@test @inferred(bias_act!(cbrt, x32, b)) cbrt.(x32copy .+ b)
@test x32 cbrt.(x32copy .+ b)

x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias
@test @inferred(bias_act!(tanh, x32, false)) tanh.(x32copy)
@test x32 tanh.(x32copy)

x32 = rand(Float32, 3, 4); x32copy = copy(x32) # now check gradient rule
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, b)
@test y x32 relu.(x32copy .+ b)

x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, false)
@test y x32 relu.(x32copy)

# Check that it doesn't try to overwrite non-float arrays:
xint = rand(-3:3, 3, 4)
bint = rand(-2:2, 3)
@test bias_act!(identity, copy(xint), bint) xint .+ bint
@test bias_act!(tanh, copy(xint), bint) tanh.(xint .+ bint)
@test bias_act!(tanh, copy(xint), false) tanh.(xint)

# Reject bias===true so that Bool means one thing:
@test_throws Exception bias_act!(identity, rand(3), true)
@test_throws Exception bias_act!(cbrt, rand(3), true)
@test_throws Exception bias_act!(cbrt, rand(1:3, 3), true)

@testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt],
ACTIVATION_FUNCTIONS,
[x->x, x -> 1/(x^2+2), x -> leakyrelu(x, 0.33)])
# Only some of these go the fast path, `cbrt` is an example of a function NNlib knows nothing about.
fun == rrelu && continue # this one is randomised!
fun == hardσ && continue # this one has heisenbugs, not solved by discontinuity-avoidance code below

@test bias_act!(fun, copy(x), b) fun.(x .+ b)
@test bias_act!(fun, copy(x), false) fun.(x)

gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)
gxplus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .+ eps())
gxminus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .- eps())
if !(gx gxplus gxminus)
@warn "skipping gradient tests due to discontinuity" fun x b
continue
end
@test gx Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1]

gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)
gx2plus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())
gx2minus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())
if !(gx2 gx2plus gx2minus)
@warn "skipping gradient tests due to discontinuity" fun x
continue
end
@test gx2 Zygote.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)[1]

gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)
@test gb Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)[1]

@test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), false) == (nothing,)
@test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b .> 0) == (nothing,)
end

@testset "gradient for fast_broadcast!" begin
# Gradient definition is just to disable mutation inside 2nd order AD
gx = ForwardDiff.gradient(x -> sum(NNlib._fast_broadcast!(cbrt(+), copy(x), b)), x)
@test gx Zygote.gradient(x -> sum(NNlib._fast_broadcast!(cbrt(+), copy(x), b)), x)[1]

# relu should take the fast path
g2 = ForwardDiff.gradient(x) do x
sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
end
@test_skip gx Zygote.gradient(x) do x # Here global variable b causes an error
sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
end
# Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).
# [5] (::typeof(∂(accum_global)))(Δ::Nothing)
@test g2 Zygote.gradient(x, b) do x, b
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(relu, copy(x), b)), x, b)[1])
end[1]

g3 = ForwardDiff.gradient(x) do x
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1])
end
@test g3 Zygote.gradient(x, b) do x, b
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1])
end[1]

# Anon function sure to take the generic path
g4 = ForwardDiff.gradient(x) do x
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1])
end
@test g4 Zygote.gradient(x, b) do x, b
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1])
end[1]
end
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ end

@testset "Activation Functions" begin
include("activations.jl")
include("bias_act.jl")
end

@testset "Attention" begin
Expand Down

0 comments on commit 90a0043

Please sign in to comment.