Skip to content

Commit

Permalink
Prettify the printing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 12, 2024
1 parent 3981a2c commit 944bb0d
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 13 deletions.
6 changes: 3 additions & 3 deletions examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,11 @@ x = gpu_device()(ones(Float32, 28, 28, 1, 3));

@code_warntype model_stateful(x, ps_stateful, st_stateful)

# Note, that we still recommend using this layer internally and not exposing this as the
# default API to the users.

# Finally checking the compact model

model_compact, ps_compact, st_compact = create_model(NeuralODECompact)

@code_warntype model_compact(x, ps_compact, st_compact)

# Note, that we still recommend using this layer internally and not exposing this as the
# default API to the users.
3 changes: 1 addition & 2 deletions examples/SimpleRNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ end
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell=lstm_cell,
classifier=classifier) do x::AbstractArray{T, 3} where {T}
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
Expand Down
24 changes: 19 additions & 5 deletions src/helpers/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ function __compact_macro_impl(_exs...)
vars = map(first Base.Fix2(getproperty, :args), kwexs)
fex = supportself(fex, vars, splatted_kwargs)

display(fex)

# assemble
return esc(:($CompactLuxLayer{$dispatch}($fex, $name, ($layer, $input, $block),
(($(Meta.quot.(splatted_kwargs)...),), ($(splatted_kwargs...),)); $(kwexs...))))
Expand Down Expand Up @@ -346,9 +344,8 @@ function CompactLuxLayer{dispatch}(
if is_lux_layer
setup_strings = merge(setup_strings, NamedTuple((name => val,)))
else
setup_strings = merge(setup_strings,
NamedTuple((name => sprint(
show, val; context=(:compact => true, :limit => true)),)))
setup_strings = merge(
setup_strings, NamedTuple((name => __kwarg_descriptor(val),)))
end
end

Expand Down Expand Up @@ -410,3 +407,20 @@ function Lux._big_show(io::IO, obj::CompactLuxLayer, indent::Int=0, name=nothing
end
return
end

function __kwarg_descriptor(val)
val isa Number && return string(val)
val isa AbstractArray && return sprint(Base.array_summary, val, axes(val))
val isa Tuple && return "(" * join(map(__kwarg_descriptor, val), ", ") * ")"
if val isa NamedTuple
fields = fieldnames(typeof(val))
strs = []
for fname in fields[1:min(length(fields), 3)]
internal_val = getfield(val, fname)
push!(strs, "$fname = $(__kwarg_descriptor(internal_val))")
end
return "@NamedTuple{$(join(strs, ", "))" * (length(fields) > 3 ? ", ..." : "") * "}"
end
val isa Function && return sprint(show, val; context=(:compact => true, :limit => true))
return lazy"$(nameof(typeof(val)))(...)"
end
6 changes: 3 additions & 3 deletions test/helpers/compact_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
return w(x .* s)
end
expected_string = """@compact(
x = randn(32),
x = 32-element Vector{Float64},
w = Dense(32 => 32), # 1_056 parameters
) do s
return w(x .* s)
Expand All @@ -197,8 +197,8 @@
end
expected_string = """@compact(
w1 = Model(32)(), # 1_024 parameters
w2 = randn(32, 32),
w3 = randn(32),
w2 = 32×32 Matrix{Float64},
w3 = 32-element Vector{Float64},
) do x
return w2 * w1(x)
end # Total: 2_080 parameters,
Expand Down

0 comments on commit 944bb0d

Please sign in to comment.