From 0bcd7d14dce592637e969040ac880861716bb69c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Sep 2022 20:24:47 -0400 Subject: [PATCH] Update to use LuxLib --- Project.toml | 4 +- src/Lux.jl | 2 +- src/autodiff.jl | 46 ------------ src/core.jl | 3 +- src/deprecated.jl | 57 +++++++++++++++ src/layers/containers.jl | 2 +- src/layers/dropout.jl | 11 +-- src/layers/normalize.jl | 143 +++++++++++++------------------------ src/layers/recurrent.jl | 4 +- src/nnlib.jl | 150 --------------------------------------- src/utils.jl | 35 ++++----- test/Project.toml | 1 - test/autodiff.jl | 13 ---- test/layers/normalize.jl | 20 +++--- 14 files changed, 141 insertions(+), 350 deletions(-) diff --git a/Project.toml b/Project.toml index 47ad43004e..4c30a5c077 100644 --- a/Project.toml +++ b/Project.toml @@ -11,9 +11,9 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" @@ -29,8 +29,8 @@ ChainRulesCore = "1" ComponentArrays = "0.13" FillArrays = "0.13" Functors = "0.2, 0.3" +LuxLib = "0.1" NNlib = "0.8" -NNlibCUDA = "0.2" Optimisers = "0.2" Requires = "1" Setfield = "0.8, 1" diff --git a/src/Lux.jl b/src/Lux.jl index dba9f3d28a..4b2fe35ab3 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -5,7 +5,7 @@ using CUDA using CUDA.CUDNN # Neural Network Backend using NNlib -import NNlibCUDA: batchnorm, ∇batchnorm, CUDNNFloat +import LuxLib ## In v0.5 we can starting `using`. For v0.4, there will be naming conflicts # Julia StdLibs using Random, Statistics, LinearAlgebra, SparseArrays # Parameter Manipulation diff --git a/src/autodiff.jl b/src/autodiff.jl index 3a3469b0da..466075bb9b 100644 --- a/src/autodiff.jl +++ b/src/autodiff.jl @@ -1,9 +1,5 @@ # Non Differentiable Functions ChainRulesCore.@non_differentiable replicate(::Any) -ChainRulesCore.@non_differentiable update_statistics(::Any, ::Any, ::Any, ::Any, ::Any, - ::Any, ::Any) -ChainRulesCore.@non_differentiable generate_dropout_mask(::Any, ::Any, ::Any, ::Any) -ChainRulesCore.@non_differentiable _get_reshape_dims(::Any, ::Any) ChainRulesCore.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any) ChainRulesCore.@non_differentiable glorot_normal(::Any...) ChainRulesCore.@non_differentiable glorot_uniform(::Any...) @@ -12,50 +8,8 @@ ChainRulesCore.@non_differentiable istraining(::Any) ChainRulesCore.@non_differentiable _get_norm_except_dims(::Any, ::Any) ChainRulesCore.@non_differentiable _affine(::Any) ChainRulesCore.@non_differentiable _track_stats(::Any) -ChainRulesCore.@non_differentiable _copy_autodiff_barrier(::Any) - -# NNlib Functions -function ChainRulesCore.rrule(::typeof(_batchnorm), g::CuArray{T}, b::CuArray{T}, - x::Union{CuArray{T, 2}, CuArray{T, 4}, CuArray{T, 5}}, - running_mean, running_var, momentum, epsilon, - training) where {T <: CUDNNFloat} - y = _batchnorm(g, b, x, running_mean, running_var, momentum, epsilon, training) - function _batchnorm_pullback(dy) - dg, db, dx = ∇batchnorm(g, b, x, unthunk(dy), running_mean, running_var, momentum; - eps=epsilon, training=training) - return NoTangent(), dg, db, dx, NoTangent(), NoTangent(), NoTangent(), NoTangent(), - NoTangent() - end - return y, _batchnorm_pullback -end - -function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, x::AbstractArray{T, N}, - p::T, q::T, dims, t::Val{training}) where {T, N, training} - y, mask, rng = dropout(rng, x, p, q, dims, t) - function dropout_pullback((dy, dmask, drng)) - return (NoTangent(), NoTangent(), dy .* mask, NoTangent(), NoTangent(), NoTangent(), - NoTangent()) - end - return (y, mask, rng), dropout_pullback -end # Utilities - -function ChainRulesCore.rrule(::typeof(_reshape_into_proper_shape), ::Nothing, y) - function _reshape_into_proper_shape_pullback(dx) - return NoTangent(), NoTangent(), NoTangent() - end - return nothing, _reshape_into_proper_shape_pullback -end - -function ChainRulesCore.rrule(::typeof(_reshape_into_proper_shape), x, y) - res = _reshape_into_proper_shape(x, y) - function _reshape_into_proper_shape_pullback(dx) - return NoTangent(), reshape(dx, size(x)), NoTangent() - end - return res, _reshape_into_proper_shape_pullback -end - function ChainRulesCore.rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1, F2} y = merge(nt1, nt2) diff --git a/src/core.jl b/src/core.jl index 4658848e22..45c4317ef5 100644 --- a/src/core.jl +++ b/src/core.jl @@ -78,8 +78,7 @@ function setup(rng::AbstractRNG, l::AbstractExplicitLayer) end """ - apply(model::AbstractExplicitLayer, x, ps::Union{ComponentArray,NamedTuple}, - st::NamedTuple) + apply(model::AbstractExplicitLayer, x, ps, st::NamedTuple) Simply calls `model(x, ps, st)` """ diff --git a/src/deprecated.jl b/src/deprecated.jl index 58e31f47a8..576ef7f3ee 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -130,3 +130,60 @@ Computes `x .* y`. Dispatches to CUDNN if possible. " v0.5. Use `x .* y` instead.", :elementwise_mul) return x .* y end + +# Dropout +""" + dropout(rng::AbstractRNG, x, p, q, dims, ::Val{training}) + dropout(rng::AbstractRNG, x, mask, p, q, dims, t::Val{training}, ::Val{update_mask}) + +If `training` then dropout is applied on `x` with probability `p` along `dims`. If `mask` is +passed it is used if `update_mask` is false. If `update_mask` is true then the mask is +generated and used. + +!!! warning + + This function has been deprecated and will be removed in v0.5. Use `LuxLib.dropout` + instead. +""" +@inline function dropout(rng::AbstractRNG, x, p, q, dims, t::Val) + # Deprecated Functionality (Remove in v0.5) + Base.depwarn("`Lux.dropout` has been deprecated and will be removed in v0.5. Use " * + "`LuxLib.dropout` instead.", :dropout) + + return LuxLib.dropout(rng, x, p, t; invp=q, dims) +end + +@inline function dropout(rng::AbstractRNG, x, mask, p, q, dims, t::Val, um::Val) + # Deprecated Functionality (Remove in v0.5) + Base.depwarn("`Lux.dropout` has been deprecated and will be removed in v0.5. Use " * + "`LuxLib.dropout` instead.", :dropout) + + return (LuxLib.dropout(rng, x, mask, p, t, um; invp=q, dims)..., Val(false)) +end + +# Normalization Implementation +""" + normalization(x, running_mean, running_var, scale, bias, activation, reduce_dims, + ::Val{training}, momentum, epsilon) + +Performs BatchNorm/GroupNorm based on input configuration + +!!! warning + + This function has been deprecated and will be removed in v0.5. Use + `LuxLib.(batch/group)norm` instead. +""" +@inline function normalization(x::AbstractArray{T, N}, + running_mean::Union{Nothing, AbstractVector{T}}, + running_var::Union{Nothing, AbstractVector{T}}, + scale::Union{Nothing, AbstractVector{T}}, + bias::Union{Nothing, AbstractVector{T}}, activation, + reduce_dims, t::Val, momentum::T=T(0.1), + epsilon::T=T(1e-5)) where {T, N} + # Deprecated Functionality (Remove in v0.5) + Base.depwarn("`Lux.normalization` has been deprecated and will be removed in v0.5. " * + "Use `LuxLib.(batch/group)norm` instead.", :normalization) + + return activation.(LuxLib._normalization(x, running_mean, running_var, scale, bias, + reduce_dims, t, momentum, epsilon)) +end diff --git a/src/layers/containers.jl b/src/layers/containers.jl index c03b8549ec..923cc20ed7 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -404,7 +404,7 @@ function _flatten_model(layers::Union{AbstractVector, Tuple}) if f isa Tuple || f isa AbstractVector append!(new_layers, f) elseif f isa Function - if !hasmethod(f, (Any, Union{ComponentArray, NamedTuple}, NamedTuple)) + if !hasmethod(f, (Any, Any, NamedTuple)) if f === identity continue else diff --git a/src/layers/dropout.jl b/src/layers/dropout.jl index 0f9375f5bb..660147323a 100644 --- a/src/layers/dropout.jl +++ b/src/layers/dropout.jl @@ -49,8 +49,8 @@ function Dropout(p; dims=:) end function (d::Dropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T} - y, _, rng = dropout(st.rng, x, d.p, d.q, d.dims, st.training) - return y, merge(st, (rng=rng,)) + y, _, rng = LuxLib.dropout(st.rng, x, d.p, st.training; invp=d.q, d.dims) + return y, merge(st, (; rng)) end function Base.show(io::IO, d::Dropout) @@ -114,9 +114,10 @@ function VariationalHiddenDropout(p; dims=:) end function (d::VariationalHiddenDropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T} - y, mask, rng, update_mask = dropout(st.rng, x, st.mask, d.p, d.q, d.dims, st.training, - st.update_mask) - return y, merge(st, (mask=mask, rng=rng, update_mask=update_mask)) + _mask = st.mask === nothing ? x : st.mask + y, mask, rng = LuxLib.dropout(st.rng, x, _mask, d.p, st.training, st.update_mask; + invp=d.q, d.dims) + return y, merge(st, (; mask, rng, update_mask=Val(false))) end function Base.show(io::IO, d::VariationalHiddenDropout) diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 3a5fbe5cbb..9f94955c57 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -21,19 +21,19 @@ slice and normalises the input accordingly. ## Keyword Arguments - - If `affine=true`, it also applies a shift and a rescale to the input through to - learnable per-channel bias and scale parameters. - - + `init_bias`: Controls how the `bias` is initiliazed - + `init_scale`: Controls how the `scale` is initiliazed - - If `track_stats=true`, accumulates mean and variance statistics in training phase that will be used to renormalize the input in test phase. + - `epsilon`: a value added to the denominator for numerical stability - `momentum`: the value used for the `running_mean` and `running_var` computation - `allow_fast_activation`: If `true`, then certain activations can be approximated with a faster version. The new activation function will be given by `NNlib.fast_act(activation)` + - If `affine=true`, it also applies a shift and a rescale to the input through to + learnable per-channel bias and scale parameters. + + + `init_bias`: Controls how the `bias` is initiliazed + + `init_scale`: Controls how the `scale` is initiliazed ## Inputs @@ -103,7 +103,7 @@ function initialparameters(rng::AbstractRNG, l::BatchNorm) if _affine(l) return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) else - return (scale=nothing, bias=nothing) + return NamedTuple() end end @@ -112,53 +112,26 @@ function initialstates(rng::AbstractRNG, l::BatchNorm) return (running_mean=zeros32(rng, l.chs), running_var=ones32(rng, l.chs), training=Val(true)) else - return (running_mean=nothing, running_var=nothing, training=Val(true)) + return (; training=Val(true)) end end parameterlength(l::BatchNorm) = _affine(l) ? (l.chs * 2) : 0 statelength(l::BatchNorm) = (_track_stats(l) ? 2 * l.chs : 0) + 1 -function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} - x_normalized, xmean, xvar = normalization(x, st.running_mean, st.running_var, ps.scale, - ps.bias, BN.activation, - collect([1:(N - 2); N]), st.training, - BN.momentum, BN.epsilon) +function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple) + y, stats = LuxLib.batchnorm(x, _getproperty(ps, Val(:scale)), + _getproperty(ps, Val(:bias)), + _getproperty(st, Val(:running_mean)), + _getproperty(st, Val(:running_var)); BN.momentum, + BN.epsilon, st.training) - st = merge(st, (running_mean=xmean, running_var=xvar)) - - return x_normalized, st -end - -function _batchnorm(scale, bias, x, running_mean, running_var, momentum, epsilon, training) - return batchnorm(scale, bias, x, running_mean, running_var, momentum; eps=epsilon, - training=training) -end - -function (BN::BatchNorm)(x::Union{CuArray{T, 2}, CuArray{T, 4}, CuArray{T, 5}}, ps, - st::NamedTuple) where {T <: Union{Float32, Float64}} - # NNlibCUDA silently updates running_mean and running_var so copying them - if istraining(st) - running_mean2 = _track_stats(BN) ? _copy_autodiff_barrier(st.running_mean) : nothing - running_var2 = _track_stats(BN) ? _copy_autodiff_barrier(st.running_var) : nothing - else - if _track_stats(BN) - running_mean2 = _copy_autodiff_barrier(st.running_mean) - running_var2 = _copy_autodiff_barrier(st.running_var) - else - N = ndims(x) - reduce_dims = collect([1:(N - 2); N]) - running_mean2 = mean(x; dims=reduce_dims) - running_var2 = var(x; mean=running_mean2, dims=reduce_dims, corrected=false) - end - end - res = BN.activation.(_batchnorm(_affine(BN) ? ps.scale : nothing, - _affine(BN) ? ps.bias : nothing, x, running_mean2, - running_var2, BN.momentum, BN.epsilon, istraining(st))) if _track_stats(BN) - st = merge(st, (running_mean=running_mean2, running_var=running_var2)) + @set! st.running_mean = stats.running_mean + @set! st.running_var = stats.running_var end - return res, st + + return BN.activation.(y), st end function Base.show(io::IO, l::BatchNorm) @@ -187,21 +160,21 @@ end ## Keyword Arguments - - If `affine=true`, it also applies a shift and a rescale to the input through to - learnable per-channel bias and scale parameters. - - + `init_bias`: Controls how the `bias` is initiliazed - + `init_scale`: Controls how the `scale` is initiliazed - - If `track_stats=true`, accumulates mean and variance statistics in training phase that will be used to renormalize the input in test phase. **(This feature has been deprecated and will be removed in v0.5)** + - `epsilon`: a value added to the denominator for numerical stability - `momentum`: the value used for the `running_mean` and `running_var` computation **(This feature has been deprecated and will be removed in v0.5)** - `allow_fast_activation`: If `true`, then certain activations can be approximated with a faster version. The new activation function will be given by `NNlib.fast_act(activation)` + - If `affine=true`, it also applies a shift and a rescale to the input through to + learnable per-channel bias and scale parameters. + + + `init_bias`: Controls how the `bias` is initiliazed + + `init_scale`: Controls how the `scale` is initiliazed ## Inputs @@ -300,7 +273,7 @@ function initialstates(rng::AbstractRNG, l::GroupNorm) (running_mean=zeros32(rng, l.groups), running_var=ones32(rng, l.groups), training=Val(true)) else - (running_mean=nothing, running_var=nothing, training=Val(true)) + (; training=Val(true)) end end @@ -308,17 +281,19 @@ parameterlength(l::GroupNorm) = _affine(l) ? (l.chs * 2) : 0 statelength(l::GroupNorm) = (_track_stats(l) ? 2 * l.groups : 0) + 1 -function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} - sz = size(x) - x_ = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ GN.groups, GN.groups, sz[N]) - - x_normalized, xmean, xvar = normalization(x_, st.running_mean, st.running_var, ps.scale, - ps.bias, GN.activation, collect(1:(N - 1)), - st.training, GN.momentum, GN.epsilon) +function (GN::GroupNorm)(x::AbstractArray, ps, st::NamedTuple) + y, stats = LuxLib.groupnorm(x, _getproperty(ps, Val(:scale)), + _getproperty(ps, Val(:bias)), + _getproperty(st, Val(:running_mean)), + _getproperty(st, Val(:running_var)); GN.groups, GN.epsilon, + GN.momentum, st.training) - st = merge(st, (running_mean=xmean, running_var=xvar)) + if _track_stats(GN) + @set! st.running_mean = stats.running_mean + @set! st.running_var = stats.running_var + end - return reshape(x_normalized, sz), st + return GN.activation.(y), st end function Base.show(io::IO, l::GroupNorm) @@ -479,6 +454,9 @@ where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. ## Keyword Arguments + - `allow_fast_activation`: If `true`, then certain activations can be approximated with + a faster version. The new activation function will be given by + `NNlib.fast_act(activation)` - `epsilon`: a value added to the denominator for numerical stability. - `dims`: Dimensions to normalize the array over. - If `affine=true`, it also applies a shift and a rescale to the input through to @@ -487,10 +465,6 @@ where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. + `init_bias`: Controls how the `bias` is initiliazed + `init_scale`: Controls how the `scale` is initiliazed - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` - ## Inputs - `x`: AbstractArray @@ -508,7 +482,7 @@ where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. + `bias`: Bias of shape `(shape..., 1)` + `scale`: Scale of shape `(shape..., 1)` """ -struct LayerNorm{affine, F1, N, T, F2, F3, D} <: AbstractExplicitLayer +struct LayerNorm{affine, F1, N, T, F2, F3, D} <: AbstractNormalizationLayer{affine, false} shape::NTuple{N, Int} activation::F1 epsilon::T @@ -526,8 +500,8 @@ function LayerNorm(shape::NTuple{N, <:Int}, activation=identity; epsilon::T=1.0f init_bias, init_scale, dims) end -function initialparameters(rng::AbstractRNG, ln::LayerNorm{affine}) where {affine} - if affine +function initialparameters(rng::AbstractRNG, ln::LayerNorm) + if _affine(ln) return (bias=ln.init_bias(rng, ln.shape..., 1), scale=ln.init_scale(rng, ln.shape..., 1)) else @@ -535,36 +509,15 @@ function initialparameters(rng::AbstractRNG, ln::LayerNorm{affine}) where {affin end end -@generated function (l::LayerNorm{affine, F1})(x::AbstractArray, ps, - st::NamedTuple) where {affine, F1} - calls = [] - push!(calls, :(_mean = mean(x; dims=l.dims); - _var = var(x; corrected=false, mean=_mean))) - - if affine - if F1 == typeof(identity) - push!(calls, - :(return ps.scale .* (x .- _mean) ./ sqrt.(_var .+ l.epsilon) .+ ps.bias, - st)) - else - push!(calls, - :(return l.activation.(ps.scale .* (x .- _mean) ./ - sqrt.(_var .+ l.epsilon) .+ ps.bias), st)) - end - else - if F1 == typeof(identity) - push!(calls, :(return (x .- _mean) ./ sqrt.(_var .+ l.epsilon), st)) - else - push!(calls, - :(return l.activation.((x .- _mean) ./ sqrt.(_var .+ l.epsilon)), st)) - end - end - return Expr(:block, calls...) +function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple) + y = l.activation.(LuxLib.layernorm(x, _getproperty(ps, Val(:scale)), + _getproperty(ps, Val(:bias)); l.dims, l.epsilon)) + return y, st end -function Base.show(io::IO, l::LayerNorm{affine}) where {affine} +function Base.show(io::IO, l::LayerNorm) print(io, "LayerNorm($(l.shape)") (l.activation == identity) || print(io, ", $(l.activation)") - print(io, ", affine=$(affine), dims=$(l.dims)") + print(io, ", affine=$(_affine(l)), dims=$(l.dims)") return print(io, ")") end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 47f22165aa..2dd5b791f4 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -338,7 +338,7 @@ function (lstm::LSTMCell{true})((x, (hidden_state, memory))::Tuple{<:AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}}, - ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) + ps, st::NamedTuple) g = ps.weight_i * x .+ ps.weight_h * hidden_state .+ ps.bias input, forget, cell, output = multigate(g, Val(4)) memory_new = @. sigmoid_fast(forget) * memory + sigmoid_fast(input) * tanh_fast(cell) @@ -350,7 +350,7 @@ function (lstm::LSTMCell{false})((x, (hidden_state, memory))::Tuple{<:AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}}, - ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) + ps, st::NamedTuple) g = ps.weight_i * x .+ ps.weight_h * hidden_state input, forget, cell, output = multigate(g, Val(4)) memory_new = @. sigmoid_fast(forget) * memory + sigmoid_fast(input) * tanh_fast(cell) diff --git a/src/nnlib.jl b/src/nnlib.jl index c11cd2a086..0c5cebc030 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -1,103 +1,3 @@ -# TODO(@avik-pal): Eventually we want to move all these functions and their adjoints to NNlib.jl - -# Normalization Implementation -@inline function update_statistics(x::AbstractArray{T, N}, - running_mean::AbstractArray{T, N}, - running_var::AbstractArray{T, N}, - batchmean::AbstractArray{T, N}, - batchvar::AbstractArray{T, N}, momentum::T, - reduce_dims) where {T, N} - sx = size(x) - m = T(prod((sx[i] for i in reduce_dims))) - if reduce_dims[end] != N - batchmean = mean(batchmean; dims=N) - batchvar = mean(batchvar; dims=N) - end - running_mean = @. (1 - momentum) * running_mean + momentum * batchmean - running_var = @. (1 - momentum) * running_var + momentum * batchvar * (m / (m - one(m))) - return (running_mean, running_var) -end - -""" - normalization(x, running_mean, running_var, scale, bias, activation, reduce_dims, - ::Val{training}, momentum, epsilon) - -Performs BatchNorm/GroupNorm/InstanceNorm based on input configuration - -!!! note - - Detailed docs are WIP -""" -@inline function normalization(x::AbstractArray{T, N}, - running_mean::Union{Nothing, AbstractVector{T}}, - running_var::Union{Nothing, AbstractVector{T}}, - scale::Union{Nothing, AbstractVector{T}}, - bias::Union{Nothing, AbstractVector{T}}, activation, - reduce_dims, t::Val, momentum::T=T(0.1), - epsilon::T=T(1e-5)) where {T, N} - running_mean_reshaped = _reshape_into_proper_shape(running_mean, x) - running_var_reshaped = _reshape_into_proper_shape(running_var, x) - scale_reshaped = _reshape_into_proper_shape(scale, x) - bias_reshaped = _reshape_into_proper_shape(bias, x) - x_norm, running_mean_, running_var_ = normalization_forward(x, running_mean_reshaped, - running_var_reshaped, - scale_reshaped, - bias_reshaped, activation, - reduce_dims, t, momentum, - epsilon) - return x_norm, _safe_vec(running_mean_), _safe_vec(running_var_) -end - -@generated function normalization_forward(x::AbstractArray{T, N}, running_mean::RM, - running_var::RV, scale::S, bias::B, activation::A, - reduce_dims, ::Val{training}, - momentum::T=T(0.1f0), - epsilon::T=T(1.0f-5)) where {RM, RV, S, B, T, N, - A, training} - calls = [] - if !training - if RM == Nothing - expr = :(batchmean = mean(x; dims=reduce_dims); - batchvar = var(x; mean=batchmean, dims=reduce_dims, corrected=false)) - else - expr = :(batchmean = running_mean; - batchvar = running_var) - end - push!(calls, expr) - else - expr = :(batchmean = mean(x; dims=reduce_dims); - batchvar = var(x; mean=batchmean, dims=reduce_dims, corrected=false)) - push!(calls, expr) - - if RM != Nothing - push!(calls, - :((running_mean, running_var) = update_statistics(x, running_mean, - running_var, batchmean, - batchvar, momentum, - reduce_dims))) - end - end - - expr = if S != Nothing - if A == typeof(identity) - :(result = scale .* (x .- batchmean) ./ sqrt.(batchvar .+ epsilon) .+ bias) - else - :(result = activation.(scale .* (x .- batchmean) ./ - sqrt.(batchvar .+ epsilon) .+ bias)) - end - else - if A == typeof(identity) - :(result = (x .- batchmean) ./ sqrt.(batchvar .+ epsilon)) - else - :(result = activation.((x .- batchmean) ./ sqrt.(batchvar .+ epsilon))) - end - end - push!(calls, expr) - push!(calls, :(return (result, running_mean, running_var))) - - return Expr(:block, calls...) -end - # Convolution @inline conv_wrapper(x, weight, cdims) = conv(x, weight, cdims) @@ -105,56 +5,6 @@ end return conv(copy(x), weight, cdims) end -# Dropout -@inline _dropout_shape(s, ::Colon) = size(s) -@inline function _dropout_shape(s, dims) - return tuple((i ∉ dims ? 1 : si for (i, si) in enumerate(size(s)))...) -end - -@inline _dropout_kernel(y::T, p, q) where {T} = y > p ? q : zero(T) - -@inline function generate_dropout_mask(rng::AbstractRNG, x, p, q; dims=:) - realfptype = float(real(eltype(x))) - y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims))) - y .= _dropout_kernel.(y, p, q) - return y -end - -""" - dropout(rng::AbstractRNG, x, p, q, dims, ::Val{training}) - dropout(rng::AbstractRNG, x, mask, p, q, dims, t::Val{training}, ::Val{update_mask}) - -If `training` then dropout is applied on `x` with probability `p` along `dims`. If `mask` is -passed it is used if `update_mask` is false. If `update_mask` is true then the mask is -generated and used. -""" -@inline @generated function dropout(rng::AbstractRNG, x, p, q, dims, - ::Val{training}) where {training} - if training - return :(rng = replicate(rng); - mask = generate_dropout_mask(rng, x, p, q; dims); - return (x .* ignore_derivatives(mask), mask, rng)) - else - return :(return (x, x, rng)) - end -end - -@inline @generated function dropout(rng::AbstractRNG, x, mask, p, q, dims, t::Val{training}, - ::Val{update_mask}) where {training, update_mask} - if update_mask - return :((y, mask, rng) = dropout(rng, x, p, q, dims, t); - return (y, mask, rng, Val(false))) - else - if training - return :(size(x, ndims(x)) != size(mask, ndims(x)) && - return (dropout(rng, x, p, q, dims, t)..., Val(false)); - return (x .* ignore_derivatives(mask), mask, rng, Val(false))) - else - return :(return (x, mask, rng, Val(false))) - end - end -end - # Adaptive Pooling @inline function compute_adaptive_pooling_dims(x::AbstractArray, outsize) insize = size(x)[1:(end - 2)] diff --git a/src/utils.jl b/src/utils.jl index e2888b05ea..7a1aa2d201 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -142,25 +142,6 @@ end # Getting typename get_typename(::T) where {T} = Base.typename(T).wrapper -# For Normalization -@inline @generated _safe_vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x - -@inline @inbounds function _get_reshape_dims(sx::NTuple{N, <:Int}, - ly::Int)::typeof(sx) where {N} - if ly == sx[N - 1] - return ntuple(i -> i == N - 1 ? ly : 1, N) - elseif N > 2 && ly == sx[N - 1] * sx[N - 2] - return ntuple(i -> i == (N - 1) || i == (N - 2) ? sx[i] : 1, N) - else - error("Invalid Dimensions") - end -end - -@inline _reshape_into_proper_shape(x::Nothing, y)::Nothing = x -@inline _reshape_into_proper_shape(x, y)::typeof(y) = reshape(x, - _get_reshape_dims(size(y), - length(x))) - # RNN Utilities @inline _gate(h::Int, n::Int) = (1:h) .+ h * (n - 1) @inline _gate(x::AbstractVector, h::Int, n::Int) = view(x, _gate(h, n)) @@ -190,10 +171,20 @@ Split up `x` into `N` equally sized chunks (along dimension `1`). # Val utilities get_known(::Val{T}) where {T} = T -# Copy and don't allow gradient propagation -_copy_autodiff_barrier(x) = copy(x) - # Indexing into NamedTuple function _index_namedtuple(nt::NamedTuple{fields}, idxs::AbstractArray) where {fields} return NamedTuple{fields[idxs]}(values(nt)[idxs]) end + +# If doesn't have a property, return nothing +@generated function _getproperty(x::NamedTuple{names}, ::Val{v}) where {names, v} + if v in names + return :(x.$v) + else + return :(nothing) + end +end + +function _getproperty(x::ComponentArray, ::Val{prop}) where {prop} + return prop in propertynames(x) ? getproperty(x, prop) : nothing +end diff --git a/test/Project.toml b/test/Project.toml index a8331265bb..dabb10125b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,7 +7,6 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/test/autodiff.jl b/test/autodiff.jl index ed32ac2f7f..c921dde374 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -30,16 +30,3 @@ end @test gs_x_1 == gs_x_2 end - -@testset "_reshape_into_proper_shape" begin - x = randn(rng, Float32, 3, 2) - y = randn(rng, Float32, 2, 2, 6, 2) - - @test size(Lux._reshape_into_proper_shape(x, y)) == (1, 1, 6, 1) - @inferred Lux._reshape_into_proper_shape(x, y) - - gs_1 = Zygote.gradient(x -> sum(Lux._reshape_into_proper_shape(x, y)), x)[1] - gs_2 = Zygote.gradient(x -> sum(reshape(x, (1, 1, 6, 1))), x)[1] - - @test gs_1 == gs_2 -end diff --git a/test/layers/normalize.jl b/test/layers/normalize.jl index 08735ff476..0b75d4d0fb 100644 --- a/test/layers/normalize.jl +++ b/test/layers/normalize.jl @@ -45,7 +45,7 @@ Random.seed!(rng, 0) x_ = m(x, ps, st_)[1] @test isapprox(x_[1], (1 .- 0.3) / sqrt(1.3), atol=1.0e-5) - @inferred m(x, ps, st) + @inferred first(m(x, ps, st)) run_JET_tests(m, x, ps, st) @@ -57,7 +57,7 @@ Random.seed!(rng, 0) x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] println(m) ps, st = Lux.setup(rng, m) - @inferred m(x, ps, st) + @inferred first(m(x, ps, st)) run_JET_tests(m, x, ps, st) if affine @@ -79,7 +79,7 @@ Random.seed!(rng, 0) @test isapprox(y, sigmoid.((x .- st_.running_mean) ./ sqrt.(st_.running_var .+ m.epsilon)), atol=1.0e-7) - @inferred m(x, ps, st) + @inferred first(m(x, ps, st)) run_JET_tests(m, x, ps, st) if affine @@ -97,7 +97,7 @@ Random.seed!(rng, 0) st = Lux.testmode(st) m(x, ps, st) @test (@allocated m(x, ps, st)) < 100_000_000 - @inferred m(x, ps, st) + @inferred first(m(x, ps, st)) run_JET_tests(m, x, ps, st) end @@ -165,7 +165,7 @@ end sqrt.(reshape(st_.running_var, 1, 1, 2, 1) .+ 1.0f-5) @test y≈reshape(out, size(x)) atol=1.0e-5 - @inferred m(x, ps, st) + @inferred first(m(x, ps, st)) run_JET_tests(m, x, ps, st) test_gradient_correctness_fdm(ps -> sum(first(m(x, ps, st))), ps; atol=1.0f-3, rtol=1.0f-3) @@ -175,7 +175,7 @@ end x = randn(rng, Float32, 3, 2, 1) println(m) ps, st = Lux.setup(rng, m) - @inferred m(x, ps, st) + @inferred first(m(x, ps, st)) run_JET_tests(m, x, ps, st) if affine @@ -194,7 +194,7 @@ end st = Lux.testmode(st) y, st_ = m(x, ps, st) - @inferred m(x, ps, st) + @inferred first(m(x, ps, st)) run_JET_tests(m, x, ps, st) if affine @@ -212,7 +212,7 @@ end st = Lux.testmode(st) m(x, ps, st) @test (@allocated m(x, ps, st)) < 100_000_000 - @inferred m(x, ps, st) + @inferred first(m(x, ps, st)) run_JET_tests(m, x, ps, st) end @@ -353,7 +353,7 @@ end println(ln) ps, st = Lux.setup(rng, ln) - @inferred ln(x, ps, st) + @inferred first(ln(x, ps, st)) y, st_ = ln(x, ps, st) @test isapprox(mean(y), 0; atol=1.0f-3, rtol=1.0f-3) @@ -374,7 +374,7 @@ end println(ln) ps, st = Lux.setup(rng, ln) - @inferred ln(x, ps, st) + @inferred first(ln(x, ps, st)) y, st_ = ln(x, ps, st) run_JET_tests(ln, x, ps, st)