From 944bb0d6606aae66699e2188ffaaa75bf844246e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Apr 2024 18:23:23 -0400 Subject: [PATCH] Prettify the printing --- examples/NeuralODE/main.jl | 6 +++--- examples/SimpleRNN/main.jl | 3 +-- src/helpers/compact.jl | 24 +++++++++++++++++++----- test/helpers/compact_tests.jl | 6 +++--- 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index f8863017a2..2901534a94 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -236,11 +236,11 @@ x = gpu_device()(ones(Float32, 28, 28, 1, 3)); @code_warntype model_stateful(x, ps_stateful, st_stateful) +# Note, that we still recommend using this layer internally and not exposing this as the +# default API to the users. + # Finally checking the compact model model_compact, ps_compact, st_compact = create_model(NeuralODECompact) @code_warntype model_compact(x, ps_compact, st_compact) - -# Note, that we still recommend using this layer internally and not exposing this as the -# default API to the users. diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index 6ea8f92bc3..0f17914083 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -108,8 +108,7 @@ end function SpiralClassifierCompact(in_dims, hidden_dims, out_dims) lstm_cell = LSTMCell(in_dims => hidden_dims) classifier = Dense(hidden_dims => out_dims, sigmoid) - return @compact(; lstm_cell=lstm_cell, - classifier=classifier) do x::AbstractArray{T, 3} where {T} + return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T} x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2))) y, carry = lstm_cell(x_init) for x in x_rest diff --git a/src/helpers/compact.jl b/src/helpers/compact.jl index b8e6088ca1..399d80f9e8 100644 --- a/src/helpers/compact.jl +++ b/src/helpers/compact.jl @@ -200,8 +200,6 @@ function __compact_macro_impl(_exs...) vars = map(first ∘ Base.Fix2(getproperty, :args), kwexs) fex = supportself(fex, vars, splatted_kwargs) - display(fex) - # assemble return esc(:($CompactLuxLayer{$dispatch}($fex, $name, ($layer, $input, $block), (($(Meta.quot.(splatted_kwargs)...),), ($(splatted_kwargs...),)); $(kwexs...)))) @@ -346,9 +344,8 @@ function CompactLuxLayer{dispatch}( if is_lux_layer setup_strings = merge(setup_strings, NamedTuple((name => val,))) else - setup_strings = merge(setup_strings, - NamedTuple((name => sprint( - show, val; context=(:compact => true, :limit => true)),))) + setup_strings = merge( + setup_strings, NamedTuple((name => __kwarg_descriptor(val),))) end end @@ -410,3 +407,20 @@ function Lux._big_show(io::IO, obj::CompactLuxLayer, indent::Int=0, name=nothing end return end + +function __kwarg_descriptor(val) + val isa Number && return string(val) + val isa AbstractArray && return sprint(Base.array_summary, val, axes(val)) + val isa Tuple && return "(" * join(map(__kwarg_descriptor, val), ", ") * ")" + if val isa NamedTuple + fields = fieldnames(typeof(val)) + strs = [] + for fname in fields[1:min(length(fields), 3)] + internal_val = getfield(val, fname) + push!(strs, "$fname = $(__kwarg_descriptor(internal_val))") + end + return "@NamedTuple{$(join(strs, ", "))" * (length(fields) > 3 ? ", ..." : "") * "}" + end + val isa Function && return sprint(show, val; context=(:compact => true, :limit => true)) + return lazy"$(nameof(typeof(val)))(...)" +end diff --git a/test/helpers/compact_tests.jl b/test/helpers/compact_tests.jl index bd7f35cf90..caa37fe3d0 100644 --- a/test/helpers/compact_tests.jl +++ b/test/helpers/compact_tests.jl @@ -180,7 +180,7 @@ return w(x .* s) end expected_string = """@compact( - x = randn(32), + x = 32-element Vector{Float64}, w = Dense(32 => 32), # 1_056 parameters ) do s return w(x .* s) @@ -197,8 +197,8 @@ end expected_string = """@compact( w1 = Model(32)(), # 1_024 parameters - w2 = randn(32, 32), - w3 = randn(32), + w2 = 32×32 Matrix{Float64}, + w3 = 32-element Vector{Float64}, ) do x return w2 * w1(x) end # Total: 2_080 parameters,