Skip to content

Commit

Permalink
Start using Flux2Lux (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Nov 17, 2022
1 parent f903d8e commit 1be3bf7
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 133 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: "Run CompatHelper"
run: |
import CompatHelper
CompatHelper.main(; subdirs=["", "examples", "lib/Boltz", "lib/LuxLib"])
CompatHelper.main(; subdirs=["", "examples", "lib/Boltz", "lib/LuxLib", "lib/Flux2Lux"])
shell: julia --color=yes {0}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
6 changes: 4 additions & 2 deletions lib/Boltz/Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "Boltz"
uuid = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.3"
version = "0.1.4"

[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Flux2Lux = "ab51a4a6-c8c3-4b1f-af31-4b52a21037df"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Expand All @@ -19,7 +20,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
CUDA = "3"
ChainRulesCore = "1.15"
JLD2 = "0.4"
Lux = "0.4.26" # Boltz Needs some of the Newer Layers
Flux2Lux = "0.1"
Lux = "0.4.26"
Metalhead = "0.7"
NNlib = "0.8"
julia = "1.6"
2 changes: 1 addition & 1 deletion lib/Boltz/src/Boltz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Artifacts, JLD2, LazyArtifacts

# TODO(@avik-pal): We want to have generic Lux implementaions for Metalhead models
# We can automatically convert several Metalhead.jl models to Lux
using Metalhead
using Flux2Lux, Metalhead

# Mark certain parts of layers as non-differentiable
import ChainRulesCore
Expand Down
62 changes: 21 additions & 41 deletions lib/Boltz/src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,37 @@
"""
fast_chunk(x::AbstractArray, ::Val{n}, ::Val{dim})
_fast_chunk(x::AbstractArray, ::Val{n}, ::Val{dim})
Type-stable and faster version of `MLUtils.chunk`
"""
@inline fast_chunk(h::Int, n::Int) = (1:h) .+ h * (n - 1)
@inline function fast_chunk(x::AbstractArray, h::Int, n::Int, ::Val{dim}) where {dim}
return selectdim(x, dim, fast_chunk(h, n))
@inline _fast_chunk(h::Int, n::Int) = (1:h) .+ h * (n - 1)
@inline function _fast_chunk(x::AbstractArray, h::Int, n::Int, ::Val{dim}) where {dim}
return selectdim(x, dim, _fast_chunk(h, n))
end
@inline function fast_chunk(x::CuArray, h::Int, n::Int, ::Val{dim}) where {dim}
# NOTE(@avik-pal): Most CuArray dispatches rely on a contiguous memory layout. Copying
# might be slow but allows us to use the faster and more reliable
# dispatches.
return copy(selectdim(x, dim, fast_chunk(h, n)))
# NOTE(@avik-pal): Most CuArray dispatches rely on a contiguous memory layout. Copying
# might be slow but allows us to use the faster and more reliable
# dispatches.
@inline function _fast_chunk(x::CuArray, h::Int, n::Int, ::Val{dim}) where {dim}
return copy(selectdim(x, dim, _fast_chunk(h, n)))
end
@inline function fast_chunk(x::AbstractArray, ::Val{N}, d::Val{D}) where {N, D}
return fast_chunk.((x,), size(x, D) ÷ N, 1:N, d)
@inline function _fast_chunk(x::AbstractArray, ::Val{N}, d::Val{D}) where {N, D}
return _fast_chunk.((x,), size(x, D) ÷ N, 1:N, d)
end

"""
flatten_spatial(x::AbstractArray{T, 4})
_flatten_spatial(x::AbstractArray{T, 4})
Flattens the first 2 dimensions of `x`, and permutes the remaining dimensions to (2, 1, 3)
"""
@inline function flatten_spatial(x::AbstractArray{T, 4}) where {T}
@inline function _flatten_spatial(x::AbstractArray{T, 4}) where {T}
return permutedims(reshape(x, (:, size(x, 3), size(x, 4))), (2, 1, 3))
end

"""
seconddimmean(x)
_seconddimmean(x)
Computes the mean of `x` along dimension `2`
"""
@inline seconddimmean(x) = dropdims(mean(x; dims=2); dims=2)

"""
normalise(x::AbstractArray, activation; dims=ndims(x), epsilon=ofeltype(x, 1e-5))
Normalises the array `x` to have a mean of 0 and standard deviation of 1, and applies the
activation function `activation` to the result.
"""
@inline function normalise(x::AbstractArray, ::typeof(identity); dims=ndims(x),
epsilon=ofeltype(x, 1e-5))
xmean = mean(x; dims=dims)
xstd = std(x; dims=dims, mean=xmean, corrected=false)
return @. (x - xmean) / (xstd + epsilon)
end

@inline function normalise(x::AbstractArray, activation; dims=ndims(x),
epsilon=ofeltype(x, 1e-5))
xmean = mean(x; dims=dims)
xstd = std(x; dims=dims, mean=xmean, corrected=false)
return @. activation((x - xmean) / (xstd + epsilon))
end
@inline _seconddimmean(x) = dropdims(mean(x; dims=2); dims=2)

# Model construction utilities
function assert_name_present_in(name, possibilities)
Expand All @@ -60,19 +40,19 @@ end

# TODO(@avik-pal): Starting v0.2 we should be storing only the parameters and some of the
# states. Fields like rng don't need to be stored explicitly.
get_pretrained_weights_path(name::Symbol) = get_pretrained_weights_path(string(name))
function get_pretrained_weights_path(name::String)
_get_pretrained_weights_path(name::Symbol) = _get_pretrained_weights_path(string(name))
function _get_pretrained_weights_path(name::String)
try
return @artifact_str(name)
catch LoadError
throw(ArgumentError("No pretrained weights available for `$name`"))
throw(ArgumentError("no pretrained weights available for `$name`"))
end
end

function initialize_model(name::Symbol, model; pretrained::Bool=false, rng=nothing, seed=0,
kwargs...)
function _initialize_model(name::Symbol, model; pretrained::Bool=false, rng=nothing, seed=0,
kwargs...)
if pretrained
path = get_pretrained_weights_path(name)
path = _get_pretrained_weights_path(name)
ps = load(joinpath(path, "$name.jld2"), "parameters")
st = load(joinpath(path, "$name.jld2"), "states")
else
Expand Down
68 changes: 40 additions & 28 deletions lib/Boltz/src/vision/metalhead.jl
Original file line number Diff line number Diff line change
@@ -1,75 +1,87 @@
function alexnet(name::Symbol; kwargs...)
function alexnet(name::Symbol; pretrained=false, kwargs...)
assert_name_present_in(name, (:alexnet,))
model = Lux.transform(AlexNet().layers)
return initialize_model(name, model; kwargs...)
model = transform(AlexNet().layers)

# Compatibility with pretrained weights
if pretrained
model = Chain(model[1], model[2])
end

return _initialize_model(name, model; pretrained, kwargs...)
end

function resnet(name::Symbol; kwargs...)
function resnet(name::Symbol; pretrained=false, kwargs...)
assert_name_present_in(name, (:resnet18, :resnet34, :resnet50, :resnet101, :resnet152))
model = if name == :resnet18
Lux.transform(ResNet(18).layers)
transform(ResNet(18).layers)
elseif name == :resnet34
Lux.transform(ResNet(34).layers)
transform(ResNet(34).layers)
elseif name == :resnet50
Lux.transform(ResNet(50).layers)
transform(ResNet(50).layers)
elseif name == :resnet101
Lux.transform(ResNet(101).layers)
transform(ResNet(101).layers)
elseif name == :resnet152
Lux.transform(ResNet(152).layers)
transform(ResNet(152).layers)
end
return initialize_model(name, model; kwargs...)

# Compatibility with pretrained weights
if pretrained
model = Chain(model[1], model[2])
end

return _initialize_model(name, model; pretrained, kwargs...)
end

function resnext(name::Symbol; kwargs...)
assert_name_present_in(name, (:resnext50, :resnext101, :resnext152))
model = if name == :resnext50
Lux.transform(ResNeXt(50).layers)
transform(ResNeXt(50).layers)
elseif name == :resnext101
Lux.transform(ResNeXt(101).layers)
transform(ResNeXt(101).layers)
elseif name == :resnext152
Lux.transform(ResNeXt(152).layers)
transform(ResNeXt(152).layers)
end
return initialize_model(name, model; kwargs...)
return _initialize_model(name, model; kwargs...)
end

function googlenet(name::Symbol; kwargs...)
assert_name_present_in(name, (:googlenet,))
model = Lux.transform(GoogLeNet().layers)
return initialize_model(name, model; kwargs...)
model = transform(GoogLeNet().layers)
return _initialize_model(name, model; kwargs...)
end

function densenet(name::Symbol; kwargs...)
assert_name_present_in(name, (:densenet121, :densenet161, :densenet169, :densenet201))
model = if name == :densenet121
Lux.transform(DenseNet(121).layers)
transform(DenseNet(121).layers)
elseif name == :densenet161
Lux.transform(DenseNet(161).layers)
transform(DenseNet(161).layers)
elseif name == :densenet169
Lux.transform(DenseNet(169).layers)
transform(DenseNet(169).layers)
elseif name == :densenet201
Lux.transform(DenseNet(201).layers)
transform(DenseNet(201).layers)
end
return initialize_model(name, model; kwargs...)
return _initialize_model(name, model; kwargs...)
end

function mobilenet(name::Symbol; kwargs...)
assert_name_present_in(name,
(:mobilenet_v1, :mobilenet_v2, :mobilenet_v3_small,
:mobilenet_v3_large))
model = if name == :mobilenet_v1
Lux.transform(MobileNetv1().layers)
transform(MobileNetv1().layers)
elseif name == :mobilenet_v2
Lux.transform(MobileNetv2().layers)
transform(MobileNetv2().layers)
elseif name == :mobilenet_v3_small
Lux.transform(MobileNetv3(:small).layers)
transform(MobileNetv3(:small).layers)
elseif name == :mobilenet_v3_large
Lux.transform(MobileNetv3(:large).layers)
transform(MobileNetv3(:large).layers)
end
return initialize_model(name, model; kwargs...)
return _initialize_model(name, model; kwargs...)
end

function convmixer(name::Symbol; kwargs...)
assert_name_present_in(name, (:base, :large, :small))
model = Lux.transform(ConvMixer(name).layers)
return initialize_model(name, model; kwargs...)
model = transform(ConvMixer(name).layers)
return _initialize_model(name, model; kwargs...)
end
26 changes: 13 additions & 13 deletions lib/Boltz/src/vision/vgg.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
vgg_block(input_filters, output_filters, depth, batchnorm)
_vgg_block(input_filters, output_filters, depth, batchnorm)
A VGG block of convolution layers ([reference](https://arxiv.org/abs/1409.1556v6)).
Expand All @@ -10,7 +10,7 @@ A VGG block of convolution layers ([reference](https://arxiv.org/abs/1409.1556v6
- `depth`: number of convolution/convolution + batch norm layers
- `batchnorm`: set to `true` to include batch normalization after each convolution
"""
function vgg_block(input_filters, output_filters, depth, batchnorm)
function _vgg_block(input_filters, output_filters, depth, batchnorm)
k = (3, 3)
p = (1, 1)
layers = []
Expand All @@ -26,43 +26,43 @@ function vgg_block(input_filters, output_filters, depth, batchnorm)
end

"""
vgg_convolutional_layers(config, batchnorm, inchannels)
_vgg_convolutional_layers(config, batchnorm, inchannels)
Create VGG convolution layers ([reference](https://arxiv.org/abs/1409.1556v6)).
# Arguments
- `config`: vector of tuples `(output_channels, num_convolutions)` for each block
(see [`Metalhead.vgg_block`](#))
(see [`Metalhead._vgg_block`](#))
- `batchnorm`: set to `true` to include batch normalization after each convolution
- `inchannels`: number of input channels
"""
function vgg_convolutional_layers(config, batchnorm, inchannels)
function _vgg_convolutional_layers(config, batchnorm, inchannels)
layers = []
input_filters = inchannels
for c in config
push!(layers, vgg_block(input_filters, c..., batchnorm))
push!(layers, _vgg_block(input_filters, c..., batchnorm))
push!(layers, MaxPool((2, 2); stride=2))
input_filters, _ = c
end
return Chain(layers...)
end

"""
vgg_classifier_layers(imsize, nclasses, fcsize, dropout)
_vgg_classifier_layers(imsize, nclasses, fcsize, dropout)
Create VGG classifier (fully connected) layers
([reference](https://arxiv.org/abs/1409.1556v6)).
# Arguments
- `imsize`: tuple `(width, height, channels)` indicating the size after the convolution
layers (see [`Metalhead.vgg_convolutional_layers`](#))
layers (see [`Metalhead._vgg_convolutional_layers`](#))
- `nclasses`: number of output classes
- `fcsize`: input and output size of the intermediate fully connected layer
- `dropout`: the dropout level between each fully connected layer
"""
function vgg_classifier_layers(imsize, nclasses, fcsize, dropout)
function _vgg_classifier_layers(imsize, nclasses, fcsize, dropout)
return Chain(FlattenLayer(), Dense(Int(prod(imsize)), fcsize, relu), Dropout(dropout),
Dense(fcsize, fcsize, relu), Dropout(dropout), Dense(fcsize, nclasses))
end
Expand All @@ -80,12 +80,12 @@ Create a VGG model ([reference](https://arxiv.org/abs/1409.1556v6)).
- `batchnorm`: set to `true` to use batch normalization after each convolution
- `nclasses`: number of output classes
- `fcsize`: intermediate fully connected layer size
(see [`Metalhead.vgg_classifier_layers`](#))
(see [`Metalhead._vgg_classifier_layers`](#))
- `dropout`: dropout level between fully connected layers
"""
function vgg(imsize; config, inchannels, batchnorm=false, nclasses, fcsize, dropout)
conv = vgg_convolutional_layers(config, batchnorm, inchannels)
class = vgg_classifier_layers((7, 7, 512), nclasses, fcsize, dropout)
conv = _vgg_convolutional_layers(config, batchnorm, inchannels)
class = _vgg_classifier_layers((7, 7, 512), nclasses, fcsize, dropout)
return Chain(Chain(conv), class)
end

Expand Down Expand Up @@ -125,5 +125,5 @@ function vgg(name::Symbol; kwargs...)
vgg((224, 224); config=VGG_CONV_CONFIG[VGG_CONFIG[19]], inchannels=3,
batchnorm=true, nclasses=1000, fcsize=4096, dropout=0.5f0)
end
return initialize_model(name, model; kwargs...)
return _initialize_model(name, model; kwargs...)
end
8 changes: 4 additions & 4 deletions lib/Boltz/src/vision/vit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function (m::MultiHeadAttention)(x::AbstractArray{T, 3}, ps, st) where {T}
qkv, st_qkv = m.qkv_layer(x_reshaped, ps.qkv_layer, st.qkv_layer)
qkv_reshaped = reshape(qkv, nfeatures ÷ m.number_heads, m.number_heads, seq_len,
3 * batch_size)
query, key, value = fast_chunk(qkv_reshaped, Val(3), Val(4))
query, key, value = _fast_chunk(qkv_reshaped, Val(3), Val(4))

scale = convert(T, sqrt(size(query, 1) / m.number_heads))
key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.number_heads,
Expand Down Expand Up @@ -144,7 +144,7 @@ function patch_embedding(imsize::Tuple{<:Int, <:Int}=(224, 224); in_channels=3,
"Image dimensions must be divisible by the patch size."

return Chain(Conv(patch_size, in_channels => embed_planes; stride=patch_size),
flatten ? flatten_spatial : identity, norm_layer(embed_planes))
flatten ? _flatten_spatial : identity, norm_layer(embed_planes))
end

# ViT Implementation
Expand All @@ -163,7 +163,7 @@ function vision_transformer(; imsize::Tuple{<:Int, <:Int}=(256, 256), in_channel
transformer_encoder(embed_planes, depth, number_heads; mlp_ratio,
dropout_rate),
((pool == :class) ? WrappedFunction(x -> x[:, 1, :]) :
WrappedFunction(seconddimmean)); disable_optimizations=true),
WrappedFunction(_seconddimmean)); disable_optimizations=true),
Chain(LayerNorm((embed_planes,); affine=true),
Dense(embed_planes, num_classes, tanh); disable_optimizations=true);
disable_optimizations=true)
Expand All @@ -182,5 +182,5 @@ const VIT_CONFIGS = Dict(:tiny => (depth=12, embed_planes=192, number_heads=3),
function vision_transformer(name::Symbol; kwargs...)
assert_name_present_in(name, keys(VIT_CONFIGS))
model = vision_transformer(; VIT_CONFIGS[name]..., kwargs...)
return initialize_model(name, model; kwargs...)
return _initialize_model(name, model; kwargs...)
end
Loading

2 comments on commit 1be3bf7

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=lib/Boltz

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/72387

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a Boltz-v0.1.4 -m "<description of version>" 1be3bf75b84256ce5a1358f37df35ef3bd3d08ca
git push origin Boltz-v0.1.4

Please sign in to comment.