Skip to content

Commit

Permalink
Move Zygote into an extension
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 20, 2023
1 parent 34f3f7d commit 6650103
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 155 deletions.
7 changes: 3 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -31,11 +29,14 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
LuxComponentArraysExt = "ComponentArrays"
LuxComponentArraysZygoteExt = ["ComponentArrays", "Zygote"]
LuxFillArraysExt = "FillArrays"
LuxFluxTransformExt = "Flux"
LuxZygoteExt = "Zygote"

[compat]
Adapt = "3"
Expand All @@ -49,10 +50,8 @@ LuxCore = "0.1"
LuxLib = "0.1.7"
NNlib = "0.8"
Optimisers = "0.2"
Preferences = "1.3"
Requires = "1"
Setfield = "0.8, 1"
SnoopPrecompile = "1"
Zygote = "0.6"
cuDNN = "1"
julia = "1.7"
1 change: 0 additions & 1 deletion docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ nav:
- "Migrating from Flux to Lux": "manual/migrate_from_flux.md"
- "Freezing Model Parameters": "manual/freezing_parameters.md"
- "Dispatch on Custom Inputs": "manual/dispatch_custom_inputs.md"
- "Controlling Precompilation": "manual/precompilation.md"
- "API Reference":
- "Layers": "api/layers.md"
- "Functional": "api/functional.md"
Expand Down
73 changes: 0 additions & 73 deletions docs/src/manual/precompilation.md

This file was deleted.

7 changes: 1 addition & 6 deletions ext/LuxComponentArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module LuxComponentArraysExt

isdefined(Base, :get_extension) ? (using ComponentArrays) : (using ..ComponentArrays)

using Functors, Lux, Optimisers, Zygote
using Functors, Lux, Optimisers
import ChainRulesCore as CRC

@inline function Lux._getproperty(x::ComponentArray, ::Val{prop}) where {prop}
Expand All @@ -14,11 +14,6 @@ function Functors.functor(::Type{<:ComponentArray}, c)
ComponentArray
end

# Zygote Fixes
function Zygote.accum(x::ComponentArray, ys::ComponentArray...)
return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x))
end

# Optimisers Fixes
Optimisers.setup(opt::AbstractRule, ps::ComponentArray) = Optimisers.setup(opt, getdata(ps))

Expand Down
15 changes: 15 additions & 0 deletions ext/LuxComponentArraysZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module LuxComponentArraysZygoteExt

if isdefined(Base, :get_extension)
using ComponentArrays
using Zygote
else
using ..ComponentArrays
using ..Zygote
end

function Zygote.accum(x::ComponentArray, ys::ComponentArray...)
return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x))
end

end
22 changes: 22 additions & 0 deletions ext/LuxZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module LuxZygoteExt

isdefined(Base, :get_extension) ? (using Zygote) : (using ..Zygote)

using Adapt, CUDA, Lux, Setfield

Adapt.adapt_storage(::Lux.LuxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x))

Adapt.adapt_storage(::Lux.LuxCPUAdaptor, x::Zygote.OneElement) = x

function Lux.Training.compute_gradients(::Lux.Training.ZygoteVJP,
objective_function::Function, data,
ts::Lux.Training.TrainState)
(loss, st, stats), back = Zygote.pullback(ps -> objective_function(ts.model, ps,
ts.states, data),
ts.parameters)
grads = back((one(loss), nothing, nothing))[1]
@set! ts.states = st
return grads, loss, stats, ts
end

end
18 changes: 10 additions & 8 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using LinearAlgebra, Markdown, Random, SparseArrays, Statistics
using Functors, Setfield
import Adapt: adapt, adapt_storage
# Automatic Differentiation
using ChainRulesCore, Zygote
using ChainRulesCore
import ChainRulesCore as CRC

# LuxCore
Expand Down Expand Up @@ -57,22 +57,24 @@ end
function __init__()
@static if !isdefined(Base, :get_extension)
# Handling ComponentArrays
@require ComponentArrays="b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" begin include("../ext/LuxComponentArraysExt.jl") end
@require ComponentArrays="b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" begin
include("../ext/LuxComponentArraysExt.jl")
# This definitely needs to be upstreamed
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/LuxComponentArraysExt.jl") end
end

# Flux InterOp
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin include("../ext/LuxFluxTransformExt.jl") end

# FillArrays
@require FillArrays="1a297f60-69ca-5386-bcde-b61e274b549b" begin include("../ext/LuxFillArraysExt.jl") end

# Automatic Differentiation
## Zygote InterOp
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/LuxZygoteExt.jl") end
end
end

# Snoop Precompile
import SnoopPrecompile
import Preferences

SnoopPrecompile.@precompile_all_calls begin include("precompile.jl") end

# Data Transfer
export cpu, gpu
# Layers
Expand Down
4 changes: 1 addition & 3 deletions src/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@ struct LuxCPUAdaptor <: LuxDeviceAdaptor end
struct LuxCUDAAdaptor <: LuxDeviceAdaptor end

adapt_storage(::LuxCUDAAdaptor, x) = CUDA.cu(x)
adapt_storage(::LuxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x))
adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng

function adapt_storage(::LuxCPUAdaptor,
x::Union{AbstractRange, Zygote.OneElement,
SparseArrays.AbstractSparseArray})
x::Union{AbstractRange, SparseArrays.AbstractSparseArray})
return x
end
adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x)
Expand Down
13 changes: 2 additions & 11 deletions src/contrib/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ module Training

# NOTE(@avik-pal): In the long term this will be pulled out into its own package but
# currently all the dependencies are met by Lux itself.
import ..Lux
import Optimisers, Random, Setfield, Zygote
using ..Lux
using Optimisers, Random, Setfield

"""
TrainState
Expand Down Expand Up @@ -125,15 +125,6 @@ struct ZygoteVJP <: AbstractVJP end

backend(::ZygoteVJP) = :Zygote

function compute_gradients(::ZygoteVJP, objective_function::Function, data, ts::TrainState)
(loss, st, stats), back = Zygote.pullback(ps -> objective_function(ts.model, ps,
ts.states, data),
ts.parameters)
grads = back((one(loss), nothing, nothing))[1]
Setfield.@set! ts.states = st
return grads, loss, stats, ts
end

"""
EnzymeVJP <: AbstractVJP
Expand Down
29 changes: 0 additions & 29 deletions src/precompile.jl

This file was deleted.

38 changes: 18 additions & 20 deletions test/contrib/training.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import Lux, Optimisers, Random, Test
using Lux, Optimisers, Random, Test

include("../test_utils.jl")

function _get_TrainState()
rng = Random.MersenneTwister(0)
rng = MersenneTwister(0)

model = Lux.Dense(3, 2)
opt = Optimisers.Adam(0.01f0)
Expand All @@ -26,36 +26,34 @@ function test_TrainState_constructor()
ps, st = Lux.setup(Lux.replicate(rng), model)
opt_st = Optimisers.setup(opt, tstate.parameters)

Test.@test tstate.model == model
Test.@test tstate.parameters == ps
Test.@test tstate.states == st
Test.@test isapprox(tstate.optimizer_state, opt_st)
Test.@test tstate.step == 0
@test tstate.model == model
@test tstate.parameters == ps
@test tstate.states == st
@test isapprox(tstate.optimizer_state, opt_st)
@test tstate.step == 0

return nothing
end

function test_abstract_vjp_interface()
_, tstate, _, _, x = _get_TrainState()

Test.@testset "NotImplemented" begin for vjp_rule in (Lux.Training.EnzymeVJP(),
Lux.Training.YotaVJP())
Test.@test_throws ArgumentError Lux.Training.compute_gradients(vjp_rule,
_loss_function, x,
tstate)
@testset "NotImplemented" begin for vjp_rule in (Lux.Training.EnzymeVJP(),
Lux.Training.YotaVJP())
@test_throws ArgumentError Lux.Training.compute_gradients(vjp_rule, _loss_function,
x, tstate)
end end

# Gradient Correctness should be tested in `test/autodiff.jl` and other parts of the
# testing codebase. Here we only test that the API works.
grads, _, _, _ = Test.@test_nowarn Lux.Training.compute_gradients(Lux.Training.ZygoteVJP(),
_loss_function, x,
tstate)
tstate_ = Test.@test_nowarn Lux.Training.apply_gradients(tstate, grads)
Test.@test tstate_.step == 1
Test.@test tstate != tstate_
grads, _, _, _ = @test_nowarn Lux.Training.compute_gradients(Lux.Training.ZygoteVJP(),
_loss_function, x, tstate)
tstate_ = @test_nowarn Lux.Training.apply_gradients(tstate, grads)
@test tstate_.step == 1
@test tstate != tstate_

return nothing
end

Test.@testset "TrainState" begin test_TrainState_constructor() end
Test.@testset "AbstractVJP" begin test_abstract_vjp_interface() end
@testset "TrainState" begin test_TrainState_constructor() end
@testset "AbstractVJP" begin test_abstract_vjp_interface() end

2 comments on commit 6650103

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/78134

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.39 -m "<description of version>" 6650103503bc8dd6017b6f70c165aa60745d0765
git push origin v0.4.39

Please sign in to comment.