Skip to content

Commit

Permalink
feat: pipeline working 🎉
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 20, 2024
1 parent e3ab45f commit d646b75
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ function get_dataloaders(batchsize; kwargs...)
test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std)

trainset = TensorDataset(CIFAR10(:train), train_transform)
trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true, kwargs...)
trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...)

testset = TensorDataset(CIFAR10(:test), test_transform)
testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true, kwargs...)
testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...)

return trainloader, testloader
end
Expand Down Expand Up @@ -74,8 +74,8 @@ function accuracy(model, ps, st, dataloader)
end

Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-3,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.001,
backend::String="reactant")
rng = StableRNG(seed)

Expand Down Expand Up @@ -118,7 +118,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::
model_compiled = model
end

loss = CrossEntropyLoss(; logits=Val(true))
loss_fn = CrossEntropyLoss(; logits=Val(true))

@printf "[Info] Training model\n"
for epoch in 1:epochs
Expand All @@ -127,8 +127,8 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::
for (i, (x, y)) in enumerate(trainloader)
lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader))
train_state = Optimisers.adjust!(train_state, lr)
(_, _, _, train_state) = Training.single_train_step!(
adtype, loss, (x, y), train_state
(_, loss, _, train_state) = Training.single_train_step!(
adtype, loss_fn, (x, y), train_state
)
end
ttime = time() - stime
Expand Down

0 comments on commit d646b75

Please sign in to comment.