Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use OneHotArrays #105

Merged
merged 11 commits into from
Jul 24, 2022
21 changes: 7 additions & 14 deletions examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,22 @@ using Lux
using Pkg #hide
Pkg.activate(joinpath(dirname(pathof(Lux)), "..", "examples")) #hide
using ComponentArrays, CUDA, DiffEqSensitivity, NNlib, Optimisers, OrdinaryDiffEq, Random,
Statistics, Zygote
Statistics, Zygote, OneHotArrays
import MLDatasets: MNIST
import MLDataUtils: convertlabel, LabelEnc
import MLUtils: DataLoader, splitobs
CUDA.allowscalar(false)

# ## Loading MNIST
## Use MLDataUtils LabelEnc for natural onehot conversion
function onehot(labels_raw)
return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
end

function loadmnist(batchsize, train_split)
## Load MNIST: Only 1500 for demonstration purposes
N = 1500
imgs = MNIST.traintensor(1:N)
labels_raw = MNIST.trainlabels(1:N)
dataset = MNIST(; split=:train)
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]

## Process images into (H,W,C,BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehot(labels_raw)
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

return (
Expand Down Expand Up @@ -88,8 +83,6 @@ function create_model()
end

# ## Define Utility Functions
get_class(x) = argmax.(eachcol(x))

logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(x, y, model, ps, st)
Expand All @@ -102,8 +95,8 @@ function accuracy(model, ps, st, dataloader)
st = Lux.testmode(st)
iterator = CUDA.functional() ? CuIterator(dataloader) : dataloader
for (x, y) in iterator
target_class = get_class(cpu(y))
predicted_class = get_class(cpu(model(x, ps, st)[1]))
target_class = onecold(cpu(y))
predicted_class = onecold(cpu(model(x, ps, st)[1]))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
Expand Down
2 changes: 2 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
Expand Down Expand Up @@ -54,6 +55,7 @@ MLDatasets = "0.5, 0.7"
MLUtils = "0.2"
Metalhead = "0.7"
NNlib = "0.8"
OneHotArrays = "0.1"
Optimisers = "0.2"
OrdinaryDiffEq = "6"
ParameterSchedulers = "0.3"
Expand Down