Skip to content

Commit

Permalink
Merge pull request #113 from avik-pal/deprecate
Browse files Browse the repository at this point in the history
Deprecate `elementwise_*` and `applyactivation`
  • Loading branch information
avik-pal authored Jul 30, 2022
2 parents 449e67d + a92d27f commit 11ac3e4
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 166 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.4.11"
version = "0.4.12"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
28 changes: 2 additions & 26 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, x::AbstractAr
p::T, q::T, dims, t::Val{training}) where {T, N, training}
y, mask, rng = dropout(rng, x, p, q, dims, t)
function dropout_pullback((dy, dmask, drng))
return (NoTangent(), NoTangent(), elementwise_mul(dy, mask), NoTangent(),
NoTangent(), NoTangent(), NoTangent())
return (NoTangent(), NoTangent(), dy .* mask, NoTangent(), NoTangent(), NoTangent(),
NoTangent())
end
return (y, mask, rng), dropout_pullback
end
Expand Down Expand Up @@ -89,30 +89,6 @@ function ChainRulesCore.rrule(::typeof(collect), v::Vector)
return y, collect_pullback
end

# Activation rrules
function ChainRulesCore.rrule(::typeof(applyactivation), f::cudnnValidActivationTypes,
x::CuArray{T}) where {T <: CUDNNFloat}
mode = getCUDNNActivationMode(f)
y = CUDNN.cudnnActivationForward(x; mode)
function applyactivation_pullback(Δ)
return NoTangent(), NoTangent(), cudnnActivationBackward(y, Δ, x; mode), NoTangent()
end
return y, applyactivation_pullback
end

# Elementwise Functions
function ChainRulesCore.rrule(::typeof(elementwise_add), x, y) where {T}
z = elementwise_add(x, y)
_elementwise_add_pullback(Δ) = (NoTangent(), elementwise_add_pullback(x, y, Δ)...)
return z, _elementwise_add_pullback
end

function ChainRulesCore.rrule(::typeof(elementwise_mul), x, y) where {T}
z = elementwise_mul(x, y)
_elementwise_mul_pullback(Δ) = (NoTangent(), elementwise_mul_pullback(x, y, Δ)...)
return z, _elementwise_mul_pullback
end

# Zygote Fixes
function Zygote.accum(x::ComponentArray, ys::ComponentArray...)
return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x))
Expand Down
45 changes: 45 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,48 @@ function ActivationFunction(f)
:ActivationFunction)
return WrappedFunction(Base.Fix1(broadcast, f))
end

"""
applyactivation(f::Function, x::AbstractArray)
Apply the function `f` on `x` elementwise, i.e. `f.(x)`. Dispatches to CUDNN if possible.
!!! warning
This function has been deprecated. Use `f.(x)` instead.
"""
@inline function applyactivation(f::Function, x::AbstractArray)
Base.depwarn("`Lux.applyactivation` has been deprecated and will be removed in" *
" v0.5. Directly apply broadcasting instead.", :applyactivation)
return f.(x)
end

"""
elementwise_add(x, y)
Computes `x .+ y`. Dispatches to CUDNN if possible.
!!! warning
This function has been deprecated. Use `x .+ y` instead.
"""
@inline function elementwise_add(x, y)
Base.depwarn("`Lux.elementwise_add` has been deprecated and will be removed in" *
" v0.5. Use `x .+ y` instead.", :elementwise_add)
return x .+ y
end

"""
elementwise_mul(x, y)
Computes `x .* y`. Dispatches to CUDNN if possible.
!!! warning
This function has been deprecated. Use `x .* y` instead.
"""
@inline function elementwise_mul(x, y)
Base.depwarn("`Lux.elementwise_mul` has been deprecated and will be removed in" *
" v0.5. Use `x .* y` instead.", :elementwise_mul)
return x .* y
end
30 changes: 13 additions & 17 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ end
statelength(d::Dense) = 0

@inline function (d::Dense{false})(x::AbstractVecOrMat, ps, st::NamedTuple)
return applyactivation(d.activation, ps.weight * x), st
return d.activation.(ps.weight * x), st
end

@inline function (d::Dense{false, typeof(identity)})(x::AbstractVecOrMat, ps,
Expand All @@ -613,8 +613,7 @@ end
@inline function (d::Dense{false})(x::AbstractArray, ps, st::NamedTuple)
sz = size(x)
x_reshaped = reshape(x, sz[1], :)
return reshape(applyactivation(d.activation, ps.weight * x_reshaped), d.out_dims,
sz[2:end]...), st
return reshape(d.activation.(ps.weight * x_reshaped), d.out_dims, sz[2:end]...), st
end

@inline function (d::Dense{false, typeof(identity)})(x::AbstractArray, ps, st::NamedTuple)
Expand All @@ -624,34 +623,32 @@ end
end

@inline function (d::Dense{true})(x::AbstractVector, ps, st::NamedTuple)
return applyactivation(d.activation, elementwise_add(ps.weight * x, vec(ps.bias))), st
return d.activation.(ps.weight * x .+ vec(ps.bias)), st
end

@inline function (d::Dense{true, typeof(identity)})(x::AbstractVector, ps, st::NamedTuple)
return elementwise_add(ps.weight * x, vec(ps.bias)), st
return ps.weight * x .+ vec(ps.bias), st
end

@inline function (d::Dense{true})(x::AbstractMatrix, ps, st::NamedTuple)
return applyactivation(d.activation, elementwise_add(ps.weight * x, ps.bias)), st
return d.activation.(ps.weight * x .+ ps.bias), st
end

@inline function (d::Dense{true, typeof(identity)})(x::AbstractMatrix, ps, st::NamedTuple)
return elementwise_add(ps.weight * x, ps.bias), st
return ps.weight * x .+ ps.bias, st
end

@inline function (d::Dense{true})(x::AbstractArray, ps, st::NamedTuple)
sz = size(x)
x_reshaped = reshape(x, sz[1], :)
return (reshape(applyactivation(d.activation,
elementwise_add(ps.weight * x_reshaped, ps.bias)),
d.out_dims, sz[2:end]...), st)
return (reshape(d.activation.(ps.weight * x_reshaped .+ ps.bias), d.out_dims,
sz[2:end]...), st)
end

@inline function (d::Dense{true, typeof(identity)})(x::AbstractArray, ps, st::NamedTuple)
sz = size(x)
x_reshaped = reshape(x, sz[1], :)
return (reshape(elementwise_add(ps.weight * x_reshaped, ps.bias), d.out_dims,
sz[2:end]...), st)
return (reshape(ps.weight * x_reshaped .+ ps.bias, d.out_dims, sz[2:end]...), st)
end

"""
Expand Down Expand Up @@ -727,18 +724,17 @@ parameterlength(d::Scale{bias}) where {bias} = (1 + bias) * prod(d.dims)
statelength(d::Scale) = 0

function (d::Scale{true})(x::AbstractArray, ps, st::NamedTuple)
return applyactivation(d.activation,
elementwise_add(elementwise_mul(ps.weight, x), ps.bias)), st
return d.activation.(ps.weight .* x .+ ps.bias), st
end

function (d::Scale{true, typeof(identity)})(x::AbstractArray, ps, st::NamedTuple)
return elementwise_add(elementwise_mul(ps.weight, x), ps.bias), st
return ps.weight .* x .+ ps.bias, st
end

function (d::Scale{false})(x::AbstractArray, ps, st::NamedTuple)
return applyactivation(d.activation, elementwise_mul(ps.weight, x)), st
return d.activation.(ps.weight .* x), st
end

function (d::Scale{false, typeof(identity)})(x::AbstractArray, ps, st::NamedTuple)
return elementwise_mul(ps.weight, x), st
return ps.weight .* x, st
end
5 changes: 2 additions & 3 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,13 @@ end
@inline function (c::Conv{N, false})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation,
groups=c.groups)
return applyactivation(c.activation, conv_wrapper(x, ps.weight, cdims)), st
return c.activation.(conv_wrapper(x, ps.weight, cdims)), st
end

@inline function (c::Conv{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation,
groups=c.groups)
return applyactivation(c.activation,
elementwise_add(conv_wrapper(x, ps.weight, cdims), ps.bias)), st
return c.activation.(conv_wrapper(x, ps.weight, cdims) .+ ps.bias), st
end

function Base.show(io::IO, l::Conv)
Expand Down
7 changes: 3 additions & 4 deletions src/layers/normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,9 @@ function (BN::BatchNorm{affine, track_stats})(x::Union{CuArray{T, 2}, CuArray{T,
running_var2 = var(x; mean=running_mean2, dims=reduce_dims, corrected=false)
end
end
res = applyactivation(BN.activation,
batchnorm(affine ? ps.scale : nothing, affine ? ps.bias : nothing,
x, running_mean2, running_var2, BN.momentum;
eps=BN.epsilon, training=istraining(st)))
res = BN.activation.(batchnorm(affine ? ps.scale : nothing, affine ? ps.bias : nothing,
x, running_mean2, running_var2, BN.momentum;
eps=BN.epsilon, training=istraining(st)))
if track_stats
st = merge(st, (running_mean=running_mean2, running_var=running_var2))
end
Expand Down
111 changes: 2 additions & 109 deletions src/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ generated and used.
if training
return :(rng = replicate(rng);
mask = generate_dropout_mask(rng, x, p, q; dims);
return (elementwise_mul(x, ignore_derivatives(mask)), mask, rng))
return (x .* ignore_derivatives(mask), mask, rng))
else
return :(return (x, x, rng))
end
Expand All @@ -148,8 +148,7 @@ end
if training
return :(size(x, ndims(x)) != size(mask, ndims(x)) &&
return (dropout(rng, x, p, q, dims, t)..., Val(false));
return (elementwise_mul(x, ignore_derivatives(mask)), mask, rng,
Val(false)))
return (x .* ignore_derivatives(mask), mask, rng, Val(false)))
else
return :(return (x, mask, rng, Val(false)))
end
Expand All @@ -164,109 +163,3 @@ end
pad = 0
return PoolDims(x, k; padding=pad, stride=stride)
end

# CUDNN Constants
const cudnnValidActivationTypes = Union{typeof(tanh), typeof(sigmoid), typeof(relu),
typeof(elu), typeof(tanh_fast), typeof(sigmoid_fast)
}

# Activation Functions
## I think this is handled by NNlibCUDA. But currently leaving here for
## benchmarking larger models
function getCUDNNActivationMode(::Union{typeof(tanh), typeof(tanh_fast)})
return CUDNN.CUDNN_ACTIVATION_TANH
end
function getCUDNNActivationMode(::Union{typeof(sigmoid), typeof(sigmoid_fast)})
return CUDNN.CUDNN_ACTIVATION_SIGMOID
end
getCUDNNActivationMode(::Union{typeof(relu)}) = CUDNN.CUDNN_ACTIVATION_RELU
getCUDNNActivationMode(::Union{typeof(elu)}) = CUDNN.CUDNN_ACTIVATION_ELU

"""
applyactivation(f::Function, x::AbstractArray)
Apply the function `f` on `x` elementwise, i.e. `f.(x)`. Dispatches to CUDNN if possible.
"""
@inline applyactivation(f::Function, x::AbstractArray) = f.(x)
@inline function applyactivation(f::cudnnValidActivationTypes, x::CuArray{<:CUDNNFloat})
return CUDNN.cudnnActivationForward(x; mode=getCUDNNActivationMode(f))
end
@inline applyactivation(::typeof(identity), x::AbstractArray) = x

# Dispatch Certain Broadcasted Functions to CUDNN
@inline function broadcast_shape_pullback(x, Δ)
sx = size(x)
= size(Δ)
sx ==&& return Δ
return sum(Δ; dims=findall(sx .!= sΔ))
end

@inline isvalidtensorop(x1, x2) = false
@inline function isvalidtensorop(x1::CuArray{N, T},
x2::CuArray{N, T}) where {N, T <: CUDNNFloat}
return ndims(x1) <= 5 &&
(all(size(x2, i) == size(x1, i) || size(x2, i) == 1 for i in 1:ndims(x2)))
end

"""
elementwise_add(x, y)
Computes `x .+ y`. Dispatches to CUDNN if possible
"""
@inline elementwise_add(x, y) = x .+ y
@inline function elementwise_add(x::CuArray, y::CuArray)
!isvalidtensorop(x, y) && return x .+ y
return cudnnOpTensorWithDefaults(x, y; op=CUDNN.CUDNN_OP_TENSOR_ADD)
end

@inline function elementwise_add_pullback(x, y, Δ)
return broadcast_shape_pullback(x, Δ), broadcast_shape_pullback(y, Δ)
end

"""
elementwise_mul(x, y)
Computes `x .* y`. Dispatches to CUDNN if possible
"""
@inline elementwise_mul(x, y) = x .* y
@inline function elementwise_mul(x::CuArray, y::CuArray)
!isvalidtensorop(x, y) && return x .* y
return cudnnOpTensorWithDefaults(x, y; op=CUDNN.CUDNN_OP_TENSOR_MUL)
end

@inline function elementwise_mul_pullback(x, y, Δ)
return broadcast_shape_pullback(x, elementwise_mul(Δ, y)),
broadcast_shape_pullback(y, elementwise_mul(Δ, x))
end

# CUDNN Helpers
function cudnnOpTensorWithDefaults(x1, x2; y=similar(x1),
op::CUDNN.cudnnOpTensorOp_t=CUDNN.CUDNN_OP_TENSOR_ADD,
compType::DataType=(eltype(x1) <: Float64 ? Float64 :
Float32),
nanOpt::CUDNN.cudnnNanPropagation_t=CUDNN.CUDNN_NOT_PROPAGATE_NAN,
opTensorDesc::CUDNN.cudnnOpTensorDescriptor=CUDNN.cudnnOpTensorDescriptor(op,
CUDNN.cudnnDataType(compType),
nanOpt),
alpha1::Real=1, alpha2::Real=1, beta::Real=0,
x1Desc::CUDNN.cudnnTensorDescriptor=CUDNN.cudnnTensorDescriptor(x1),
x2Desc::CUDNN.cudnnTensorDescriptor=CUDNN.cudnnTensorDescriptor(x2),
yDesc::CUDNN.cudnnTensorDescriptor=CUDNN.cudnnTensorDescriptor(y))
T = eltype(x1)
alpha1, alpha2, beta = CUDNN.scalingParameter.((T,), (alpha1, alpha2, beta))
return CUDNN.cudnnOpTensorAD(x1, x2; opTensorDesc, alpha1, x1Desc, alpha2, x2Desc, beta,
yDesc, y)
end

function cudnnActivationBackward(y::CuArray{T}, Δ::CuArray{T}, x::CuArray{T};
mode) where {T}
Δx = similar(x)
desc = CUDNN.cudnnActivationDescriptor(mode, CUDNN.CUDNN_NOT_PROPAGATE_NAN, Cdouble(1))
CUDNN.cudnnActivationBackward(CUDNN.handle(), desc, CUDNN.scalingParameter(T, 1),
CUDNN.cudnnTensorDescriptor(y), y,
CUDNN.cudnnTensorDescriptor(Δ), Δ,
CUDNN.cudnnTensorDescriptor(x), x,
CUDNN.scalingParameter(T, 0),
CUDNN.cudnnTensorDescriptor(Δx), Δx)
return Δx
end
10 changes: 4 additions & 6 deletions test/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ include("test_utils.jl")

# On CPU the fallback should always work
@test Lux.elementwise_add(x, y) == x .+ y
@test_deprecated Lux.elementwise_add(x, y)
@test Lux.elementwise_mul(x, y) == x .* y
@test_deprecated Lux.elementwise_mul(x, y)
@test Lux.applyactivation(tanh, x) == tanh.(x)
@test Lux.applyactivation(custom_activation, x) == custom_activation.(x)
@test_deprecated Lux.applyactivation(tanh, x)

if T <: Real
# Gradient for complex outputs are not defined
Expand All @@ -29,12 +32,7 @@ include("test_utils.jl")

@test Lux.elementwise_add(x_g, y_g) == x_g .+ y_g
@test Lux.elementwise_mul(x_g, y_g) == x_g .* y_g
if T <: Real
@test Lux.applyactivation(tanh, x_g) == tanh.(x_g)
else
## See https://github.com/FluxML/NNlibCUDA.jl/issues/47
@test_broken Lux.applyactivation(tanh, x_g) == tanh.(x_g)
end
@test Lux.applyactivation(tanh, x_g) == tanh.(x_g)
# Custom Activation test
@test Lux.applyactivation(custom_activation, x_g) == custom_activation.(x_g)
end
Expand Down

0 comments on commit 11ac3e4

Please sign in to comment.