Skip to content

Commit

Permalink
fix: simpleRNN works with reactant
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 7, 2025
1 parent ba68e64 commit 981a191
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 21 deletions.
18 changes: 14 additions & 4 deletions examples/SimpleRNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,22 +150,32 @@ function main(model_type)

for epoch in 1:25
## Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
total_loss += loss * length(y)
total_samples += length(y)
end
@printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss/total_samples)

## Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0

st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end

@printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss/total_samples) (total_acc/total_samples)
end

return (train_state.parameters, train_state.states) |> cpu_device()
Expand Down
21 changes: 6 additions & 15 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,14 @@ function wrapped_objective_function(
end

function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
# XXX: Hacky workaround for https://github.com/LuxDL/Lux.jl/issues/1186
# stats_wrapper = StatsAndNewStateWrapper(nothing, nothing)
# res = Enzyme.gradient(
# Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
# Const(wrapped_objective_function), Const(objective_function),
# Const(model), ps, Const(st), Const(data), Const(stats_wrapper)
# )
# loss, dps = res.val, res.derivs[3]
# return dps, loss, stats_wrapper.stats, stats_wrapper.st
stats_wrapper = StatsAndNewStateWrapper(nothing, nothing)
res = Enzyme.gradient(
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
Const(objective_function), Const(model), ps, Const(st), Const(data)
Const(wrapped_objective_function), Const(objective_function),
Const(model), ps, Const(st), Const(data), Const(stats_wrapper)
)
(loss, new_st, stats) = res.val
(_, dps, _, _) = res.derivs
return dps, loss, stats, new_st
loss, dps = res.val, res.derivs[3]
return dps, loss, stats_wrapper.stats, stats_wrapper.st
end

function maybe_dump_to_mlir_file!(f::F, args...) where {F}
Expand Down Expand Up @@ -98,8 +90,7 @@ for inplace in ("!", "")
return ts
end

# XXX: Should we add a check to ensure the inputs to this function is same as the one
# used in the compiled function? We can re-trigger the compilation with a warning
# XXX: recompile with a warning if new input types are used
@eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
maybe_dump_to_mlir_file!($(internal_fn), objective_function, ts.model, data,
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxCore"
uuid = "bb33d45b-7691-41d6-9220-0943567d0623"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.1"
version = "1.2.2"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
5 changes: 4 additions & 1 deletion lib/LuxCore/ext/LuxCoreReactantExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module LuxCoreReactantExt

using LuxCore: AbstractLuxLayer
using LuxCore: AbstractLuxLayer, LuxCore
using Reactant: Reactant

# Avoid tracing though models since it won't contain anything useful
Expand All @@ -10,4 +10,7 @@ function Reactant.make_tracer(
return model
end

LuxCore.replicate(rng::Reactant.TracedRNG) = copy(rng)
LuxCore.replicate(rng::Reactant.ConcreteRNG) = copy(rng)

end

0 comments on commit 981a191

Please sign in to comment.