Skip to content

Commit

Permalink
Update to use LuxLib
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 24, 2022
1 parent 191d63d commit 8f2985f
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 345 deletions.
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.4.23"
version = "0.4.24"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -11,9 +11,9 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -30,7 +30,6 @@ ComponentArrays = "0.13"
FillArrays = "0.13"
Functors = "0.2, 0.3"
NNlib = "0.8"
NNlibCUDA = "0.2"
Optimisers = "0.2"
Requires = "1"
Setfield = "0.8, 1"
Expand Down
2 changes: 1 addition & 1 deletion src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using CUDA
using CUDA.CUDNN
# Neural Network Backend
using NNlib
import NNlibCUDA: batchnorm, ∇batchnorm, CUDNNFloat
import LuxLib ## In v0.5 we can starting `using`. For v0.4, there will be naming conflicts
# Julia StdLibs
using Random, Statistics, LinearAlgebra, SparseArrays
# Parameter Manipulation
Expand Down
46 changes: 0 additions & 46 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
# Non Differentiable Functions
ChainRulesCore.@non_differentiable replicate(::Any)
ChainRulesCore.@non_differentiable update_statistics(::Any, ::Any, ::Any, ::Any, ::Any,
::Any, ::Any)
ChainRulesCore.@non_differentiable generate_dropout_mask(::Any, ::Any, ::Any, ::Any)
ChainRulesCore.@non_differentiable _get_reshape_dims(::Any, ::Any)
ChainRulesCore.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any)
ChainRulesCore.@non_differentiable glorot_normal(::Any...)
ChainRulesCore.@non_differentiable glorot_uniform(::Any...)
Expand All @@ -12,50 +8,8 @@ ChainRulesCore.@non_differentiable istraining(::Any)
ChainRulesCore.@non_differentiable _get_norm_except_dims(::Any, ::Any)
ChainRulesCore.@non_differentiable _affine(::Any)
ChainRulesCore.@non_differentiable _track_stats(::Any)
ChainRulesCore.@non_differentiable _copy_autodiff_barrier(::Any)

# NNlib Functions
function ChainRulesCore.rrule(::typeof(_batchnorm), g::CuArray{T}, b::CuArray{T},
x::Union{CuArray{T, 2}, CuArray{T, 4}, CuArray{T, 5}},
running_mean, running_var, momentum, epsilon,
training) where {T <: CUDNNFloat}
y = _batchnorm(g, b, x, running_mean, running_var, momentum, epsilon, training)
function _batchnorm_pullback(dy)
dg, db, dx = ∇batchnorm(g, b, x, unthunk(dy), running_mean, running_var, momentum;
eps=epsilon, training=training)
return NoTangent(), dg, db, dx, NoTangent(), NoTangent(), NoTangent(), NoTangent(),
NoTangent()
end
return y, _batchnorm_pullback
end

function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, x::AbstractArray{T, N},
p::T, q::T, dims, t::Val{training}) where {T, N, training}
y, mask, rng = dropout(rng, x, p, q, dims, t)
function dropout_pullback((dy, dmask, drng))
return (NoTangent(), NoTangent(), dy .* mask, NoTangent(), NoTangent(), NoTangent(),
NoTangent())
end
return (y, mask, rng), dropout_pullback
end

# Utilities

function ChainRulesCore.rrule(::typeof(_reshape_into_proper_shape), ::Nothing, y)
function _reshape_into_proper_shape_pullback(dx)
return NoTangent(), NoTangent(), NoTangent()
end
return nothing, _reshape_into_proper_shape_pullback
end

function ChainRulesCore.rrule(::typeof(_reshape_into_proper_shape), x, y)
res = _reshape_into_proper_shape(x, y)
function _reshape_into_proper_shape_pullback(dx)
return NoTangent(), reshape(dx, size(x)), NoTangent()
end
return res, _reshape_into_proper_shape_pullback
end

function ChainRulesCore.rrule(::typeof(merge), nt1::NamedTuple{F1},
nt2::NamedTuple{F2}) where {F1, F2}
y = merge(nt1, nt2)
Expand Down
57 changes: 57 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,60 @@ Computes `x .* y`. Dispatches to CUDNN if possible.
" v0.5. Use `x .* y` instead.", :elementwise_mul)
return x .* y
end

# Dropout
"""
dropout(rng::AbstractRNG, x, p, q, dims, ::Val{training})
dropout(rng::AbstractRNG, x, mask, p, q, dims, t::Val{training}, ::Val{update_mask})
If `training` then dropout is applied on `x` with probability `p` along `dims`. If `mask` is
passed it is used if `update_mask` is false. If `update_mask` is true then the mask is
generated and used.
!!! warning
This function has been deprecated and will be removed in v0.5. Use `LuxLib.dropout`
instead.
"""
@inline function dropout(rng::AbstractRNG, x, p, q, dims, t::Val)
# Deprecated Functionality (Remove in v0.5)
Base.depwarn("`Lux.dropout` has been deprecated and will be removed in v0.5. Use " *
"`LuxLib.dropout` instead.", :dropout)

return LuxLib.dropout(rng, x, p, t; invp=q, dims)
end

@inline function dropout(rng::AbstractRNG, x, mask, p, q, dims, t::Val, um::Val)
# Deprecated Functionality (Remove in v0.5)
Base.depwarn("`Lux.dropout` has been deprecated and will be removed in v0.5. Use " *
"`LuxLib.dropout` instead.", :dropout)

return (LuxLib.dropout(rng, x, mask, p, t, um; invp=q, dims)..., Val(false))
end

# Normalization Implementation
"""
normalization(x, running_mean, running_var, scale, bias, activation, reduce_dims,
::Val{training}, momentum, epsilon)
Performs BatchNorm/GroupNorm based on input configuration
!!! warning
This function has been deprecated and will be removed in v0.5. Use
`LuxLib.(batch/group)norm` instead.
"""
@inline function normalization(x::AbstractArray{T, N},
running_mean::Union{Nothing, AbstractVector{T}},
running_var::Union{Nothing, AbstractVector{T}},
scale::Union{Nothing, AbstractVector{T}},
bias::Union{Nothing, AbstractVector{T}}, activation,
reduce_dims, t::Val, momentum::T=T(0.1),
epsilon::T=T(1e-5)) where {T, N}
# Deprecated Functionality (Remove in v0.5)
Base.depwarn("`Lux.normalization` has been deprecated and will be removed in v0.5. " *
"Use `LuxLib.(batch/group)norm` instead.", :normalization)

return activation.(LuxLib._normalization(x, running_mean, running_var, scale, bias,
reduce_dims, t, momentum, epsilon))
end
10 changes: 5 additions & 5 deletions src/layers/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ function Dropout(p; dims=:)
end

function (d::Dropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T}
y, _, rng = dropout(st.rng, x, d.p, d.q, d.dims, st.training)
return y, merge(st, (rng=rng,))
y, _, rng = LuxLib.dropout(st.rng, x, d.p, st.training; invp=d.q, d.dims)
return y, merge(st, (; rng))
end

function Base.show(io::IO, d::Dropout)
Expand Down Expand Up @@ -114,9 +114,9 @@ function VariationalHiddenDropout(p; dims=:)
end

function (d::VariationalHiddenDropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T}
y, mask, rng, update_mask = dropout(st.rng, x, st.mask, d.p, d.q, d.dims, st.training,
st.update_mask)
return y, merge(st, (mask=mask, rng=rng, update_mask=update_mask))
y, mask, rng = LuxLib.dropout(st.rng, x, st.mask, d.p, st.training, st.update_mask;
invp=d.q, d.dims)
return y, merge(st, (; mask, rng, update_mask=Val(false)))
end

function Base.show(io::IO, d::VariationalHiddenDropout)
Expand Down
Loading

0 comments on commit 8f2985f

Please sign in to comment.