From 5d2a714fab621af445bc5bbf5800eac78f966b57 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 6 Jan 2025 09:31:46 -0500 Subject: [PATCH] docs: update SimpleRNN --- examples/SimpleRNN/Project.toml | 7 +------ examples/SimpleRNN/main.jl | 29 +++++++++++++++++++---------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/examples/SimpleRNN/Project.toml b/examples/SimpleRNN/Project.toml index 81e54f61e5..4eba4ce69e 100644 --- a/examples/SimpleRNN/Project.toml +++ b/examples/SimpleRNN/Project.toml @@ -2,20 +2,15 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" [compat] ADTypes = "1.10" JLD2 = "0.5" Lux = "1" -LuxCUDA = "0.3" MLUtils = "0.4" Optimisers = "0.4.1" -Statistics = "1" -Zygote = "0.6" diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index a11a2c5cbc..bf003bde85 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -9,7 +9,7 @@ # ## Package Imports -using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics +using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random # ## Dataset @@ -34,9 +34,11 @@ function get_dataloaders(; dataset_size=1000, sequence_length=50) ## Create DataLoaders return ( ## Use DataLoader to automatically minibatch and shuffle the data - DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true), + DataLoader( + collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false), ## Don't shuffle the validation data - DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false)) + DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false) + ) end # ## Creating a Classifier @@ -128,31 +130,38 @@ accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred) # ## Training the Model function main(model_type) - dev = gpu_device() + dev = reactant_device() + cdev = cpu_device() ## Get the dataloaders - train_loader, val_loader = get_dataloaders() .|> dev + train_loader, val_loader = get_dataloaders() |> dev ## Create the model model = model_type(2, 8, 1) - rng = Xoshiro(0) - ps, st = Lux.setup(rng, model) |> dev + ps, st = Lux.setup(Random.default_rng(), model) |> dev train_state = Training.TrainState(model, ps, st, Adam(0.01f0)) + model_compiled = if dev isa ReactantDevice + @compile model(first(train_loader)[1], ps, Lux.testmode(st)) + else + model + end + ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote() for epoch in 1:25 ## Train the model for (x, y) in train_loader (_, loss, _, train_state) = Training.single_train_step!( - AutoZygote(), lossfn, (x, y), train_state) - + ad, lossfn, (x, y), train_state + ) @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss end ## Validate the model st_ = Lux.testmode(train_state.states) for (x, y) in val_loader - ŷ, st_ = model(x, train_state.parameters, st_) + ŷ, st_ = model_compiled(x, train_state.parameters, st_) + ŷ, y = cdev(ŷ), cdev(y) loss = lossfn(ŷ, y) acc = accuracy(ŷ, y) @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc