Skip to content

Commit

Permalink
fix: update reactant version
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 20, 2024
1 parent d0fb186 commit e3ab45f
Show file tree
Hide file tree
Showing 11 changed files with 10 additions and 25 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.2"
version = "1.4.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -110,7 +110,7 @@ NNlib = "0.9.24"
Optimisers = "0.4.1"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.8"
Reactant = "0.2.11"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Optimisers = "0.4.1"
Pkg = "1.10"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.8"
Reactant = "0.2.11"
StableRNGs = "1"
StaticArrays = "1"
WeightInitializers = "1"
Expand Down
1 change: 0 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pages = [
"tutorials/intermediate/2_BayesianNN.md",
"tutorials/intermediate/3_HyperNet.md",
"tutorials/intermediate/4_PINN2DPDE.md",
"tutorials/intermediate/5_ConditionalVAE.md",
],
"Advanced" => [
"tutorials/advanced/1_GravitationalWaveForm.md"
Expand Down
4 changes: 0 additions & 4 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,6 @@ export default defineConfig({
text: "Training a PINN on 2D PDE",
link: "/tutorials/intermediate/4_PINN2DPDE",
},
{
text: "Conditional VAE for MNIST using Reactant",
link: "/tutorials/intermediate/5_ConditionalVAE",
}
],
},
{
Expand Down
1 change: 0 additions & 1 deletion docs/tutorials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ const INTERMEDIATE_TUTORIALS = [
"BayesianNN/main.jl" => "CPU",
"HyperNet/main.jl" => "CUDA",
"PINN2DPDE/main.jl" => "CUDA",
"ConditionalVAE/main.jl" => "CUDA",
]
const ADVANCED_TUTORIALS = [
"GravitationalWaveForm/main.jl" => "CPU",
Expand Down
2 changes: 1 addition & 1 deletion examples/ConvMixer/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ PreferenceTools = "0.1.2"
Printf = "1.10"
ProgressBars = "1.5.1"
Random = "1.10"
Reactant = "0.2.8"
Reactant = "0.2.11"
StableRNGs = "1.0.2"
Statistics = "1.10"
Zygote = "0.6.70"
2 changes: 1 addition & 1 deletion examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function accuracy(model, ps, st, dataloader)
return total_correct / total
end

Comonicon.@main function main(; batchsize::Int=64, hidden_dim::Int=256, depth::Int=8,
Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01,
backend::String="reactant")
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxReactantExt/patches.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Utils.vec(x::AnyTracedRArray) = Reactant.materialize_traced_array(vec(x))
Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(vec(x))

# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g
13 changes: 2 additions & 11 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,14 @@ for inplace in ("!", "")

@eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
@show 1213

compiled_grad_and_step_function = @compile $(internal_fn)(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

@show Lux.Functors.fmap(typeof, ts.states)

grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

@show Lux.Functors.fmap(typeof, st)

cache = TrainingBackendCache(
backend, False(), nothing, (; compiled_grad_and_step_function))
@set! ts.cache = cache
Expand All @@ -59,16 +53,13 @@ for inplace in ("!", "")
@set! ts.optimizer_state = opt_state
@set! ts.step = ts.step + 1

@show Lux.Functors.fmap(typeof, ts.states)

return grads, loss, stats, ts
end

# XXX: Should we add a check to ensure the inputs to this function is same as the one
# used in the compiled function? We can re-trigger the compilation with a warning
@eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
@show Lux.Functors.fmap(typeof, ts.parameters)
@show Lux.Functors.fmap(typeof, ts.states)

grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function(
obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state)

Expand Down
2 changes: 1 addition & 1 deletion lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ EnzymeCore = "0.8.6"
Functors = "0.5"
MLDataDevices = "1.6"
Random = "1.10"
Reactant = "0.2.6"
Reactant = "0.2.11"
ReverseDiff = "1.15"
Setfield = "1"
Tracker = "0.2.36"
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Metal = "1"
OneHotArrays = "0.2.5"
Preferences = "1.4"
Random = "1.10"
Reactant = "0.2.6"
Reactant = "0.2.11"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Expand Down

0 comments on commit e3ab45f

Please sign in to comment.