Skip to content

Commit

Permalink
fix: workaround for #1186
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 8, 2025
1 parent e38900b commit 6557b64
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,22 @@ function wrapped_objective_function(
end

function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
stats_wrapper = StatsAndNewStateWrapper(nothing, nothing)
# XXX: Hacky workaround for https://github.com/LuxDL/Lux.jl/issues/1186
# stats_wrapper = StatsAndNewStateWrapper(nothing, nothing)
# res = Enzyme.gradient(
# Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
# Const(wrapped_objective_function), Const(objective_function),
# Const(model), ps, Const(st), Const(data), Const(stats_wrapper)
# )
# loss, dps = res.val, res.derivs[3]
# return dps, loss, stats_wrapper.stats, stats_wrapper.st
res = Enzyme.gradient(
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
Const(wrapped_objective_function), Const(objective_function),
Const(model), ps, Const(st), Const(data), Const(stats_wrapper)
Const(objective_function), Const(model), ps, Const(st), Const(data)
)
loss, dps = res.val, res.derivs[3]
return dps, loss, stats_wrapper.stats, stats_wrapper.st
(loss, new_st, stats) = res.val
(_, dps, _, _) = res.derivs
return dps, loss, stats, new_st
end

function maybe_dump_to_mlir_file!(f::F, args...) where {F}
Expand Down

0 comments on commit 6557b64

Please sign in to comment.