diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 300694473b..3a1edd8af3 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -27,7 +27,7 @@ jobs: - name: "Run CompatHelper" run: | import CompatHelper - CompatHelper.main(; subdirs=["", "examples", "lib/Boltz", "lib/LuxLib"]) + CompatHelper.main(; subdirs=["", "examples", "lib/Boltz", "lib/LuxLib", "lib/Flux2Lux"]) shell: julia --color=yes {0} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/lib/Boltz/Project.toml b/lib/Boltz/Project.toml index 53a2958b3d..8b22e1dbac 100644 --- a/lib/Boltz/Project.toml +++ b/lib/Boltz/Project.toml @@ -1,13 +1,14 @@ name = "Boltz" uuid = "4544d5e4-abc5-4dea-817f-29e4c205d9c8" authors = ["Avik Pal and contributors"] -version = "0.1.3" +version = "0.1.4" [deps] Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +Flux2Lux = "ab51a4a6-c8c3-4b1f-af31-4b52a21037df" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" @@ -19,7 +20,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" CUDA = "3" ChainRulesCore = "1.15" JLD2 = "0.4" -Lux = "0.4.26" # Boltz Needs some of the Newer Layers +Flux2Lux = "0.1" +Lux = "0.4.26" Metalhead = "0.7" NNlib = "0.8" julia = "1.6" diff --git a/lib/Boltz/src/Boltz.jl b/lib/Boltz/src/Boltz.jl index af91f02829..3a39809dae 100644 --- a/lib/Boltz/src/Boltz.jl +++ b/lib/Boltz/src/Boltz.jl @@ -7,7 +7,7 @@ using Artifacts, JLD2, LazyArtifacts # TODO(@avik-pal): We want to have generic Lux implementaions for Metalhead models # We can automatically convert several Metalhead.jl models to Lux -using Metalhead +using Flux2Lux, Metalhead # Mark certain parts of layers as non-differentiable import ChainRulesCore diff --git a/lib/Boltz/src/utils.jl b/lib/Boltz/src/utils.jl index b8851e17af..256243c13f 100644 --- a/lib/Boltz/src/utils.jl +++ b/lib/Boltz/src/utils.jl @@ -1,57 +1,37 @@ """ - fast_chunk(x::AbstractArray, ::Val{n}, ::Val{dim}) + _fast_chunk(x::AbstractArray, ::Val{n}, ::Val{dim}) Type-stable and faster version of `MLUtils.chunk` """ -@inline fast_chunk(h::Int, n::Int) = (1:h) .+ h * (n - 1) -@inline function fast_chunk(x::AbstractArray, h::Int, n::Int, ::Val{dim}) where {dim} - return selectdim(x, dim, fast_chunk(h, n)) +@inline _fast_chunk(h::Int, n::Int) = (1:h) .+ h * (n - 1) +@inline function _fast_chunk(x::AbstractArray, h::Int, n::Int, ::Val{dim}) where {dim} + return selectdim(x, dim, _fast_chunk(h, n)) end -@inline function fast_chunk(x::CuArray, h::Int, n::Int, ::Val{dim}) where {dim} - # NOTE(@avik-pal): Most CuArray dispatches rely on a contiguous memory layout. Copying - # might be slow but allows us to use the faster and more reliable - # dispatches. - return copy(selectdim(x, dim, fast_chunk(h, n))) +# NOTE(@avik-pal): Most CuArray dispatches rely on a contiguous memory layout. Copying +# might be slow but allows us to use the faster and more reliable +# dispatches. +@inline function _fast_chunk(x::CuArray, h::Int, n::Int, ::Val{dim}) where {dim} + return copy(selectdim(x, dim, _fast_chunk(h, n))) end -@inline function fast_chunk(x::AbstractArray, ::Val{N}, d::Val{D}) where {N, D} - return fast_chunk.((x,), size(x, D) ÷ N, 1:N, d) +@inline function _fast_chunk(x::AbstractArray, ::Val{N}, d::Val{D}) where {N, D} + return _fast_chunk.((x,), size(x, D) ÷ N, 1:N, d) end """ - flatten_spatial(x::AbstractArray{T, 4}) + _flatten_spatial(x::AbstractArray{T, 4}) Flattens the first 2 dimensions of `x`, and permutes the remaining dimensions to (2, 1, 3) """ -@inline function flatten_spatial(x::AbstractArray{T, 4}) where {T} +@inline function _flatten_spatial(x::AbstractArray{T, 4}) where {T} return permutedims(reshape(x, (:, size(x, 3), size(x, 4))), (2, 1, 3)) end """ - seconddimmean(x) + _seconddimmean(x) Computes the mean of `x` along dimension `2` """ -@inline seconddimmean(x) = dropdims(mean(x; dims=2); dims=2) - -""" - normalise(x::AbstractArray, activation; dims=ndims(x), epsilon=ofeltype(x, 1e-5)) - -Normalises the array `x` to have a mean of 0 and standard deviation of 1, and applies the -activation function `activation` to the result. -""" -@inline function normalise(x::AbstractArray, ::typeof(identity); dims=ndims(x), - epsilon=ofeltype(x, 1e-5)) - xmean = mean(x; dims=dims) - xstd = std(x; dims=dims, mean=xmean, corrected=false) - return @. (x - xmean) / (xstd + epsilon) -end - -@inline function normalise(x::AbstractArray, activation; dims=ndims(x), - epsilon=ofeltype(x, 1e-5)) - xmean = mean(x; dims=dims) - xstd = std(x; dims=dims, mean=xmean, corrected=false) - return @. activation((x - xmean) / (xstd + epsilon)) -end +@inline _seconddimmean(x) = dropdims(mean(x; dims=2); dims=2) # Model construction utilities function assert_name_present_in(name, possibilities) @@ -60,19 +40,19 @@ end # TODO(@avik-pal): Starting v0.2 we should be storing only the parameters and some of the # states. Fields like rng don't need to be stored explicitly. -get_pretrained_weights_path(name::Symbol) = get_pretrained_weights_path(string(name)) -function get_pretrained_weights_path(name::String) +_get_pretrained_weights_path(name::Symbol) = _get_pretrained_weights_path(string(name)) +function _get_pretrained_weights_path(name::String) try return @artifact_str(name) catch LoadError - throw(ArgumentError("No pretrained weights available for `$name`")) + throw(ArgumentError("no pretrained weights available for `$name`")) end end -function initialize_model(name::Symbol, model; pretrained::Bool=false, rng=nothing, seed=0, - kwargs...) +function _initialize_model(name::Symbol, model; pretrained::Bool=false, rng=nothing, seed=0, + kwargs...) if pretrained - path = get_pretrained_weights_path(name) + path = _get_pretrained_weights_path(name) ps = load(joinpath(path, "$name.jld2"), "parameters") st = load(joinpath(path, "$name.jld2"), "states") else diff --git a/lib/Boltz/src/vision/metalhead.jl b/lib/Boltz/src/vision/metalhead.jl index 07e1a4b9cc..999cca3388 100644 --- a/lib/Boltz/src/vision/metalhead.jl +++ b/lib/Boltz/src/vision/metalhead.jl @@ -1,55 +1,67 @@ -function alexnet(name::Symbol; kwargs...) +function alexnet(name::Symbol; pretrained=false, kwargs...) assert_name_present_in(name, (:alexnet,)) - model = Lux.transform(AlexNet().layers) - return initialize_model(name, model; kwargs...) + model = transform(AlexNet().layers) + + # Compatibility with pretrained weights + if pretrained + model = Chain(model[1], model[2]) + end + + return _initialize_model(name, model; pretrained, kwargs...) end -function resnet(name::Symbol; kwargs...) +function resnet(name::Symbol; pretrained=false, kwargs...) assert_name_present_in(name, (:resnet18, :resnet34, :resnet50, :resnet101, :resnet152)) model = if name == :resnet18 - Lux.transform(ResNet(18).layers) + transform(ResNet(18).layers) elseif name == :resnet34 - Lux.transform(ResNet(34).layers) + transform(ResNet(34).layers) elseif name == :resnet50 - Lux.transform(ResNet(50).layers) + transform(ResNet(50).layers) elseif name == :resnet101 - Lux.transform(ResNet(101).layers) + transform(ResNet(101).layers) elseif name == :resnet152 - Lux.transform(ResNet(152).layers) + transform(ResNet(152).layers) end - return initialize_model(name, model; kwargs...) + + # Compatibility with pretrained weights + if pretrained + model = Chain(model[1], model[2]) + end + + return _initialize_model(name, model; pretrained, kwargs...) end function resnext(name::Symbol; kwargs...) assert_name_present_in(name, (:resnext50, :resnext101, :resnext152)) model = if name == :resnext50 - Lux.transform(ResNeXt(50).layers) + transform(ResNeXt(50).layers) elseif name == :resnext101 - Lux.transform(ResNeXt(101).layers) + transform(ResNeXt(101).layers) elseif name == :resnext152 - Lux.transform(ResNeXt(152).layers) + transform(ResNeXt(152).layers) end - return initialize_model(name, model; kwargs...) + return _initialize_model(name, model; kwargs...) end function googlenet(name::Symbol; kwargs...) assert_name_present_in(name, (:googlenet,)) - model = Lux.transform(GoogLeNet().layers) - return initialize_model(name, model; kwargs...) + model = transform(GoogLeNet().layers) + return _initialize_model(name, model; kwargs...) end function densenet(name::Symbol; kwargs...) assert_name_present_in(name, (:densenet121, :densenet161, :densenet169, :densenet201)) model = if name == :densenet121 - Lux.transform(DenseNet(121).layers) + transform(DenseNet(121).layers) elseif name == :densenet161 - Lux.transform(DenseNet(161).layers) + transform(DenseNet(161).layers) elseif name == :densenet169 - Lux.transform(DenseNet(169).layers) + transform(DenseNet(169).layers) elseif name == :densenet201 - Lux.transform(DenseNet(201).layers) + transform(DenseNet(201).layers) end - return initialize_model(name, model; kwargs...) + return _initialize_model(name, model; kwargs...) end function mobilenet(name::Symbol; kwargs...) @@ -57,19 +69,19 @@ function mobilenet(name::Symbol; kwargs...) (:mobilenet_v1, :mobilenet_v2, :mobilenet_v3_small, :mobilenet_v3_large)) model = if name == :mobilenet_v1 - Lux.transform(MobileNetv1().layers) + transform(MobileNetv1().layers) elseif name == :mobilenet_v2 - Lux.transform(MobileNetv2().layers) + transform(MobileNetv2().layers) elseif name == :mobilenet_v3_small - Lux.transform(MobileNetv3(:small).layers) + transform(MobileNetv3(:small).layers) elseif name == :mobilenet_v3_large - Lux.transform(MobileNetv3(:large).layers) + transform(MobileNetv3(:large).layers) end - return initialize_model(name, model; kwargs...) + return _initialize_model(name, model; kwargs...) end function convmixer(name::Symbol; kwargs...) assert_name_present_in(name, (:base, :large, :small)) - model = Lux.transform(ConvMixer(name).layers) - return initialize_model(name, model; kwargs...) + model = transform(ConvMixer(name).layers) + return _initialize_model(name, model; kwargs...) end diff --git a/lib/Boltz/src/vision/vgg.jl b/lib/Boltz/src/vision/vgg.jl index e18beb2697..5e316e6d0c 100644 --- a/lib/Boltz/src/vision/vgg.jl +++ b/lib/Boltz/src/vision/vgg.jl @@ -1,5 +1,5 @@ """ - vgg_block(input_filters, output_filters, depth, batchnorm) + _vgg_block(input_filters, output_filters, depth, batchnorm) A VGG block of convolution layers ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -10,7 +10,7 @@ A VGG block of convolution layers ([reference](https://arxiv.org/abs/1409.1556v6 - `depth`: number of convolution/convolution + batch norm layers - `batchnorm`: set to `true` to include batch normalization after each convolution """ -function vgg_block(input_filters, output_filters, depth, batchnorm) +function _vgg_block(input_filters, output_filters, depth, batchnorm) k = (3, 3) p = (1, 1) layers = [] @@ -26,22 +26,22 @@ function vgg_block(input_filters, output_filters, depth, batchnorm) end """ - vgg_convolutional_layers(config, batchnorm, inchannels) + _vgg_convolutional_layers(config, batchnorm, inchannels) Create VGG convolution layers ([reference](https://arxiv.org/abs/1409.1556v6)). # Arguments - `config`: vector of tuples `(output_channels, num_convolutions)` for each block - (see [`Metalhead.vgg_block`](#)) + (see [`Metalhead._vgg_block`](#)) - `batchnorm`: set to `true` to include batch normalization after each convolution - `inchannels`: number of input channels """ -function vgg_convolutional_layers(config, batchnorm, inchannels) +function _vgg_convolutional_layers(config, batchnorm, inchannels) layers = [] input_filters = inchannels for c in config - push!(layers, vgg_block(input_filters, c..., batchnorm)) + push!(layers, _vgg_block(input_filters, c..., batchnorm)) push!(layers, MaxPool((2, 2); stride=2)) input_filters, _ = c end @@ -49,7 +49,7 @@ function vgg_convolutional_layers(config, batchnorm, inchannels) end """ - vgg_classifier_layers(imsize, nclasses, fcsize, dropout) + _vgg_classifier_layers(imsize, nclasses, fcsize, dropout) Create VGG classifier (fully connected) layers ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -57,12 +57,12 @@ Create VGG classifier (fully connected) layers # Arguments - `imsize`: tuple `(width, height, channels)` indicating the size after the convolution - layers (see [`Metalhead.vgg_convolutional_layers`](#)) + layers (see [`Metalhead._vgg_convolutional_layers`](#)) - `nclasses`: number of output classes - `fcsize`: input and output size of the intermediate fully connected layer - `dropout`: the dropout level between each fully connected layer """ -function vgg_classifier_layers(imsize, nclasses, fcsize, dropout) +function _vgg_classifier_layers(imsize, nclasses, fcsize, dropout) return Chain(FlattenLayer(), Dense(Int(prod(imsize)), fcsize, relu), Dropout(dropout), Dense(fcsize, fcsize, relu), Dropout(dropout), Dense(fcsize, nclasses)) end @@ -80,12 +80,12 @@ Create a VGG model ([reference](https://arxiv.org/abs/1409.1556v6)). - `batchnorm`: set to `true` to use batch normalization after each convolution - `nclasses`: number of output classes - `fcsize`: intermediate fully connected layer size - (see [`Metalhead.vgg_classifier_layers`](#)) + (see [`Metalhead._vgg_classifier_layers`](#)) - `dropout`: dropout level between fully connected layers """ function vgg(imsize; config, inchannels, batchnorm=false, nclasses, fcsize, dropout) - conv = vgg_convolutional_layers(config, batchnorm, inchannels) - class = vgg_classifier_layers((7, 7, 512), nclasses, fcsize, dropout) + conv = _vgg_convolutional_layers(config, batchnorm, inchannels) + class = _vgg_classifier_layers((7, 7, 512), nclasses, fcsize, dropout) return Chain(Chain(conv), class) end @@ -125,5 +125,5 @@ function vgg(name::Symbol; kwargs...) vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[19]], inchannels=3, batchnorm=true, nclasses=1000, fcsize=4096, dropout=0.5f0) end - return initialize_model(name, model; kwargs...) + return _initialize_model(name, model; kwargs...) end diff --git a/lib/Boltz/src/vision/vit.jl b/lib/Boltz/src/vision/vit.jl index 005473ea7d..aa0613d239 100644 --- a/lib/Boltz/src/vision/vit.jl +++ b/lib/Boltz/src/vision/vit.jl @@ -32,7 +32,7 @@ function (m::MultiHeadAttention)(x::AbstractArray{T, 3}, ps, st) where {T} qkv, st_qkv = m.qkv_layer(x_reshaped, ps.qkv_layer, st.qkv_layer) qkv_reshaped = reshape(qkv, nfeatures ÷ m.number_heads, m.number_heads, seq_len, 3 * batch_size) - query, key, value = fast_chunk(qkv_reshaped, Val(3), Val(4)) + query, key, value = _fast_chunk(qkv_reshaped, Val(3), Val(4)) scale = convert(T, sqrt(size(query, 1) / m.number_heads)) key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.number_heads, @@ -144,7 +144,7 @@ function patch_embedding(imsize::Tuple{<:Int, <:Int}=(224, 224); in_channels=3, "Image dimensions must be divisible by the patch size." return Chain(Conv(patch_size, in_channels => embed_planes; stride=patch_size), - flatten ? flatten_spatial : identity, norm_layer(embed_planes)) + flatten ? _flatten_spatial : identity, norm_layer(embed_planes)) end # ViT Implementation @@ -163,7 +163,7 @@ function vision_transformer(; imsize::Tuple{<:Int, <:Int}=(256, 256), in_channel transformer_encoder(embed_planes, depth, number_heads; mlp_ratio, dropout_rate), ((pool == :class) ? WrappedFunction(x -> x[:, 1, :]) : - WrappedFunction(seconddimmean)); disable_optimizations=true), + WrappedFunction(_seconddimmean)); disable_optimizations=true), Chain(LayerNorm((embed_planes,); affine=true), Dense(embed_planes, num_classes, tanh); disable_optimizations=true); disable_optimizations=true) @@ -182,5 +182,5 @@ const VIT_CONFIGS = Dict(:tiny => (depth=12, embed_planes=192, number_heads=3), function vision_transformer(name::Symbol; kwargs...) assert_name_present_in(name, keys(VIT_CONFIGS)) model = vision_transformer(; VIT_CONFIGS[name]..., kwargs...) - return initialize_model(name, model; kwargs...) + return _initialize_model(name, model; kwargs...) end diff --git a/src/transform.jl b/src/transform.jl index 9df21bfc5b..07803ebfa4 100644 --- a/src/transform.jl +++ b/src/transform.jl @@ -1,5 +1,3 @@ -# TODO(@avik-pal): Deprecation warnings once Flux2Lux.jl is registered. - import .Flux """ @@ -7,6 +5,11 @@ import .Flux Convert a Flux Model to Lux Model. +!!! tip + + It is recommended to use the package `Flux2Lux` instead of this function. It supports + convertion of a wider variation of Flux models. + # Examples ```julia @@ -22,66 +25,54 @@ ps, st = Lux.setup(Random.default_rng(), m2); m2(x, ps, st) ``` """ -transform(::T) where {T} = error("Transformation for type $T not implemented") +function transform(x) + Base.depwarn("`Lux.transform` has been deprecated in favor of the package `Flux2Lux`." * + "This function will be removed in v0.5", :transform) + return _transform(x) +end + +_transform(::T) where {T} = error("Transformation for type $T not implemented") -transform(model::Flux.Chain) = Chain(transform.(model.layers)...) +_transform(model::Flux.Chain) = Chain(_transform.(model.layers)...) -function transform(model::Flux.BatchNorm) +function _transform(model::Flux.BatchNorm) return BatchNorm(model.chs, model.λ; affine=model.affine, track_stats=model.track_stats, epsilon=model.ϵ, momentum=model.momentum) end -function transform(model::Flux.Conv) +function _transform(model::Flux.Conv) + in_chs = size(model.weight, ndims(model.weight) - 1) * model.groups return Conv(size(model.weight)[1:(end - 2)], - size(model.weight, ndims(model.weight) - 1) * model.groups => size(model.weight, - ndims(model.weight)), - model.σ; stride=model.stride, pad=model.pad, + in_chs => size(model.weight, ndims(model.weight)), model.σ; + stride=model.stride, pad=model.pad, bias=model.bias isa Bool ? model.bias : !(model.bias isa Flux.Zeros), dilation=model.dilation, groups=model.groups) end -function transform(model::Flux.SkipConnection) - return SkipConnection(transform(model.layers), model.connection) +function _transform(model::Flux.SkipConnection) + return SkipConnection(_transform(model.layers), model.connection) end -function transform(model::Flux.Dense) - return Dense(size(model.weight, 2), size(model.weight, 1), model.σ) -end +_transform(model::Flux.Dense) = Dense(size(model.weight, 2), size(model.weight, 1), model.σ) -function transform(model::Flux.MaxPool) - return MaxPool(model.k, model.pad, model.stride) -end +_transform(model::Flux.MaxPool) = MaxPool(model.k, model.pad, model.stride) -function transform(model::Flux.MeanPool) - return MeanPool(model.k, model.pad, model.stride) -end +_transform(model::Flux.MeanPool) = MeanPool(model.k, model.pad, model.stride) -function transform(::Flux.GlobalMaxPool) - return GlobalMaxPool() -end +_transform(::Flux.GlobalMaxPool) = GlobalMaxPool() -function transform(::Flux.GlobalMeanPool) - return GlobalMeanPool() -end +_transform(::Flux.GlobalMeanPool) = GlobalMeanPool() -function transform(p::Flux.AdaptiveMaxPool) - return AdaptiveMaxPool(p.out) -end +_transform(p::Flux.AdaptiveMaxPool) = AdaptiveMaxPool(p.out) -function transform(p::Flux.AdaptiveMeanPool) - return AdaptiveMeanPool(p.out) -end +_transform(p::Flux.AdaptiveMeanPool) = AdaptiveMeanPool(p.out) -function transform(model::Flux.Parallel) - return Parallel(model.connection, transform.(model.layers)...) -end +_transform(model::Flux.Parallel) = Parallel(model.connection, _transform.(model.layers)...) -function transform(d::Flux.Dropout) - return Dropout(Float32(d.p); dims=d.dims) -end +_transform(d::Flux.Dropout) = Dropout(Float32(d.p); dims=d.dims) -transform(::typeof(identity)) = NoOpLayer() +_transform(::typeof(identity)) = NoOpLayer() -transform(::typeof(Flux.flatten)) = FlattenLayer() +_transform(::typeof(Flux.flatten)) = FlattenLayer() -transform(f::Function) = WrappedFunction(f) +_transform(f::Function) = WrappedFunction(f) diff --git a/test/runtests.jl b/test/runtests.jl index 2ef51ffb31..6b861dff65 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,8 +16,11 @@ else end cross_dependencies = Dict("Lux" => [_get_lib_path("LuxLib")], - "Boltz" => [_get_lib_path("LuxLib"), dirname(@__DIR__)], - "LuxLib" => [], + "Boltz" => [ + _get_lib_path("LuxLib"), + dirname(@__DIR__), + _get_lib_path("Flux2Lux"), + ], "LuxLib" => [], "Flux2Lux" => [_get_lib_path("LuxLib"), dirname(@__DIR__)]) const OVERRIDE_INTER_DEPENDENCIES = get(ENV, "OVERRIDE_INTER_DEPENDENCIES", "false") ==