diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index e80876b204..4e9d160bd0 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,9 +18,10 @@ jobs: fail-fast: false matrix: group: - - Lux # Core Framework - - Boltz # Prebuilt Models using Lux - - LuxLib # Backend of Lux + - Lux # Core Framework + - Boltz # Prebuilt Models using Lux + - LuxLib # Backend of Lux + - Flux2Lux # Flux2Lux Converter version: - '1.6' # JET tests are disabled on 1.6 - '1.7' diff --git a/.github/workflows/CINightly.yml b/.github/workflows/CINightly.yml index 9adc7ea90b..761f9667fa 100644 --- a/.github/workflows/CINightly.yml +++ b/.github/workflows/CINightly.yml @@ -18,9 +18,10 @@ jobs: fail-fast: false matrix: group: - - Lux # Core Framework - - Boltz # Prebuilt Models using Lux - - LuxLib # Backend of Lux + - Lux # Core Framework + - Boltz # Prebuilt Models using Lux + - LuxLib # Backend of Lux + - Flux2Lux # Flux2Lux Converter version: - 'nightly' # merge even if tests fail steps: diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml index acfa793e01..1a422df37c 100644 --- a/.github/workflows/Documentation.yml +++ b/.github/workflows/Documentation.yml @@ -30,7 +30,7 @@ jobs: ${{ runner.os }}-test- ${{ runner.os }}- - name: Install documentation dependencies - run: julia --project=docs -e 'using Pkg; Pkg.develop(PackageSpec(path=joinpath(pwd(), "lib/LuxLib"))); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' + run: julia --project=docs -e 'using Pkg; Pkg.develop(PackageSpec(path=joinpath(pwd(), "lib/Flux2Lux"))); Pkg.develop(PackageSpec(path=joinpath(pwd(), "lib/LuxLib"))); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - name: Install examples dependencies run: julia --project=examples -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - name: Build and deploy diff --git a/docs/Project.toml b/docs/Project.toml index c05cc0cf62..7367b2e3be 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,7 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterMarkdown = "997ab1e6-3595-5248-9280-8efb232c3433" +Flux2Lux = "ab51a4a6-c8c3-4b1f-af31-4b52a21037df" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" diff --git a/docs/make.jl b/docs/make.jl index df8e991801..c49be79ec6 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,10 +1,11 @@ -using Documenter, DocumenterMarkdown, Lux, LuxLib, Pkg +using Documenter, DocumenterMarkdown, Flux2Lux, Lux, LuxLib, Pkg deployconfig = Documenter.auto_detect_deploy_system() Documenter.post_status(deployconfig; type="pending", repo="github.com/avik-pal/Lux.jl.git") makedocs(; sitename="Lux", authors="Avik Pal et al.", clean=true, doctest=true, - modules=[Lux, LuxLib], strict=[ + modules=[Flux2Lux, Lux, LuxLib], + strict=[ :doctest, :linkcheck, :parse_error, diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 8f278f02d8..b9dc908408 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -116,6 +116,9 @@ nav: - "Frameworks": - "Boltz": "lib/Boltz/index.md" - "FluxMPI": "https://avik-pal.github.io/FluxMPI.jl/dev/" + - "Flux2Lux": + - "Introduction": "lib/Flux2Lux/index.md" + - "API Reference": "lib/Flux2Lux/api.md" - "NNlib": "https://fluxml.ai/Flux.jl/stable/models/nnlib/" - "LuxLib": - "Introduction": "lib/LuxLib/index.md" diff --git a/docs/src/lib/Flux2Lux/api.md b/docs/src/lib/Flux2Lux/api.md new file mode 100644 index 0000000000..d178eeba50 --- /dev/null +++ b/docs/src/lib/Flux2Lux/api.md @@ -0,0 +1,21 @@ +```@meta +CurrentModule = Flux2Lux +``` + +## Functions + +```@docs +transform +``` + +## Layers + +```@docs +FluxLayer +``` + +## Index + +```@index +Pages = ["api.md"] +``` diff --git a/docs/src/lib/Flux2Lux/index.md b/docs/src/lib/Flux2Lux/index.md new file mode 100644 index 0000000000..30013a245c --- /dev/null +++ b/docs/src/lib/Flux2Lux/index.md @@ -0,0 +1,22 @@ +# Flux2Lux + +[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) +[![Project Status: Active – The project has reached a stable, usable state and is being actively developed.](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/lib/Flux2Lux/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/lib/Flux2Lux/) + +[![CI](https://github.com/avik-pal/Lux.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/avik-pal/Lux.jl/actions/workflows/CI.yml) +[![CI Nightly](https://github.com/avik-pal/Lux.jl/actions/workflows/CINightly.yml/badge.svg)](https://github.com/avik-pal/Lux.jl/actions/workflows/CINightly.yml) +[![codecov](https://codecov.io/gh/avik-pal/Lux.jl/branch/main/graph/badge.svg?token=IMqBM1e3hz)](https://codecov.io/gh/avik-pal/Lux.jl) +[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/Flux2Lux)](https://pkgs.genieframework.com?packages=Flux2Lux) + +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) +[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + +Flux2Lux is a package that allows you to convert Flux.jl models to Lux.jl. + +## Difference from `Lux.transform` + +`Lux.transform` has been deprecated in favor of `Flux2Lux.jl`. This package is a strict +superset of its predecessor. It provides additional features like `preserve_ps_st` and +`force_transform`. See the documentation of `Flux2Lux.transform` for more details. diff --git a/lib/Flux2Lux/LICENSE b/lib/Flux2Lux/LICENSE new file mode 100644 index 0000000000..1f70fe7580 --- /dev/null +++ b/lib/Flux2Lux/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Avik Pal and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/Flux2Lux/Project.toml b/lib/Flux2Lux/Project.toml new file mode 100644 index 0000000000..fa1692b0f2 --- /dev/null +++ b/lib/Flux2Lux/Project.toml @@ -0,0 +1,16 @@ +name = "Flux2Lux" +uuid = "ab51a4a6-c8c3-4b1f-af31-4b52a21037df" +authors = ["Avik Pal and contributors"] +version = "0.1.0" + +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[compat] +Flux = "0.13" +Lux = "0.4.26" +Optimisers = "0.2" +julia = "1.6" diff --git a/lib/Flux2Lux/README.md b/lib/Flux2Lux/README.md new file mode 100644 index 0000000000..30013a245c --- /dev/null +++ b/lib/Flux2Lux/README.md @@ -0,0 +1,22 @@ +# Flux2Lux + +[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) +[![Project Status: Active – The project has reached a stable, usable state and is being actively developed.](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/lib/Flux2Lux/) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/lib/Flux2Lux/) + +[![CI](https://github.com/avik-pal/Lux.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/avik-pal/Lux.jl/actions/workflows/CI.yml) +[![CI Nightly](https://github.com/avik-pal/Lux.jl/actions/workflows/CINightly.yml/badge.svg)](https://github.com/avik-pal/Lux.jl/actions/workflows/CINightly.yml) +[![codecov](https://codecov.io/gh/avik-pal/Lux.jl/branch/main/graph/badge.svg?token=IMqBM1e3hz)](https://codecov.io/gh/avik-pal/Lux.jl) +[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/Flux2Lux)](https://pkgs.genieframework.com?packages=Flux2Lux) + +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) +[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + +Flux2Lux is a package that allows you to convert Flux.jl models to Lux.jl. + +## Difference from `Lux.transform` + +`Lux.transform` has been deprecated in favor of `Flux2Lux.jl`. This package is a strict +superset of its predecessor. It provides additional features like `preserve_ps_st` and +`force_transform`. See the documentation of `Flux2Lux.transform` for more details. diff --git a/lib/Flux2Lux/src/Flux2Lux.jl b/lib/Flux2Lux/src/Flux2Lux.jl new file mode 100644 index 0000000000..449c5468f3 --- /dev/null +++ b/lib/Flux2Lux/src/Flux2Lux.jl @@ -0,0 +1,377 @@ +module Flux2Lux + +# Don't do using, both Lux and Flux have very similar exports +import Flux +using Lux, Optimisers, Random + +struct FluxModelConversionError <: Exception + msg::String +end + +function Base.showerror(io::IO, e::FluxModelConversionError) + return print(io, "FluxModelConversionError(", e.msg, ")") +end + +""" + FluxLayer(layer) + +Serves as a compatibility layer between Flux and Lux. This uses `Optimisers.destructure` +API internally. + +!!! warning + + Lux was written to overcome the limitations of `destructure` + `Flux`. It is recommended + to rewrite your l in Lux instead of using this layer. + +!!! warning + + Introducing this Layer in your model will lead to type instabilities, given the way + `Optimisers.destructure` works. + +## Arguments + + - `layer`: Flux layer + +## Parameters + + - `p`: Flattened parameters of the `layer` +""" +struct FluxLayer{L, RE, I} <: Lux.AbstractExplicitLayer + layer::L + re::RE + init_parameters::I +end + +function FluxLayer(l) + p, re = Optimisers.destructure(l) + p_ = copy(p) + return FluxLayer(l, re, () -> p_) +end + +Lux.initialparameters(::AbstractRNG, l::FluxLayer) = (p=l.init_parameters(),) + +(l::FluxLayer)(x, ps, st) = l.re(ps.p)(x), st + +Base.show(io::IO, l::FluxLayer) = print(io, "FluxLayer($(l.layer))") + +""" + transform(l; preserve_ps_st::Bool=false, force_preserve::Bool=false) + +Convert a Flux Model to Lux Model. + +!!! warning + + `transform` always ingores the `active` field of some of the Flux layers. This is + almost never going to be supported on Flux2Lux. + +## Arguments + + - `l`: Flux l or any generic Julia function / object. + +## Keyword Arguments + + - `preserve_ps_st`: Set to `true` to preserve the states and parameters of the l. + + - `force_transform`: Some of the transformations with state and parameters preservation + haven't been implemented yet, in these cases, if `force_transform` is `false` a warning + will be printed and a core Lux layer will be returned. Else, it will create a + [`FluxLayer`](@ref). + +# Examples + +```julia +using Flux2Lux, Lux, Metalhead, Random + +m = ResNet(18) +m2 = transform(m.layers) + +x = randn(Float32, 224, 224, 3, 1); + +ps, st = Lux.setup(Random.default_rng(), m2); + +m2(x, ps, st) +``` +""" +function transform(l::T; preserve_ps_st::Bool=false, kwargs...) where {T} + @warn """Transformation for type $T not implemented. Using `FluxLayer` as + a fallback.""" maxlog=1 + + if !preserve_ps_st + @warn """`FluxLayer` uses the parameters and states of the `layer`. It is not + possible to NOT preserve the parameters and states. Ignoring this keyword + argument.""" maxlog=1 + end + + return FluxLayer(l) +end + +transform(l::Function; kwargs...) = WrappedFunction(l) + +function transform(l::Flux.Chain; kwargs...) + fn = x -> transform(x; kwargs...) + layers = map(fn, l.layers) + if layers isa NamedTuple + return Chain(layers; disable_optimizations=true) + else + return Chain(layers...; disable_optimizations=true) + end +end + +function transform(l::Flux.Dense; preserve_ps_st::Bool=false, kwargs...) + out_dims, in_dims = size(l.weight) + if preserve_ps_st + return Dense(in_dims => out_dims, l.σ; init_weight=(args...) -> copy(l.weight), + init_bias=(args...) -> reshape(copy(l.bias), out_dims, 1), + use_bias=!(l.bias isa Bool)) + else + return Dense(in_dims => out_dims, l.σ; use_bias=!(l.bias isa Bool)) + end +end + +function transform(l::Flux.Scale; preserve_ps_st::Bool=false, kwargs...) + if preserve_ps_st + return Scale(size(l.scale), l.σ; init_weight=(args...) -> copy(l.scale), + init_bias=(args...) -> copy(l.bias), use_bias=!(l.bias isa Bool)) + else + return Scale(size(l.scale), l.σ; use_bias=!(l.bias isa Bool)) + end +end + +transform(l::Flux.Maxout; kwargs...) = Maxout(transform.(l.layers; kwargs...)...) + +function transform(l::Flux.SkipConnection; kwargs...) + connection = l.connection isa Function ? l.connection : + transform(l.connection; kwargs...) + return SkipConnection(transform(l.layers; kwargs...), connection) +end + +function transform(l::Flux.Bilinear; preserve_ps_st::Bool=false, kwargs...) + out, in1, in2 = size(l.weight) + if preserve_ps_st + return Bilinear((in1, in2) => out, l.σ; init_weight=(args...) -> copy(l.weight), + init_bias=(args...) -> copy(l.bias), use_bias=!(l.bias isa Bool)) + else + return Bilinear((in1, in2) => out, l.σ; use_bias=!(l.bias isa Bool)) + end +end + +function transform(l::Flux.Parallel; kwargs...) + fn = x -> transform(x; kwargs...) + layers = map(fn, l.layers) + if layers isa NamedTuple + return Parallel(l.connection; layers...) + else + return Parallel(l.connection, layers...) + end +end + +function transform(l::Flux.PairwiseFusion; kwargs...) + fn = x -> transform(x; kwargs...) + layers = map(fn, l.layers) + if layers isa NamedTuple + return PairwiseFusion(l.connection; layers...) + else + return PairwiseFusion(l.connection, layers...) + end +end + +function transform(l::Flux.Embedding; preserve_ps_st::Bool=true, kwargs...) + out_dims, in_dims = size(l.weight) + if preserve_ps_st + return Embedding(in_dims => out_dims; init_weight=(args...) -> copy(l.weight)) + else + return Embedding(in_dims => out_dims) + end +end + +function transform(l::Flux.Conv; preserve_ps_st::Bool=false, kwargs...) + k = size(l.weight)[1:(end - 2)] + in_chs, out_chs = size(l.weight)[(end - 1):end] + groups = l.groups + pad = l.pad isa Flux.SamePad ? SamePad() : l.pad + if preserve_ps_st + _bias = reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1) + return Conv(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, groups, + use_bias=!(l.bias isa Bool), init_weight=(args...) -> copy(l.weight), + init_bias=(args...) -> _bias) + else + return Conv(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, groups, + use_bias=!(l.bias isa Bool)) + end +end + +function transform(l::Flux.ConvTranspose; preserve_ps_st::Bool=false, kwargs...) + k = size(l.weight)[1:(end - 2)] + in_chs, out_chs = size(l.weight)[(end - 1):end] + groups = l.groups + pad = l.pad isa Flux.SamePad ? SamePad() : l.pad + if preserve_ps_st + return ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, + groups, use_bias=!(l.bias isa Bool), + init_weight=(args...) -> copy(l.weight), + init_bias=(args...) -> copy(l.bias)) + else + return ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, + groups, use_bias=!(l.bias isa Bool)) + end +end + +function transform(l::Flux.CrossCor; preserve_ps_st::Bool=false, kwargs...) + k = size(l.weight)[1:(end - 2)] + in_chs, out_chs = size(l.weight)[(end - 1):end] + pad = l.pad isa Flux.SamePad ? SamePad() : l.pad + if preserve_ps_st + return CrossCor(k, in_chs => out_chs, l.σ; l.stride, pad, l.dilation, + use_bias=!(l.bias isa Bool), + init_weight=(args...) -> copy(l.weight), + init_bias=(args...) -> copy(l.bias)) + else + return CrossCor(k, in_chs => out_chs, l.σ; l.stride, pad, l.dilation, + use_bias=!(l.bias isa Bool)) + end +end + +transform(l::Flux.AdaptiveMaxPool; kwargs...) = AdaptiveMaxPool(l.out) + +transform(l::Flux.AdaptiveMeanPool; kwargs...) = AdaptiveMeanPool(l.out) + +transform(::Flux.GlobalMaxPool; kwargs...) = GlobalMaxPool() + +transform(::Flux.GlobalMeanPool; kwargs...) = GlobalMeanPool() + +function transform(l::Flux.MaxPool; kwargs...) + pad = l.pad isa Flux.SamePad ? SamePad() : l.pad + return MaxPool(l.k; l.stride, pad) +end + +function transform(l::Flux.MeanPool; kwargs...) + pad = l.pad isa Flux.SamePad ? SamePad() : l.pad + return MeanPool(l.k; l.stride, pad) +end + +transform(l::Flux.Dropout; kwargs...) = Dropout(l.p; l.dims) + +function transform(l::Flux.LayerNorm; kwargs...) + return Chain(; layernorm=LayerNorm(l.size; epsilon=l.ϵ, affine=false), + scale=transform(l.diag; kwargs...)) +end + +transform(::typeof(identity); kwargs...) = NoOpLayer() + +transform(::typeof(Flux.flatten); kwargs...) = FlattenLayer() + +transform(l::Flux.PixelShuffle; kwargs...) = PixelShuffle(l.r) + +function transform(l::Flux.Upsample{mode}; kwargs...) where {mode} + return Upsample{mode, typeof(l.scale), typeof(l.size)}(l.scale, l.size) +end + +_const_return_anon_function(x) = (args...) -> x + +function transform(l::Flux.RNNCell; preserve_ps_st::Bool=false, force_preserve::Bool=false) + out_dims, in_dims = size(l.Wi) + if preserve_ps_st + if force_preserve + throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux use a " * + "`reset!` mechanism which hasn't been " * + "extensively tested with `FluxLayer`. Rewrite " * + "the model manually to use `RNNCell`.")) + end + @warn """Preserving Parameters: `Wh` & `Wi` for `Flux.RNNCell` is ambiguous in Lux + and hence not supported. Ignoring these parameters.""" maxlog=1 + return RNNCell(in_dims => out_dims, l.σ; init_bias=(args...) -> copy(l.b), + init_state=(args...) -> copy(l.state0)) + else + return RNNCell(in_dims => out_dims, l.σ) + end +end + +function transform(l::Flux.LSTMCell; preserve_ps_st::Bool=false, force_preserve::Bool=false) + _out_dims, in_dims = size(l.Wi) + out_dims = _out_dims ÷ 4 + if preserve_ps_st + if force_preserve + throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux use a " * + "`reset!` mechanism which hasn't been " * + "extensively tested with `FluxLayer`. Rewrite " * + "the model manually to use `LSTMCell`.")) + end + @warn """Preserving Parameters: `Wh` & `Wi` for `Flux.LSTMCell` is ambiguous in Lux + and hence not supported. Ignoring these parameters.""" maxlog=1 + bs = Lux.multigate(l.b, Val(4)) + _s, _m = copy.(l.state0) + return LSTMCell(in_dims => out_dims; init_bias=_const_return_anon_function.(bs), + init_state=(args...) -> _s, init_memory=(args...) -> _m) + else + return LSTMCell(in_dims => out_dims) + end +end + +function transform(l::Flux.GRUCell; preserve_ps_st::Bool=false, force_preserve::Bool=false) + _out_dims, in_dims = size(l.Wi) + out_dims = _out_dims ÷ 3 + if preserve_ps_st + if force_preserve + throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux use a " * + "`reset!` mechanism which hasn't been " * + "extensively tested with `FluxLayer`. Rewrite " * + "the model manually to use `GRUCell`.")) + end + @warn """Preserving Parameters: `Wh` & `Wi` for `Flux.GRUCell` is ambiguous in Lux + and hence not supported. Ignoring these parameters.""" maxlog=1 + bs = Lux.multigate(l.b, Val(3)) + return GRUCell(in_dims => out_dims; init_bias=_const_return_anon_function.(bs), + init_state=(args...) -> copy(l.state0)) + else + return GRUCell(in_dims => out_dims) + end +end + +function transform(l::Flux.BatchNorm; preserve_ps_st::Bool=false, + force_preserve::Bool=false) + if preserve_ps_st + if l.track_stats + force_preserve && return FluxLayer(l) + @warn """Preserving the state of `Flux.BatchNorm` is currently not supported by + Flux2Lux. Ignoring the state.""" maxlog=1 + end + if l.affine + return BatchNorm(l.chs, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum, + init_bias=(args...) -> copy(l.β), + init_scale=(args...) -> copy(l.γ)) + else + return BatchNorm(l.chs, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum) + end + end + return BatchNorm(l.chs, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum) +end + +function transform(l::Flux.GroupNorm; preserve_ps_st::Bool=false, + force_preserve::Bool=false) + if preserve_ps_st + if l.track_stats + force_preserve && return FluxLayer(l) + @warn """Preserving the state of `Flux.GroupNorm` is currently not supported by + Flux2Lux. Ignoring the state.""" maxlog=1 + end + if l.affine + return GroupNorm(l.chs, l.G, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, + l.momentum, init_bias=(args...) -> copy(l.β), + init_scale=(args...) -> copy(l.γ)) + else + return GroupNorm(l.chs, l.G, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, + l.momentum) + end + end + return GroupNorm(l.chs, l.G, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum) +end + +_INVALID_TRANSFORMATION_TYPES = Union{<:Flux.Recur} + +function transform(l::T; kwargs...) where {T <: _INVALID_TRANSFORMATION_TYPES} + throw(FluxModelConversionError("Transformation of type $(T) is not supported.")) +end + +export transform, FluxLayer + +end diff --git a/lib/Flux2Lux/test/Project.toml b/lib/Flux2Lux/test/Project.toml new file mode 100644 index 0000000000..96a939e6da --- /dev/null +++ b/lib/Flux2Lux/test/Project.toml @@ -0,0 +1,8 @@ +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +julia = "1.6" diff --git a/lib/Flux2Lux/test/runtests.jl b/lib/Flux2Lux/test/runtests.jl new file mode 100644 index 0000000000..0fda0b0d38 --- /dev/null +++ b/lib/Flux2Lux/test/runtests.jl @@ -0,0 +1,55 @@ +import Flux +using Flux2Lux, Lux, Random, Test + +@testset "Flux2Lux.jl" begin + @testset "Containers" begin + + end + + @testset "Linear" begin + + end + + @testset "Convolutions" begin + + end + + @testset "Pooling" begin + + end + + @testset "Upsampling" begin + + end + + @testset "Recurrent" begin + + end + + @testset "Normalize" begin + + end + + @testset "Dropout" begin + + end + + @testset "Custom Layer" begin + struct CustomFluxLayer + weight + bias + end + + Flux.@functor CustomFluxLayer + + (c::CustomFluxLayer)(x) = c.weight .* x .+ c.bias + + c = CustomFluxLayer(randn(10), randn(10)) + x = randn(10) + + c_lux = transform(c) + ps, st = Lux.setup(Random.default_rng(), c_lux) + + @test c(x) ≈ c_lux(x, ps, st)[1] + end +end diff --git a/src/transform.jl b/src/transform.jl index 463225e62a..9df21bfc5b 100644 --- a/src/transform.jl +++ b/src/transform.jl @@ -1,3 +1,5 @@ +# TODO(@avik-pal): Deprecation warnings once Flux2Lux.jl is registered. + import .Flux """ diff --git a/test/runtests.jl b/test/runtests.jl index 8797efc4f1..90a51bf4d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,7 @@ function activate_subpkg_env(subpkg) end groups = if GROUP == "All" - ["Lux", "Boltz", "LuxLib"] + ["Lux", "Boltz", "LuxLib", "Flux2Lux"] else [GROUP] end