From a263edf99f114c24e403bb0c60fc2f386a1707e5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 20:41:06 +0530 Subject: [PATCH 01/10] feat: conditional VAE testcase --- docs/src/.vitepress/config.mts | 4 + docs/tutorials.jl | 1 + examples/ConditionalVAE/Project.toml | 28 +++ examples/ConditionalVAE/main.jl | 291 +++++++++++++++++++++++++++ examples/ConvMixer/main.jl | 5 +- 5 files changed, 326 insertions(+), 3 deletions(-) create mode 100644 examples/ConditionalVAE/Project.toml create mode 100644 examples/ConditionalVAE/main.jl diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index f785f6a31..35c573943 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -218,6 +218,10 @@ export default defineConfig({ text: "Training a PINN on 2D PDE", link: "/tutorials/intermediate/4_PINN2DPDE", }, + { + text: "Conditional VAE for MNIST using Reactant", + link: "/tutorials/intermediate/5_ConditionalVAE", + } ], }, { diff --git a/docs/tutorials.jl b/docs/tutorials.jl index d9dad6510..b9b9971d3 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -11,6 +11,7 @@ const INTERMEDIATE_TUTORIALS = [ "BayesianNN/main.jl" => "CPU", "HyperNet/main.jl" => "CUDA", "PINN2DPDE/main.jl" => "CUDA", + "ConditionalVAE/main.jl" => "CUDA", ] const ADVANCED_TUTORIALS = [ "GravitationalWaveForm/main.jl" => "CPU", diff --git a/examples/ConditionalVAE/Project.toml b/examples/ConditionalVAE/Project.toml new file mode 100644 index 000000000..d11cad9b6 --- /dev/null +++ b/examples/ConditionalVAE/Project.toml @@ -0,0 +1,28 @@ +[deps] +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31" +Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" + +[compat] +ConcreteStructs = "0.2.3" +DataAugmentation = "0.3.2" +Enzyme = "0.13.20" +ImageShow = "0.3.8" +Images = "0.26.1" +Lux = "1.4.1" +MLDatasets = "0.7.18" +MLUtils = "0.4.4" +OneHotArrays = "0.2.6" +Printf = "1.10" +Random = "1.10" +Reactant = "0.2.9" diff --git a/examples/ConditionalVAE/main.jl b/examples/ConditionalVAE/main.jl new file mode 100644 index 000000000..7f0a9f6f8 --- /dev/null +++ b/examples/ConditionalVAE/main.jl @@ -0,0 +1,291 @@ +# # [Conditional VAE for MNIST using Reactant](@id Conditional-VAE-Tutorial) + +# Convolutional variational autoencoder (CVAE) implementation in MLX using MNIST. This is +# based on the [CVAE implementation in MLX](https://github.com/ml-explore/mlx-examples/blob/main/cvae/). + +using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, MLUtils, DataAugmentation, + ConcreteStructs, OneHotArrays, ImageShow, Images, Printf, Optimisers + +const xdev = reactant_device() +const cdev = cpu_device() + +# ## Model Definition + +# First we will define the encoder.It maps the input to a normal distribution in latent +# space and sample a latent vector from that distribution. + +function cvae_encoder( + rng=Random.default_rng(); num_latent_dims::Int, + image_shape::Dims{3}, max_num_filters::Int +) + flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters + return @compact(; + embed=Chain( + Chain( + Conv((3, 3), image_shape[3] => max_num_filters ÷ 4; stride=2, pad=1), + BatchNorm(max_num_filters ÷ 4, leakyrelu) + ), + Chain( + Conv((3, 3), max_num_filters ÷ 4 => max_num_filters ÷ 2; stride=2, pad=1), + BatchNorm(max_num_filters ÷ 2, leakyrelu) + ), + Chain( + Conv((3, 3), max_num_filters ÷ 2 => max_num_filters; stride=2, pad=1), + BatchNorm(max_num_filters, leakyrelu) + ), + FlattenLayer() + ), + proj_mu=Dense(flattened_dim, num_latent_dims), + proj_log_var=Dense(flattened_dim, num_latent_dims), + rng) do x + y = embed(x) + + μ = proj_mu(y) + logσ² = proj_log_var(y) + σ² = exp.(logσ² .* eltype(logσ²)(0.5)) + + ## Generate a tensor of random values from a normal distribution + rng = Lux.replicate(rng) + ϵ = randn_like(rng, σ²) + + ## Reparametrization trick to brackpropagate through sampling + z = ϵ .* σ² .+ μ + + @return z, μ, logσ² + end +end + +# Similarly we define the decoder. + +function cvae_decoder(; num_latent_dims::Int, image_shape::Dims{3}, max_num_filters::Int) + flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters + return @compact(; + linear=Dense(num_latent_dims, flattened_dim), + upchain=Chain( + Chain( + Upsample(2), + Conv((3, 3), max_num_filters => max_num_filters ÷ 2; stride=1, pad=1), + BatchNorm(max_num_filters ÷ 2, leakyrelu) + ), + Chain( + Upsample(2), + Conv((3, 3), max_num_filters ÷ 2 => max_num_filters ÷ 4; stride=1, pad=1), + BatchNorm(max_num_filters ÷ 4, leakyrelu) + ), + Chain( + Upsample(2), + Conv((3, 3), max_num_filters ÷ 4 => image_shape[3]; stride=1, pad=1) + ) + ), + max_num_filters) do x + y = linear(x) + img = reshape(y, image_shape[1] ÷ 8, image_shape[2] ÷ 8, max_num_filters, :) + @return upchain(img) + end +end + +@concrete struct CVAE <: Lux.AbstractLuxContainerLayer{(:encoder, :decoder)} + encoder <: Lux.AbstractLuxLayer + decoder <: Lux.AbstractLuxLayer +end + +function CVAE(; num_latent_dims::Int, image_shape::Dims{3}, max_num_filters::Int) + decoder = cvae_decoder(; num_latent_dims, image_shape, max_num_filters) + encoder = cvae_encoder(; num_latent_dims, image_shape, max_num_filters) + return CVAE(encoder, decoder) +end + +function (cvae::CVAE)(x, ps, st) + (z, μ, logσ²), st_enc = cvae.encoder(x, ps.encoder, st.encoder) + x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder) + return (x_rec, μ, logσ²), (; encoder=st_enc, decoder=st_dec) +end + +function encode(cvae::CVAE, x, ps, st) + (z, _, _), st_enc = cvae.encoder(x, ps.encoder, st.encoder) + return z, (; encoder=st_enc, st.decoder) +end + +function decode(cvae::CVAE, z, ps, st) + x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder) + return x_rec, (; decoder=st_dec, st.encoder) +end + +# ## Loading MNIST + +@concrete struct TensorDataset + dataset + transform +end + +Base.length(ds::TensorDataset) = length(ds.dataset) + +function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange}) + img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3)) + return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img) +end + +function loadmnist(batchsize, image_size::Dims{2}) + ## Load MNIST: Only 1500 for demonstration purposes + N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing + train_dataset = MNIST(; split=:train) + test_dataset = MNIST(; split=:test) + if N !== nothing + train_dataset = train_dataset[1:N] + test_dataset = test_dataset[1:N] + end + + train_transform = ScaleKeepAspect(image_size) |> Maybe(FlipX{2}()) |> ImageToTensor() + test_transform = ScaleKeepAspect(image_size) |> ImageToTensor() + + trainset = TensorDataset(train_dataset, train_transform) + trainloader = DataLoader( + trainset; batchsize, shuffle=true, parallel=true, partial=false) + + testset = TensorDataset(test_dataset, test_transform) + testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true, partial=false) + + return trainloader, testloader +end + +# ## Helper Functions + +# Generate an Image Grid from a list of images + +function create_image_grid(imgs::AbstractArray, grid_rows::Int, grid_cols::Int) + total_images = grid_rows * grid_cols + imgs = map(eachslice(imgs[:, :, :, 1:total_images]; dims=4)) do img + cimg = size(img, 3) == 1 ? colorview(Gray, view(img, :, :, 1)) : colorview(RGB, img) + return cimg' + end + return create_image_grid(imgs, grid_rows, grid_cols) +end + +function create_image_grid(images::Vector, grid_rows::Int, grid_cols::Int) + ## Check if the number of images matches the grid + total_images = grid_rows * grid_cols + @assert length(images) == total_images + + ## Get the size of a single image (assuming all images are the same size) + img_height, img_width = size(images[1]) + + ## Create a blank grid canvas + grid_height = img_height * grid_rows + grid_width = img_width * grid_cols + grid_canvas = similar(images[1], grid_height, grid_width) + + ## Place each image in the correct position on the canvas + for idx in 1:total_images + row = div(idx - 1, grid_cols) + 1 + col = mod(idx - 1, grid_cols) + 1 + + start_row = (row - 1) * img_height + 1 + start_col = (col - 1) * img_width + 1 + + grid_canvas[start_row:(start_row + img_height - 1), start_col:(start_col + img_width - 1)] .= images[idx] + end + + return grid_canvas +end + +function loss_function(model, ps, st, X) + (y, μ, logσ²), st = model(X, ps, st) + reconstruction_loss = MSELoss(; agg=sum)(y, X) + kldiv_loss = -0.5f0 * sum(1 .+ logσ² .- μ .^ 2 .- exp.(logσ²)) + loss = reconstruction_loss + kldiv_loss + return loss, st, (; y, μ, logσ², reconstruction_loss, kldiv_loss) +end + +function generate_images(model, ps, st; num_samples::Int=128, num_latent_dims::Int) + z = randn(Float32, num_latent_dims, num_samples) |> get_device((ps, st)) + images, _ = decode(model, z, ps, Lux.testmode(st)) + return create_image_grid(images, 8, num_samples ÷ 8) +end + +# ## Training the Model + +Comonicon.@main function main(; batchsize=128, image_size=(64, 64), num_latent_dims=32, + max_num_filters=64, seed=0, epochs=100, weight_decay=1e-3, learning_rate=1e-3) + rng = Random.default_rng() + Random.seed!(rng, seed) + + cvae = CVAE(; num_latent_dims, image_shape=(image_size..., 1), max_num_filters) + ps, st = Lux.setup(rng, cvae) |> xdev + + train_dataloader, test_dataloader = loadmnist(batchsize, image_size) |> xdev + + opt = AdamW(; eta=learning_rate, lambda=weight_decay) + + train_state = Training.TrainState(cvae, ps, st, opt) + + for epoch in 1:epochs + loss_total = 0.0f0 + total_samples = 0 + + stime = time() + for (i, X) in enumerate(train_dataloader) + throughput_tic = time() + (_, loss, _, train_state) = Training.single_train_step!( + AutoEnzyme(), loss_function, X, train_state) + throughput_toc = time() + + loss_total += loss + total_samples += size(X, ndims(X)) + + if i % 10 == 0 || i == length(train_dataloader) + @printf "Epoch %d, Iter %d, Loss: %.4f, Throughput: %.6f it/s\n" epoch i loss ((throughput_toc - + throughput_tic)/size( + X, ndims(X))) + end + end + ttime = time() - stime + + train_loss = loss_total / total_samples + @printf "Epoch %d, Train Loss: %.4f, Time: %.4fs\n" epoch train_loss ttime + end +end + +# XXX: Move into a proper function + +rng = Random.default_rng() +Random.seed!(rng, 0) + +cvae = CVAE(; num_latent_dims=32, image_shape=(64, 64, 1), max_num_filters=64) +ps, st = Lux.setup(rng, cvae) |> xdev; + +train_dataloader, test_dataloader = loadmnist(128, (64, 64)) |> xdev + +opt = AdamW(; eta=1e-3, lambda=1e-3) + +epochs = 100 + +train_state = Training.TrainState(cvae, ps, st, opt) + +for epoch in 1:epochs + loss_total = 0.0f0 + total_samples = 0 + + stime = time() + for (i, X) in enumerate(train_dataloader) + throughput_tic = time() + (_, loss, _, train_state) = Training.single_train_step!( + AutoEnzyme(), loss_function, X, train_state) + throughput_toc = time() + + loss_total += loss + total_samples += size(X, ndims(X)) + + if i % 10 == 0 || i == length(train_dataloader) + @printf "Epoch %d, Iter %d, Loss: %.4f, Throughput: %.6f it/s\n" epoch i loss ((throughput_toc - + throughput_tic)/size( + X, ndims(X))) + end + end + ttime = time() - stime + + train_loss = loss_total / total_samples + @printf "Epoch %d, Train Loss: %.4f, Time: %.4fs\n" epoch train_loss ttime + + # XXX: Generate images conditionally + display(generate_images(cvae, ps, st; num_samples=128, num_latent_dims=32)) +end diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index 03ddc63a5..f51c2caf7 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -73,7 +73,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: rng = StableRNG(seed) gdev = gpu_device() - trainloader, testloader = get_dataloaders(batchsize) .|> gdev + trainloader, testloader = get_dataloaders(batchsize) |> gdev model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) ps, st = Lux.setup(rng, model) |> gdev @@ -81,8 +81,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: opt = AdamW(; eta=lr_max, lambda=weight_decay) clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) - train_state = Training.TrainState( - model, ps, st, AdamW(; eta=lr_max, lambda=weight_decay)) + train_state = Training.TrainState(model, ps, st, opt) lr_schedule = linear_interpolation( [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]) From cf44d9506e2116a5c8fa2a51efa2981d7ebecdd4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 10 Dec 2024 09:52:47 +0530 Subject: [PATCH 02/10] feat: overload Utils.vec for upcoming wrapper array changes --- examples/ConditionalVAE/Project.toml | 1 + examples/ConditionalVAE/main.jl | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/ConditionalVAE/Project.toml b/examples/ConditionalVAE/Project.toml index d11cad9b6..6f39e0d2d 100644 --- a/examples/ConditionalVAE/Project.toml +++ b/examples/ConditionalVAE/Project.toml @@ -1,4 +1,5 @@ [deps] +Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" diff --git a/examples/ConditionalVAE/main.jl b/examples/ConditionalVAE/main.jl index 7f0a9f6f8..af26e910f 100644 --- a/examples/ConditionalVAE/main.jl +++ b/examples/ConditionalVAE/main.jl @@ -4,7 +4,7 @@ # based on the [CVAE implementation in MLX](https://github.com/ml-explore/mlx-examples/blob/main/cvae/). using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, MLUtils, DataAugmentation, - ConcreteStructs, OneHotArrays, ImageShow, Images, Printf, Optimisers + ConcreteStructs, OneHotArrays, ImageShow, Images, Printf, Optimisers, Comonicon const xdev = reactant_device() const cdev = cpu_device() From c44cc380fe8bdb8961a6f32e10acf01cf1f9fbdd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 17:34:20 -0500 Subject: [PATCH 03/10] docs: update ConvMixer to support reactant --- examples/ConvMixer/Project.toml | 4 ++ examples/ConvMixer/README.md | 3 ++ examples/ConvMixer/main.jl | 69 ++++++++++++++++++++++++--------- 3 files changed, 58 insertions(+), 18 deletions(-) diff --git a/examples/ConvMixer/Project.toml b/examples/ConvMixer/Project.toml index 8ae780657..db7dbc64a 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/ConvMixer/Project.toml @@ -2,6 +2,7 @@ Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" @@ -15,6 +16,7 @@ PreferenceTools = "ba661fbb-e901-4445-b070-854aec6bfbc5" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -23,6 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Comonicon = "1.0.8" ConcreteStructs = "0.2.3" DataAugmentation = "0.3" +Enzyme = "0.13.14" ImageCore = "0.10.2" ImageShow = "0.3.8" Interpolations = "0.15.1" @@ -36,6 +39,7 @@ PreferenceTools = "0.1.2" Printf = "1.10" ProgressBars = "1.5.1" Random = "1.10" +Reactant = "0.2.5" StableRNGs = "1.0.2" Statistics = "1.10" Zygote = "0.6.70" diff --git a/examples/ConvMixer/README.md b/examples/ConvMixer/README.md index f072c1074..560b2b1d3 100644 --- a/examples/ConvMixer/README.md +++ b/examples/ConvMixer/README.md @@ -11,6 +11,9 @@ for new experiments on small datasets. You can get around **90.0%** accuracy in just **25 epochs** by running the script with the following arguments, which trains a ConvMixer-256/8 with kernel size 5 and patch size 2. +> [!NOTE] +> To train the model using Reactant.jl pass in `--backend=reactant` to the script. + ```bash julia --startup-file=no \ --project=. \ diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index f51c2caf7..d4da2ca94 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -1,6 +1,7 @@ using Comonicon, ConcreteStructs, DataAugmentation, ImageShow, Interpolations, Lux, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, ProgressBars, Random, StableRNGs, Statistics, Zygote +using Reactant, Enzyme CUDA.allowscalar(false) @@ -17,7 +18,7 @@ function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, Abstrac return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img), y end -function get_dataloaders(batchsize) +function get_dataloaders(batchsize; kwargs...) cifar10_mean = (0.4914, 0.4822, 0.4465) cifar10_std = (0.2471, 0.2435, 0.2616) @@ -29,10 +30,10 @@ function get_dataloaders(batchsize) test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) trainset = TensorDataset(CIFAR10(:train), train_transform) - trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true) + trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true, kwargs...) testset = TensorDataset(CIFAR10(:test), test_transform) - testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true) + testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true, kwargs...) return trainloader, testloader end @@ -43,10 +44,14 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size), BatchNorm(dim), [Chain( - SkipConnection( - Chain(Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim, - pad=SamePad()), BatchNorm(dim)), +), - Conv((1, 1), dim => dim, gelu), BatchNorm(dim)) + SkipConnection( + Chain( + Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim, pad=SamePad()), + BatchNorm(dim) + ), + + + ), + Conv((1, 1), dim => dim, gelu), BatchNorm(dim)) for _ in 1:depth]..., GlobalMeanPool(), FlattenLayer(), @@ -57,10 +62,11 @@ end function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 + cdev = cpu_device() st = Lux.testmode(st) for (x, y) in dataloader - target_class = onecold(y) - predicted_class = onecold(first(model(x, ps, st))) + target_class = onecold(cdev(y)) + predicted_class = onecold(cdev(first(model(x, ps, st)))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end @@ -69,14 +75,28 @@ end Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, - clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01) + clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, + backend::String="reactant") rng = StableRNG(seed) - gdev = gpu_device() - trainloader, testloader = get_dataloaders(batchsize) |> gdev + if backend == "gpu_if_available" + accelerator_device = gpu_device() + elseif backend == "gpu" + accelerator_device = gpu_device(; force=true) + elseif backend == "reactant" + accelerator_device = reactant_device(; force=true) + elseif backend == "cpu" + accelerator_device = cpu_device() + else + error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \ + `reactant`, and `cpu`.") + end + + kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : () + trainloader, testloader = get_dataloaders(batchsize; kwargs...) |> accelerator_device model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) - ps, st = Lux.setup(rng, model) |> gdev + ps, st = Lux.setup(rng, model) |> accelerator_device opt = AdamW(; eta=lr_max, lambda=weight_decay) clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) @@ -84,7 +104,17 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: train_state = Training.TrainState(model, ps, st, opt) lr_schedule = linear_interpolation( - [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]) + [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0] + ) + + adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote() + + if backend == "reactant" + x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device + model_compiled = @compile model(x_ra, ps, st) + else + model_compiled = model + end loss = CrossEntropyLoss(; logits=Val(true)) @@ -95,14 +125,17 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader)) train_state = Optimisers.adjust!(train_state, lr) (_, _, _, train_state) = Training.single_train_step!( - AutoZygote(), loss, (x, y), train_state) + adtype, loss, (x, y), train_state + ) end ttime = time() - stime train_acc = accuracy( - model, train_state.parameters, train_state.states, trainloader) * 100 - test_acc = accuracy(model, train_state.parameters, train_state.states, testloader) * - 100 + model_compiled, train_state.parameters, train_state.states, trainloader + ) * 100 + test_acc = accuracy( + model_compiled, train_state.parameters, train_state.states, testloader + ) * 100 @printf "Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: %.2f%%, \ Time: %.2f\n" epoch lr train_acc test_acc ttime From 4acc535186f8c978b757c8ee252bb58f02692e5c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Nov 2024 21:59:02 -0500 Subject: [PATCH 04/10] docs: keep the ConvMixer default backend as cuda.jl for now --- examples/ConvMixer/Project.toml | 4 ++-- examples/ConvMixer/README.md | 1 + examples/ConvMixer/main.jl | 13 +++++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/ConvMixer/Project.toml b/examples/ConvMixer/Project.toml index db7dbc64a..11e2f29d3 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/ConvMixer/Project.toml @@ -25,7 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Comonicon = "1.0.8" ConcreteStructs = "0.2.3" DataAugmentation = "0.3" -Enzyme = "0.13.14" +Enzyme = "0.13.16" ImageCore = "0.10.2" ImageShow = "0.3.8" Interpolations = "0.15.1" @@ -39,7 +39,7 @@ PreferenceTools = "0.1.2" Printf = "1.10" ProgressBars = "1.5.1" Random = "1.10" -Reactant = "0.2.5" +Reactant = "0.2.8" StableRNGs = "1.0.2" Statistics = "1.10" Zygote = "0.6.70" diff --git a/examples/ConvMixer/README.md b/examples/ConvMixer/README.md index 560b2b1d3..f61bf1c4e 100644 --- a/examples/ConvMixer/README.md +++ b/examples/ConvMixer/README.md @@ -69,6 +69,7 @@ Options --seed <42::Int> --epochs <25::Int> --lr-max <0.01::Float64> + --backend Flags --clip-norm diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index d4da2ca94..debce9c97 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -76,7 +76,7 @@ end Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, - backend::String="reactant") + backend::String="gpu_if_available") rng = StableRNG(seed) if backend == "gpu_if_available" @@ -111,13 +111,16 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: if backend == "reactant" x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device - model_compiled = @compile model(x_ra, ps, st) + @printf "[Info] Compiling model with Reactant.jl\n" + model_compiled = @compile model(x_ra, ps, Lux.testmode(st)) + @printf "[Info] Model compiled!\n" else model_compiled = model end loss = CrossEntropyLoss(; logits=Val(true)) + @printf "[Info] Training model\n" for epoch in 1:epochs stime = time() lr = 0 @@ -127,6 +130,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: (_, _, _, train_state) = Training.single_train_step!( adtype, loss, (x, y), train_state ) + @show i, time() - stime end ttime = time() - stime @@ -137,7 +141,8 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: model_compiled, train_state.parameters, train_state.states, testloader ) * 100 - @printf "Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: %.2f%%, \ - Time: %.2f\n" epoch lr train_acc test_acc ttime + @printf "[Train] Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: \ + %.2f%%, Time: %.2f\n" epoch lr train_acc test_acc ttime end + @printf "[Info] Finished training\n" end From a84b132b9f82e04b686bee5018b52cf0779077aa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 20 Dec 2024 11:50:50 +0530 Subject: [PATCH 05/10] fix: remove unnecessary patches for now --- examples/ConditionalVAE/Project.toml | 29 --- examples/ConditionalVAE/main.jl | 291 --------------------------- examples/ConvMixer/main.jl | 5 +- ext/LuxReactantExt/training.jl | 11 + 4 files changed, 13 insertions(+), 323 deletions(-) delete mode 100644 examples/ConditionalVAE/Project.toml delete mode 100644 examples/ConditionalVAE/main.jl diff --git a/examples/ConditionalVAE/Project.toml b/examples/ConditionalVAE/Project.toml deleted file mode 100644 index 6f39e0d2d..000000000 --- a/examples/ConditionalVAE/Project.toml +++ /dev/null @@ -1,29 +0,0 @@ -[deps] -Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" -ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31" -Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" -OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" - -[compat] -ConcreteStructs = "0.2.3" -DataAugmentation = "0.3.2" -Enzyme = "0.13.20" -ImageShow = "0.3.8" -Images = "0.26.1" -Lux = "1.4.1" -MLDatasets = "0.7.18" -MLUtils = "0.4.4" -OneHotArrays = "0.2.6" -Printf = "1.10" -Random = "1.10" -Reactant = "0.2.9" diff --git a/examples/ConditionalVAE/main.jl b/examples/ConditionalVAE/main.jl deleted file mode 100644 index af26e910f..000000000 --- a/examples/ConditionalVAE/main.jl +++ /dev/null @@ -1,291 +0,0 @@ -# # [Conditional VAE for MNIST using Reactant](@id Conditional-VAE-Tutorial) - -# Convolutional variational autoencoder (CVAE) implementation in MLX using MNIST. This is -# based on the [CVAE implementation in MLX](https://github.com/ml-explore/mlx-examples/blob/main/cvae/). - -using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, MLUtils, DataAugmentation, - ConcreteStructs, OneHotArrays, ImageShow, Images, Printf, Optimisers, Comonicon - -const xdev = reactant_device() -const cdev = cpu_device() - -# ## Model Definition - -# First we will define the encoder.It maps the input to a normal distribution in latent -# space and sample a latent vector from that distribution. - -function cvae_encoder( - rng=Random.default_rng(); num_latent_dims::Int, - image_shape::Dims{3}, max_num_filters::Int -) - flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters - return @compact(; - embed=Chain( - Chain( - Conv((3, 3), image_shape[3] => max_num_filters ÷ 4; stride=2, pad=1), - BatchNorm(max_num_filters ÷ 4, leakyrelu) - ), - Chain( - Conv((3, 3), max_num_filters ÷ 4 => max_num_filters ÷ 2; stride=2, pad=1), - BatchNorm(max_num_filters ÷ 2, leakyrelu) - ), - Chain( - Conv((3, 3), max_num_filters ÷ 2 => max_num_filters; stride=2, pad=1), - BatchNorm(max_num_filters, leakyrelu) - ), - FlattenLayer() - ), - proj_mu=Dense(flattened_dim, num_latent_dims), - proj_log_var=Dense(flattened_dim, num_latent_dims), - rng) do x - y = embed(x) - - μ = proj_mu(y) - logσ² = proj_log_var(y) - σ² = exp.(logσ² .* eltype(logσ²)(0.5)) - - ## Generate a tensor of random values from a normal distribution - rng = Lux.replicate(rng) - ϵ = randn_like(rng, σ²) - - ## Reparametrization trick to brackpropagate through sampling - z = ϵ .* σ² .+ μ - - @return z, μ, logσ² - end -end - -# Similarly we define the decoder. - -function cvae_decoder(; num_latent_dims::Int, image_shape::Dims{3}, max_num_filters::Int) - flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters - return @compact(; - linear=Dense(num_latent_dims, flattened_dim), - upchain=Chain( - Chain( - Upsample(2), - Conv((3, 3), max_num_filters => max_num_filters ÷ 2; stride=1, pad=1), - BatchNorm(max_num_filters ÷ 2, leakyrelu) - ), - Chain( - Upsample(2), - Conv((3, 3), max_num_filters ÷ 2 => max_num_filters ÷ 4; stride=1, pad=1), - BatchNorm(max_num_filters ÷ 4, leakyrelu) - ), - Chain( - Upsample(2), - Conv((3, 3), max_num_filters ÷ 4 => image_shape[3]; stride=1, pad=1) - ) - ), - max_num_filters) do x - y = linear(x) - img = reshape(y, image_shape[1] ÷ 8, image_shape[2] ÷ 8, max_num_filters, :) - @return upchain(img) - end -end - -@concrete struct CVAE <: Lux.AbstractLuxContainerLayer{(:encoder, :decoder)} - encoder <: Lux.AbstractLuxLayer - decoder <: Lux.AbstractLuxLayer -end - -function CVAE(; num_latent_dims::Int, image_shape::Dims{3}, max_num_filters::Int) - decoder = cvae_decoder(; num_latent_dims, image_shape, max_num_filters) - encoder = cvae_encoder(; num_latent_dims, image_shape, max_num_filters) - return CVAE(encoder, decoder) -end - -function (cvae::CVAE)(x, ps, st) - (z, μ, logσ²), st_enc = cvae.encoder(x, ps.encoder, st.encoder) - x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder) - return (x_rec, μ, logσ²), (; encoder=st_enc, decoder=st_dec) -end - -function encode(cvae::CVAE, x, ps, st) - (z, _, _), st_enc = cvae.encoder(x, ps.encoder, st.encoder) - return z, (; encoder=st_enc, st.decoder) -end - -function decode(cvae::CVAE, z, ps, st) - x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder) - return x_rec, (; decoder=st_dec, st.encoder) -end - -# ## Loading MNIST - -@concrete struct TensorDataset - dataset - transform -end - -Base.length(ds::TensorDataset) = length(ds.dataset) - -function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange}) - img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3)) - return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img) -end - -function loadmnist(batchsize, image_size::Dims{2}) - ## Load MNIST: Only 1500 for demonstration purposes - N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing - train_dataset = MNIST(; split=:train) - test_dataset = MNIST(; split=:test) - if N !== nothing - train_dataset = train_dataset[1:N] - test_dataset = test_dataset[1:N] - end - - train_transform = ScaleKeepAspect(image_size) |> Maybe(FlipX{2}()) |> ImageToTensor() - test_transform = ScaleKeepAspect(image_size) |> ImageToTensor() - - trainset = TensorDataset(train_dataset, train_transform) - trainloader = DataLoader( - trainset; batchsize, shuffle=true, parallel=true, partial=false) - - testset = TensorDataset(test_dataset, test_transform) - testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true, partial=false) - - return trainloader, testloader -end - -# ## Helper Functions - -# Generate an Image Grid from a list of images - -function create_image_grid(imgs::AbstractArray, grid_rows::Int, grid_cols::Int) - total_images = grid_rows * grid_cols - imgs = map(eachslice(imgs[:, :, :, 1:total_images]; dims=4)) do img - cimg = size(img, 3) == 1 ? colorview(Gray, view(img, :, :, 1)) : colorview(RGB, img) - return cimg' - end - return create_image_grid(imgs, grid_rows, grid_cols) -end - -function create_image_grid(images::Vector, grid_rows::Int, grid_cols::Int) - ## Check if the number of images matches the grid - total_images = grid_rows * grid_cols - @assert length(images) == total_images - - ## Get the size of a single image (assuming all images are the same size) - img_height, img_width = size(images[1]) - - ## Create a blank grid canvas - grid_height = img_height * grid_rows - grid_width = img_width * grid_cols - grid_canvas = similar(images[1], grid_height, grid_width) - - ## Place each image in the correct position on the canvas - for idx in 1:total_images - row = div(idx - 1, grid_cols) + 1 - col = mod(idx - 1, grid_cols) + 1 - - start_row = (row - 1) * img_height + 1 - start_col = (col - 1) * img_width + 1 - - grid_canvas[start_row:(start_row + img_height - 1), start_col:(start_col + img_width - 1)] .= images[idx] - end - - return grid_canvas -end - -function loss_function(model, ps, st, X) - (y, μ, logσ²), st = model(X, ps, st) - reconstruction_loss = MSELoss(; agg=sum)(y, X) - kldiv_loss = -0.5f0 * sum(1 .+ logσ² .- μ .^ 2 .- exp.(logσ²)) - loss = reconstruction_loss + kldiv_loss - return loss, st, (; y, μ, logσ², reconstruction_loss, kldiv_loss) -end - -function generate_images(model, ps, st; num_samples::Int=128, num_latent_dims::Int) - z = randn(Float32, num_latent_dims, num_samples) |> get_device((ps, st)) - images, _ = decode(model, z, ps, Lux.testmode(st)) - return create_image_grid(images, 8, num_samples ÷ 8) -end - -# ## Training the Model - -Comonicon.@main function main(; batchsize=128, image_size=(64, 64), num_latent_dims=32, - max_num_filters=64, seed=0, epochs=100, weight_decay=1e-3, learning_rate=1e-3) - rng = Random.default_rng() - Random.seed!(rng, seed) - - cvae = CVAE(; num_latent_dims, image_shape=(image_size..., 1), max_num_filters) - ps, st = Lux.setup(rng, cvae) |> xdev - - train_dataloader, test_dataloader = loadmnist(batchsize, image_size) |> xdev - - opt = AdamW(; eta=learning_rate, lambda=weight_decay) - - train_state = Training.TrainState(cvae, ps, st, opt) - - for epoch in 1:epochs - loss_total = 0.0f0 - total_samples = 0 - - stime = time() - for (i, X) in enumerate(train_dataloader) - throughput_tic = time() - (_, loss, _, train_state) = Training.single_train_step!( - AutoEnzyme(), loss_function, X, train_state) - throughput_toc = time() - - loss_total += loss - total_samples += size(X, ndims(X)) - - if i % 10 == 0 || i == length(train_dataloader) - @printf "Epoch %d, Iter %d, Loss: %.4f, Throughput: %.6f it/s\n" epoch i loss ((throughput_toc - - throughput_tic)/size( - X, ndims(X))) - end - end - ttime = time() - stime - - train_loss = loss_total / total_samples - @printf "Epoch %d, Train Loss: %.4f, Time: %.4fs\n" epoch train_loss ttime - end -end - -# XXX: Move into a proper function - -rng = Random.default_rng() -Random.seed!(rng, 0) - -cvae = CVAE(; num_latent_dims=32, image_shape=(64, 64, 1), max_num_filters=64) -ps, st = Lux.setup(rng, cvae) |> xdev; - -train_dataloader, test_dataloader = loadmnist(128, (64, 64)) |> xdev - -opt = AdamW(; eta=1e-3, lambda=1e-3) - -epochs = 100 - -train_state = Training.TrainState(cvae, ps, st, opt) - -for epoch in 1:epochs - loss_total = 0.0f0 - total_samples = 0 - - stime = time() - for (i, X) in enumerate(train_dataloader) - throughput_tic = time() - (_, loss, _, train_state) = Training.single_train_step!( - AutoEnzyme(), loss_function, X, train_state) - throughput_toc = time() - - loss_total += loss - total_samples += size(X, ndims(X)) - - if i % 10 == 0 || i == length(train_dataloader) - @printf "Epoch %d, Iter %d, Loss: %.4f, Throughput: %.6f it/s\n" epoch i loss ((throughput_toc - - throughput_tic)/size( - X, ndims(X))) - end - end - ttime = time() - stime - - train_loss = loss_total / total_samples - @printf "Epoch %d, Train Loss: %.4f, Time: %.4fs\n" epoch train_loss ttime - - # XXX: Generate images conditionally - display(generate_images(cvae, ps, st; num_samples=128, num_latent_dims=32)) -end diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index debce9c97..08e5553e7 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -73,10 +73,10 @@ function accuracy(model, ps, st, dataloader) return total_correct / total end -Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, +Comonicon.@main function main(; batchsize::Int=64, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, - backend::String="gpu_if_available") + backend::String="reactant") rng = StableRNG(seed) if backend == "gpu_if_available" @@ -130,7 +130,6 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: (_, _, _, train_state) = Training.single_train_step!( adtype, loss, (x, y), train_state ) - @show i, time() - stime end ttime = time() - stime diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index c35d5cb05..5311098ab 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -76,14 +76,20 @@ for inplace in ("!", "") # used in the compiled function? We can re-trigger the compilation with a warning @eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState) where {F} + @show 1213 + compiled_grad_and_step_function = @compile $(internal_fn)( objective_function, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) + @show Lux.Functors.fmap(typeof, ts.states) + grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function( objective_function, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) + @show Lux.Functors.fmap(typeof, st) + cache = TrainingBackendCache( backend, False(), nothing, (; compiled_grad_and_step_function)) @set! ts.cache = cache @@ -93,11 +99,16 @@ for inplace in ("!", "") @set! ts.optimizer_state = opt_state @set! ts.step = ts.step + 1 + @show Lux.Functors.fmap(typeof, ts.states) + return grads, loss, stats, ts end @eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data, ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F} + @show Lux.Functors.fmap(typeof, ts.parameters) + @show Lux.Functors.fmap(typeof, ts.states) + grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function( obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) From c7dfc28916379fbcc716853e13a053c805a6d328 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 20 Dec 2024 12:01:35 +0530 Subject: [PATCH 06/10] fix: update reactant version --- docs/src/.vitepress/config.mts | 4 ---- docs/tutorials.jl | 1 - examples/ConvMixer/Project.toml | 2 +- examples/ConvMixer/main.jl | 2 +- ext/LuxReactantExt/training.jl | 11 ----------- 5 files changed, 2 insertions(+), 18 deletions(-) diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 35c573943..f785f6a31 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -218,10 +218,6 @@ export default defineConfig({ text: "Training a PINN on 2D PDE", link: "/tutorials/intermediate/4_PINN2DPDE", }, - { - text: "Conditional VAE for MNIST using Reactant", - link: "/tutorials/intermediate/5_ConditionalVAE", - } ], }, { diff --git a/docs/tutorials.jl b/docs/tutorials.jl index b9b9971d3..d9dad6510 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -11,7 +11,6 @@ const INTERMEDIATE_TUTORIALS = [ "BayesianNN/main.jl" => "CPU", "HyperNet/main.jl" => "CUDA", "PINN2DPDE/main.jl" => "CUDA", - "ConditionalVAE/main.jl" => "CUDA", ] const ADVANCED_TUTORIALS = [ "GravitationalWaveForm/main.jl" => "CPU", diff --git a/examples/ConvMixer/Project.toml b/examples/ConvMixer/Project.toml index 11e2f29d3..04fec524d 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/ConvMixer/Project.toml @@ -39,7 +39,7 @@ PreferenceTools = "0.1.2" Printf = "1.10" ProgressBars = "1.5.1" Random = "1.10" -Reactant = "0.2.8" +Reactant = "0.2.11" StableRNGs = "1.0.2" Statistics = "1.10" Zygote = "0.6.70" diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index 08e5553e7..ac36b6f57 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -73,7 +73,7 @@ function accuracy(model, ps, st, dataloader) return total_correct / total end -Comonicon.@main function main(; batchsize::Int=64, hidden_dim::Int=256, depth::Int=8, +Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, backend::String="reactant") diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index 5311098ab..c35d5cb05 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -76,20 +76,14 @@ for inplace in ("!", "") # used in the compiled function? We can re-trigger the compilation with a warning @eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState) where {F} - @show 1213 - compiled_grad_and_step_function = @compile $(internal_fn)( objective_function, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) - @show Lux.Functors.fmap(typeof, ts.states) - grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function( objective_function, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) - @show Lux.Functors.fmap(typeof, st) - cache = TrainingBackendCache( backend, False(), nothing, (; compiled_grad_and_step_function)) @set! ts.cache = cache @@ -99,16 +93,11 @@ for inplace in ("!", "") @set! ts.optimizer_state = opt_state @set! ts.step = ts.step + 1 - @show Lux.Functors.fmap(typeof, ts.states) - return grads, loss, stats, ts end @eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data, ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F} - @show Lux.Functors.fmap(typeof, ts.parameters) - @show Lux.Functors.fmap(typeof, ts.states) - grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function( obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) From af3d01f39624319f152306811db507330ded4b57 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 20 Dec 2024 13:21:38 +0530 Subject: [PATCH 07/10] feat: pipeline working :tada: --- examples/ConvMixer/main.jl | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index ac36b6f57..d46154fd1 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -30,10 +30,10 @@ function get_dataloaders(batchsize; kwargs...) test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) trainset = TensorDataset(CIFAR10(:train), train_transform) - trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true, kwargs...) + trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...) testset = TensorDataset(CIFAR10(:test), test_transform) - testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true, kwargs...) + testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...) return trainloader, testloader end @@ -43,16 +43,23 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) return Chain( Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size), BatchNorm(dim), - [Chain( - SkipConnection( - Chain( - Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim, pad=SamePad()), - BatchNorm(dim) + [ + Chain( + SkipConnection( + Chain( + Conv( + (kernel_size, kernel_size), dim => dim, gelu; + groups=dim, pad=SamePad() + ), + BatchNorm(dim) + ), + + ), - + - ), - Conv((1, 1), dim => dim, gelu), BatchNorm(dim)) - for _ in 1:depth]..., + Conv((1, 1), dim => dim, gelu), + BatchNorm(dim) + ) + for _ in 1:depth + ]..., GlobalMeanPool(), FlattenLayer(), Dense(dim => 10) @@ -74,9 +81,9 @@ function accuracy(model, ps, st, dataloader) end Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, - patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, + patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-4, clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, - backend::String="reactant") + backend::String="gpu_if_available") rng = StableRNG(seed) if backend == "gpu_if_available" @@ -118,7 +125,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: model_compiled = model end - loss = CrossEntropyLoss(; logits=Val(true)) + loss_fn = CrossEntropyLoss(; logits=Val(true)) @printf "[Info] Training model\n" for epoch in 1:epochs @@ -128,7 +135,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader)) train_state = Optimisers.adjust!(train_state, lr) (_, _, _, train_state) = Training.single_train_step!( - adtype, loss, (x, y), train_state + adtype, loss_fn, (x, y), train_state ) end ttime = time() - stime From 6526cbb1b9048e551c02507f742974562da32bd8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 22 Dec 2024 12:59:30 +0530 Subject: [PATCH 08/10] fix: more bug fixes for reactant --- examples/ConvMixer/Project.toml | 2 -- examples/ConvMixer/README.md | 7 ++++--- examples/ConvMixer/main.jl | 23 +++++++++++++---------- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/examples/ConvMixer/Project.toml b/examples/ConvMixer/Project.toml index 04fec524d..125c6612d 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/ConvMixer/Project.toml @@ -17,7 +17,6 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -40,6 +39,5 @@ Printf = "1.10" ProgressBars = "1.5.1" Random = "1.10" Reactant = "0.2.11" -StableRNGs = "1.0.2" Statistics = "1.10" Zygote = "0.6.70" diff --git a/examples/ConvMixer/README.md b/examples/ConvMixer/README.md index f61bf1c4e..54d7d1f94 100644 --- a/examples/ConvMixer/README.md +++ b/examples/ConvMixer/README.md @@ -20,10 +20,11 @@ julia --startup-file=no \ --threads=auto \ main.jl \ --lr-max=0.05 \ - --weight-decay=0.0001 + --weight-decay=0.0001 \ + --backend=reactant ``` -Here's an example of the output of the above command (on a V100 32GB GPU): +Here's an example output of the above command (on a RTX 4050 6GB Laptop GPU): ``` Epoch 1: Learning Rate 5.05e-03, Train Acc: 56.91%, Test Acc: 56.49%, Time: 129.84 @@ -69,7 +70,7 @@ Options --seed <42::Int> --epochs <25::Int> --lr-max <0.01::Float64> - --backend + --backend Flags --clip-norm diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index d46154fd1..9d1c6cb5e 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -1,6 +1,6 @@ using Comonicon, ConcreteStructs, DataAugmentation, ImageShow, Interpolations, Lux, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, ProgressBars, Random, - StableRNGs, Statistics, Zygote + Statistics, Zygote using Reactant, Enzyme CUDA.allowscalar(false) @@ -70,7 +70,6 @@ end function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 cdev = cpu_device() - st = Lux.testmode(st) for (x, y) in dataloader target_class = onecold(cdev(y)) predicted_class = onecold(cdev(first(model(x, ps, st)))) @@ -81,10 +80,11 @@ function accuracy(model, ps, st, dataloader) end Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, - patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-4, - clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, + patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.005, + clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05, backend::String="gpu_if_available") - rng = StableRNG(seed) + rng = Random.default_rng() + Random.seed!(rng, seed) if backend == "gpu_if_available" accelerator_device = gpu_device() @@ -119,7 +119,8 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: if backend == "reactant" x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device @printf "[Info] Compiling model with Reactant.jl\n" - model_compiled = @compile model(x_ra, ps, Lux.testmode(st)) + st_test = Lux.testmode(st) + model_compiled = @compile model(x_ra, ps, st_test) @printf "[Info] Model compiled!\n" else model_compiled = model @@ -141,14 +142,16 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: ttime = time() - stime train_acc = accuracy( - model_compiled, train_state.parameters, train_state.states, trainloader + model_compiled, train_state.parameters, + Lux.testmode(train_state.states), trainloader ) * 100 test_acc = accuracy( - model_compiled, train_state.parameters, train_state.states, testloader + model_compiled, train_state.parameters, + Lux.testmode(train_state.states), testloader ) * 100 - @printf "[Train] Epoch %2d: Learning Rate %.2e, Train Acc: %.2f%%, Test Acc: \ - %.2f%%, Time: %.2f\n" epoch lr train_acc test_acc ttime + @printf "[Train] Epoch %2d: Learning Rate %.6f, Train Acc: %.4f%%, Test Acc: \ + %.4f%%, Time: %.2f\n" epoch lr train_acc test_acc ttime end @printf "[Info] Finished training\n" end From ceda8c0bbd2eb3233c4db80d08c0bb1cc1d322c9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 22 Dec 2024 21:31:04 +0530 Subject: [PATCH 09/10] refactor: centralize the CIFAR10 examples --- docs/src/.vitepress/config.mts | 4 +- docs/src/tutorials/index.md | 6 +- examples/{ConvMixer => CIFAR10}/Project.toml | 5 +- examples/{ConvMixer => CIFAR10}/README.md | 65 +++++------ .../{ConvMixer/main.jl => CIFAR10/common.jl} | 104 ++++++++---------- examples/CIFAR10/conv_mixer.jl | 50 +++++++++ examples/CIFAR10/mlp_mixer.jl | 6 + examples/CIFAR10/simple_cnn.jl | 36 ++++++ 8 files changed, 174 insertions(+), 102 deletions(-) rename examples/{ConvMixer => CIFAR10}/Project.toml (89%) rename examples/{ConvMixer => CIFAR10}/README.md (79%) rename examples/{ConvMixer/main.jl => CIFAR10/common.jl} (56%) create mode 100644 examples/CIFAR10/conv_mixer.jl create mode 100644 examples/CIFAR10/mlp_mixer.jl create mode 100644 examples/CIFAR10/simple_cnn.jl diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index f785f6a31..bdd870e08 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -243,8 +243,8 @@ export default defineConfig({ link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/DDIM", }, { - text: "ConvMixer on CIFAR-10", - link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/ConvMixer", + text: "Different Vision Models on CIFAR-10", + link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/CIFAR10", }, ], }, diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index 75c45f7b9..6b01da2b7 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -97,10 +97,10 @@ const large_models = [ desc: "Train a Diffusion Model to generate images from Gaussian noises." }, { - href: "https://github.com/LuxDL/Lux.jl/tree/main/examples/ConvMixer", + href: "https://github.com/LuxDL/Lux.jl/tree/main/examples/CIFAR10", src: "https://datasets.activeloop.ai/wp-content/uploads/2022/09/CIFAR-10-dataset-Activeloop-Platform-visualization-image-1.webp", - caption: "ConvMixer on CIFAR-10", - desc: "Train ConvMixer on CIFAR-10 to 90% accuracy within 10 minutes." + caption: "Vision Models on CIFAR-10", + desc: "Train differnt vision models on CIFAR-10 to 90% accuracy within 10 minutes." } ]; diff --git a/examples/ConvMixer/Project.toml b/examples/CIFAR10/Project.toml similarity index 89% rename from examples/ConvMixer/Project.toml rename to examples/CIFAR10/Project.toml index 125c6612d..540774f4e 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/CIFAR10/Project.toml @@ -12,9 +12,8 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -PreferenceTools = "ba661fbb-e901-4445-b070-854aec6bfbc5" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" +ProgressTables = "e0b4b9f6-8cc7-451e-9c86-94c5316e9f73" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -34,9 +33,7 @@ MLDatasets = "0.7.14" MLUtils = "0.4.4" OneHotArrays = "0.2.5" Optimisers = "0.4.1" -PreferenceTools = "0.1.2" Printf = "1.10" -ProgressBars = "1.5.1" Random = "1.10" Reactant = "0.2.11" Statistics = "1.10" diff --git a/examples/ConvMixer/README.md b/examples/CIFAR10/README.md similarity index 79% rename from examples/ConvMixer/README.md rename to examples/CIFAR10/README.md index 54d7d1f94..6e1841663 100644 --- a/examples/ConvMixer/README.md +++ b/examples/CIFAR10/README.md @@ -1,6 +1,35 @@ -# Train ConvMixer on CIFAR-10 +# Train Vision Models on CIFAR-10 - ✈️ 🚗 🐦 🐈 🦌 🐕 🐸 🐎 🚢 🚚 +✈️ 🚗 🐦 🐈 🦌 🐕 🐸 🐎 🚢 🚚 + +We have the following scripts to train vision models on CIFAR-10: + +1. `simple_cnn.jl`: Simple CNN model with a sequence of convolutional layers. +2. `mlp_mixer.jl`: MLP-Mixer model. +3. `conv_mixer.jl`: ConvMixer model. + +To get the options for each script, run the script with the `--help` flag. + +> [!NOTE] +> To train the model using Reactant.jl pass in `--backend=reactant` to the script. This is +> the recommended approch to train the models present in this directory. + +## Simple CNN + +```bash +julia --startup-file=no \ + --project=. \ + --threads=auto \ + simple_cnn.jl \ + --backend=reactant +``` + +On a RTX 4050 6GB Laptop GPU the training takes approximately 3 mins and the final training +and test accuracies are 97% and 65%, respectively. + +## MLP-Mixer + +## ConvMixer > [!NOTE] > This code has been adapted from https://github.com/locuslab/convmixer-cifar10 @@ -11,14 +40,11 @@ for new experiments on small datasets. You can get around **90.0%** accuracy in just **25 epochs** by running the script with the following arguments, which trains a ConvMixer-256/8 with kernel size 5 and patch size 2. -> [!NOTE] -> To train the model using Reactant.jl pass in `--backend=reactant` to the script. - ```bash julia --startup-file=no \ --project=. \ --threads=auto \ - main.jl \ + conv_mixer.jl \ --lr-max=0.05 \ --weight-decay=0.0001 \ --backend=reactant @@ -54,32 +80,7 @@ Epoch 24: Learning Rate 8.29e-04, Train Acc: 99.99%, Test Acc: 90.79%, Time: 21. Epoch 25: Learning Rate 4.12e-04, Train Acc: 100.00%, Test Acc: 90.83%, Time: 21.32 ``` -## Usage - -```bash - main [options] [flags] - -Options - - --batchsize <512::Int> - --hidden-dim <256::Int> - --depth <8::Int> - --patch-size <2::Int> - --kernel-size <5::Int> - --weight-decay <0.01::Float64> - --seed <42::Int> - --epochs <25::Int> - --lr-max <0.01::Float64> - --backend - -Flags - --clip-norm - - -h, --help Print this help message. - --version Print version. -``` - -## Notes +### Notes 1. To match the results from the original repo, we need more augmentation strategies, that are currently not implemented in DataAugmentation.jl. diff --git a/examples/ConvMixer/main.jl b/examples/CIFAR10/common.jl similarity index 56% rename from examples/ConvMixer/main.jl rename to examples/CIFAR10/common.jl index 9d1c6cb5e..84647e8ae 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/CIFAR10/common.jl @@ -1,9 +1,6 @@ -using Comonicon, ConcreteStructs, DataAugmentation, ImageShow, Interpolations, Lux, LuxCUDA, - MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, ProgressBars, Random, - Statistics, Zygote -using Reactant, Enzyme - -CUDA.allowscalar(false) +using ConcreteStructs, DataAugmentation, ImageShow, Lux, MLDatasets, MLUtils, OneHotArrays, + Printf, ProgressTables, Random +using LuxCUDA, Reactant @concrete struct TensorDataset dataset @@ -18,7 +15,7 @@ function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, Abstrac return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img), y end -function get_dataloaders(batchsize; kwargs...) +function get_cifar10_dataloaders(batchsize; kwargs...) cifar10_mean = (0.4914, 0.4822, 0.4465) cifar10_std = (0.2471, 0.2435, 0.2616) @@ -38,35 +35,6 @@ function get_dataloaders(batchsize; kwargs...) return trainloader, testloader end -function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) - #! format: off - return Chain( - Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size), - BatchNorm(dim), - [ - Chain( - SkipConnection( - Chain( - Conv( - (kernel_size, kernel_size), dim => dim, gelu; - groups=dim, pad=SamePad() - ), - BatchNorm(dim) - ), - + - ), - Conv((1, 1), dim => dim, gelu), - BatchNorm(dim) - ) - for _ in 1:depth - ]..., - GlobalMeanPool(), - FlattenLayer(), - Dense(dim => 10) - ) - #! format: on -end - function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 cdev = cpu_device() @@ -79,41 +47,37 @@ function accuracy(model, ps, st, dataloader) return total_correct / total end -Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, - patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.005, - clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05, - backend::String="gpu_if_available") - rng = Random.default_rng() - Random.seed!(rng, seed) - +function get_accelerator_device(backend::String) if backend == "gpu_if_available" - accelerator_device = gpu_device() + return gpu_device() elseif backend == "gpu" - accelerator_device = gpu_device(; force=true) + return gpu_device(; force=true) elseif backend == "reactant" - accelerator_device = reactant_device(; force=true) + return reactant_device(; force=true) elseif backend == "cpu" - accelerator_device = cpu_device() + return cpu_device() else error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \ `reactant`, and `cpu`.") end +end +function train_model( + model, opt, scheduler=nothing; + backend::String, batchsize::Int=512, seed::Int=1234, epochs::Int=25 +) + rng = Random.default_rng() + Random.seed!(rng, seed) + + accelerator_device = get_accelerator_device(backend) kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : () - trainloader, testloader = get_dataloaders(batchsize; kwargs...) |> accelerator_device + trainloader, testloader = get_cifar10_dataloaders(batchsize; kwargs...) |> + accelerator_device - model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) ps, st = Lux.setup(rng, model) |> accelerator_device - opt = AdamW(; eta=lr_max, lambda=weight_decay) - clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) - train_state = Training.TrainState(model, ps, st, opt) - lr_schedule = linear_interpolation( - [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0] - ) - adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote() if backend == "reactant" @@ -128,16 +92,32 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: loss_fn = CrossEntropyLoss(; logits=Val(true)) + pt = ProgressTable(; + header=[ + "Epoch", "Learning Rate", "Train Accuracy (%)", "Test Accuracy (%)", "Time (s)" + ], + widths=[24, 24, 24, 24, 24], + format=["%3d", "%.6f", "%.6f", "%.6f", "%.6f"], + color=[:normal, :normal, :blue, :blue, :normal], + border=true, + alignment=[:center, :center, :center, :center, :center] + ) + @printf "[Info] Training model\n" + initialize(pt) + for epoch in 1:epochs stime = time() lr = 0 for (i, (x, y)) in enumerate(trainloader) - lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader)) - train_state = Optimisers.adjust!(train_state, lr) - (_, _, _, train_state) = Training.single_train_step!( + if scheduler !== nothing + lr = scheduler((epoch - 1) + (i + 1) / length(trainloader)) + train_state = Optimisers.adjust!(train_state, lr) + end + (_, loss, _, train_state) = Training.single_train_step!( adtype, loss_fn, (x, y), train_state ) + isnan(loss) && error("NaN loss encountered!") end ttime = time() - stime @@ -150,8 +130,10 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth:: Lux.testmode(train_state.states), testloader ) * 100 - @printf "[Train] Epoch %2d: Learning Rate %.6f, Train Acc: %.4f%%, Test Acc: \ - %.4f%%, Time: %.2f\n" epoch lr train_acc test_acc ttime + scheduler === nothing && (lr = NaN32) + next(pt, [epoch, lr, train_acc, test_acc, ttime]) end + + finalize(pt) @printf "[Info] Finished training\n" end diff --git a/examples/CIFAR10/conv_mixer.jl b/examples/CIFAR10/conv_mixer.jl new file mode 100644 index 000000000..55f0b20da --- /dev/null +++ b/examples/CIFAR10/conv_mixer.jl @@ -0,0 +1,50 @@ +using Comonicon, Interpolations, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme + +@isdefined(includet) ? includet("common.jl") : include("common.jl") + +CUDA.allowscalar(false) + +function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) + #! format: off + return Chain( + Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size), + BatchNorm(dim), + [ + Chain( + SkipConnection( + Chain( + Conv( + (kernel_size, kernel_size), dim => dim, gelu; + groups=dim, pad=SamePad() + ), + BatchNorm(dim) + ), + + + ), + Conv((1, 1), dim => dim, gelu), + BatchNorm(dim) + ) + for _ in 1:depth + ]..., + GlobalMeanPool(), + FlattenLayer(), + Dense(dim => 10) + ) + #! format: on +end + +Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, + patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.0001, + clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05, + backend::String="reactant") + model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) + + opt = AdamW(; eta=lr_max, lambda=weight_decay) + clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) + + lr_schedule = linear_interpolation( + [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0] + ) + + return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs) +end diff --git a/examples/CIFAR10/mlp_mixer.jl b/examples/CIFAR10/mlp_mixer.jl new file mode 100644 index 000000000..1132d0991 --- /dev/null +++ b/examples/CIFAR10/mlp_mixer.jl @@ -0,0 +1,6 @@ +using Comonicon, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme + +CUDA.allowscalar(false) + +@isdefined(includet) ? includet("common.jl") : include("common.jl") + diff --git a/examples/CIFAR10/simple_cnn.jl b/examples/CIFAR10/simple_cnn.jl new file mode 100644 index 000000000..23dd51051 --- /dev/null +++ b/examples/CIFAR10/simple_cnn.jl @@ -0,0 +1,36 @@ +using Comonicon, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme + +@isdefined(includet) ? includet("common.jl") : include("common.jl") + +CUDA.allowscalar(false) + +function SimpleCNN() + return Chain( + Conv((3, 3), 3 => 16, gelu; stride=2, pad=1), + BatchNorm(16), + Conv((3, 3), 16 => 32, gelu; stride=2, pad=1), + BatchNorm(32), + Conv((3, 3), 32 => 64, gelu; stride=2, pad=1), + BatchNorm(64), + Conv((3, 3), 64 => 128, gelu; stride=2, pad=1), + BatchNorm(128), + GlobalMeanPool(), + FlattenLayer(), + Dense(128 => 64, gelu), + BatchNorm(64), + Dense(64 => 10) + ) +end + +Comonicon.@main function main(; + batchsize::Int=512, weight_decay::Float64=0.0001, + clip_norm::Bool=false, seed::Int=1234, epochs::Int=50, lr::Float64=0.003, + backend::String="reactant" +) + model = SimpleCNN() + + opt = AdamW(; eta=lr, lambda=weight_decay) + clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) + + return train_model(model, opt, nothing; backend, batchsize, seed, epochs) +end From d73e9f88a1b19558596ebe27315e99345aa4d7b3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 23 Dec 2024 03:49:39 -0500 Subject: [PATCH 10/10] feat: working ConvMixer --- examples/CIFAR10/README.md | 34 ---------------------------------- examples/CIFAR10/conv_mixer.jl | 6 ++++-- 2 files changed, 4 insertions(+), 36 deletions(-) diff --git a/examples/CIFAR10/README.md b/examples/CIFAR10/README.md index 6e1841663..dea9cfc3d 100644 --- a/examples/CIFAR10/README.md +++ b/examples/CIFAR10/README.md @@ -45,44 +45,10 @@ julia --startup-file=no \ --project=. \ --threads=auto \ conv_mixer.jl \ - --lr-max=0.05 \ - --weight-decay=0.0001 \ --backend=reactant ``` -Here's an example output of the above command (on a RTX 4050 6GB Laptop GPU): - -``` -Epoch 1: Learning Rate 5.05e-03, Train Acc: 56.91%, Test Acc: 56.49%, Time: 129.84 -Epoch 2: Learning Rate 1.01e-02, Train Acc: 69.75%, Test Acc: 68.40%, Time: 21.22 -Epoch 3: Learning Rate 1.51e-02, Train Acc: 76.86%, Test Acc: 74.73%, Time: 21.33 -Epoch 4: Learning Rate 2.01e-02, Train Acc: 81.03%, Test Acc: 78.14%, Time: 21.40 -Epoch 5: Learning Rate 2.51e-02, Train Acc: 72.71%, Test Acc: 70.29%, Time: 21.34 -Epoch 6: Learning Rate 3.01e-02, Train Acc: 83.12%, Test Acc: 80.20%, Time: 21.38 -Epoch 7: Learning Rate 3.51e-02, Train Acc: 82.38%, Test Acc: 78.66%, Time: 21.39 -Epoch 8: Learning Rate 4.01e-02, Train Acc: 84.24%, Test Acc: 79.97%, Time: 21.49 -Epoch 9: Learning Rate 4.51e-02, Train Acc: 84.93%, Test Acc: 80.18%, Time: 21.40 -Epoch 10: Learning Rate 5.00e-02, Train Acc: 84.97%, Test Acc: 80.26%, Time: 21.37 -Epoch 11: Learning Rate 4.52e-02, Train Acc: 89.09%, Test Acc: 83.53%, Time: 21.31 -Epoch 12: Learning Rate 4.05e-02, Train Acc: 91.62%, Test Acc: 85.10%, Time: 21.39 -Epoch 13: Learning Rate 3.57e-02, Train Acc: 93.71%, Test Acc: 86.78%, Time: 21.29 -Epoch 14: Learning Rate 3.10e-02, Train Acc: 95.14%, Test Acc: 87.23%, Time: 21.37 -Epoch 15: Learning Rate 2.62e-02, Train Acc: 95.36%, Test Acc: 87.08%, Time: 21.34 -Epoch 16: Learning Rate 2.15e-02, Train Acc: 97.07%, Test Acc: 87.91%, Time: 21.26 -Epoch 17: Learning Rate 1.67e-02, Train Acc: 98.67%, Test Acc: 89.57%, Time: 21.40 -Epoch 18: Learning Rate 1.20e-02, Train Acc: 99.41%, Test Acc: 89.77%, Time: 21.28 -Epoch 19: Learning Rate 7.20e-03, Train Acc: 99.81%, Test Acc: 90.31%, Time: 21.39 -Epoch 20: Learning Rate 2.50e-03, Train Acc: 99.94%, Test Acc: 90.83%, Time: 21.44 -Epoch 21: Learning Rate 2.08e-03, Train Acc: 99.96%, Test Acc: 90.83%, Time: 21.23 -Epoch 22: Learning Rate 1.66e-03, Train Acc: 99.97%, Test Acc: 90.91%, Time: 21.29 -Epoch 23: Learning Rate 1.25e-03, Train Acc: 99.99%, Test Acc: 90.82%, Time: 21.29 -Epoch 24: Learning Rate 8.29e-04, Train Acc: 99.99%, Test Acc: 90.79%, Time: 21.32 -Epoch 25: Learning Rate 4.12e-04, Train Acc: 100.00%, Test Acc: 90.83%, Time: 21.32 -``` - ### Notes 1. To match the results from the original repo, we need more augmentation strategies, that are currently not implemented in DataAugmentation.jl. - 2. Don't compare the reported timings in that repo against the numbers here. They time the - entire loop. We only time the training part of the loop. diff --git a/examples/CIFAR10/conv_mixer.jl b/examples/CIFAR10/conv_mixer.jl index 55f0b20da..170b11910 100644 --- a/examples/CIFAR10/conv_mixer.jl +++ b/examples/CIFAR10/conv_mixer.jl @@ -33,10 +33,12 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) #! format: on end -Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, +Comonicon.@main function main(; + batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.0001, clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05, - backend::String="reactant") + backend::String="reactant" +) model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) opt = AdamW(; eta=lr_max, lambda=weight_decay)