Skip to content

Commit

Permalink
Move CA stuff into an extension
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 17, 2023
1 parent 887f1f9 commit 1acb52e
Show file tree
Hide file tree
Showing 20 changed files with 535 additions and 102 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
version:
- "1.6"
- "1.8"
- "~1.9.0-0"
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ model_weights
docs/docs
docs/site

scripts
scripts
test_ext
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.4.36"
version = "0.4.37"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Expand All @@ -26,11 +27,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"

[extensions]
Flux2LuxExt = ["Flux", "Optimisers"]
LuxComponentArraysExt = "ComponentArrays"
LuxFluxTransformExt = "Flux"

[compat]
Adapt = "3"
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Documenter, DocumenterMarkdown, LuxCore, Lux, LuxLib, Pkg

import Flux, Optimisers # Load weak dependencies
import Flux # Load weak dependencies

function _setup_subdir_pkgs_index_file(subpkg)
src_file = joinpath(dirname(@__DIR__), "lib", subpkg, "README.md")
Expand Down
2 changes: 1 addition & 1 deletion docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ nav:
- "Functional": "api/functional.md"
- "Core": "api/core.md"
- "Utilities": "api/utilities.md"
- "Flux2Lux": "api/flux2lux.md"
- "Flux & Lux InterOp": "api/flux2lux.md"
- "Experimental": "api/contrib.md"
- "Frameworks":
- "Boltz": "lib/Boltz/index.md"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/api/flux2lux.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
CurrentModule = Lux
```

Accessing these functions require manually loading `Flux` and `Optimisers`, i.e.,
`using Flux, Optimisers` must be present somewhere in the code for these to be used.
Accessing these functions require manually loading `Flux`, i.e., `using Flux` must be
present somewhere in the code for these to be used.

## Functions

Expand Down
5 changes: 5 additions & 0 deletions docs/src/manual/migrate_from_flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ implemented soon. If you **really** need those functionality check out the next

We don't recommend this method, but here is a way to compose Flux with Lux.

!!! tip

Starting `v0.4.37`, if you have `using Flux` in your code, Lux will automatically
provide a function `transform` that can convert Flux layers to Lux layers

```julia
using Lux, NNlib, Random, Optimisers
import Flux
Expand Down
14 changes: 1 addition & 13 deletions docs/src/manual/precompilation.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ using Preferences, UUIDs

Preferences.@set_preferences!(UUID("b2108857-7c20-44ae-9111-449ecde12c47"),
"LuxSnoopPrecompile", false)
# Preferences.@set_preferences!(UUID("b2108857-7c20-44ae-9111-449ecde12c47"),
# "LuxPrecompileComponentArrays", false)
```

If `LuxSnoopPrecompile` is set to `false`, then `Lux` will not use `SnoopPrecompile.jl`:
Expand All @@ -66,20 +64,10 @@ julia> @time_imports using Lux
119.0 ms Lux 6.41% compilation time
```

The other option is to just disable compilation of `ComponentArrays.jl` codepaths. This is
desirable if you are not planning to use Lux with any of the SciML Packages. This can be
done by setting `LuxPrecompileComponentArrays` to `false`:
If you have `LuxSnoopPrecompile` set to `true`:

```julia-repl
julia> @time_imports using Lux
3366.4 ms Lux 0.22% compilation time
```

If you have both the `LuxSnoopPrecompile` and `LuxPrecompileComponentArrays` set to `true`:

```julia-repl
julia> @time_imports using Lux
5738.5 ms Lux 0.13% compilation time
```
41 changes: 41 additions & 0 deletions ext/LuxComponentArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
module LuxComponentArraysExt

isdefined(Base, :get_extension) ? (using ComponentArrays) : (using ..ComponentArrays)

using Functors, Lux, Optimisers, Zygote

@inline function Lux._getproperty(x::ComponentArray, ::Val{prop}) where {prop}
return prop in propertynames(x) ? getproperty(x, prop) : nothing
end

function Functors.functor(::Type{<:ComponentArray}, c)
return NamedTuple{propertynames(c)}(getproperty.((c,), propertynames(c))),
ComponentArray
end

# Zygote Fixes
function Zygote.accum(x::ComponentArray, ys::ComponentArray...)
return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x))
end

# Optimisers Fixes
Optimisers.setup(opt::AbstractRule, ps::ComponentArray) = Optimisers.setup(opt, getdata(ps))

function Optimisers.update(tree, ps::ComponentArray, gs::ComponentArray)
tree, ps_new = Optimisers.update(tree, getdata(ps), getdata(gs))
return tree, ComponentArray(ps_new, getaxes(ps))
end

function Optimisers.update!(tree::Optimisers.Leaf, ps::ComponentArray, gs::ComponentArray)
tree, ps_new = Optimisers.update!(tree, getdata(ps), getdata(gs))
return tree, ComponentArray(ps_new, getaxes(ps))
end

# Freezing
Lux._merge(nt1::ComponentArray, nt2::NamedTuple) = merge(NamedTuple(nt1), nt2)
Lux._merge(nt1::NamedTuple, nt2::ComponentArray) = merge(nt1, NamedTuple(nt2))

# Parameter Sharing
Lux._parameter_structure(ps::ComponentArray) = Lux._parameter_structure(NamedTuple(ps))

end
37 changes: 12 additions & 25 deletions ext/Flux2LuxExt.jl → ext/LuxFluxTransformExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
module Flux2LuxExt

@static if isdefined(Base, :get_extension)
import Flux
using Optimisers
else
import ..Flux
using ..Optimisers
end
module LuxFluxTransformExt

using Lux, Random
isdefined(Base, :get_extension) ? (import Flux) : (import ..Flux)
using Lux, Random, Optimisers
import Lux: transform, FluxLayer

struct FluxModelConversionError <: Exception
msg::String
Expand Down Expand Up @@ -42,12 +36,6 @@ API internally.
- `p`: Flattened parameters of the `layer`
"""
struct FluxLayer{L, RE, I} <: Lux.AbstractExplicitLayer
layer::L
re::RE
init_parameters::I
end

function FluxLayer(l)
p, re = Optimisers.destructure(l)
p_ = copy(p)
Expand All @@ -68,7 +56,7 @@ Convert a Flux Model to Lux Model.
!!! warning
`transform` always ingores the `active` field of some of the Flux layers. This is
almost never going to be supported on Flux2Lux.
almost never going to be supported.
## Arguments
Expand All @@ -88,7 +76,8 @@ Convert a Flux Model to Lux Model.
# Examples
```julia
using Flux2Lux, Lux, Metalhead, Random
import Flux
using Lux, Metalhead, Random
m = ResNet(18)
m2 = transform(m.layers)
Expand Down Expand Up @@ -339,8 +328,8 @@ function transform(l::Flux.BatchNorm; preserve_ps_st::Bool=false,
if preserve_ps_st
if l.track_stats
force_preserve && return FluxLayer(l)
@warn """Preserving the state of `Flux.BatchNorm` is currently not supported by
Flux2Lux. Ignoring the state.""" maxlog=1
@warn """Preserving the state of `Flux.BatchNorm` is currently not supported.
Ignoring the state.""" maxlog=1
end
if l.affine
return BatchNorm(l.chs, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum,
Expand All @@ -358,8 +347,8 @@ function transform(l::Flux.GroupNorm; preserve_ps_st::Bool=false,
if preserve_ps_st
if l.track_stats
force_preserve && return FluxLayer(l)
@warn """Preserving the state of `Flux.GroupNorm` is currently not supported by
Flux2Lux. Ignoring the state.""" maxlog=1
@warn """Preserving the state of `Flux.GroupNorm` is currently not supported.
Ignoring the state.""" maxlog=1
end
if l.affine
return GroupNorm(l.chs, l.G, l.λ; l.affine, l.track_stats, epsilon=l.ϵ,
Expand All @@ -379,6 +368,4 @@ function transform(l::T; kwargs...) where {T <: _INVALID_TRANSFORMATION_TYPES}
throw(FluxModelConversionError("Transformation of type $(T) is not supported."))
end

export transform, FluxLayer

end
end
29 changes: 20 additions & 9 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,11 @@ using Random, Statistics, LinearAlgebra, SparseArrays
using Functors, Setfield
import Adapt: adapt, adapt_storage
# Arrays
using FillArrays, ComponentArrays
using FillArrays
# Automatic Differentiation
using ChainRulesCore, Zygote
# Optional Dependency
using Requires
# Docstrings
using Markdown
# Optimisers + ComponentArrays
using Optimisers

# LuxCore
using LuxCore
Expand Down Expand Up @@ -48,17 +44,21 @@ include("layers/display.jl")
# AutoDiff
include("autodiff.jl")

# Flux to Lux
# Extensions
if !isdefined(Base, :get_extension)
using Requires
end

function __init__()
@static if !isdefined(Base, :get_extension)
# Handling ComponentArrays
@require ComponentArrays="b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" begin
include("../ext/LuxComponentArraysExt.jl")
end

# Flux InterOp
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
@require Optimisers="3bd65402-5787-11e9-1adc-39752487f4e2" begin
include("../ext/Flux2LuxExt.jl")
end
include("../ext/LuxFluxTransformExt.jl")
end
end
end
Expand Down Expand Up @@ -92,4 +92,15 @@ export NoOpLayer, ReshapeLayer, SelectDim, FlattenLayer, WrappedFunction, Activa
export RNNCell, LSTMCell, GRUCell, Recurrence, StatefulRecurrentCell
export SamePad

# Extension Exports: Flux
function transform end

struct FluxLayer{L, RE, I} <: Lux.AbstractExplicitLayer
layer::L
re::RE
init_parameters::I
end

export transform, FluxLayer

end
5 changes: 0 additions & 5 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@ function ChainRulesCore.rrule(::typeof(copy), x)
return copy(x), copy_pullback
end

# Zygote Fixes
function Zygote.accum(x::ComponentArray, ys::ComponentArray...)
return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x))
end

# Adapt Interface
function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray)
return Array(x), d -> (NoTangent(), CUDA.cu(d))
Expand Down
2 changes: 0 additions & 2 deletions src/contrib/freeze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ function initialstates(rng::AbstractRNG, l::FrozenLayer{which_params}) where {wh
end

_merge(nt1::NamedTuple, nt2::NamedTuple) = merge(nt1, nt2)
_merge(nt1::ComponentArray, nt2::NamedTuple) = merge(NamedTuple(nt1), nt2)
_merge(nt1::NamedTuple, nt2::ComponentArray) = merge(nt1, NamedTuple(nt2))

function (f::FrozenLayer)(x, ps, st::NamedTuple)
y, st_ = f.layer(x, _merge(ps, st.frozen_params), st.states)
Expand Down
9 changes: 3 additions & 6 deletions src/contrib/share_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,14 @@ function _safe_update_parameter(ps, lens, new_ps)
msg = "The structure of the new parameters must be the same as the " *
"old parameters for lens $(lens)!!! The new parameters have a structure: " *
"$new_ps_st while the old parameters have a structure: $ps_st."
if !(new_ps isa Union{ComponentArray, AbstractArray, NamedTuple, Tuple, Number})
msg = msg *
"This could potentially be caused since `_parameter_structure` is not" *
" appropriately defined for type $(typeof(new_ps))."
end
msg = msg *
" This could potentially be caused since `_parameter_structure` is not" *
" appropriately defined for type $(typeof(new_ps))."
throw(ArgumentError(msg))
end
return Setfield.set(ps, lens, new_ps)
end

_parameter_structure(ps::ComponentArray) = _parameter_structure(NamedTuple(ps))
_parameter_structure(ps::AbstractArray) = size(ps)
_parameter_structure(::Number) = 1
_parameter_structure(ps) = fmap(_parameter_structure, ps)
Expand Down
11 changes: 0 additions & 11 deletions src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,5 @@ if Preferences.@load_preference("LuxSnoopPrecompile", true)

layer(x, ps, st)
Zygote.gradient(p -> sum(layer(x, ps, st)[1]), ps)

# ComponentArrays
if Preferences.@load_preference("LuxPrecompileComponentArrays", true)
ps, st = setup(rng, layer)
ps = ps |> ComponentArray |> dev
st = st |> dev
x = rand(rng, Float32, x_size...) |> dev

layer(x, ps, st)
Zygote.gradient(p -> sum(layer(x, ps, st)[1]), ps)
end
end
end
22 changes: 0 additions & 22 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,24 +171,6 @@ function _calc_padding(::SamePad, k::NTuple{N, T}, dilation, stride) where {N, T
return Tuple(mapfoldl(i -> [cld(i, 2), fld(i, 2)], vcat, pad_amt))
end

# Handling ComponentArrays
function Functors.functor(::Type{<:ComponentArray}, c)
return NamedTuple{propertynames(c)}(getproperty.((c,), propertynames(c))),
ComponentArray
end

Optimisers.setup(opt::AbstractRule, ps::ComponentArray) = Optimisers.setup(opt, getdata(ps))

function Optimisers.update(tree, ps::ComponentArray, gs::ComponentArray)
tree, ps_new = Optimisers.update(tree, getdata(ps), getdata(gs))
return tree, ComponentArray(ps_new, getaxes(ps))
end

function Optimisers.update!(tree::Optimisers.Leaf, ps::ComponentArray, gs::ComponentArray)
tree, ps_new = Optimisers.update!(tree, getdata(ps), getdata(gs))
return tree, ComponentArray(ps_new, getaxes(ps))
end

# Getting typename
get_typename(::T) where {T} = Base.typename(T).wrapper

Expand Down Expand Up @@ -235,10 +217,6 @@ end
end
end

@inline function _getproperty(x::ComponentArray, ::Val{prop}) where {prop}
return prop in propertynames(x) ? getproperty(x, prop) : nothing
end

@inline function _eachslice(x::T, ::Val{dims}) where {T <: AbstractArray, dims}
return [selectdim(x, dims, i) for i in axes(x, dims)]
end
Loading

0 comments on commit 1acb52e

Please sign in to comment.