Skip to content

Commit

Permalink
Allow fmapping over the model
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 1, 2022
1 parent 3018707 commit f671c7f
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 5 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# v0.4

## v0.4.20

- Introduces `Lux.@layer_map` and `Lux.layer_map` for mapping over layers.
- Allows `fmap`-ping over layers.

## v0.4.19

- Generic Container layers (like `Chain`, `Parallel`, etc.) can now used custom naming for
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.4.19"
version = "0.4.20"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
7 changes: 7 additions & 0 deletions docs/src/api/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ refactors being planned (See [this issue](https://github.com/EnzymeAD/Enzyme.jl/
Once the package is stable and we have the necessary backend support, we will be dropping
the VJP rules in this module.

## Map over Layer

```@docs
Lux.layer_map
Lux.@layer_map
```

## Index

```@index
Expand Down
1 change: 1 addition & 0 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ function __init__()
end

# Experimental
include("contrib/map.jl")
include("contrib/training.jl")

# Deprecations
Expand Down
114 changes: 114 additions & 0 deletions src/contrib/map.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using Functors: functor

"""
@layer_map func layer ps st
See the documentation of [`Lux.layer_map`](@ref) for more details. This macro eliminates
the need to the set the layer name, and uses the variable name as the starting point.
## Example
```julia
using Lux, Random, Setfield
c = Parallel(+; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3),
dense_2=Dense(3 => 5)),
dense_3=Dense(5 => 1))
rng = Random.default_rng()
ps, st = Lux.setup(rng, c)
# Makes parameters of Dense Layers inside Chain zero
function zero_dense_params(l, ps, st, name)
if l isa Dense
println("zeroing params of $name")
@set! ps.weight = zero.(ps.weight)
@set! ps.bias = zero.(ps.bias)
end
return l, ps, st
end
Lux.@layer_map zero_dense_params c ps st
```
"""
macro layer_map(f, l, ps, st)
quote
layer_map($(esc(f)), $(esc(l)), $(esc(ps)), $(esc(st)), $(string(l)))
end
end

"""
layer_map(f::Function, l::AbstractExplicitLayer, ps, st::NamedTuple,
name::String="model")
Map the function `f` over the model `l`, with the parameters `ps` and states `st`. This is
different from `Functors.fmap` since it zips the layers, parameters, and states and invokes
the function on all of them together.
## Call Signature for `f`
- Must take 4 inputs -- `AbstractExplicitLayer`, Corresponding Parameters, Corresponding
States, and the name of the layer.
- Must return a tuple of 3 elements -- `AbstractExplicitLayer`, new parameters and the new
states.
!!! tip
We recommend using the macro `Lux.@layer_map` instead of this function. It automatically
sets the `name` of the layer to be the variable name.
## Example
```julia
using Lux, Random, Setfield
c = Parallel(+; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3),
dense_2=Dense(3 => 5)),
dense_3=Dense(5 => 1))
rng = Random.default_rng()
ps, st = Lux.setup(rng, c)
# Makes parameters of Dense Layers inside Chain zero
function zero_dense_params(l, ps, st, name)
if l isa Dense
println("zeroing params of $name")
@set! ps.weight = zero.(ps.weight)
@set! ps.bias = zero.(ps.bias)
end
return l, ps, st
end
Lux.layer_map(zero_dense_params, c, ps, st)
```
"""
function layer_map(f::Function, l::AbstractExplicitLayer, ps, st::NamedTuple,
name::String="model")
l_c, l_re = Functors.functor(l)
ps_c, ps_re = Functors.functor(ps)
st_c, st_re = Functors.functor(st)

length(l_c) == 0 && return f(l, ps, st, name)

l_c_ = l_c isa Tuple ? l_c[1] : l_c
ks = keys(l_c_)

l_c_new, ps_c_new, st_c_new = [], [], []
for k in ks
l_c_new_, ps_c_new_, st_c_new_ = layer_map(f, getproperty(l_c_, k),
getproperty(ps_c, k),
getproperty(st_c, k),
join((name, k), "."))
push!(l_c_new, k => l_c_new_)
push!(ps_c_new, k => ps_c_new_)
push!(st_c_new, k => st_c_new_)
end
l_c_new = (; l_c_new...)
l_c_new = l_c isa Tuple ? (l_c_new,) : l_c_new

l_new = l_re(l_c_new)
ps_new = ps_re((; ps_c_new...))
st_new = st_re((; st_c_new...))

return l_new, ps_new, st_new
end
22 changes: 21 additions & 1 deletion src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,13 @@ Abstract Container Type for certain Lux Layers. `layers` is a tuple containing f
for the layer, and constructs the parameters and states using those.
Users implementing their custom layer can extend the same functions as in
[`AbstractExplicitLayer`](@ref)
[`AbstractExplicitLayer`](@ref).
!!! tip
Advanced structure manipulation of these layers post construction is possible via
`Functors.fmap`. For a more flexible interface, we recommend using the experimental
feature [`Lux.@layer_map`](@ref).
"""
abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer end

Expand All @@ -125,6 +131,20 @@ function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers}
return sum(statelength, getfield.((l,), layers))
end

# Make AbstractExplicit Layers Functor Compatible
function Functors.functor(::Type{<:AbstractExplicitContainerLayer},
x::AbstractExplicitContainerLayer{layers}) where {layers}
_children = getproperty.((x,), layers)
function layer_reconstructor(z)
l = x
for (child, name) in zip(z, layers)
l = Setfield.set(l, Setfield.PropertyLens{name}(), child)
end
return l
end
return _children, layer_reconstructor
end

# Test Mode
"""
testmode(st::NamedTuple)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down
42 changes: 42 additions & 0 deletions test/contrib/map.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using Lux, Random, Setfield, Test

c = Parallel(+;
chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)),
dense_3=Dense(5 => 1))

rng = Random.default_rng()
ps, st = Lux.setup(rng, c)

function zero_dense_params(l, ps, st, name)
if l isa Dense && occursin("model.chain", name)
@set! ps.weight = zero.(ps.weight)
@set! ps.bias = zero.(ps.bias)
end
return l, ps, st
end

c_, ps_, st_ = Lux.layer_map(zero_dense_params, c, ps, st)

@test ps_.chain.dense_1.weight == zeros(3, 2)
@test ps_.chain.dense_1.bias == zeros(3, 1)
@test ps_.chain.dense_2.weight == zeros(5, 3)
@test ps_.chain.dense_2.bias == zeros(5, 1)
@test ps_.dense_3.weight != zeros(1, 5)
@test ps_.dense_3.bias == zeros(1, 1)

function zero_dense_params(l, ps, st, name)
if l isa Dense && occursin("c.chain", name)
@set! ps.weight = zero.(ps.weight)
@set! ps.bias = zero.(ps.bias)
end
return l, ps, st
end

c_, ps_, st_ = Lux.@layer_map zero_dense_params c ps st

@test ps_.chain.dense_1.weight == zeros(3, 2)
@test ps_.chain.dense_1.bias == zeros(3, 1)
@test ps_.chain.dense_2.weight == zeros(5, 3)
@test ps_.chain.dense_2.bias == zeros(5, 1)
@test ps_.dense_3.weight != zeros(1, 5)
@test ps_.dense_3.bias == zeros(1, 1)
15 changes: 14 additions & 1 deletion test/core.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Lux, Random, Test
using Functors, Lux, Random, Test

rng = Random.default_rng()
Random.seed!(rng, 0)
Expand Down Expand Up @@ -53,3 +53,16 @@ end
@test_deprecated Lux.trainmode(st, true)
@test_deprecated Lux.testmode(st, true)
end

@testset "Functors Compatibility" begin
c = Parallel(+; chain=Chain(; dense_1=Dense(2 => 3), dense_2=Dense(3 => 5)),
dense_3=Dense(5 => 1))

@test_nowarn fmap(println, c)

l = Dense(2 => 2)
new_model = fmap(x -> l, c)
@test new_model.chain.dense_1 == l
@test new_model.chain.dense_2 == l
@test new_model.dense_3 == l
end
7 changes: 5 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using SafeTestsets, Test, Pkg
using Pkg, SafeTestsets, Test

const GROUP = get(ENV, "GROUP", "All")

Expand Down Expand Up @@ -43,7 +43,10 @@ end

@time @safetestset "Automatic Differentiation" begin include("autodiff.jl") end

@testset "Experimental" begin @time @safetestset "Training" begin include("contrib/training.jl") end end
@testset "Experimental" begin
@time @safetestset "Map" begin include("contrib/map.jl") end
@time @safetestset "Training" begin include("contrib/training.jl") end
end
end
else
dev_subpkg(group)
Expand Down

0 comments on commit f671c7f

Please sign in to comment.