Skip to content

Commit

Permalink
chore: bump minimum Reactant version (#1125)
Browse files Browse the repository at this point in the history
* chore: bump minimum Reactant version

* fix: manually `set_abi` for reactant
  • Loading branch information
avik-pal authored Dec 6, 2024
1 parent fa3ff80 commit 51c0e47
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ NNlib = "0.9.24"
Optimisers = "0.4.1"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.6"
Reactant = "0.2.8"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Optimisers = "0.4.1"
Pkg = "1.10"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.6"
Reactant = "0.2.8"
StableRNGs = "1"
StaticArrays = "1"
WeightInitializers = "1"
Expand Down
6 changes: 4 additions & 2 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ function compute_gradients_internal_and_step(objective_function::F, model, data,
st, opt_state) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model),
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
opt_state, ps = Optimisers.update(opt_state, ps, dps)
return dps, ps, loss, stats, stₙ, opt_state
Expand All @@ -84,7 +85,8 @@ function compute_gradients_internal_and_step!(objective_function::F, model, data
st, opt_state) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model),
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
# XXX: Inplace updates not actually inplace
opt_state, ps = Optimisers.update!(opt_state, ps, dps)
Expand Down
6 changes: 3 additions & 3 deletions test/reactant/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ end
y_ra, _ = @jit model(x_ra, ps_ra, st_ra)
y, _ = model(x, ps, st)

@test y_ray atol=1e-3 rtol=1e-3
@test y_ray atol=1e-2 rtol=1e-2

@testset "gradient" begin
∂x, ∂ps = ∇sumabs2_zygote(model, x, ps, st)
∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model, x_ra, ps_ra, st_ra)
@test ∂x_ra∂x atol=1e-3 rtol=1e-3
@test check_approx(∂ps_ra, ∂ps; atol=1e-3, rtol=1e-3)
@test ∂x_ra∂x atol=1e-2 rtol=1e-2
@test check_approx(∂ps_ra, ∂ps; atol=1e-2, rtol=1e-2)
end
end
end
Expand Down
3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ const RETESTITEMS_NWORKER_THREADS = parse(

ReTestItems.runtests(Lux;
tags=(tag == "all" ? nothing : [Symbol(tag)]), testitem_timeout=2400,
nworkers, nworker_threads=RETESTITEMS_NWORKER_THREADS,
retries=tag == "reactant" ? 2 : 0
nworkers, nworker_threads=RETESTITEMS_NWORKER_THREADS
)
end
end
Expand Down

0 comments on commit 51c0e47

Please sign in to comment.