From b3e9a5f06c096bb9ed4d7a266e866e48b32c8b3e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Apr 2024 11:16:57 -0400 Subject: [PATCH 1/5] Move compact out of experimental --- docs/src/api/Lux/contrib.md | 16 +++++++++++----- docs/src/api/Lux/utilities.md | 5 +++++ docs/src/introduction/index.md | 7 +++---- src/Lux.jl | 9 ++++++++- src/contrib/contrib.jl | 8 ++------ src/contrib/{stateful.jl => deprecated.jl} | 10 +++++++++- src/deprecated.jl | 8 ++++++-- src/{contrib => helpers}/compact.jl | 22 +++++++++++++--------- 8 files changed, 57 insertions(+), 28 deletions(-) rename src/contrib/{stateful.jl => deprecated.jl} (53%) rename src/{contrib => helpers}/compact.jl (95%) diff --git a/docs/src/api/Lux/contrib.md b/docs/src/api/Lux/contrib.md index b4cd1ec24..bf9c9c238 100644 --- a/docs/src/api/Lux/contrib.md +++ b/docs/src/api/Lux/contrib.md @@ -6,8 +6,8 @@ CurrentModule = Lux All features listed on this page are **experimental** which means: -1. No SemVer Guarantees. We use code here to iterate fast and most users should wait for - these features to be marked non-experimental. +1. No SemVer Guarantees. We use code here to iterate fast. That said, historically we have + never broken any code in this module and have always provided a deprecation period. 2. Expect edge-cases and report them. It will help us move these features out of experimental sooner. 3. None of the features are exported. @@ -74,8 +74,14 @@ Lux.Experimental.DebugLayer Lux.Experimental.share_parameters ``` +## StatefulLuxLayer + +[`Lux.StatefulLuxLayer`](@ref) used to be part of experimental features, but has been +promoted to stable API. It is now available via `Lux.StatefulLuxLayer`. Change all uses of +`Lux.Experimental.StatefulLuxLayer` to `Lux.StatefulLuxLayer`. + ## Compact Layer API -```@docs -Lux.Experimental.@compact -``` +[`Lux.@compact`](@ref) used to be part of experimental features, but has been promoted to +stable API. It is now available via `Lux.@compact`. Change all uses of +`Lux.Experimental.@compact` to `Lux.@compact`. diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index acfd4e245..0c8aa7dc5 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -54,6 +54,11 @@ Lux.f64 StatefulLuxLayer ``` +## Compact Layer + +```@docs +@compact +``` ## Truncated Stacktraces diff --git a/docs/src/introduction/index.md b/docs/src/introduction/index.md index a17d6ebd1..59e54141a 100644 --- a/docs/src/introduction/index.md +++ b/docs/src/introduction/index.md @@ -48,13 +48,13 @@ standard AD and Optimisers API. ```@example quickstart # Get the device determined by Lux -device = gpu_device() +dev = gpu_device() # Parameter and State Variables -ps, st = Lux.setup(rng, model) .|> device +ps, st = Lux.setup(rng, model) .|> dev # Dummy Input -x = rand(rng, Float32, 128, 2) |> device +x = rand(rng, Float32, 128, 2) |> dev # Run the model y, st = Lux.apply(model, x, ps, st) @@ -74,7 +74,6 @@ st_opt, ps = Optimisers.update(st_opt, ps, gs) ```@example custom_compact using Lux, Random, Optimisers, Zygote # using LuxCUDA, LuxAMDGPU, Metal # Optional packages for GPU support -import Lux.Experimental: @compact using Printf # For pretty printing ``` diff --git a/src/Lux.jl b/src/Lux.jl index e7ab1645d..69cabe120 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -26,6 +26,10 @@ using PrecompileTools: @recompile_invalidations inputsize, outputsize, update_state, trainmode, testmode, setup, apply, display_name, replicate using LuxDeviceUtils: get_device + + # @compact specific + using MacroTools: block, combinedef, splitdef + using ConstructionBase: ConstructionBase end @reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers @@ -56,6 +60,7 @@ include("contrib/contrib.jl") # Helpful Functionalities include("helpers/stateful.jl") +include("helpers/compact.jl") # Transform to and from other frameworks include("transform/types.jl") @@ -70,7 +75,8 @@ include("distributed/public_api.jl") include("deprecated.jl") # Layers -export cpu, gpu +export cpu, gpu # deprecated + export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer export Bilinear, Dense, Embedding, Scale export Conv, ConvTranspose, CrossCor, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool, @@ -83,6 +89,7 @@ export RNNCell, LSTMCell, GRUCell, Recurrence, StatefulRecurrentCell export SamePad, TimeLastIndex, BatchLastIndex export StatefulLuxLayer +export @compact, CompactLuxLayer export f16, f32, f64 diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index b72d545f3..ea65a653b 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -3,15 +3,12 @@ module Experimental import ..Lux using ..Lux, LuxCore, LuxDeviceUtils, Random using LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer -import ..Lux: _merge, _pairs, initialstates, initialparameters, apply, NAME_TYPE, - _getproperty +import ..Lux: _merge, _pairs, initialstates, initialparameters, apply using ADTypes: ADTypes import ChainRulesCore as CRC using ConcreteStructs: @concrete -import ConstructionBase: constructorof using Functors: Functors, fmap, functor -using MacroTools: block, combinedef, splitdef using Markdown: @doc_str using Random: AbstractRNG, Random using Setfield: Setfield @@ -21,8 +18,7 @@ include("training.jl") include("freeze.jl") include("share_parameters.jl") include("debug.jl") -include("stateful.jl") -include("compact.jl") +include("deprecated.jl") end diff --git a/src/contrib/stateful.jl b/src/contrib/deprecated.jl similarity index 53% rename from src/contrib/stateful.jl rename to src/contrib/deprecated.jl index 2ba8293c0..6496b6ddd 100644 --- a/src/contrib/stateful.jl +++ b/src/contrib/deprecated.jl @@ -1,4 +1,12 @@ -# Deprecated +macro compact(exs...) + Base.depwarn( + "Lux.Experimental.@compact` has been promoted out of `Lux.Experimental` and is now \ + available in `Lux`. In other words this has been deprecated and will be removed \ + in v0.6. Use `Lux.@compact` instead.", + Symbol("@compact")) + return Lux.__compact_macro_impl(exs...) +end + function StatefulLuxLayer(args...; kwargs...) Base.depwarn( "Lux.Experimental.StatefulLuxLayer` has been promoted out of `Lux.Experimental` \ diff --git a/src/deprecated.jl b/src/deprecated.jl index 4507de522..b5cb0aedd 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -4,7 +4,7 @@ Transfer `x` to CPU. -!!! warning +!!! danger This function has been deprecated. Use [`cpu_device`](@ref) instead. """ @@ -19,7 +19,7 @@ end Transfer `x` to GPU determined by the backend set using [`Lux.gpu_backend!`](@ref). -!!! warning +!!! danger This function has been deprecated. Use [`gpu_device`](@ref) instead. Using this function inside performance critical code will cause massive slowdowns due to type inference @@ -41,6 +41,10 @@ end An easy way to update `TruncatedStacktraces.VERBOSE` without having to load it manually. Effectively does `TruncatedStacktraces.VERBOSE[] = disable` + +!!! danger + + This function is now deprecated and will be removed in v0.6. """ function disable_stacktrace_truncation!(; disable::Bool=true) Base.depwarn("`disable_stacktrace_truncation!` is not needed anymore, as \ diff --git a/src/contrib/compact.jl b/src/helpers/compact.jl similarity index 95% rename from src/contrib/compact.jl rename to src/helpers/compact.jl index 4993e51e3..28ae3c2ed 100644 --- a/src/contrib/compact.jl +++ b/src/helpers/compact.jl @@ -33,7 +33,6 @@ Here is a linear model: ```julia using Lux, Random -import Lux.Experimental: @compact r = @compact(w=rand(3)) do x return w .* x @@ -123,6 +122,11 @@ used inside a `Chain`. account for the total number of parameters printed at the bottom. """ macro compact(_exs...) + return __compact_macro_impl(_exs...) +end + +# Needed for the deprecation path +function __compact_macro_impl(_exs...) # check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs if isempty(_exs) msg = "expects at least two expressions: a function and at least one keyword" @@ -187,12 +191,13 @@ macro compact(_exs...) fex_args = fex.args[1] isa(fex_args, Symbol) ? string(fex_args) : join(fex_args.args, ", ") catch e - @warn "Function stringifying does not yet handle all cases. Falling back to empty string for input arguments" + @warn "Function stringifying does not yet handle all cases. Falling back to empty \ + string for input arguments" end block = string(Base.remove_linenums!(fex).args[2]) # edit expressions - vars = map(ex -> ex.args[1], kwexs) + vars = map(first ∘ Base.Fix2(getproperty, :args), kwexs) fex = supportself(fex, vars) # assemble @@ -212,9 +217,8 @@ function supportself(fex::Expr, vars) calls = [] for var in vars push!(calls, - :($var = Lux.Experimental.__maybe_make_stateful( - Lux._getproperty($self, $(Val(var))), - Lux._getproperty($ps, $(Val(var))), Lux._getproperty($st, $(Val(var)))))) + :($var = $(__maybe_make_stateful)($(_getproperty)($self, $(Val(var))), + $(_getproperty)($ps, $(Val(var))), $(_getproperty)($st, $(Val(var)))))) end body = Expr(:let, Expr(:block, calls...), sdef[:body]) sdef[:body] = body @@ -223,7 +227,7 @@ function supportself(fex::Expr, vars) end @inline function __maybe_make_stateful(layer::AbstractExplicitLayer, ps, st) - return Lux.StatefulLuxLayer(layer, ps, st) + return StatefulLuxLayer(layer, ps, st) end @inline __maybe_make_stateful(::Nothing, ps, st) = ps === nothing ? st : ps @inline function __maybe_make_stateful(model::Union{AbstractVector, Tuple}, ps, st) @@ -271,7 +275,7 @@ end value_storage end -function constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch} +function ConstructionBase.constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch} return CompactLuxLayer{dispatch} end @@ -288,7 +292,7 @@ function __try_make_lux_layer(x::Union{AbstractVector, Tuple}) return __try_make_lux_layer(NamedTuple{Tuple(Symbol.(1:length(x)))}(x)) end function __try_make_lux_layer(x) - function __maybe_convert_layer(l) + __maybe_convert_layer = @closure l -> begin l isa AbstractExplicitLayer && return l l isa Function && return WrappedFunction(l) return l From 11289ed7760e7295a9515704a1e68707b1f85a71 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Apr 2024 16:36:35 -0400 Subject: [PATCH 2/5] Compat is more flexible and allows for custom parameters --- .github/workflows/SpellCheck.yml | 13 +++ Project.toml | 2 +- docs/src/introduction/index.md | 4 +- docs/src/manual/interface.md | 11 +- examples/NeuralODE/main.jl | 33 +++++- examples/SimpleRNN/main.jl | 31 +++++- src/helpers/compact.jl | 113 ++++++++++++++------- test/{contrib => helpers}/compact_tests.jl | 1 - 8 files changed, 157 insertions(+), 51 deletions(-) create mode 100644 .github/workflows/SpellCheck.yml rename test/{contrib => helpers}/compact_tests.jl (99%) diff --git a/.github/workflows/SpellCheck.yml b/.github/workflows/SpellCheck.yml new file mode 100644 index 000000000..ed4fe1779 --- /dev/null +++ b/.github/workflows/SpellCheck.yml @@ -0,0 +1,13 @@ +name: Spell Check + +on: [pull_request] + +jobs: + typos-check: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - name: Checkout Actions Repository + uses: actions/checkout@v4 + - name: Check spelling + uses: crate-ci/typos@v1.18.0 diff --git a/Project.toml b/Project.toml index f31e75d73..f2231f037 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.34" +version = "0.5.35" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/introduction/index.md b/docs/src/introduction/index.md index 59e54141a..bb8e36d44 100644 --- a/docs/src/introduction/index.md +++ b/docs/src/introduction/index.md @@ -61,8 +61,8 @@ y, st = Lux.apply(model, x, ps, st) # Gradients ## Pullback API to capture change in state -(l, st_), pb = pullback(p -> Lux.apply(model, x, p, st), ps) -gs = pb((one.(l), nothing))[1] +(l, st_), pb = pullback(Lux.apply, model, x, ps, st) +gs = pb((one.(l), nothing))[3] # Optimization st_opt = Optimisers.setup(Adam(0.0001f0), ps) diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md index ed916e9fc..02ae7cfd5 100644 --- a/docs/src/manual/interface.md +++ b/docs/src/manual/interface.md @@ -20,6 +20,13 @@ First let's set the expectations straight. functionality in the core library (and officially supported ones) **must** adhere to the interface +!!! tip + + While writing out a custom struct and defining dispatches manually is a good way to + understand the interface, it is not the most concise way. We recommend using the + [`Lux.@compact`](@ref) macro to define layers which makes handling the states and + parameters downright trivial. + ## Layer Interface ### Singular Layer @@ -35,8 +42,8 @@ architecture cannot change. !!! tip - For people coming from Flux.jl background this might be weird. We recommend checking out - [the Flux to Lux migration guide](@ref migrate-from-flux) first before proceeding. + For people coming from Flux.jl background, this might be weird. We recommend checking + out [the Flux to Lux migration guide](@ref migrate-from-flux) first before proceeding. ```@example layer_interface using Lux, Random diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index 57cce80dd..f8863017a 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -42,7 +42,23 @@ function loadmnist(batchsize, train_split) end # ## Define the Neural ODE Layer -# +# +# First we will use the [`@compact`](@ref) macro to define the Neural ODE Layer. + +function NeuralODECompact( + model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) + return @compact(; model, solver, tspan, kwargs...) do x, p + dudt(u, p, t) = vec(model(reshape(u, size(x)), p)) + ## Note the `p.model` here + prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model) + return solve(prob, solver; kwargs...) + end +end + +# We recommend using the compact macro for creating custom layers. The below implementation +# exists mostly for historical reasons when `@compact` was not part of the stable API. Also, +# it helps users understand how the layer interface of Lux works. + # The NeuralODE is a ContainerLayer, which stores a `model`. The parameters and states of # the NeuralODE are same as those of the underlying model. struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <: @@ -154,6 +170,8 @@ function train(model_function; cpu::Bool=false, kwargs...) end end +train(NeuralODECompact) + train(NeuralODE) # We can also change the sensealg and train the model! `GaussAdjoint` allows you to use @@ -173,8 +191,9 @@ train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true) # ## Alternate Implementation using Stateful Layer -# Starting `v0.5.5`, Lux provides a `Lux.Experimental.StatefulLuxLayer` which can be used -# to avoid the [`Box`ing of `st`](https://github.com/JuliaLang/julia/issues/15276). +# Starting `v0.5.5`, Lux provides a [`StatefulLuxLayer`](@ref) which can be used +# to avoid the [`Box`ing of `st`](https://github.com/JuliaLang/julia/issues/15276). Using +# the `@compact` API avoids this problem entirely. struct StatefulNeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <: Lux.AbstractExplicitContainerLayer{(:model,)} model::M @@ -189,7 +208,7 @@ function StatefulNeuralODE( end function (n::StatefulNeuralODE)(x, ps, st) - st_model = Lux.StatefulLuxLayer(n.model, ps, st) + st_model = StatefulLuxLayer(n.model, ps, st) dudt(u, p, t) = st_model(u, p) prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps) return solve(prob, n.solver; n.kwargs...), st_model.st @@ -217,5 +236,11 @@ x = gpu_device()(ones(Float32, 28, 28, 1, 3)); @code_warntype model_stateful(x, ps_stateful, st_stateful) +# Finally checking the compact model + +model_compact, ps_compact, st_compact = create_model(NeuralODECompact) + +@code_warntype model_compact(x, ps_compact, st_compact) + # Note, that we still recommend using this layer internally and not exposing this as the # default API to the users. diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index 116edd8cc..6ea8f92bc 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -99,6 +99,26 @@ function (s::SpiralClassifier)( return vec(y), st end +# ## Using the `@compact` API + +# We can also define the model using the [`Lux.@compact`](@ref) API, which is a more concise +# way of defining models. This macro automatically handles the boilerplate code for you and +# as such we recommend this way of defining custom layers + +function SpiralClassifierCompact(in_dims, hidden_dims, out_dims) + lstm_cell = LSTMCell(in_dims => hidden_dims) + classifier = Dense(hidden_dims => out_dims, sigmoid) + return @compact(; lstm_cell=lstm_cell, + classifier=classifier) do x::AbstractArray{T, 3} where {T} + x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2))) + y, carry = lstm_cell(x_init) + for x in x_rest + y, carry = lstm_cell((x, carry)) + end + return vec(classifier(y)) + end +end + # ## Defining Accuracy, Loss and Optimiser # Now let's define the binarycrossentropy loss. Typically it is recommended to use @@ -125,12 +145,12 @@ accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred) # ## Training the Model -function main() +function main(model_type) ## Get the dataloaders (train_loader, val_loader) = get_dataloaders() ## Create the model - model = SpiralClassifier(2, 8, 1) + model = model_type(2, 8, 1) rng = Xoshiro(0) dev = gpu_device() @@ -164,7 +184,12 @@ function main() return (train_state.parameters, train_state.states) |> cpu_device() end -ps_trained, st_trained = main() +ps_trained, st_trained = main(SpiralClassifier) +nothing #hide + +# We can also train the compact model with the exact same code! + +ps_trained2, st_trained2 = main(SpiralClassifierCompact) nothing #hide # ## Saving the Model diff --git a/src/helpers/compact.jl b/src/helpers/compact.jl index 28ae3c2ed..b8e6088ca 100644 --- a/src/helpers/compact.jl +++ b/src/helpers/compact.jl @@ -13,6 +13,9 @@ end @compact(kw...) do x ... end + @compact(kw...) do x, p + ... + end @compact(forward::Function; name=nothing, dispatch=nothing, parameters...) Creates a layer by specifying some `parameters`, in the form of keywords, and (usually as a @@ -21,12 +24,23 @@ Creates a layer by specifying some `parameters`, in the form of keywords, and (u be used within the body of the `forward` function. Note that unlike typical Lux models, the forward function doesn't need to explicitly manage states. +Defining the version with `p` allows you to access the parameters in the forward pass. This +is useful when using it with SciML tools which require passing in the parameters explicitly. + ## Reserved Kwargs: 1. `name`: The name of the layer. 2. `dispatch`: The constructed layer has the type `Lux.Experimental.CompactLuxLayer{dispatch}` which can be used for custom dispatches. +!!! tip + + Check the Lux tutorials for more examples of using `@compact`. + +If you are passing in kwargs by splatting them, they will be passed as is to the function +body. This means if your splatted kwargs contain a lux layer that won't be registered +in the CompactLuxLayer. + ## Examples Here is a linear model: @@ -161,32 +175,18 @@ function __compact_macro_impl(_exs...) kwexs = (kwexs1..., kwexs2...) # check if user has named layer - name_idx = findfirst(ex -> ex.args[1] == :name, kwexs) - name = nothing - if name_idx !== nothing && kwexs[name_idx].args[2] !== nothing - if length(kwexs) == 1 - throw(LuxCompactModelParsingException("expects keyword arguments")) - end - name = kwexs[name_idx].args[2] - # remove name from kwexs (a tuple) - kwexs = (kwexs[1:(name_idx - 1)]..., kwexs[(name_idx + 1):end]...) - end + name, kwexs = __extract_reserved_kwarg(kwexs, :name) # check if user has provided a custom dispatch - dispatch_idx = findfirst(ex -> ex.args[1] == :dispatch, kwexs) - dispatch = nothing - if dispatch_idx !== nothing && kwexs[dispatch_idx].args[2] !== nothing - if length(kwexs) == 1 - throw(LuxCompactModelParsingException("expects keyword arguments")) - end - dispatch = kwexs[dispatch_idx].args[2] - # remove dispatch from kwexs (a tuple) - kwexs = (kwexs[1:(dispatch_idx - 1)]..., kwexs[(dispatch_idx + 1):end]...) - end + dispatch, kwexs = __extract_reserved_kwarg(kwexs, :dispatch) + + # Extract splatted kwargs + splat_idxs = findall(ex -> ex.head == :..., kwexs) + splatted_kwargs = map(first ∘ Base.Fix2(getproperty, :args), kwexs[splat_idxs]) + kwexs = filter(ex -> ex.head != :..., kwexs) # make strings layer = "@compact" - setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs)) input = try fex_args = fex.args[1] isa(fex_args, Symbol) ? string(fex_args) : join(fex_args.args, ", ") @@ -198,28 +198,43 @@ function __compact_macro_impl(_exs...) # edit expressions vars = map(first ∘ Base.Fix2(getproperty, :args), kwexs) - fex = supportself(fex, vars) + fex = supportself(fex, vars, splatted_kwargs) + + display(fex) # assemble - return esc(:($CompactLuxLayer{$dispatch}( - $fex, $name, ($layer, $input, $block), $setup; $(kwexs...)))) + return esc(:($CompactLuxLayer{$dispatch}($fex, $name, ($layer, $input, $block), + (($(Meta.quot.(splatted_kwargs)...),), ($(splatted_kwargs...),)); $(kwexs...)))) +end + +function __extract_reserved_kwarg(kwexs, sym::Symbol) + idx = findfirst(ex -> ex.args[1] == sym, kwexs) + val = nothing + if idx !== nothing && kwexs[idx].args[2] !== nothing + length(kwexs) == 1 && + throw(LuxCompactModelParsingException("expects keyword arguments")) + val = kwexs[idx].args[2] + kwexs = (kwexs[1:(idx - 1)]..., kwexs[(idx + 1):end]...) + end + return val, kwexs end -function supportself(fex::Expr, vars) +function supportself(fex::Expr, vars, splatted_kwargs) @gensym self ps st curried_f res # To avoid having to manipulate fex's arguments and body explicitly, we split the input # function body and add the required arguments to the function definition. sdef = splitdef(fex) - if length(sdef[:args]) != 1 - throw(LuxCompactModelParsingException("expects exactly 1 argument")) - end - args = [self, sdef[:args]..., ps, st] + custom_param = length(sdef[:args]) == 2 + length(sdef[:args]) > 2 && + throw(LuxCompactModelParsingException("expects at most 2 arguments")) + args = [self, sdef[:args][1], ps, st] calls = [] for var in vars push!(calls, :($var = $(__maybe_make_stateful)($(_getproperty)($self, $(Val(var))), $(_getproperty)($ps, $(Val(var))), $(_getproperty)($st, $(Val(var)))))) end + custom_param && push!(calls, :($(sdef[:args][2]) = $ps)) body = Expr(:let, Expr(:block, calls...), sdef[:body]) sdef[:body] = body sdef[:args] = args @@ -229,7 +244,7 @@ end @inline function __maybe_make_stateful(layer::AbstractExplicitLayer, ps, st) return StatefulLuxLayer(layer, ps, st) end -@inline __maybe_make_stateful(::Nothing, ps, st) = ps === nothing ? st : ps +@inline __maybe_make_stateful(::Nothing, ps, st) = ifelse(ps === nothing, st, ps) @inline function __maybe_make_stateful(model::Union{AbstractVector, Tuple}, ps, st) return map(i -> __maybe_make_stateful(model[i], ps[i], st[i]), eachindex(model)) end @@ -248,13 +263,13 @@ end function ValueStorage(; kwargs...) ps_init_fns, st_init_fns = [], [] for (key, val) in pairs(kwargs) - push!(val isa AbstractArray ? ps_init_fns : st_init_fns, key => () -> val) + push!(val isa AbstractArray ? ps_init_fns : st_init_fns, key => Returns(val)) end return ValueStorage(NamedTuple(ps_init_fns), NamedTuple(st_init_fns)) end function (v::ValueStorage)(x, ps, st) - throw(ArgumentError("ValueStorage isn't meant to be used as a layer!!!")) + throw(ArgumentError("`ValueStorage` isn't meant to be used as a layer!!!")) end function initialparameters(::AbstractRNG, v::ValueStorage) @@ -273,6 +288,7 @@ end setup_strings layers value_storage + stored_kwargs end function ConstructionBase.constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch} @@ -300,27 +316,48 @@ function __try_make_lux_layer(x) return fmap(__maybe_convert_layer, x) end -function CompactLuxLayer{dispatch}(f::Function, name::NAME_TYPE, str::Tuple, - setup_str::NamedTuple; kws...) where {dispatch} +function CompactLuxLayer{dispatch}( + f::F, name::NAME_TYPE, str::Tuple, splatted_kwargs; kws...) where {F, dispatch} layers, others = [], [] + setup_strings = NamedTuple() for (name, val) in pairs(kws) + is_lux_layer = false if val isa AbstractExplicitLayer + is_lux_layer = true push!(layers, name => val) elseif LuxCore.contains_lux_layer(val) # TODO: Rearrange Tuple and Vectors to NamedTuples for proper CA.jl support - # FIXME: This might lead to incorrect constructions? If the function is a closure over the provided keyword arguments? + # FIXME: This might lead to incorrect constructions? If the function is a + # closure over the provided keyword arguments? val = __try_make_lux_layer(val) if LuxCore.check_fmap_condition( !Base.Fix2(isa, AbstractExplicitLayer), nothing, val) - throw(LuxCompactModelParsingException("A container `$(name) = $(val)` is found which combines Lux layers with non-Lux layers. This is not supported.")) + throw(LuxCompactModelParsingException("A container `$(name) = $(val)` is \ + found which combines Lux layers \ + with non-Lux layers. This is not \ + supported.")) end + is_lux_layer = true push!(layers, name => val) else push!(others, name => val) end + + if is_lux_layer + setup_strings = merge(setup_strings, NamedTuple((name => val,))) + else + setup_strings = merge(setup_strings, + NamedTuple((name => sprint( + show, val; context=(:compact => true, :limit => true)),))) + end end - return CompactLuxLayer{dispatch}( - f, name, str, setup_str, NamedTuple((; layers...)), ValueStorage(; others...)) + + for (kw_name, kw_val) in zip(splatted_kwargs[1], splatted_kwargs[2]) + push!(others, kw_name => kw_val) + end + + return CompactLuxLayer{dispatch}(f, name, str, setup_strings, NamedTuple((; layers...)), + ValueStorage(; others...), nothing) end function (m::CompactLuxLayer)(x, ps, st::NamedTuple{fields}) where {fields} diff --git a/test/contrib/compact_tests.jl b/test/helpers/compact_tests.jl similarity index 99% rename from test/contrib/compact_tests.jl rename to test/helpers/compact_tests.jl index bc2a97ad5..bd7f35cf9 100644 --- a/test/contrib/compact_tests.jl +++ b/test/helpers/compact_tests.jl @@ -1,6 +1,5 @@ @testitem "@compact" setup=[SharedTestSetup] begin using ComponentArrays - import Lux.Experimental: @compact rng = get_stable_rng(12345) From c45052439794dcefe7cf04ca3ca894be6d37f7b3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Apr 2024 16:59:13 -0400 Subject: [PATCH 3/5] Fix typos --- docs/src/manual/interface.md | 2 +- docs/src/manual/migrate_from_flux.md | 2 +- docs/src/tutorials/index.md | 4 ++-- src/distributed/public_api.jl | 2 +- src/layers/normalize.jl | 16 ++++++++-------- src/transform/flux.jl | 2 +- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md index 02ae7cfd5..948c9a0ad 100644 --- a/docs/src/manual/interface.md +++ b/docs/src/manual/interface.md @@ -87,7 +87,7 @@ reconstruction of the parameters and states. println("Parameter Length: ", Lux.parameterlength(l), "; State Length: ", Lux.statelength(l)) -# But still recommened to define these +# But still recommended to define these Lux.parameterlength(l::Linear) = l.out_dims * l.in_dims + l.out_dims Lux.statelength(::Linear) = 0 diff --git a/docs/src/manual/migrate_from_flux.md b/docs/src/manual/migrate_from_flux.md index 36acda7a7..6a3248257 100644 --- a/docs/src/manual/migrate_from_flux.md +++ b/docs/src/manual/migrate_from_flux.md @@ -99,7 +99,7 @@ end # `A` is not trainable Optimisers.trainable(f::FluxLinear) = (B=f.B,) -# Needed so that both `A` and `B` can be transfered between devices +# Needed so that both `A` and `B` can be transferred between devices Flux.@functor FluxLinear (l::FluxLinear)(x) = l.A * l.B * x diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index 36b4afb4f..25a73ad55 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -59,7 +59,7 @@ const advanced = [ } ]; -const thrid_party = [ +const third_party = [ { href: "https://docs.sciml.ai/Overview/stable/showcase/pinngpu/", src: "../pinn.gif", @@ -114,7 +114,7 @@ of them are non-functional and we will try to get them updated. ::: - + ::: tip diff --git a/src/distributed/public_api.jl b/src/distributed/public_api.jl index 945518fd9..ae7eda7c1 100644 --- a/src/distributed/public_api.jl +++ b/src/distributed/public_api.jl @@ -179,7 +179,7 @@ function __reduce! end CRC.@non_differentiable reduce!(::Any...) -# syncronize! +# synchronize! """ synchronize!!(backend::AbstractLuxDistributedBackend, ps; root::Int=0) diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 38b587c5d..5da8e5c5e 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -33,8 +33,8 @@ slice and normalises the input accordingly. - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. - + `init_bias`: Controls how the `bias` is initiliazed - + `init_scale`: Controls how the `scale` is initiliazed + + `init_bias`: Controls how the `bias` is initialiazed + + `init_scale`: Controls how the `scale` is initialiazed ## Inputs @@ -167,8 +167,8 @@ end - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. - + `init_bias`: Controls how the `bias` is initiliazed - + `init_scale`: Controls how the `scale` is initiliazed + + `init_bias`: Controls how the `bias` is initialiazed + + `init_scale`: Controls how the `scale` is initialiazed ## Inputs @@ -265,8 +265,8 @@ accordingly. - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. - + `init_bias`: Controls how the `bias` is initiliazed - + `init_scale`: Controls how the `scale` is initiliazed + + `init_bias`: Controls how the `bias` is initialiazed + + `init_scale`: Controls how the `scale` is initialiazed ## Inputs @@ -506,8 +506,8 @@ where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. - + `init_bias`: Controls how the `bias` is initiliazed - + `init_scale`: Controls how the `scale` is initiliazed + + `init_bias`: Controls how the `bias` is initialiazed + + `init_scale`: Controls how the `scale` is initialiazed ## Inputs diff --git a/src/transform/flux.jl b/src/transform/flux.jl index e64d42c7d..430c0652e 100644 --- a/src/transform/flux.jl +++ b/src/transform/flux.jl @@ -5,7 +5,7 @@ Convert a Flux Model to Lux Model. !!! warning - This always ingores the `active` field of some of the Flux layers. This is almost never + This always ignores the `active` field of some of the Flux layers. This is almost never going to be supported. ## Keyword Arguments From 3981a2c56b4d20fdd8de004d6ef131a4ed08182a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Apr 2024 17:25:14 -0400 Subject: [PATCH 4/5] Update Hypernet example to use the compact layer --- examples/DDIM/README.md | 2 +- examples/GravitationalWaveForm/main.jl | 12 +++--- examples/HyperNet/main.jl | 40 ++++++++----------- src/layers/containers.jl | 2 +- src/layers/normalize.jl | 16 ++++---- test/distributed/common_distributedtest.jl | 2 +- .../synchronize_distributedtest.jl | 2 +- 7 files changed, 34 insertions(+), 42 deletions(-) diff --git a/examples/DDIM/README.md b/examples/DDIM/README.md index d5bb55040..6e3dd073f 100644 --- a/examples/DDIM/README.md +++ b/examples/DDIM/README.md @@ -11,7 +11,7 @@ The model generates images from Gaussian noises by denoising iterativel # Usage Install Julia and instantiate `Project.toml`. -Follwoing scripts are tested on a single NVIDIA Tesla T4 instance. +Following scripts are tested on a single NVIDIA Tesla T4 instance. ## Dataset Download and extract `Dataset images` from [102 Category Flower Dataset](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/). diff --git a/examples/GravitationalWaveForm/main.jl b/examples/GravitationalWaveForm/main.jl index 63e6fc37f..16658484e 100644 --- a/examples/GravitationalWaveForm/main.jl +++ b/examples/GravitationalWaveForm/main.jl @@ -183,11 +183,11 @@ end function RelativisticOrbitModel(u, (p, M, e), t) χ, ϕ = u - numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2 + number = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2 denom = sqrt((p - 2)^2 - 4 * e^2) - χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom) - ϕ̇ = numer / (M * (p^(3 / 2)) * denom) + χ̇ = number * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom) + ϕ̇ = number / (M * (p^(3 / 2)) * denom) return [χ̇, ϕ̇] end @@ -260,11 +260,11 @@ function ODE_model(u, nn_params, t) ## it, however, in general, we should use `st` to store the state of the neural network. y = 1 .+ nn_model([first(u)], nn_params) - numer = (1 + e * cos(χ))^2 + number = (1 + e * cos(χ))^2 denom = M * (p^(3 / 2)) - χ̇ = (numer / denom) * y[1] - ϕ̇ = (numer / denom) * y[2] + χ̇ = (number / denom) * y[1] + ϕ̇ = (number / denom) * y[2] return [χ̇, ϕ̇] end diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index b6b4401f5..824f02ab0 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -31,33 +31,25 @@ function load_datasets(n_train=1024, n_eval=32, batchsize=256) end # ## Implement a HyperNet Layer -struct HyperNet{W <: Lux.AbstractExplicitLayer, C <: Lux.AbstractExplicitLayer, A} <: - Lux.AbstractExplicitContainerLayer{(:weight_generator, :core_network)} - weight_generator::W - core_network::C - ca_axes::A -end - -function HyperNet(w::Lux.AbstractExplicitLayer, c::Lux.AbstractExplicitLayer) - ca_axes = Lux.initialparameters(Random.default_rng(), c) |> ComponentArray |> getaxes - return HyperNet(w, c, ca_axes) -end - -function Lux.initialparameters(rng::AbstractRNG, h::HyperNet) - return (weight_generator=Lux.initialparameters(rng, h.weight_generator),) +function HyperNet(weight_generator::Lux.AbstractExplicitLayer, + core_network::Lux.AbstractExplicitLayer) + ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |> + ComponentArray |> + getaxes + return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y) + ## Generate the weights + ps_new = ComponentArray(vec(weight_generator(x)), ca_axes) + return core_network(y, ps_new) + end end -function (hn::HyperNet)(x, ps, st::NamedTuple) - ps_new, st_ = hn.weight_generator(x, ps.weight_generator, st.weight_generator) - @set! st.weight_generator = st_ - return ComponentArray(vec(ps_new), hn.ca_axes), st -end +# Defining functions on the CompactLuxLayer requires some understanding of how the layer +# is structured, as such we don't recommend doing it unless you are familiar with the +# internals. In this case, we simply write it to ignore the initialization of the +# `core_network` parameters. -function (hn::HyperNet)((x, y)::T, ps, st::NamedTuple) where {T <: Tuple} - ps_ca, st = hn(x, ps, st) - pred, st_ = hn.core_network(y, ps_ca, st.core_network) - @set! st.core_network = st_ - return pred, st +function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet}) + return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),) end # ## Create and Initialize the HyperNet diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 478468a73..a7384b88e 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -508,7 +508,7 @@ outputsize(c::Chain) = outputsize(c.layers[end]) This contains a number of internal layers, each of which receives the same input. Its output is the elementwise maximum of the the internal layers' outputs. -Maxout over linear dense layers satisfies the univeral approximation theorem. See [1]. +Maxout over linear dense layers satisfies the universal approximation theorem. See [1]. See also [`Parallel`](@ref) to reduce with other operators. diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 5da8e5c5e..bb22b71d8 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -33,8 +33,8 @@ slice and normalises the input accordingly. - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. - + `init_bias`: Controls how the `bias` is initialiazed - + `init_scale`: Controls how the `scale` is initialiazed + + `init_bias`: Controls how the `bias` is initialized + + `init_scale`: Controls how the `scale` is initialized ## Inputs @@ -167,8 +167,8 @@ end - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. - + `init_bias`: Controls how the `bias` is initialiazed - + `init_scale`: Controls how the `scale` is initialiazed + + `init_bias`: Controls how the `bias` is initialized + + `init_scale`: Controls how the `scale` is initialized ## Inputs @@ -265,8 +265,8 @@ accordingly. - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. - + `init_bias`: Controls how the `bias` is initialiazed - + `init_scale`: Controls how the `scale` is initialiazed + + `init_bias`: Controls how the `bias` is initialized + + `init_scale`: Controls how the `scale` is initialized ## Inputs @@ -506,8 +506,8 @@ where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. - + `init_bias`: Controls how the `bias` is initialiazed - + `init_scale`: Controls how the `scale` is initialiazed + + `init_bias`: Controls how the `bias` is initialized + + `init_scale`: Controls how the `scale` is initialized ## Inputs diff --git a/test/distributed/common_distributedtest.jl b/test/distributed/common_distributedtest.jl index c5a53c5e3..8ac8e5314 100644 --- a/test/distributed/common_distributedtest.jl +++ b/test/distributed/common_distributedtest.jl @@ -19,7 +19,7 @@ nworkers = DistributedUtils.total_workers(backend) @test rank < nworkers # Test the communication primitives -## broacast! +## broadcast! for arrType in (Array, aType) sendbuf = (rank == 0) ? arrType(ones(512)) : arrType(zeros(512)) recvbuf = arrType(zeros(512)) diff --git a/test/distributed/synchronize_distributedtest.jl b/test/distributed/synchronize_distributedtest.jl index f29130426..2b49b5b14 100644 --- a/test/distributed/synchronize_distributedtest.jl +++ b/test/distributed/synchronize_distributedtest.jl @@ -80,7 +80,7 @@ gs = DistributedUtils.synchronize!!(backend, gs; root) @test all(gs[1][2] .== 1) @test all(gs[2] .== 1) -# Miscelleneous +# Miscellaneous x = nothing x = DistributedUtils.synchronize!!(backend, x; root) @test x === nothing From 72fc49f9c9d66a3e171100a8393073f77787fe2e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Apr 2024 18:23:23 -0400 Subject: [PATCH 5/5] Prettify the printing --- .typos.toml | 2 ++ examples/GravitationalWaveForm/main.jl | 14 +++++++------- examples/NeuralODE/main.jl | 6 +++--- examples/SimpleRNN/main.jl | 3 +-- src/helpers/compact.jl | 24 +++++++++++++++++++----- test/helpers/compact_tests.jl | 6 +++--- 6 files changed, 35 insertions(+), 20 deletions(-) create mode 100644 .typos.toml diff --git a/.typos.toml b/.typos.toml new file mode 100644 index 000000000..e2b3e6f9a --- /dev/null +++ b/.typos.toml @@ -0,0 +1,2 @@ +[default.extend-words] +numer = "numer" \ No newline at end of file diff --git a/examples/GravitationalWaveForm/main.jl b/examples/GravitationalWaveForm/main.jl index 16658484e..7d63a5ce9 100644 --- a/examples/GravitationalWaveForm/main.jl +++ b/examples/GravitationalWaveForm/main.jl @@ -163,7 +163,7 @@ function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where { m₁ = mass_ratio * m₂ orbit₁, orbit₂ = one2two(orbit, m₁, m₂) - waveform = h_22_strain_two_body(dt, orbit1, mass1, orbit2, mass2) + waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂) else waveform = h_22_strain_one_body(dt, orbit) end @@ -183,11 +183,11 @@ end function RelativisticOrbitModel(u, (p, M, e), t) χ, ϕ = u - number = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2 + numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2 denom = sqrt((p - 2)^2 - 4 * e^2) - χ̇ = number * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom) - ϕ̇ = number / (M * (p^(3 / 2)) * denom) + χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom) + ϕ̇ = numer / (M * (p^(3 / 2)) * denom) return [χ̇, ϕ̇] end @@ -260,11 +260,11 @@ function ODE_model(u, nn_params, t) ## it, however, in general, we should use `st` to store the state of the neural network. y = 1 .+ nn_model([first(u)], nn_params) - number = (1 + e * cos(χ))^2 + numer = (1 + e * cos(χ))^2 denom = M * (p^(3 / 2)) - χ̇ = (number / denom) * y[1] - ϕ̇ = (number / denom) * y[2] + χ̇ = (numer / denom) * y[1] + ϕ̇ = (numer / denom) * y[2] return [χ̇, ϕ̇] end diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index f8863017a..2901534a9 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -236,11 +236,11 @@ x = gpu_device()(ones(Float32, 28, 28, 1, 3)); @code_warntype model_stateful(x, ps_stateful, st_stateful) +# Note, that we still recommend using this layer internally and not exposing this as the +# default API to the users. + # Finally checking the compact model model_compact, ps_compact, st_compact = create_model(NeuralODECompact) @code_warntype model_compact(x, ps_compact, st_compact) - -# Note, that we still recommend using this layer internally and not exposing this as the -# default API to the users. diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index 6ea8f92bc..0f1791408 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -108,8 +108,7 @@ end function SpiralClassifierCompact(in_dims, hidden_dims, out_dims) lstm_cell = LSTMCell(in_dims => hidden_dims) classifier = Dense(hidden_dims => out_dims, sigmoid) - return @compact(; lstm_cell=lstm_cell, - classifier=classifier) do x::AbstractArray{T, 3} where {T} + return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T} x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2))) y, carry = lstm_cell(x_init) for x in x_rest diff --git a/src/helpers/compact.jl b/src/helpers/compact.jl index b8e6088ca..399d80f9e 100644 --- a/src/helpers/compact.jl +++ b/src/helpers/compact.jl @@ -200,8 +200,6 @@ function __compact_macro_impl(_exs...) vars = map(first ∘ Base.Fix2(getproperty, :args), kwexs) fex = supportself(fex, vars, splatted_kwargs) - display(fex) - # assemble return esc(:($CompactLuxLayer{$dispatch}($fex, $name, ($layer, $input, $block), (($(Meta.quot.(splatted_kwargs)...),), ($(splatted_kwargs...),)); $(kwexs...)))) @@ -346,9 +344,8 @@ function CompactLuxLayer{dispatch}( if is_lux_layer setup_strings = merge(setup_strings, NamedTuple((name => val,))) else - setup_strings = merge(setup_strings, - NamedTuple((name => sprint( - show, val; context=(:compact => true, :limit => true)),))) + setup_strings = merge( + setup_strings, NamedTuple((name => __kwarg_descriptor(val),))) end end @@ -410,3 +407,20 @@ function Lux._big_show(io::IO, obj::CompactLuxLayer, indent::Int=0, name=nothing end return end + +function __kwarg_descriptor(val) + val isa Number && return string(val) + val isa AbstractArray && return sprint(Base.array_summary, val, axes(val)) + val isa Tuple && return "(" * join(map(__kwarg_descriptor, val), ", ") * ")" + if val isa NamedTuple + fields = fieldnames(typeof(val)) + strs = [] + for fname in fields[1:min(length(fields), 3)] + internal_val = getfield(val, fname) + push!(strs, "$fname = $(__kwarg_descriptor(internal_val))") + end + return "@NamedTuple{$(join(strs, ", "))" * (length(fields) > 3 ? ", ..." : "") * "}" + end + val isa Function && return sprint(show, val; context=(:compact => true, :limit => true)) + return lazy"$(nameof(typeof(val)))(...)" +end diff --git a/test/helpers/compact_tests.jl b/test/helpers/compact_tests.jl index bd7f35cf9..caa37fe3d 100644 --- a/test/helpers/compact_tests.jl +++ b/test/helpers/compact_tests.jl @@ -180,7 +180,7 @@ return w(x .* s) end expected_string = """@compact( - x = randn(32), + x = 32-element Vector{Float64}, w = Dense(32 => 32), # 1_056 parameters ) do s return w(x .* s) @@ -197,8 +197,8 @@ end expected_string = """@compact( w1 = Model(32)(), # 1_024 parameters - w2 = randn(32, 32), - w3 = randn(32), + w2 = 32×32 Matrix{Float64}, + w3 = 32-element Vector{Float64}, ) do x return w2 * w1(x) end # Total: 2_080 parameters,