Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow fmapping over the model #144

Merged
merged 1 commit into from
Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# v0.4

## v0.4.20

- Introduces `Lux.@layer_map` and `Lux.layer_map` for mapping over layers.
- Allows `fmap`-ping over layers.

## v0.4.19

- Generic Container layers (like `Chain`, `Parallel`, etc.) can now used custom naming for
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.4.19"
version = "0.4.20"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
7 changes: 7 additions & 0 deletions docs/src/api/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ refactors being planned (See [this issue](https://github.com/EnzymeAD/Enzyme.jl/
Once the package is stable and we have the necessary backend support, we will be dropping
the VJP rules in this module.

## Map over Layer

```@docs
Lux.layer_map
Lux.@layer_map
```

## Index

```@index
Expand Down
1 change: 1 addition & 0 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ function __init__()
end

# Experimental
include("contrib/map.jl")
include("contrib/training.jl")

# Deprecations
Expand Down
114 changes: 114 additions & 0 deletions src/contrib/map.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using Functors: functor

"""
@layer_map func layer ps st

See the documentation of [`Lux.layer_map`](@ref) for more details. This macro eliminates
the need to the set the layer name, and uses the variable name as the starting point.

## Example

```julia
using Lux, Random, Setfield

c = Parallel(+; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3),
dense_2=Dense(3 => 5)),
dense_3=Dense(5 => 1))

rng = Random.default_rng()
ps, st = Lux.setup(rng, c)

# Makes parameters of Dense Layers inside Chain zero
function zero_dense_params(l, ps, st, name)
if l isa Dense
println("zeroing params of $name")
@set! ps.weight = zero.(ps.weight)
@set! ps.bias = zero.(ps.bias)
end
return l, ps, st
end

Lux.@layer_map zero_dense_params c ps st
```
"""
macro layer_map(f, l, ps, st)
quote
layer_map($(esc(f)), $(esc(l)), $(esc(ps)), $(esc(st)), $(string(l)))
end
end

"""
layer_map(f::Function, l::AbstractExplicitLayer, ps, st::NamedTuple,
name::String="model")

Map the function `f` over the model `l`, with the parameters `ps` and states `st`. This is
different from `Functors.fmap` since it zips the layers, parameters, and states and invokes
the function on all of them together.

## Call Signature for `f`

- Must take 4 inputs -- `AbstractExplicitLayer`, Corresponding Parameters, Corresponding
States, and the name of the layer.
- Must return a tuple of 3 elements -- `AbstractExplicitLayer`, new parameters and the new
states.

!!! tip

We recommend using the macro `Lux.@layer_map` instead of this function. It automatically
sets the `name` of the layer to be the variable name.

## Example

```julia
using Lux, Random, Setfield

c = Parallel(+; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3),
dense_2=Dense(3 => 5)),
dense_3=Dense(5 => 1))

rng = Random.default_rng()
ps, st = Lux.setup(rng, c)

# Makes parameters of Dense Layers inside Chain zero
function zero_dense_params(l, ps, st, name)
if l isa Dense
println("zeroing params of $name")
@set! ps.weight = zero.(ps.weight)
@set! ps.bias = zero.(ps.bias)
end
return l, ps, st
end

Lux.layer_map(zero_dense_params, c, ps, st)
```
"""
function layer_map(f::Function, l::AbstractExplicitLayer, ps, st::NamedTuple,
name::String="model")
l_c, l_re = Functors.functor(l)
ps_c, ps_re = Functors.functor(ps)
st_c, st_re = Functors.functor(st)

length(l_c) == 0 && return f(l, ps, st, name)

l_c_ = l_c isa Tuple ? l_c[1] : l_c
ks = keys(l_c_)

l_c_new, ps_c_new, st_c_new = [], [], []
for k in ks
l_c_new_, ps_c_new_, st_c_new_ = layer_map(f, getproperty(l_c_, k),
getproperty(ps_c, k),
getproperty(st_c, k),
join((name, k), "."))
push!(l_c_new, k => l_c_new_)
push!(ps_c_new, k => ps_c_new_)
push!(st_c_new, k => st_c_new_)
end
l_c_new = (; l_c_new...)
l_c_new = l_c isa Tuple ? (l_c_new,) : l_c_new

l_new = l_re(l_c_new)
ps_new = ps_re((; ps_c_new...))
st_new = st_re((; st_c_new...))

return l_new, ps_new, st_new
end
22 changes: 21 additions & 1 deletion src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,13 @@ Abstract Container Type for certain Lux Layers. `layers` is a tuple containing f
for the layer, and constructs the parameters and states using those.

Users implementing their custom layer can extend the same functions as in
[`AbstractExplicitLayer`](@ref)
[`AbstractExplicitLayer`](@ref).

!!! tip

Advanced structure manipulation of these layers post construction is possible via
`Functors.fmap`. For a more flexible interface, we recommend using the experimental
feature [`Lux.@layer_map`](@ref).
"""
abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer end

Expand All @@ -125,6 +131,20 @@ function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers}
return sum(statelength, getfield.((l,), layers))
end

# Make AbstractExplicit Layers Functor Compatible
function Functors.functor(::Type{<:AbstractExplicitContainerLayer},
x::AbstractExplicitContainerLayer{layers}) where {layers}
_children = getproperty.((x,), layers)
function layer_reconstructor(z)
l = x
for (child, name) in zip(z, layers)
l = Setfield.set(l, Setfield.PropertyLens{name}(), child)
end
return l
end
return _children, layer_reconstructor
end

# Test Mode
"""
testmode(st::NamedTuple)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down
42 changes: 42 additions & 0 deletions test/contrib/map.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using Lux, Random, Setfield, Test

c = Parallel(+;
chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)),
dense_3=Dense(5 => 1))

rng = Random.default_rng()
ps, st = Lux.setup(rng, c)

function zero_dense_params(l, ps, st, name)
if l isa Dense && occursin("model.chain", name)
@set! ps.weight = zero.(ps.weight)
@set! ps.bias = zero.(ps.bias)
end
return l, ps, st
end

c_, ps_, st_ = Lux.layer_map(zero_dense_params, c, ps, st)

@test ps_.chain.dense_1.weight == zeros(3, 2)
@test ps_.chain.dense_1.bias == zeros(3, 1)
@test ps_.chain.dense_2.weight == zeros(5, 3)
@test ps_.chain.dense_2.bias == zeros(5, 1)
@test ps_.dense_3.weight != zeros(1, 5)
@test ps_.dense_3.bias == zeros(1, 1)

function zero_dense_params(l, ps, st, name)
if l isa Dense && occursin("c.chain", name)
@set! ps.weight = zero.(ps.weight)
@set! ps.bias = zero.(ps.bias)
end
return l, ps, st
end

c_, ps_, st_ = Lux.@layer_map zero_dense_params c ps st

@test ps_.chain.dense_1.weight == zeros(3, 2)
@test ps_.chain.dense_1.bias == zeros(3, 1)
@test ps_.chain.dense_2.weight == zeros(5, 3)
@test ps_.chain.dense_2.bias == zeros(5, 1)
@test ps_.dense_3.weight != zeros(1, 5)
@test ps_.dense_3.bias == zeros(1, 1)
15 changes: 14 additions & 1 deletion test/core.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Lux, Random, Test
using Functors, Lux, Random, Test

rng = Random.default_rng()
Random.seed!(rng, 0)
Expand Down Expand Up @@ -53,3 +53,16 @@ end
@test_deprecated Lux.trainmode(st, true)
@test_deprecated Lux.testmode(st, true)
end

@testset "Functors Compatibility" begin
c = Parallel(+; chain=Chain(; dense_1=Dense(2 => 3), dense_2=Dense(3 => 5)),
dense_3=Dense(5 => 1))

@test_nowarn fmap(println, c)

l = Dense(2 => 2)
new_model = fmap(x -> l, c)
@test new_model.layers.chain.layers.dense_1 == l
@test new_model.layers.chain.layers.dense_2 == l
@test new_model.layers.dense_3 == l
end
7 changes: 5 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using SafeTestsets, Test, Pkg
using Pkg, SafeTestsets, Test

const GROUP = get(ENV, "GROUP", "All")

Expand Down Expand Up @@ -43,7 +43,10 @@ end

@time @safetestset "Automatic Differentiation" begin include("autodiff.jl") end

@testset "Experimental" begin @time @safetestset "Training" begin include("contrib/training.jl") end end
@testset "Experimental" begin
@time @safetestset "Map" begin include("contrib/map.jl") end
@time @safetestset "Training" begin include("contrib/training.jl") end
end
end
else
dev_subpkg(group)
Expand Down