Skip to content

Commit

Permalink
fix: more bug fixes for reactant
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 22, 2024
1 parent 541d4c5 commit 05500f2
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
2 changes: 0 additions & 2 deletions examples/ConvMixer/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -40,6 +39,5 @@ Printf = "1.10"
ProgressBars = "1.5.1"
Random = "1.10"
Reactant = "0.2.11"
StableRNGs = "1.0.2"
Statistics = "1.10"
Zygote = "0.6.70"
7 changes: 4 additions & 3 deletions examples/ConvMixer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ julia --startup-file=no \
--threads=auto \
main.jl \
--lr-max=0.05 \
--weight-decay=0.0001
--weight-decay=0.0001 \
--backend=reactant
```

Here's an example of the output of the above command (on a V100 32GB GPU):
Here's an example output of the above command (on a RTX 4050 6GB Laptop GPU):

```
Epoch 1: Learning Rate 5.05e-03, Train Acc: 56.91%, Test Acc: 56.49%, Time: 129.84
Expand Down Expand Up @@ -69,7 +70,7 @@ Options
--seed <42::Int>
--epochs <25::Int>
--lr-max <0.01::Float64>
--backend <reactant::String>
--backend <gpu_if_available::String>

Flags
--clip-norm
Expand Down
23 changes: 13 additions & 10 deletions examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Comonicon, ConcreteStructs, DataAugmentation, ImageShow, Interpolations, Lux, LuxCUDA,
MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, ProgressBars, Random,
StableRNGs, Statistics, Zygote
Statistics, Zygote
using Reactant, Enzyme

CUDA.allowscalar(false)
Expand Down Expand Up @@ -70,7 +70,6 @@ end
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model(x, ps, st))))
Expand All @@ -81,10 +80,11 @@ function accuracy(model, ps, st, dataloader)
end

Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-4,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.005,
clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05,
backend::String="gpu_if_available")
rng = StableRNG(seed)
rng = Random.default_rng()
Random.seed!(rng, seed)

if backend == "gpu_if_available"
accelerator_device = gpu_device()
Expand Down Expand Up @@ -119,7 +119,8 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::
if backend == "reactant"
x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device
@printf "[Info] Compiling model with Reactant.jl\n"
model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
st_test = Lux.testmode(st)
model_compiled = @compile model(x_ra, ps, st_test)
@printf "[Info] Model compiled!\n"
else
model_compiled = model
Expand All @@ -141,14 +142,16 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::
ttime = time() - stime

train_acc = accuracy(
model_compiled, train_state.parameters, train_state.states, trainloader
model_compiled, train_state.parameters,
Lux.testmode(train_state.states), trainloader
) * 100
test_acc = accuracy(
model_compiled, train_state.parameters, train_state.states, testloader
model_compiled, train_state.parameters,
Lux.testmode(train_state.states), testloader
) * 100

@printf "[Train] Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: \
%.2f%%, Time: %.2f\n" epoch lr train_acc test_acc ttime
@printf "[Train] Epoch %2d: Learning Rate %.6f, Train Acc: %.4f%%, Test Acc: \
%.4f%%, Time: %.2f\n" epoch lr train_acc test_acc ttime
end
@printf "[Info] Finished training\n"
end

0 comments on commit 05500f2

Please sign in to comment.