-
-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* sometimes-in-place bias_act * update after dropout PR * add to docs * also fix two unrelated docstring which just told you what the function was called without explaining anything * tidy & un-comment * comment out 2nd path again * add Returns for 1.6 * upgrade tests * more tests * skip hardσ tests * Update test/bias_act.jl
- Loading branch information
Showing
7 changed files
with
287 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
|
||
using NNlib: fast_act, tanh_fast | ||
using ChainRulesCore | ||
|
||
const RCR = RuleConfig{>:HasReverseMode} | ||
|
||
# This just saves typing `only.(only.(` many times: | ||
@inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x))) | ||
|
||
# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` | ||
# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. | ||
struct NotaNumber <: Real end | ||
|
||
""" | ||
bias_act!(σ, x, b) | ||
This is equivalent to `x .= σ.(x .+ b)`, also replacing `sigmoid` & `tanh` | ||
with `sigmoid_fast` & `tanh_fast`. | ||
It will only overwrite `x` when `x isa StridedArray{<:AbstractFloat}`. | ||
When used within a gradient, it will overwrite only when `σ` has | ||
a method of `derivatives_given_output` which does not need the input at all. | ||
Such methods are defined by e.g. `@scalar_rule relu(x) Ω > 0` where the derivative | ||
contains only `Ω` (the output) not `x`. | ||
!!! warning | ||
This is not safe to use if `x` is still needed for the gradient | ||
of some other function. Incorrect use will give silently wrong answers. | ||
It is intended mainly for Flux layers, in which the previous operation is | ||
known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer. | ||
""" | ||
bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) = | ||
_fast_broadcast!(fast_act(σ, x)∘(+), x, b) # works around a SIMD bug | ||
|
||
function bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) | ||
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") | ||
_fast_broadcast!(fast_act(σ, x), x) | ||
end | ||
|
||
function bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) | ||
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") | ||
x # pass-through | ||
end | ||
|
||
function bias_act!(σ::Function, x::AbstractArray, b) | ||
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this") | ||
fast_act(σ, x).(x .+ b) # fallback | ||
end | ||
|
||
function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B} | ||
biasgrad = if eltype(B) !== Bool | ||
# Summing over ndims(x)+1 is a trick to make b_dims type-stable | ||
dims = ntuple(d -> size(b,d)==1 ? d : N+1, N) | ||
_biasgrad(dx) = reshape(sum(dx; dims), size(b)) | ||
else | ||
Returns(NoTangent()) | ||
end | ||
|
||
# Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ | ||
if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) | ||
Ω = bias_act!(σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat} | ||
function bias_act!_fastback(Δ) | ||
# Tempting to overwrite x again, but only safe if you call pullback at most once, | ||
# TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340 | ||
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592 | ||
dx = only_derivative.(Ω, σ, NotaNumber()) .* unthunk(Δ) | ||
return (NoTangent(), NoTangent(), dx, biasgrad(dx)) | ||
end | ||
return Ω, bias_act!_fastback | ||
|
||
# # Slower path: can't overwrite x, but can use derivatives_given_output | ||
# # This case is WRONG and tests fail, but not sure why | ||
# elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) | ||
# Ω2 = fast_act(σ, x).(x) .+ b | ||
# @show σ b | ||
# function bias_act!_back2(Δ) | ||
# dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ) | ||
# return (NoTangent(), NoTangent(), dx, biasgrad(dx)) | ||
# end | ||
# return Ω2, bias_act!_back2 | ||
|
||
# Fallback path: let AD handle the broadcast | ||
else | ||
Ω3, back = rrule_via_ad(cfg, broadcast, fast_act(σ, x), bias_act!(identity, x, b)) | ||
@inline function bias_act!_slowback(Δ) | ||
_, _, dx = back(Δ) | ||
return (NoTangent(), NoTangent(), dx, biasgrad(dx)) | ||
end | ||
return Ω3, bias_act!_slowback | ||
end | ||
end | ||
|
||
# Two easy cases with identity | ||
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::B) where {T,N,B} | ||
dims = ntuple(d -> size(b,d)==1 ? d : N+1, N) | ||
biasgrad(dx) = reshape(sum(dx; dims), size(b)) | ||
function bias_act!_idback(Δ) | ||
dx = unthunk(Δ) | ||
return (NoTangent(), NoTangent(), dx, biasgrad(dx)) | ||
end | ||
return bias_act!(identity, x, b), bias_act!_idback | ||
end | ||
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::Bool) where {T,N} | ||
bias_act!_trivial(Δ) = (NoTangent(), NoTangent(), Δ, NoTangent()) | ||
return x, bias_act!_trivial | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
using NNlib, Zygote, ChainRulesCore, Test | ||
using Zygote: ForwardDiff | ||
|
||
ACTIVATION_FUNCTIONS = | ||
[@eval($a) for a in NNlib.ACTIVATIONS] | ||
|
||
@testset "bias_act!" begin | ||
x = randn(3,4) | ||
b = randn(3) | ||
@test @inferred(bias_act!(identity, x, false)) === x # pass-through | ||
@test @inferred(bias_act!(identity, copy(x), b)) ≈ (x .+ b) | ||
@test @inferred(bias_act!(relu, copy(x), b)) ≈ relu.(x .+ b) | ||
@test @inferred(bias_act!(tanh, copy(x), b)) ≈ tanh.(x .+ b) | ||
@test @inferred(bias_act!(tanh, copy(x), false)) ≈ tanh.(x) | ||
|
||
# Check that it does overwrite: | ||
x32 = rand(Float32, 3, 4); x32copy = copy(x32) | ||
@test @inferred(bias_act!(cbrt, x32, b)) ≈ cbrt.(x32copy .+ b) | ||
@test x32 ≈ cbrt.(x32copy .+ b) | ||
|
||
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias | ||
@test @inferred(bias_act!(tanh, x32, false)) ≈ tanh.(x32copy) | ||
@test x32 ≈ tanh.(x32copy) | ||
|
||
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # now check gradient rule | ||
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, b) | ||
@test y ≈ x32 ≈ relu.(x32copy .+ b) | ||
|
||
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias | ||
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, false) | ||
@test y ≈ x32 ≈ relu.(x32copy) | ||
|
||
# Check that it doesn't try to overwrite non-float arrays: | ||
xint = rand(-3:3, 3, 4) | ||
bint = rand(-2:2, 3) | ||
@test bias_act!(identity, copy(xint), bint) ≈ xint .+ bint | ||
@test bias_act!(tanh, copy(xint), bint) ≈ tanh.(xint .+ bint) | ||
@test bias_act!(tanh, copy(xint), false) ≈ tanh.(xint) | ||
|
||
# Reject bias===true so that Bool means one thing: | ||
@test_throws Exception bias_act!(identity, rand(3), true) | ||
@test_throws Exception bias_act!(cbrt, rand(3), true) | ||
@test_throws Exception bias_act!(cbrt, rand(1:3, 3), true) | ||
|
||
@testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt], | ||
ACTIVATION_FUNCTIONS, | ||
[x->x, x -> 1/(x^2+2), x -> leakyrelu(x, 0.33)]) | ||
# Only some of these go the fast path, `cbrt` is an example of a function NNlib knows nothing about. | ||
fun == rrelu && continue # this one is randomised! | ||
fun == hardσ && continue # this one has heisenbugs, not solved by discontinuity-avoidance code below | ||
|
||
@test bias_act!(fun, copy(x), b) ≈ fun.(x .+ b) | ||
@test bias_act!(fun, copy(x), false) ≈ fun.(x) | ||
|
||
gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x) | ||
gxplus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .+ eps()) | ||
gxminus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .- eps()) | ||
if !(gx ≈ gxplus ≈ gxminus) | ||
@warn "skipping gradient tests due to discontinuity" fun x b | ||
continue | ||
end | ||
@test gx ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1] | ||
|
||
gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x) | ||
gx2plus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps()) | ||
gx2minus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps()) | ||
if !(gx2 ≈ gx2plus ≈ gx2minus) | ||
@warn "skipping gradient tests due to discontinuity" fun x | ||
continue | ||
end | ||
@test gx2 ≈ Zygote.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)[1] | ||
|
||
gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b) | ||
@test gb ≈ Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)[1] | ||
|
||
@test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), false) == (nothing,) | ||
@test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b .> 0) == (nothing,) | ||
end | ||
|
||
@testset "gradient for fast_broadcast!" begin | ||
# Gradient definition is just to disable mutation inside 2nd order AD | ||
gx = ForwardDiff.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x) | ||
@test gx ≈ Zygote.gradient(x -> sum(NNlib._fast_broadcast!(cbrt∘(+), copy(x), b)), x)[1] | ||
|
||
# relu should take the fast path | ||
g2 = ForwardDiff.gradient(x) do x | ||
sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1]) | ||
end | ||
@test_skip gx ≈ Zygote.gradient(x) do x # Here global variable b causes an error | ||
sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1]) | ||
end | ||
# Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)). | ||
# [5] (::typeof(∂(accum_global)))(Δ::Nothing) | ||
@test g2 ≈ Zygote.gradient(x, b) do x, b | ||
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(relu, copy(x), b)), x, b)[1]) | ||
end[1] | ||
|
||
g3 = ForwardDiff.gradient(x) do x | ||
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1]) | ||
end | ||
@test g3 ≈ Zygote.gradient(x, b) do x, b | ||
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1]) | ||
end[1] | ||
|
||
# Anon function sure to take the generic path | ||
g4 = ForwardDiff.gradient(x) do x | ||
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1]) | ||
end | ||
@test g4 ≈ Zygote.gradient(x, b) do x, b | ||
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1]) | ||
end[1] | ||
end | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters