From e3ab45fa95b716c21332db3bf0b02525a0f5ea20 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 20 Dec 2024 12:01:35 +0530 Subject: [PATCH] fix: update reactant version --- Project.toml | 4 ++-- docs/Project.toml | 2 +- docs/make.jl | 1 - docs/src/.vitepress/config.mts | 4 ---- docs/tutorials.jl | 1 - examples/ConvMixer/Project.toml | 2 +- examples/ConvMixer/main.jl | 2 +- ext/LuxReactantExt/patches.jl | 2 +- ext/LuxReactantExt/training.jl | 13 ++----------- lib/LuxCore/Project.toml | 2 +- lib/MLDataDevices/Project.toml | 2 +- 11 files changed, 10 insertions(+), 25 deletions(-) diff --git a/Project.toml b/Project.toml index 5f86abc9d..b0c81f269 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.4.2" +version = "1.4.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -110,7 +110,7 @@ NNlib = "0.9.24" Optimisers = "0.4.1" Preferences = "1.4.3" Random = "1.10" -Reactant = "0.2.8" +Reactant = "0.2.11" Reexport = "1.2.2" ReverseDiff = "1.15" SIMDTypes = "0.1" diff --git a/docs/Project.toml b/docs/Project.toml index 3eb44b24e..0ba60d55e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -56,7 +56,7 @@ Optimisers = "0.4.1" Pkg = "1.10" Printf = "1.10" Random = "1.10" -Reactant = "0.2.8" +Reactant = "0.2.11" StableRNGs = "1" StaticArrays = "1" WeightInitializers = "1" diff --git a/docs/make.jl b/docs/make.jl index 8d407f3d2..c9f2e98a3 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -30,7 +30,6 @@ pages = [ "tutorials/intermediate/2_BayesianNN.md", "tutorials/intermediate/3_HyperNet.md", "tutorials/intermediate/4_PINN2DPDE.md", - "tutorials/intermediate/5_ConditionalVAE.md", ], "Advanced" => [ "tutorials/advanced/1_GravitationalWaveForm.md" diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 35c573943..f785f6a31 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -218,10 +218,6 @@ export default defineConfig({ text: "Training a PINN on 2D PDE", link: "/tutorials/intermediate/4_PINN2DPDE", }, - { - text: "Conditional VAE for MNIST using Reactant", - link: "/tutorials/intermediate/5_ConditionalVAE", - } ], }, { diff --git a/docs/tutorials.jl b/docs/tutorials.jl index b9b9971d3..d9dad6510 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -11,7 +11,6 @@ const INTERMEDIATE_TUTORIALS = [ "BayesianNN/main.jl" => "CPU", "HyperNet/main.jl" => "CUDA", "PINN2DPDE/main.jl" => "CUDA", - "ConditionalVAE/main.jl" => "CUDA", ] const ADVANCED_TUTORIALS = [ "GravitationalWaveForm/main.jl" => "CPU", diff --git a/examples/ConvMixer/Project.toml b/examples/ConvMixer/Project.toml index 11e2f29d3..04fec524d 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/ConvMixer/Project.toml @@ -39,7 +39,7 @@ PreferenceTools = "0.1.2" Printf = "1.10" ProgressBars = "1.5.1" Random = "1.10" -Reactant = "0.2.8" +Reactant = "0.2.11" StableRNGs = "1.0.2" Statistics = "1.10" Zygote = "0.6.70" diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index 08e5553e7..ac36b6f57 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -73,7 +73,7 @@ function accuracy(model, ps, st, dataloader) return total_correct / total end -Comonicon.@main function main(; batchsize::Int=64, hidden_dim::Int=256, depth::Int=8, +Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, backend::String="reactant") diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl index a22c26f31..f9f4519e0 100644 --- a/ext/LuxReactantExt/patches.jl +++ b/ext/LuxReactantExt/patches.jl @@ -1,4 +1,4 @@ -Utils.vec(x::AnyTracedRArray) = Reactant.materialize_traced_array(vec(x)) +Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(vec(x)) # XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index b745acdfb..a37cd54a5 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -36,20 +36,14 @@ for inplace in ("!", "") @eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState) where {F} - @show 1213 - compiled_grad_and_step_function = @compile $(internal_fn)( objective_function, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) - @show Lux.Functors.fmap(typeof, ts.states) - grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function( objective_function, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) - @show Lux.Functors.fmap(typeof, st) - cache = TrainingBackendCache( backend, False(), nothing, (; compiled_grad_and_step_function)) @set! ts.cache = cache @@ -59,16 +53,13 @@ for inplace in ("!", "") @set! ts.optimizer_state = opt_state @set! ts.step = ts.step + 1 - @show Lux.Functors.fmap(typeof, ts.states) - return grads, loss, stats, ts end + # XXX: Should we add a check to ensure the inputs to this function is same as the one + # used in the compiled function? We can re-trigger the compilation with a warning @eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data, ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F} - @show Lux.Functors.fmap(typeof, ts.parameters) - @show Lux.Functors.fmap(typeof, ts.states) - grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function( obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index acb9f2ec1..27d1562ec 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -40,7 +40,7 @@ EnzymeCore = "0.8.6" Functors = "0.5" MLDataDevices = "1.6" Random = "1.10" -Reactant = "0.2.6" +Reactant = "0.2.11" ReverseDiff = "1.15" Setfield = "1" Tracker = "0.2.36" diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 2bc461363..eef790884 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -66,7 +66,7 @@ Metal = "1" OneHotArrays = "0.2.5" Preferences = "1.4" Random = "1.10" -Reactant = "0.2.6" +Reactant = "0.2.11" RecursiveArrayTools = "3.8" ReverseDiff = "1.15" SparseArrays = "1.10"