Skip to content

Commit

Permalink
Update Boltz to use Lux extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 17, 2023
1 parent 1acb52e commit 852e878
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
julia --code-coverage=user --project=docs/ --color=yes docs/make.jl
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,lib/Boltz/src,lib/LuxLib/src
directories: src,lib/Boltz/src,lib/LuxLib/src,ext
- uses: codecov/codecov-action@v3
with:
files: lcov.info
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterMarkdown = "997ab1e6-3595-5248-9280-8efb232c3433"
Flux2Lux = "ab51a4a6-c8c3-4b1f-af31-4b52a21037df"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ deployconfig = Documenter.auto_detect_deploy_system()
Documenter.post_status(deployconfig; type="pending", repo="github.com/avik-pal/Lux.jl.git")

makedocs(; sitename="Lux", authors="Avik Pal et al.", clean=true, doctest=true,
modules=[Flux2Lux, Lux, LuxLib, LuxCore],
modules=[Lux, LuxLib, LuxCore],
strict=[
:doctest,
:linkcheck,
Expand Down
4 changes: 2 additions & 2 deletions examples/Basics/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ end
# efficient computationally than a VJP, and, conversely, a JVP is more efficient when the
# Jacobian matrix is a tall matrix.

using ForwardDiff, Zygote, AbstractDifferentiation
using ComponentArrays, ForwardDiff, Zygote, AbstractDifferentiation

# ### Gradients

Expand Down Expand Up @@ -252,7 +252,7 @@ Random.seed!(rng, 0)

# Let us initialize the parameters and states (in this case it is empty) for the model.
ps, st = Lux.setup(rng, model)
ps = ps |> Lux.ComponentArray
ps = ps |> ComponentArray

# Set problem dimensions.
n_samples = 20
Expand Down
6 changes: 3 additions & 3 deletions lib/Boltz/Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
name = "Boltz"
uuid = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.4"
version = "0.1.5"

[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Flux2Lux = "ab51a4a6-c8c3-4b1f-af31-4b52a21037df"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Expand All @@ -20,7 +20,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
CUDA = "3"
ChainRulesCore = "1.15"
JLD2 = "0.4"
Flux2Lux = "0.1"
Flux = "0.13"
Lux = "0.4.26"
Metalhead = "0.7"
NNlib = "0.8"
Expand Down
5 changes: 3 additions & 2 deletions lib/Boltz/src/Boltz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ using Artifacts, JLD2, LazyArtifacts

# TODO(@avik-pal): We want to have generic Lux implementaions for Metalhead models
# We can automatically convert several Metalhead.jl models to Lux
using Flux2Lux, Metalhead
using Metalhead
import Flux

# Mark certain parts of layers as non-differentiable
import ChainRulesCore
import ChainRulesCore as CRC

# Utility Functions
include("utils.jl")
Expand Down
2 changes: 1 addition & 1 deletion lib/Boltz/src/vision/vit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ ClassTokens(dim::Int; init=Lux.zeros32) = ClassTokens(dim, init)
Lux.initialparameters(rng::AbstractRNG, c::ClassTokens) = (token=c.init(rng, c.dim, 1, 1),)

_fill_like(y::AbstractArray{T, 3}) where {T} = fill!(similar(y, 1, 1, size(y, 3)), one(T))
ChainRulesCore.@non_differentiable _fill_like(y)
CRC.@non_differentiable _fill_like(y)

function (m::ClassTokens)(x::AbstractArray{T, 3}, ps, st) where {T}
# Generic Alternative: Repeat is extremely inefficient on GPUs and even in general
Expand Down
9 changes: 3 additions & 6 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import Adapt: adapt, adapt_storage
using FillArrays
# Automatic Differentiation
using ChainRulesCore, Zygote
import ChainRulesCore as CRC
# Docstrings
using Markdown

Expand Down Expand Up @@ -52,14 +53,10 @@ end
function __init__()
@static if !isdefined(Base, :get_extension)
# Handling ComponentArrays
@require ComponentArrays="b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" begin
include("../ext/LuxComponentArraysExt.jl")
end
@require ComponentArrays="b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" begin include("../ext/LuxComponentArraysExt.jl") end

# Flux InterOp
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
include("../ext/LuxFluxTransformExt.jl")
end
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin include("../ext/LuxFluxTransformExt.jl") end
end
end

Expand Down
44 changes: 21 additions & 23 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
# Non Differentiable Functions
ChainRulesCore.@non_differentiable replicate(::Any)
ChainRulesCore.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any)
ChainRulesCore.@non_differentiable glorot_normal(::Any...)
ChainRulesCore.@non_differentiable glorot_uniform(::Any...)
ChainRulesCore.@non_differentiable kaiming_normal(::Any...)
ChainRulesCore.@non_differentiable kaiming_uniform(::Any...)
ChainRulesCore.@non_differentiable check_use_cuda()
ChainRulesCore.@non_differentiable istraining(::Any)
ChainRulesCore.@non_differentiable _get_norm_except_dims(::Any, ::Any)
ChainRulesCore.@non_differentiable _affine(::Any)
ChainRulesCore.@non_differentiable _track_stats(::Any)
ChainRulesCore.@non_differentiable _conv_transpose_dims(::Any...)
ChainRulesCore.@non_differentiable _calc_padding(::Any...)
CRC.@non_differentiable replicate(::Any)
CRC.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any)
CRC.@non_differentiable glorot_normal(::Any...)
CRC.@non_differentiable glorot_uniform(::Any...)
CRC.@non_differentiable kaiming_normal(::Any...)
CRC.@non_differentiable kaiming_uniform(::Any...)
CRC.@non_differentiable check_use_cuda()
CRC.@non_differentiable istraining(::Any)
CRC.@non_differentiable _get_norm_except_dims(::Any, ::Any)
CRC.@non_differentiable _affine(::Any)
CRC.@non_differentiable _track_stats(::Any)
CRC.@non_differentiable _conv_transpose_dims(::Any...)
CRC.@non_differentiable _calc_padding(::Any...)

# Utilities
function ChainRulesCore.rrule(::typeof(merge), nt1::NamedTuple{F1},
nt2::NamedTuple{F2}) where {F1, F2}
function CRC.rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1, F2}
y = merge(nt1, nt2)
function merge_pullback(dy)
dnt1 = NamedTuple((f1 => (f1 in F2 ? NoTangent() : getproperty(dy, f1))
Expand All @@ -29,44 +28,43 @@ function ChainRulesCore.rrule(::typeof(merge), nt1::NamedTuple{F1},
return y, merge_pullback
end

function ChainRulesCore.rrule(::typeof(vec), x::AbstractMatrix)
function CRC.rrule(::typeof(vec), x::AbstractMatrix)
y = vec(x)
vec_pullback(dy) = NoTangent(), reshape(dy, size(x))
return y, vec_pullback
end

function ChainRulesCore.rrule(::typeof(collect), v::Vector)
function CRC.rrule(::typeof(collect), v::Vector)
y = collect(v)
function collect_pullback(dy)
return NoTangent(), dy
end
return y, collect_pullback
end

function ChainRulesCore.rrule(::typeof(copy), x)
function CRC.rrule(::typeof(copy), x)
copy_pullback(dy) = (NoTangent(), dy)
return copy(x), copy_pullback
end

# Adapt Interface
function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray)
function CRC.rrule(::Type{Array}, x::CUDA.CuArray)
return Array(x), d -> (NoTangent(), CUDA.cu(d))
end

function ChainRulesCore.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor,
x::CUDA.AbstractGPUArray)
function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::CUDA.AbstractGPUArray)
return adapt_storage(to, x),
d -> (NoTangent(), NoTangent(), adapt_storage(LuxCUDAAdaptor(), d))
end

function ChainRulesCore.rrule(::typeof(adapt_storage), to::LuxCUDAAdaptor, x::Array)
function CRC.rrule(::typeof(adapt_storage), to::LuxCUDAAdaptor, x::Array)
return adapt_storage(to, x),
d -> (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), d))
end

# RNN Helpers
## Taken from https://github.com/FluxML/Flux.jl/blob/1f82da4bfa051c809f7f3ce7dd7aeb43be515b14/src/layers/recurrent.jl#L9
function ChainRulesCore.rrule(::typeof(multigate), x::AbstractArray, c::Val{N}) where {N}
function CRC.rrule(::typeof(multigate), x::AbstractArray, c::Val{N}) where {N}
function multigate_pullback(dy)
dx = map!(zero, similar(x, float(eltype(x)), axes(x)), x)
foreach(multigate(dx, c), dy) do dxᵢ, dyᵢ
Expand Down
20 changes: 16 additions & 4 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -458,14 +458,20 @@ end
display(layer)
ps, st = Lux.setup(rng, layer)
run_JET_tests(layer, y, ps, st; opt_broken=true)
@inferred layer(y, ps, st)
@static if VERSION >= v"1.7"
# Inference broken in v1.6
@inferred layer(y, ps, st)
end
x_hat1 = layer(y, ps, st)[1]

layer = ConvTranspose((3, 3), 1 => 1; use_bias=false)
display(layer)
ps, st = Lux.setup(rng, layer)
run_JET_tests(layer, y, ps, st; opt_broken=true)
@inferred layer(y, ps, st)
@static if VERSION >= v"1.7"
# Inference broken in v1.6
@inferred layer(y, ps, st)
end
x_hat2 = layer(y, ps, st)[1]

@test size(x_hat1) == size(x_hat2) == size(x)
Expand All @@ -475,7 +481,10 @@ end
ps, st = Lux.setup(rng, layer)
x = rand(Float32, 5, 5, 1, 1)
run_JET_tests(layer, x, ps, st; opt_broken=true)
@inferred layer(x, ps, st)
@static if VERSION >= v"1.7"
# Inference broken in v1.6
@inferred layer(x, ps, st)
end
test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3,
rtol=1.0f-3)

Expand All @@ -484,7 +493,10 @@ end
display(layer)
ps, st = Lux.setup(rng, layer)
run_JET_tests(layer, x, ps, st; opt_broken=true)
@inferred layer(x, ps, st)
@static if VERSION >= v"1.7"
# Inference broken in v1.6
@inferred layer(x, ps, st)
end
test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3,
rtol=1.0f-3)

Expand Down

0 comments on commit 852e878

Please sign in to comment.