Skip to content

Commit

Permalink
feat: pipeline working 🎉
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 21, 2024
1 parent 8665bc3 commit 8e1ba6a
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ function get_dataloaders(batchsize; kwargs...)
test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std)

trainset = TensorDataset(CIFAR10(:train), train_transform)
trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true, kwargs...)
trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...)

testset = TensorDataset(CIFAR10(:test), test_transform)
testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true, kwargs...)
testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...)

return trainloader, testloader
end
Expand All @@ -42,17 +42,24 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2)
#! format: off
return Chain(
Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size),
BatchNorm(dim),
[Chain(
SkipConnection(
Chain(
Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim, pad=SamePad()),
BatchNorm(dim)
BatchNorm(dim; track_stats=false),
[
Chain(
SkipConnection(
Chain(
Conv(
(kernel_size, kernel_size), dim => dim, gelu;
groups=dim, pad=SamePad()
),
BatchNorm(dim)
),
+
),
+
),
Conv((1, 1), dim => dim, gelu), BatchNorm(dim))
for _ in 1:depth]...,
Conv((1, 1), dim => dim, gelu),
BatchNorm(dim)
)
for _ in 1:depth
]...,
GlobalMeanPool(),
FlattenLayer(),
Dense(dim => 10)
Expand All @@ -73,9 +80,11 @@ function accuracy(model, ps, st, dataloader)
return total_correct / total
end

Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5,
Comonicon.@main function main(; batchsize::Int=64, hidden_dim::Int=256, depth::Int=8,
# Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-4,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01,
# backend::String="gpu_if_available")
backend::String="reactant")
rng = StableRNG(seed)

Expand Down Expand Up @@ -118,7 +127,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::
model_compiled = model
end

loss = CrossEntropyLoss(; logits=Val(true))
loss_fn = CrossEntropyLoss(; logits=Val(true))

@printf "[Info] Training model\n"
for epoch in 1:epochs
Expand All @@ -128,7 +137,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::
lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader))
train_state = Optimisers.adjust!(train_state, lr)
(_, _, _, train_state) = Training.single_train_step!(
adtype, loss, (x, y), train_state
adtype, loss_fn, (x, y), train_state
)
end
ttime = time() - stime
Expand Down

0 comments on commit 8e1ba6a

Please sign in to comment.