Skip to content

Commit

Permalink
docs: run partial dataset only on CI (#1128)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Dec 7, 2024
1 parent 1ea272a commit 546798a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 10 deletions.
19 changes: 15 additions & 4 deletions examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,20 @@ using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays,
CUDA.allowscalar(false)

# ## Loading Datasets
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
imgs, labels = dset(:train)[1:n_train]
function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},
n_eval::Union{Nothing, Int}, batchsize::Int) where {dset}
if n_train === nothing
imgs, labels = dset(:train)
else
imgs, labels = dset(:train)[1:n_train]
end
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)

imgs, labels = dset(:test)[1:n_eval]
if n_eval === nothing
imgs, labels = dset(:test)
else
imgs, labels = dset(:test)[1:n_eval]
end
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)

return (
Expand All @@ -21,7 +30,9 @@ function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) w
)
end

function load_datasets(n_train=1024, n_eval=32, batchsize=256)
function load_datasets(batchsize=256)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end

Expand Down
11 changes: 8 additions & 3 deletions examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@ CUDA.allowscalar(false)
# ## Loading MNIST
function loadmnist(batchsize, train_split)
## Load MNIST: Only 1500 for demonstration purposes
N = 1500
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
dataset = MNIST(; split=:train)
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
if N !== nothing
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
else
imgs = dataset.features
labels_raw = dataset.targets
end

## Process images into (H,W,C,BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
Expand Down
11 changes: 8 additions & 3 deletions examples/SimpleChains/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ using SimpleChains: SimpleChains
# ## Loading MNIST
function loadmnist(batchsize, train_split)
## Load MNIST
N = 2000
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
dataset = MNIST(; split=:train)
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
if N !== nothing
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
else
imgs = dataset.features
labels_raw = dataset.targets
end

## Process images into (H, W, C, BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
Expand Down

0 comments on commit 546798a

Please sign in to comment.