-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: update ConvMixer to support reactant #1063
Draft
avik-pal
wants to merge
9
commits into
ap/reactant_updates
Choose a base branch
from
ap/conv_mixer_reactant
base: ap/reactant_updates
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+277
−152
Draft
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
dccf81f
feat: conditional VAE testcase
avik-pal 108770e
feat: overload Utils.vec for upcoming wrapper array changes
avik-pal 1ada932
docs: update ConvMixer to support reactant
avik-pal 8d43074
docs: keep the ConvMixer default backend as cuda.jl for now
avik-pal e82b578
fix: remove unnecessary patches for now
avik-pal a684b56
fix: update reactant version
avik-pal 541d4c5
feat: pipeline working :tada:
avik-pal 05500f2
fix: more bug fixes for reactant
avik-pal dd724d6
refactor: centralize the CIFAR10 examples
avik-pal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
using ConcreteStructs, DataAugmentation, ImageShow, Lux, MLDatasets, MLUtils, OneHotArrays, | ||
Printf, ProgressTables, Random | ||
using LuxCUDA, Reactant | ||
|
||
@concrete struct TensorDataset | ||
dataset | ||
transform | ||
end | ||
|
||
Base.length(ds::TensorDataset) = length(ds.dataset) | ||
|
||
function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange}) | ||
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3)) | ||
y = onehotbatch(ds.dataset.targets[idxs], 0:9) | ||
return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img), y | ||
end | ||
|
||
function get_cifar10_dataloaders(batchsize; kwargs...) | ||
cifar10_mean = (0.4914, 0.4822, 0.4465) | ||
cifar10_std = (0.2471, 0.2435, 0.2616) | ||
|
||
train_transform = RandomResizeCrop((32, 32)) |> | ||
Maybe(FlipX{2}()) |> | ||
ImageToTensor() |> | ||
Normalize(cifar10_mean, cifar10_std) | ||
|
||
test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) | ||
|
||
trainset = TensorDataset(CIFAR10(:train), train_transform) | ||
trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...) | ||
|
||
testset = TensorDataset(CIFAR10(:test), test_transform) | ||
testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...) | ||
|
||
return trainloader, testloader | ||
end | ||
|
||
function accuracy(model, ps, st, dataloader) | ||
total_correct, total = 0, 0 | ||
cdev = cpu_device() | ||
for (x, y) in dataloader | ||
target_class = onecold(cdev(y)) | ||
predicted_class = onecold(cdev(first(model(x, ps, st)))) | ||
total_correct += sum(target_class .== predicted_class) | ||
total += length(target_class) | ||
end | ||
return total_correct / total | ||
end | ||
|
||
function get_accelerator_device(backend::String) | ||
if backend == "gpu_if_available" | ||
return gpu_device() | ||
elseif backend == "gpu" | ||
return gpu_device(; force=true) | ||
elseif backend == "reactant" | ||
return reactant_device(; force=true) | ||
elseif backend == "cpu" | ||
return cpu_device() | ||
else | ||
error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \ | ||
`reactant`, and `cpu`.") | ||
end | ||
end | ||
|
||
function train_model( | ||
model, opt, scheduler=nothing; | ||
backend::String, batchsize::Int=512, seed::Int=1234, epochs::Int=25 | ||
) | ||
rng = Random.default_rng() | ||
Random.seed!(rng, seed) | ||
|
||
accelerator_device = get_accelerator_device(backend) | ||
kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : () | ||
trainloader, testloader = get_cifar10_dataloaders(batchsize; kwargs...) |> | ||
accelerator_device | ||
|
||
ps, st = Lux.setup(rng, model) |> accelerator_device | ||
|
||
train_state = Training.TrainState(model, ps, st, opt) | ||
|
||
adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote() | ||
|
||
if backend == "reactant" | ||
x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device | ||
@printf "[Info] Compiling model with Reactant.jl\n" | ||
st_test = Lux.testmode(st) | ||
model_compiled = @compile model(x_ra, ps, st_test) | ||
@printf "[Info] Model compiled!\n" | ||
else | ||
model_compiled = model | ||
end | ||
|
||
loss_fn = CrossEntropyLoss(; logits=Val(true)) | ||
|
||
pt = ProgressTable(; | ||
header=[ | ||
"Epoch", "Learning Rate", "Train Accuracy (%)", "Test Accuracy (%)", "Time (s)" | ||
], | ||
widths=[24, 24, 24, 24, 24], | ||
format=["%3d", "%.6f", "%.6f", "%.6f", "%.6f"], | ||
color=[:normal, :normal, :blue, :blue, :normal], | ||
border=true, | ||
alignment=[:center, :center, :center, :center, :center] | ||
) | ||
|
||
@printf "[Info] Training model\n" | ||
initialize(pt) | ||
|
||
for epoch in 1:epochs | ||
stime = time() | ||
lr = 0 | ||
for (i, (x, y)) in enumerate(trainloader) | ||
if scheduler !== nothing | ||
lr = scheduler((epoch - 1) + (i + 1) / length(trainloader)) | ||
train_state = Optimisers.adjust!(train_state, lr) | ||
end | ||
(_, loss, _, train_state) = Training.single_train_step!( | ||
adtype, loss_fn, (x, y), train_state | ||
) | ||
isnan(loss) && error("NaN loss encountered!") | ||
end | ||
ttime = time() - stime | ||
|
||
train_acc = accuracy( | ||
model_compiled, train_state.parameters, | ||
Lux.testmode(train_state.states), trainloader | ||
) * 100 | ||
test_acc = accuracy( | ||
model_compiled, train_state.parameters, | ||
Lux.testmode(train_state.states), testloader | ||
) * 100 | ||
|
||
scheduler === nothing && (lr = NaN32) | ||
next(pt, [epoch, lr, train_acc, test_acc, ttime]) | ||
end | ||
|
||
finalize(pt) | ||
@printf "[Info] Finished training\n" | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
using Comonicon, Interpolations, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme | ||
|
||
@isdefined(includet) ? includet("common.jl") : include("common.jl") | ||
|
||
CUDA.allowscalar(false) | ||
|
||
function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) | ||
#! format: off | ||
return Chain( | ||
Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size), | ||
BatchNorm(dim), | ||
[ | ||
Chain( | ||
SkipConnection( | ||
Chain( | ||
Conv( | ||
(kernel_size, kernel_size), dim => dim, gelu; | ||
groups=dim, pad=SamePad() | ||
), | ||
BatchNorm(dim) | ||
), | ||
+ | ||
), | ||
Conv((1, 1), dim => dim, gelu), | ||
BatchNorm(dim) | ||
) | ||
for _ in 1:depth | ||
]..., | ||
GlobalMeanPool(), | ||
FlattenLayer(), | ||
Dense(dim => 10) | ||
) | ||
#! format: on | ||
end | ||
|
||
Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, | ||
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.0001, | ||
clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05, | ||
backend::String="reactant") | ||
model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) | ||
|
||
opt = AdamW(; eta=lr_max, lambda=weight_decay) | ||
clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) | ||
|
||
lr_schedule = linear_interpolation( | ||
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0] | ||
) | ||
|
||
return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
using Comonicon, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme | ||
|
||
CUDA.allowscalar(false) | ||
|
||
@isdefined(includet) ? includet("common.jl") : include("common.jl") | ||
|
||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
using Comonicon, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme | ||
|
||
@isdefined(includet) ? includet("common.jl") : include("common.jl") | ||
|
||
CUDA.allowscalar(false) | ||
|
||
function SimpleCNN() | ||
return Chain( | ||
Conv((3, 3), 3 => 16, gelu; stride=2, pad=1), | ||
BatchNorm(16), | ||
Conv((3, 3), 16 => 32, gelu; stride=2, pad=1), | ||
BatchNorm(32), | ||
Conv((3, 3), 32 => 64, gelu; stride=2, pad=1), | ||
BatchNorm(64), | ||
Conv((3, 3), 64 => 128, gelu; stride=2, pad=1), | ||
BatchNorm(128), | ||
GlobalMeanPool(), | ||
FlattenLayer(), | ||
Dense(128 => 64, gelu), | ||
BatchNorm(64), | ||
Dense(64 => 10) | ||
) | ||
end | ||
|
||
Comonicon.@main function main(; | ||
batchsize::Int=512, weight_decay::Float64=0.0001, | ||
clip_norm::Bool=false, seed::Int=1234, epochs::Int=50, lr::Float64=0.003, | ||
backend::String="reactant" | ||
) | ||
model = SimpleCNN() | ||
|
||
opt = AdamW(; eta=lr, lambda=weight_decay) | ||
clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) | ||
|
||
return train_model(model, opt, nothing; backend, batchsize, seed, epochs) | ||
end |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶