Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: init hidden state for reactant #1026

Merged
merged 7 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ NNlib = "0.9.24"
Optimisers = "0.3.4, 0.4"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.4"
Reactant = "0.2.6"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Optimisers = "0.3.4, 0.4"
Pkg = "1.10"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.4"
Reactant = "0.2.6"
StableRNGs = "1"
StaticArrays = "1"
WeightInitializers = "1"
Expand Down
3 changes: 2 additions & 1 deletion ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ module LuxReactantExt

using Enzyme: Enzyme, Const, Duplicated, Active
using Optimisers: Optimisers
using Reactant: Reactant, @compile, TracedRArray
using Reactant: Reactant, @compile, TracedRArray, TracedRNumber
using Setfield: @set!
using Static: False

using Lux: Lux, LuxOps, Training
using Lux.Training: TrainingBackendCache, ReactantBackend

include("patches.jl")
include("training.jl")

end
7 changes: 7 additions & 0 deletions ext/LuxReactantExt/patches.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# For some reason xlogx and xlogy with boolean inputs leads to incorrect results sometimes
# XXX: Once https://github.com/EnzymeAD/Reactant.jl/pull/278 is merged and tagged
LuxOps.xlogx(x::TracedRNumber{Bool}) = zero(x)

function LuxOps.xlogy(x::TracedRNumber, y::TracedRNumber)
return invoke(LuxOps.xlogy, Tuple{Number, Number}, float(x), float(y))
end
2 changes: 1 addition & 1 deletion lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ EnzymeCore = "0.8.5"
Functors = "0.5"
MLDataDevices = "1.6"
Random = "1.10"
Reactant = "0.2.4"
Reactant = "0.2.6"
ReverseDiff = "1.15"
Setfield = "1"
Tracker = "0.2.36"
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Metal = "1"
OneHotArrays = "0.2.5"
Preferences = "1.4"
Random = "1.10"
Reactant = "0.2.4"
Reactant = "0.2.6"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Expand Down
9 changes: 5 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ using Static: Static, StaticBool, StaticInteger, StaticSymbol
using StaticArraysCore: SMatrix, SVector

using LuxCore: LuxCore, AbstractLuxLayer
using MLDataDevices: get_device
using NNlib: NNlib

const CRC = ChainRulesCore
Expand Down Expand Up @@ -162,11 +161,13 @@ add!!(x::Number, y::Number) = x + y
add!!(::Nothing, ::Nothing) = nothing

function init_rnn_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix)
# TODO: Once we support moving `rng` to the device, we can directly initialize on the
# device
return rnn.init_state(rng, rnn.out_dims, Base.size(x, 2)) |> get_device(x)
y = similar(x, rnn.out_dims, Base.size(x, 2))
copyto!(y, rnn.init_state(rng, size(y)...))
return ArrayInterface.aos_to_soa(y)
end

@non_differentiable init_rnn_hidden_state(::Any...)

function init_trainable_rnn_hidden_state(hidden_state::AbstractVector, x::AbstractMatrix)
return repeat(hidden_state, 1, Base.size(x, 2))
end
Expand Down
65 changes: 65 additions & 0 deletions test/reactant/layer_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
@testsetup module SharedReactantLayersTestSetup

using Lux, Reactant, Enzyme, Zygote

sumabs2(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))

function ∇sumabs2_zygote(model, x, ps, st)
return Zygote.gradient((x, ps) -> sumabs2(model, x, ps, st), x, ps)
end

function ∇sumabs2_enzyme(model, x, ps, st)
dx = Enzyme.make_zero(x)
dps = Enzyme.make_zero(ps)
Enzyme.autodiff(
Enzyme.Reverse, sumabs2, Active,
Const(model), Duplicated(x, dx),
Duplicated(ps, dps), Const(st)
)
return dx, dps
end

export ∇sumabs2_zygote, ∇sumabs2_enzyme

end

@testitem "Recurrent Layers" tags=[:reactant] setup=[
SharedTestSetup, SharedReactantLayersTestSetup] skip=:(Sys.iswindows()) begin
using Reactant, Lux
using LuxTestUtils: check_approx

rng = StableRNG(123)

@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
if mode == "amdgpu"
@warn "Skipping AMDGPU tests for Reactant"
continue
end

if ongpu
Reactant.set_default_backend("gpu")
else
Reactant.set_default_backend("cpu")
end

@testset for cell in (RNNCell, LSTMCell, GRUCell)
model = Recurrence(cell(4 => 4))
ps, st = Lux.setup(rng, model)
ps_ra, st_ra = (ps, st) |> Reactant.to_rarray
x = rand(Float32, 4, 16, 12)
x_ra = x |> Reactant.to_rarray

y_ra, _ = @jit model(x_ra, ps_ra, st_ra)
y, _ = model(x, ps, st)

@test y_ra≈y atol=1e-3 rtol=1e-3

@testset "gradient" begin
∂x, ∂ps = ∇sumabs2_zygote(model, x, ps, st)
∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model, x_ra, ps_ra, st_ra)
@test ∂x_ra≈∂x atol=1e-3 rtol=1e-3
@test check_approx(∂ps_ra, ∂ps; atol=1e-3, rtol=1e-3)
end
end
end
end
88 changes: 31 additions & 57 deletions test/reactant/loss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@
fn1(x) = LuxOps.xlogx.(x)
fn2(x, y) = LuxOps.xlogy.(x, y)

fn1_compiled = @compile fn1(x_ra)
@test fn1(x) ≈ fn1_compiled(x_ra)

fn2_compiled = @compile fn2(x_ra, y_ra)
@test fn2(x, y) ≈ fn2_compiled(x_ra, y_ra)
@test fn1(x) ≈ @jit(fn1(x_ra))
@test fn2(x, y) ≈ @jit(fn2(x_ra, y_ra))
end

@testset "Regression Loss" begin
Expand All @@ -43,14 +40,9 @@
loss_sum = eval(Symbol(loss * "Loss"))(; agg=sum)
loss_sum2 = eval(Symbol(loss * "Loss"))(; agg=(args...) -> sum(args...))

loss_mean_compiled = @compile loss_mean(ŷ_ra, y_ra)
@test loss_mean(ŷ, y) ≈ loss_mean_compiled(ŷ_ra, y_ra)

loss_sum_compiled = @compile loss_sum(ŷ_ra, y_ra)
@test loss_sum(ŷ, y) ≈ loss_sum_compiled(ŷ_ra, y_ra)

loss_sum2_compiled = @compile loss_sum2(ŷ_ra, y_ra)
@test loss_sum2(ŷ, y) ≈ loss_sum2_compiled(ŷ_ra, y_ra)
@test loss_mean(ŷ, y) ≈ @jit(loss_mean(ŷ_ra, y_ra))
@test loss_sum(ŷ, y) ≈ @jit(loss_sum(ŷ_ra, y_ra))
@test loss_sum2(ŷ, y) ≈ @jit(loss_sum2(ŷ_ra, y_ra))
end

@testset "MSLE" begin
Expand All @@ -61,8 +53,7 @@
ŷ_ra = Reactant.to_rarray(ŷ)

loss_msle = MSLELoss()
loss_msle_compiled = @compile loss_msle(ŷ_ra, y_ra)
@test loss_msle(ŷ, y) ≈ loss_msle_compiled(ŷ_ra, y_ra)
@test loss_msle(ŷ, y) ≈ @jit(loss_msle(ŷ_ra, y_ra))
end
end

Expand All @@ -75,39 +66,35 @@

@testset "CrossEntropyLoss" begin
celoss = CrossEntropyLoss()
celoss_compiled = @compile celoss(ŷ_ra, y_ra)
@test celoss(ŷ, y) ≈ celoss_compiled(ŷ_ra, y_ra)
@test celoss(ŷ, y) ≈ @jit(celoss(ŷ_ra, y_ra))

celoss_ls = CrossEntropyLoss(; label_smoothing=0.1)
celoss_ls_compiled = @compile celoss_ls(ŷ_ra, y_ra)
@test celoss_ls(ŷ, y) ≈ celoss_ls_compiled(ŷ_ra, y_ra)
@test celoss_ls(ŷ, y) ≈ @jit(celoss_ls(ŷ_ra, y_ra))

celoss_lp = CrossEntropyLoss(; logits=Val(true))
celoss_lp_compiled = @compile celoss_lp(log.(ŷ_ra), y_ra)
@test celoss_lp(log.(ŷ), y) ≈ celoss_lp_compiled(log.(ŷ_ra), y_ra)
logit_celoss_lp = (ŷ, y) -> celoss_lp(log.(), y)
@test logit_celoss_lp(ŷ, y) ≈ @jit(logit_celoss_lp(ŷ_ra, y_ra))

celoss_lp_ls = CrossEntropyLoss(; logits=Val(true), label_smoothing=0.1)
celoss_lp_ls_compiled = @compile celoss_lp_ls(log.(ŷ_ra), y_ra)
@test celoss_lp_ls(log.(ŷ), y) ≈ celoss_lp_ls_compiled(log.(ŷ_ra), y_ra)
logit_celoss_lp_ls = (ŷ, y) -> celoss_lp_ls(log.(), y)
@test logit_celoss_lp_ls(ŷ, y) ≈ @jit(logit_celoss_lp_ls(ŷ_ra, y_ra))
end

@testset "Binary CrossEntropyLoss" begin
bceloss = BinaryCrossEntropyLoss()
bceloss_compiled = @compile bceloss(ŷ_ra, y_ra)
@test bceloss(ŷ, y) ≈ bceloss_compiled(ŷ_ra, y_ra)
@test bceloss(ŷ, y) ≈ @jit(bceloss(ŷ_ra, y_ra))

bceloss_ls = BinaryCrossEntropyLoss(; label_smoothing=0.1)
bceloss_ls_compiled = @compile bceloss_ls(ŷ_ra, y_ra)
@test bceloss_ls(ŷ, y) ≈ bceloss_ls_compiled(ŷ_ra, y_ra)
@test bceloss_ls(ŷ, y) ≈ @jit(bceloss_ls(ŷ_ra, y_ra))

bceloss_lp = BinaryCrossEntropyLoss(; logits=Val(true))
bceloss_lp_compiled = @compile bceloss_lp(log.(ŷ_ra), y_ra)
@test bceloss_lp(log.(ŷ), y) ≈ bceloss_lp_compiled(log.(ŷ_ra), y_ra)
logit_bceloss_lp = (ŷ, y) -> bceloss_lp(log.(), y)
@test logit_bceloss_lp(ŷ, y) ≈ @jit(logit_bceloss_lp(ŷ_ra, y_ra))

bceloss_lp_ls = BinaryCrossEntropyLoss(;
logits=Val(true), label_smoothing=0.1)
bceloss_lp_ls_compiled = @compile bceloss_lp_ls(log.(ŷ_ra), y_ra)
@test bceloss_lp_ls(log.(ŷ), y) ≈ bceloss_lp_ls_compiled(log.(ŷ_ra), y_ra)
logit_bceloss_lp_ls = (ŷ, y) -> bceloss_lp_ls(log.(), y)
@test logit_bceloss_lp_ls(ŷ, y) ≈ @jit(logit_bceloss_lp_ls(ŷ_ra, y_ra))
end

@testset "BinaryFocalLoss" begin
Expand All @@ -120,8 +107,7 @@
ŷ_ra = Reactant.to_rarray(ŷ)

bfl = BinaryFocalLoss()
bfl_compiled = @compile bfl(ŷ_ra, y_ra)
@test bfl(ŷ, y) ≈ bfl_compiled(ŷ_ra, y_ra)
@test bfl(ŷ, y) ≈ @jit(bfl(ŷ_ra, y_ra))
end

@testset "FocalLoss" begin
Expand All @@ -134,8 +120,7 @@
ŷ_ra = Reactant.to_rarray(ŷ)

fl = FocalLoss()
fl_compiled = @compile fl(ŷ_ra, y_ra)
@test fl(ŷ, y) ≈ fl_compiled(ŷ_ra, y_ra)
@test fl(ŷ, y) ≈ @jit(fl(ŷ_ra, y_ra))
end
end

Expand All @@ -148,8 +133,7 @@
ŷ_ra = Reactant.to_rarray(ŷ)

kldl = KLDivergenceLoss()
kldl_compiled = @compile kldl(ŷ_ra, y_ra)
@test kldl(ŷ, y) ≈ kldl_compiled(ŷ_ra, y_ra)
@test kldl(ŷ, y) ≈ @jit(kldl(ŷ_ra, y_ra))
end

@testset "HingeLoss" begin
Expand All @@ -160,12 +144,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

hl = HingeLoss()
hl_compiled = @compile hl(ŷ_ra, y_ra)
@test hl(ŷ, y) ≈ hl_compiled(ŷ_ra, y_ra)
@test hl(ŷ, y) ≈ @jit(hl(ŷ_ra, y_ra))

hl = HingeLoss(; agg=mean)
hl_compiled = @compile hl(ŷ_ra, y_ra)
@test hl(ŷ, y) ≈ hl_compiled(ŷ_ra, y_ra)
@test hl(ŷ, y) ≈ @jit(hl(ŷ_ra, y_ra))
end

@testset "SquaredHingeLoss" begin
Expand All @@ -176,12 +158,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

hl = SquaredHingeLoss()
hl_compiled = @compile hl(ŷ_ra, y_ra)
@test hl(ŷ, y) ≈ hl_compiled(ŷ_ra, y_ra)
@test hl(ŷ, y) ≈ @jit(hl(ŷ_ra, y_ra))

hl = SquaredHingeLoss(; agg=mean)
hl_compiled = @compile hl(ŷ_ra, y_ra)
@test hl(ŷ, y) ≈ hl_compiled(ŷ_ra, y_ra)
@test hl(ŷ, y) ≈ @jit(hl(ŷ_ra, y_ra))
end

@testset "PoissonLoss" begin
Expand All @@ -192,12 +172,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

pl = PoissonLoss()
pl_compiled = @compile pl(ŷ_ra, y_ra)
@test pl(ŷ, y) ≈ pl_compiled(ŷ_ra, y_ra)
@test pl(ŷ, y) ≈ @jit(pl(ŷ_ra, y_ra))

pl = PoissonLoss(; agg=mean)
pl_compiled = @compile pl(ŷ_ra, y_ra)
@test pl(ŷ, y) ≈ pl_compiled(ŷ_ra, y_ra)
@test pl(ŷ, y) ≈ @jit(pl(ŷ_ra, y_ra))
end

@testset "DiceCoeffLoss" begin
Expand All @@ -208,12 +186,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

dl = DiceCoeffLoss()
dl_compiled = @compile dl(ŷ_ra, y_ra)
@test dl(ŷ, y) ≈ dl_compiled(ŷ_ra, y_ra)
@test dl(ŷ, y) ≈ @jit(dl(ŷ_ra, y_ra))

dl = DiceCoeffLoss(; agg=mean)
dl_compiled = @compile dl(ŷ_ra, y_ra)
@test dl(ŷ, y) ≈ dl_compiled(ŷ_ra, y_ra)
@test dl(ŷ, y) ≈ @jit(dl(ŷ_ra, y_ra))
end

@testset "Siamese Contrastive Loss" begin
Expand All @@ -228,12 +204,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

sl = SiameseContrastiveLoss()
sl_compiled = @compile sl(ŷ_ra, y_ra)
@test sl(ŷ, y) ≈ sl_compiled(ŷ_ra, y_ra)
@test sl(ŷ, y) ≈ @jit(sl(ŷ_ra, y_ra))

sl = SiameseContrastiveLoss(; agg=mean)
sl_compiled = @compile sl(ŷ_ra, y_ra)
@test sl(ŷ, y) ≈ sl_compiled(ŷ_ra, y_ra)
@test sl(ŷ, y) ≈ @jit(sl(ŷ_ra, y_ra))
end
end
end
Expand Down
17 changes: 12 additions & 5 deletions test/reactant/training_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,23 @@
ps, st = Lux.setup(StableRNG(1234), model) |> xdev

x_ra = randn(Float32, 2, 32) |> xdev
y_ra = rand(Float32, 2, 32) |> xdev

inference_fn = @compile model(x_ra, ps, Lux.testmode(st))
inference_loss_fn = (xᵢ, yᵢ, mode, ps, st) -> begin
ŷᵢ, _ = model(xᵢ, ps, Lux.testmode(st))
return MSELoss()(ŷᵢ, yᵢ)
end
inference_loss_fn_compiled = @compile inference_loss_fn(
x_ra, y_ra, model, ps, st
)

x = [rand(Float32, 2, 32) for _ in 1:32]
y = [xᵢ .^ 2 for xᵢ in x]

dataloader = DeviceIterator(xdev, zip(x, y))

total_initial_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ)
ŷᵢ, _ = inference_fn(xᵢ, ps, Lux.testmode(st))
return MSELoss()(ŷᵢ, yᵢ)
inference_loss_fn_compiled(xᵢ, yᵢ, model, ps, st)
end

train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
Expand All @@ -52,8 +58,9 @@
end

total_final_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ)
ŷᵢ, _ = inference_fn(xᵢ, train_state.parameters, Lux.testmode(st))
return MSELoss()(ŷᵢ, yᵢ)
inference_loss_fn_compiled(
xᵢ, yᵢ, model, train_state.parameters, train_state.states
)
end

@test total_final_loss < 100 * total_initial_loss
Expand Down
Loading