This repository has been archived by the owner on Nov 4, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow fusing activation into normalization
- Loading branch information
Showing
11 changed files
with
103 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "LuxLib" | ||
uuid = "82251201-b29d-42c6-8e01-566dec8acb11" | ||
authors = ["Avik Pal <[email protected]> and contributors"] | ||
version = "0.3.13" | ||
version = "0.3.14" | ||
|
||
[deps] | ||
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
""" | ||
fast_activation!!(σ::F, x) where {F} | ||
Compute `σ.(x)` with the best possible implementation available. If it is possible to | ||
rewrite `x` in-place, it does so. If `x` is an immutable array, it falls back to the | ||
generic implementation. | ||
!!! note | ||
This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be | ||
done by the user if needed. | ||
## Arguments | ||
- `σ`: Activation function | ||
- `x`: Input array | ||
## Returns | ||
- Output Array with the same size as `x` | ||
""" | ||
@inline function fast_activation!!(σ::F, x::AbstractArray) where {F} | ||
σ === identity && return x | ||
ArrayInterface.can_setindex(x) && __fast_activation_impl!(σ, x) | ||
return σ.(x) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Specialized Implementation based off NNlib._fast_broadcast with added logic from | ||
# ArrayInterface | ||
# If we enter here, we already know that we can setindex into the array | ||
@inline function __fast_activation_impl!(σ::F, x::AbstractArray) where {F} | ||
if ArrayInterface.fast_scalar_indexing(x) | ||
bc = Broadcast.instantiate(Broadcast.broadcasted(σ, x)) | ||
@simd ivdep for I in eachindex(bc) | ||
@inbounds x[I] = bc[I] | ||
end | ||
else | ||
@. x = σ(x) | ||
end | ||
return x | ||
end | ||
|
||
function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, | ||
::typeof(__fast_activation_impl!), σ::F, x::AbstractArray{T}) where {F, T} | ||
σ === identity && return x, @closure(Δ->(CRC.NoTangent(), CRC.NoTangent(), Δ)) | ||
|
||
if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) | ||
__fast_activation_impl!(σ, x) | ||
∇__fast_activation_impl_no_cached = @closure Δ -> begin | ||
∂x = only_derivative.(x, σ, NotaNumber()) .* CRC.unthunk(Δ) | ||
return CRC.NoTangent(), CRC.NoTangent(), ∂x | ||
end | ||
return x, ∇__fast_activation_impl_no_cached | ||
end | ||
|
||
if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) | ||
y = @. σ(x) | ||
∇__fast_activation_impl_cached_crc = @closure Δ -> begin | ||
∂z = only_derivative.(y, σ, x) .* CRC.unthunk(Δ) | ||
return CRC.NoTangent(), CRC.NoTangent(), ∂z | ||
end | ||
return z, ∇__fast_activation_impl_cached_crc | ||
end | ||
|
||
y, pb_f = CRC.rrule_via_ad(cfg, broadcast, σ, x) | ||
∇__fast_activation_impl_cached = @closure Δ -> begin | ||
_, _, ∂x = pb_f(Δ) | ||
return CRC.NoTangent(), CRC.NoTangent(), ∂x | ||
end | ||
return y, ∇__fast_activation_impl_cached | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters