Skip to content

Commit

Permalink
Update to use LuxLib
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 25, 2022
1 parent 076d030 commit 0bcd7d1
Show file tree
Hide file tree
Showing 14 changed files with 141 additions and 350 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -29,8 +29,8 @@ ChainRulesCore = "1"
ComponentArrays = "0.13"
FillArrays = "0.13"
Functors = "0.2, 0.3"
LuxLib = "0.1"
NNlib = "0.8"
NNlibCUDA = "0.2"
Optimisers = "0.2"
Requires = "1"
Setfield = "0.8, 1"
Expand Down
2 changes: 1 addition & 1 deletion src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using CUDA
using CUDA.CUDNN
# Neural Network Backend
using NNlib
import NNlibCUDA: batchnorm, ∇batchnorm, CUDNNFloat
import LuxLib ## In v0.5 we can starting `using`. For v0.4, there will be naming conflicts
# Julia StdLibs
using Random, Statistics, LinearAlgebra, SparseArrays
# Parameter Manipulation
Expand Down
46 changes: 0 additions & 46 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
# Non Differentiable Functions
ChainRulesCore.@non_differentiable replicate(::Any)
ChainRulesCore.@non_differentiable update_statistics(::Any, ::Any, ::Any, ::Any, ::Any,
::Any, ::Any)
ChainRulesCore.@non_differentiable generate_dropout_mask(::Any, ::Any, ::Any, ::Any)
ChainRulesCore.@non_differentiable _get_reshape_dims(::Any, ::Any)
ChainRulesCore.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any)
ChainRulesCore.@non_differentiable glorot_normal(::Any...)
ChainRulesCore.@non_differentiable glorot_uniform(::Any...)
Expand All @@ -12,50 +8,8 @@ ChainRulesCore.@non_differentiable istraining(::Any)
ChainRulesCore.@non_differentiable _get_norm_except_dims(::Any, ::Any)
ChainRulesCore.@non_differentiable _affine(::Any)
ChainRulesCore.@non_differentiable _track_stats(::Any)
ChainRulesCore.@non_differentiable _copy_autodiff_barrier(::Any)

# NNlib Functions
function ChainRulesCore.rrule(::typeof(_batchnorm), g::CuArray{T}, b::CuArray{T},
x::Union{CuArray{T, 2}, CuArray{T, 4}, CuArray{T, 5}},
running_mean, running_var, momentum, epsilon,
training) where {T <: CUDNNFloat}
y = _batchnorm(g, b, x, running_mean, running_var, momentum, epsilon, training)
function _batchnorm_pullback(dy)
dg, db, dx = ∇batchnorm(g, b, x, unthunk(dy), running_mean, running_var, momentum;
eps=epsilon, training=training)
return NoTangent(), dg, db, dx, NoTangent(), NoTangent(), NoTangent(), NoTangent(),
NoTangent()
end
return y, _batchnorm_pullback
end

function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, x::AbstractArray{T, N},
p::T, q::T, dims, t::Val{training}) where {T, N, training}
y, mask, rng = dropout(rng, x, p, q, dims, t)
function dropout_pullback((dy, dmask, drng))
return (NoTangent(), NoTangent(), dy .* mask, NoTangent(), NoTangent(), NoTangent(),
NoTangent())
end
return (y, mask, rng), dropout_pullback
end

# Utilities

function ChainRulesCore.rrule(::typeof(_reshape_into_proper_shape), ::Nothing, y)
function _reshape_into_proper_shape_pullback(dx)
return NoTangent(), NoTangent(), NoTangent()
end
return nothing, _reshape_into_proper_shape_pullback
end

function ChainRulesCore.rrule(::typeof(_reshape_into_proper_shape), x, y)
res = _reshape_into_proper_shape(x, y)
function _reshape_into_proper_shape_pullback(dx)
return NoTangent(), reshape(dx, size(x)), NoTangent()
end
return res, _reshape_into_proper_shape_pullback
end

function ChainRulesCore.rrule(::typeof(merge), nt1::NamedTuple{F1},
nt2::NamedTuple{F2}) where {F1, F2}
y = merge(nt1, nt2)
Expand Down
3 changes: 1 addition & 2 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ function setup(rng::AbstractRNG, l::AbstractExplicitLayer)
end

"""
apply(model::AbstractExplicitLayer, x, ps::Union{ComponentArray,NamedTuple},
st::NamedTuple)
apply(model::AbstractExplicitLayer, x, ps, st::NamedTuple)
Simply calls `model(x, ps, st)`
"""
Expand Down
57 changes: 57 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,60 @@ Computes `x .* y`. Dispatches to CUDNN if possible.
" v0.5. Use `x .* y` instead.", :elementwise_mul)
return x .* y
end

# Dropout
"""
dropout(rng::AbstractRNG, x, p, q, dims, ::Val{training})
dropout(rng::AbstractRNG, x, mask, p, q, dims, t::Val{training}, ::Val{update_mask})
If `training` then dropout is applied on `x` with probability `p` along `dims`. If `mask` is
passed it is used if `update_mask` is false. If `update_mask` is true then the mask is
generated and used.
!!! warning
This function has been deprecated and will be removed in v0.5. Use `LuxLib.dropout`
instead.
"""
@inline function dropout(rng::AbstractRNG, x, p, q, dims, t::Val)
# Deprecated Functionality (Remove in v0.5)
Base.depwarn("`Lux.dropout` has been deprecated and will be removed in v0.5. Use " *
"`LuxLib.dropout` instead.", :dropout)

return LuxLib.dropout(rng, x, p, t; invp=q, dims)
end

@inline function dropout(rng::AbstractRNG, x, mask, p, q, dims, t::Val, um::Val)
# Deprecated Functionality (Remove in v0.5)
Base.depwarn("`Lux.dropout` has been deprecated and will be removed in v0.5. Use " *
"`LuxLib.dropout` instead.", :dropout)

return (LuxLib.dropout(rng, x, mask, p, t, um; invp=q, dims)..., Val(false))
end

# Normalization Implementation
"""
normalization(x, running_mean, running_var, scale, bias, activation, reduce_dims,
::Val{training}, momentum, epsilon)
Performs BatchNorm/GroupNorm based on input configuration
!!! warning
This function has been deprecated and will be removed in v0.5. Use
`LuxLib.(batch/group)norm` instead.
"""
@inline function normalization(x::AbstractArray{T, N},
running_mean::Union{Nothing, AbstractVector{T}},
running_var::Union{Nothing, AbstractVector{T}},
scale::Union{Nothing, AbstractVector{T}},
bias::Union{Nothing, AbstractVector{T}}, activation,
reduce_dims, t::Val, momentum::T=T(0.1),
epsilon::T=T(1e-5)) where {T, N}
# Deprecated Functionality (Remove in v0.5)
Base.depwarn("`Lux.normalization` has been deprecated and will be removed in v0.5. " *
"Use `LuxLib.(batch/group)norm` instead.", :normalization)

return activation.(LuxLib._normalization(x, running_mean, running_var, scale, bias,
reduce_dims, t, momentum, epsilon))
end
2 changes: 1 addition & 1 deletion src/layers/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ function _flatten_model(layers::Union{AbstractVector, Tuple})
if f isa Tuple || f isa AbstractVector
append!(new_layers, f)
elseif f isa Function
if !hasmethod(f, (Any, Union{ComponentArray, NamedTuple}, NamedTuple))
if !hasmethod(f, (Any, Any, NamedTuple))
if f === identity
continue
else
Expand Down
11 changes: 6 additions & 5 deletions src/layers/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ function Dropout(p; dims=:)
end

function (d::Dropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T}
y, _, rng = dropout(st.rng, x, d.p, d.q, d.dims, st.training)
return y, merge(st, (rng=rng,))
y, _, rng = LuxLib.dropout(st.rng, x, d.p, st.training; invp=d.q, d.dims)
return y, merge(st, (; rng))
end

function Base.show(io::IO, d::Dropout)
Expand Down Expand Up @@ -114,9 +114,10 @@ function VariationalHiddenDropout(p; dims=:)
end

function (d::VariationalHiddenDropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T}
y, mask, rng, update_mask = dropout(st.rng, x, st.mask, d.p, d.q, d.dims, st.training,
st.update_mask)
return y, merge(st, (mask=mask, rng=rng, update_mask=update_mask))
_mask = st.mask === nothing ? x : st.mask
y, mask, rng = LuxLib.dropout(st.rng, x, _mask, d.p, st.training, st.update_mask;
invp=d.q, d.dims)
return y, merge(st, (; mask, rng, update_mask=Val(false)))
end

function Base.show(io::IO, d::VariationalHiddenDropout)
Expand Down
Loading

0 comments on commit 0bcd7d1

Please sign in to comment.