From f14c794b88bfe7ea72922fcb5412adf7e5ef5ba5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 7 Jan 2025 16:11:02 -0500 Subject: [PATCH] fix: simpleRNN works with reactant --- examples/SimpleRNN/main.jl | 18 ++++++++++++++---- ext/LuxReactantExt/training.jl | 21 ++++++--------------- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/ext/LuxCoreReactantExt.jl | 5 ++++- 4 files changed, 25 insertions(+), 21 deletions(-) diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index e1c53245f1..bc642d2f01 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -150,22 +150,32 @@ function main(model_type) for epoch in 1:25 ## Train the model + total_loss = 0.0f0 + total_samples = 0 for (x, y) in train_loader (_, loss, _, train_state) = Training.single_train_step!( ad, lossfn, (x, y), train_state ) - @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss + total_loss += loss * length(y) + total_samples += length(y) end + @printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss/total_samples) ## Validate the model + total_acc = 0.0f0 + total_loss = 0.0f0 + total_samples = 0 + st_ = Lux.testmode(train_state.states) for (x, y) in val_loader ŷ, st_ = model_compiled(x, train_state.parameters, st_) ŷ, y = cdev(ŷ), cdev(y) - loss = lossfn(ŷ, y) - acc = accuracy(ŷ, y) - @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc + total_acc += accuracy(ŷ, y) * length(y) + total_loss += lossfn(ŷ, y) * length(y) + total_samples += length(y) end + + @printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss/total_samples) (total_acc/total_samples) end return (train_state.parameters, train_state.states) |> cpu_device() diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index 1e811ec596..2462bd252b 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -13,22 +13,14 @@ function wrapped_objective_function( end function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F} - # XXX: Hacky workaround for https://github.com/LuxDL/Lux.jl/issues/1186 - # stats_wrapper = StatsAndNewStateWrapper(nothing, nothing) - # res = Enzyme.gradient( - # Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), - # Const(wrapped_objective_function), Const(objective_function), - # Const(model), ps, Const(st), Const(data), Const(stats_wrapper) - # ) - # loss, dps = res.val, res.derivs[3] - # return dps, loss, stats_wrapper.stats, stats_wrapper.st + stats_wrapper = StatsAndNewStateWrapper(nothing, nothing) res = Enzyme.gradient( Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), - Const(objective_function), Const(model), ps, Const(st), Const(data) + Const(wrapped_objective_function), Const(objective_function), + Const(model), ps, Const(st), Const(data), Const(stats_wrapper) ) - (loss, new_st, stats) = res.val - (_, dps, _, _) = res.derivs - return dps, loss, stats, new_st + loss, dps = res.val, res.derivs[3] + return dps, loss, stats_wrapper.stats, stats_wrapper.st end function maybe_dump_to_mlir_file!(f::F, args...) where {F} @@ -98,8 +90,7 @@ for inplace in ("!", "") return ts end - # XXX: Should we add a check to ensure the inputs to this function is same as the one - # used in the compiled function? We can re-trigger the compilation with a warning + # XXX: recompile with a warning if new input types are used @eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState) where {F} maybe_dump_to_mlir_file!($(internal_fn), objective_function, ts.model, data, diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index acb9f2ec12..5b095d97d3 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "1.2.1" +version = "1.2.2" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/lib/LuxCore/ext/LuxCoreReactantExt.jl b/lib/LuxCore/ext/LuxCoreReactantExt.jl index 3ad0c0dc21..f6e7770964 100644 --- a/lib/LuxCore/ext/LuxCoreReactantExt.jl +++ b/lib/LuxCore/ext/LuxCoreReactantExt.jl @@ -1,6 +1,6 @@ module LuxCoreReactantExt -using LuxCore: AbstractLuxLayer +using LuxCore: AbstractLuxLayer, LuxCore using Reactant: Reactant # Avoid tracing though models since it won't contain anything useful @@ -10,4 +10,7 @@ function Reactant.make_tracer( return model end +LuxCore.replicate(rng::Reactant.TracedRNG) = copy(rng) +LuxCore.replicate(rng::Reactant.ConcreteRNG) = copy(rng) + end