diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index ac36b6f57..6a7da56f6 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -30,10 +30,10 @@ function get_dataloaders(batchsize; kwargs...) test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) trainset = TensorDataset(CIFAR10(:train), train_transform) - trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true, kwargs...) + trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...) testset = TensorDataset(CIFAR10(:test), test_transform) - testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true, kwargs...) + testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...) return trainloader, testloader end @@ -42,17 +42,24 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) #! format: off return Chain( Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size), - BatchNorm(dim), - [Chain( - SkipConnection( - Chain( - Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim, pad=SamePad()), - BatchNorm(dim) + BatchNorm(dim; track_stats=false), + [ + Chain( + SkipConnection( + Chain( + Conv( + (kernel_size, kernel_size), dim => dim, gelu; + groups=dim, pad=SamePad() + ), + BatchNorm(dim) + ), + + ), - + - ), - Conv((1, 1), dim => dim, gelu), BatchNorm(dim)) - for _ in 1:depth]..., + Conv((1, 1), dim => dim, gelu), + BatchNorm(dim) + ) + for _ in 1:depth + ]..., GlobalMeanPool(), FlattenLayer(), Dense(dim => 10) @@ -73,9 +80,11 @@ function accuracy(model, ps, st, dataloader) return total_correct / total end -Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, - patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, +Comonicon.@main function main(; batchsize::Int=64, hidden_dim::Int=256, depth::Int=8, +# Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, + patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-4, clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, + # backend::String="gpu_if_available") backend::String="reactant") rng = StableRNG(seed) @@ -118,7 +127,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: model_compiled = model end - loss = CrossEntropyLoss(; logits=Val(true)) + loss_fn = CrossEntropyLoss(; logits=Val(true)) @printf "[Info] Training model\n" for epoch in 1:epochs @@ -128,7 +137,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader)) train_state = Optimisers.adjust!(train_state, lr) (_, _, _, train_state) = Training.single_train_step!( - adtype, loss, (x, y), train_state + adtype, loss_fn, (x, y), train_state ) end ttime = time() - stime