Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement to the @compact API #584

Merged
merged 5 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions docs/src/api/Lux/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ CurrentModule = Lux

All features listed on this page are **experimental** which means:

1. No SemVer Guarantees. We use code here to iterate fast and most users should wait for
these features to be marked non-experimental.
1. No SemVer Guarantees. We use code here to iterate fast. That said, historically we have
never broken any code in this module and have always provided a deprecation period.
2. Expect edge-cases and report them. It will help us move these features out of
experimental sooner.
3. None of the features are exported.
Expand Down Expand Up @@ -74,8 +74,14 @@ Lux.Experimental.DebugLayer
Lux.Experimental.share_parameters
```

## StatefulLuxLayer

[`Lux.StatefulLuxLayer`](@ref) used to be part of experimental features, but has been
promoted to stable API. It is now available via `Lux.StatefulLuxLayer`. Change all uses of
`Lux.Experimental.StatefulLuxLayer` to `Lux.StatefulLuxLayer`.

## Compact Layer API

```@docs
Lux.Experimental.@compact
```
[`Lux.@compact`](@ref) used to be part of experimental features, but has been promoted to
stable API. It is now available via `Lux.@compact`. Change all uses of
`Lux.Experimental.@compact` to `Lux.@compact`.
5 changes: 5 additions & 0 deletions docs/src/api/Lux/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ Lux.f64
StatefulLuxLayer
```

## Compact Layer

```@docs
@compact
```

## Truncated Stacktraces

Expand Down
7 changes: 3 additions & 4 deletions docs/src/introduction/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ standard AD and Optimisers API.

```@example quickstart
# Get the device determined by Lux
device = gpu_device()
dev = gpu_device()

# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> device
ps, st = Lux.setup(rng, model) .|> dev

# Dummy Input
x = rand(rng, Float32, 128, 2) |> device
x = rand(rng, Float32, 128, 2) |> dev

# Run the model
y, st = Lux.apply(model, x, ps, st)
Expand All @@ -74,7 +74,6 @@ st_opt, ps = Optimisers.update(st_opt, ps, gs)
```@example custom_compact
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, LuxAMDGPU, Metal # Optional packages for GPU support
import Lux.Experimental: @compact
using Printf # For pretty printing
```

Expand Down
9 changes: 8 additions & 1 deletion src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ using PrecompileTools: @recompile_invalidations
inputsize, outputsize, update_state, trainmode, testmode, setup, apply,
display_name, replicate
using LuxDeviceUtils: get_device

# @compact specific
using MacroTools: block, combinedef, splitdef
using ConstructionBase: ConstructionBase
end

@reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers
Expand Down Expand Up @@ -56,6 +60,7 @@ include("contrib/contrib.jl")

# Helpful Functionalities
include("helpers/stateful.jl")
include("helpers/compact.jl")

# Transform to and from other frameworks
include("transform/types.jl")
Expand All @@ -70,7 +75,8 @@ include("distributed/public_api.jl")
include("deprecated.jl")

# Layers
export cpu, gpu
export cpu, gpu # deprecated

export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer
export Bilinear, Dense, Embedding, Scale
export Conv, ConvTranspose, CrossCor, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool,
Expand All @@ -83,6 +89,7 @@ export RNNCell, LSTMCell, GRUCell, Recurrence, StatefulRecurrentCell
export SamePad, TimeLastIndex, BatchLastIndex

export StatefulLuxLayer
export @compact, CompactLuxLayer

export f16, f32, f64

Expand Down
8 changes: 2 additions & 6 deletions src/contrib/contrib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@ module Experimental
import ..Lux
using ..Lux, LuxCore, LuxDeviceUtils, Random
using LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer
import ..Lux: _merge, _pairs, initialstates, initialparameters, apply, NAME_TYPE,
_getproperty
import ..Lux: _merge, _pairs, initialstates, initialparameters, apply

using ADTypes: ADTypes
import ChainRulesCore as CRC
using ConcreteStructs: @concrete
import ConstructionBase: constructorof
using Functors: Functors, fmap, functor
using MacroTools: block, combinedef, splitdef
using Markdown: @doc_str
using Random: AbstractRNG, Random
using Setfield: Setfield
Expand All @@ -21,8 +18,7 @@ include("training.jl")
include("freeze.jl")
include("share_parameters.jl")
include("debug.jl")
include("stateful.jl")
include("compact.jl")
include("deprecated.jl")

end

Expand Down
10 changes: 9 additions & 1 deletion src/contrib/stateful.jl → src/contrib/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# Deprecated
macro compact(exs...)
Base.depwarn(
"Lux.Experimental.@compact` has been promoted out of `Lux.Experimental` and is now \
available in `Lux`. In other words this has been deprecated and will be removed \
in v0.6. Use `Lux.@compact` instead.",
Symbol("@compact"))
return Lux.__compact_macro_impl(exs...)
end

function StatefulLuxLayer(args...; kwargs...)
Base.depwarn(
"Lux.Experimental.StatefulLuxLayer` has been promoted out of `Lux.Experimental` \
Expand Down
8 changes: 6 additions & 2 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

Transfer `x` to CPU.

!!! warning
!!! danger

This function has been deprecated. Use [`cpu_device`](@ref) instead.
"""
Expand All @@ -19,7 +19,7 @@ end

Transfer `x` to GPU determined by the backend set using [`Lux.gpu_backend!`](@ref).

!!! warning
!!! danger

This function has been deprecated. Use [`gpu_device`](@ref) instead. Using this function
inside performance critical code will cause massive slowdowns due to type inference
Expand All @@ -41,6 +41,10 @@ end
An easy way to update `TruncatedStacktraces.VERBOSE` without having to load it manually.

Effectively does `TruncatedStacktraces.VERBOSE[] = disable`

!!! danger

This function is now deprecated and will be removed in v0.6.
"""
function disable_stacktrace_truncation!(; disable::Bool=true)
Base.depwarn("`disable_stacktrace_truncation!` is not needed anymore, as \
Expand Down
22 changes: 13 additions & 9 deletions src/contrib/compact.jl → src/helpers/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ Here is a linear model:

```julia
using Lux, Random
import Lux.Experimental: @compact

r = @compact(w=rand(3)) do x
return w .* x
Expand Down Expand Up @@ -123,6 +122,11 @@ used inside a `Chain`.
account for the total number of parameters printed at the bottom.
"""
macro compact(_exs...)
return __compact_macro_impl(_exs...)
end

# Needed for the deprecation path
function __compact_macro_impl(_exs...)
# check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs
if isempty(_exs)
msg = "expects at least two expressions: a function and at least one keyword"
Expand Down Expand Up @@ -187,12 +191,13 @@ macro compact(_exs...)
fex_args = fex.args[1]
isa(fex_args, Symbol) ? string(fex_args) : join(fex_args.args, ", ")
catch e
@warn "Function stringifying does not yet handle all cases. Falling back to empty string for input arguments"
@warn "Function stringifying does not yet handle all cases. Falling back to empty \
string for input arguments"
end
block = string(Base.remove_linenums!(fex).args[2])

# edit expressions
vars = map(ex -> ex.args[1], kwexs)
vars = map(first ∘ Base.Fix2(getproperty, :args), kwexs)
fex = supportself(fex, vars)

# assemble
Expand All @@ -212,9 +217,8 @@ function supportself(fex::Expr, vars)
calls = []
for var in vars
push!(calls,
:($var = Lux.Experimental.__maybe_make_stateful(
Lux._getproperty($self, $(Val(var))),
Lux._getproperty($ps, $(Val(var))), Lux._getproperty($st, $(Val(var))))))
:($var = $(__maybe_make_stateful)($(_getproperty)($self, $(Val(var))),
$(_getproperty)($ps, $(Val(var))), $(_getproperty)($st, $(Val(var))))))
end
body = Expr(:let, Expr(:block, calls...), sdef[:body])
sdef[:body] = body
Expand All @@ -223,7 +227,7 @@ function supportself(fex::Expr, vars)
end

@inline function __maybe_make_stateful(layer::AbstractExplicitLayer, ps, st)
return Lux.StatefulLuxLayer(layer, ps, st)
return StatefulLuxLayer(layer, ps, st)
end
@inline __maybe_make_stateful(::Nothing, ps, st) = ps === nothing ? st : ps
@inline function __maybe_make_stateful(model::Union{AbstractVector, Tuple}, ps, st)
Expand Down Expand Up @@ -271,7 +275,7 @@ end
value_storage
end

function constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch}
function ConstructionBase.constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch}
return CompactLuxLayer{dispatch}
end

Expand All @@ -288,7 +292,7 @@ function __try_make_lux_layer(x::Union{AbstractVector, Tuple})
return __try_make_lux_layer(NamedTuple{Tuple(Symbol.(1:length(x)))}(x))
end
function __try_make_lux_layer(x)
function __maybe_convert_layer(l)
__maybe_convert_layer = @closure l -> begin
l isa AbstractExplicitLayer && return l
l isa Function && return WrappedFunction(l)
return l
Expand Down
Loading