-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #24 from JuliaTeachingCTU/macha/refactor-lectures-6-7
Macha/fixes-neural-networks
- Loading branch information
Showing
12 changed files
with
174 additions
and
172 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,108 +1,54 @@ | ||
using BSON | ||
using Flux | ||
using Flux: onehotbatch, onecold | ||
using Flux: onecold | ||
using MLDatasets | ||
|
||
Core.eval(Main, :(using Flux)) # hide | ||
|
||
function reshape_data(X::AbstractArray{T, 3}, y::AbstractVector) where T | ||
s = size(X) | ||
return reshape(X, s[1], s[2], 1, s[3]), reshape(y, 1, :) | ||
end | ||
|
||
function train_or_load!(file_name, m, X, y; force=false, kwargs...) | ||
|
||
!isdir(dirname(file_name)) && mkpath(dirname(file_name)) | ||
|
||
if force || !isfile(file_name) | ||
train_model!(m, X, y; file_name=file_name, kwargs...) | ||
else | ||
m_loaded = BSON.load(file_name)[:m] | ||
Flux.loadparams!(m, params(m_loaded)) | ||
end | ||
end | ||
|
||
function load_data(dataset; onehot=false, T=Float32) | ||
classes = 0:9 | ||
X_train, y_train = reshape_data(dataset(T, :train)[:]...) | ||
X_test, y_test = reshape_data(dataset(T, :test)[:]...) | ||
y_train = T.(y_train) | ||
y_test = T.(y_test) | ||
|
||
if onehot | ||
y_train = onehotbatch(y_train[:], classes) | ||
y_test = onehotbatch(y_test[:], classes) | ||
end | ||
|
||
return X_train, y_train, X_test, y_test | ||
end | ||
|
||
using Plots | ||
|
||
plot_image(x::AbstractArray{T, 2}) where T = plot(Gray.(x'), axis=nothing) | ||
|
||
function plot_image(x::AbstractArray{T, 4}) where T | ||
@assert size(x,4) == 1 | ||
plot_image(x[:,:,:,1]) | ||
end | ||
|
||
function plot_image(x::AbstractArray{T, 3}) where T | ||
@assert size(x,3) == 1 | ||
plot_image(x[:,:,1]) | ||
end | ||
|
||
include(joinpath(dirname(@__FILE__), "utilities.jl")) | ||
|
||
T = Float32 | ||
dataset = MLDatasets.MNIST | ||
|
||
X_train, y_train, X_test, y_test = load_data(dataset; T=T, onehot=true) | ||
|
||
model = Chain( | ||
Conv((2, 2), 1 => 16, sigmoid), | ||
MaxPool((2, 2)), | ||
Conv((2, 2), 16 => 8, sigmoid), | ||
MaxPool((2, 2)), | ||
Flux.flatten, | ||
Dense(288, size(y_train, 1)), | ||
softmax, | ||
) | ||
|
||
|
||
|
||
|
||
m = Chain( | ||
Conv((2,2), 1=>16, sigmoid), | ||
MaxPool((2,2)), | ||
Conv((2,2), 16=>8, sigmoid), | ||
MaxPool((2,2)), | ||
flatten, | ||
Dense(288, size(y_train,1)), softmax) | ||
|
||
file_name = joinpath("data", "mnist_sigmoid.bson") | ||
train_or_load!(file_name, m, X_train, y_train) | ||
|
||
|
||
|
||
file_name = joinpath("data", "mnist_sigmoid.jld2") | ||
train_or_load!(file_name, model, X_train, y_train) | ||
|
||
ii1 = findall(onecold(y_train, 0:9) .== 1)[1:5] | ||
ii2 = findall(onecold(y_train, 0:9) .== 9)[1:5] | ||
|
||
|
||
for qwe = 0:9 | ||
ii0 = findall(onecold(y_train, 0:9) .== qwe)[1:5] | ||
|
||
p0 = [plot_image(X_train[:,:,:,i:i][:,:,1,1]) for i in ii0] | ||
p1 = [plot_image((m[1:2](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii0] | ||
p2 = [plot_image((m[1:4](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii0] | ||
p0 = [plot_image(X_train[:, :, :, i:i][:, :, 1, 1]) for i in ii0] | ||
p1 = [plot_image((model[1:2](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii0] | ||
p2 = [plot_image((model[1:4](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii0] | ||
|
||
p = plot(p0..., p1..., p2...; layout=(3,5)) | ||
p = plot(p0..., p1..., p2...; layout=(3, 5)) | ||
display(p) | ||
end | ||
|
||
p0 = [plot_image(X_train[:,:,:,i:i][:,:,1,1]) for i in ii1] | ||
p1 = [plot_image((m[1:2](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii1] | ||
p2 = [plot_image((m[1:4](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii1] | ||
p0 = [plot_image(X_train[:, :, :, i:i][:, :, 1, 1]) for i in ii1] | ||
p1 = [plot_image((model[1:2](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii1] | ||
p2 = [plot_image((model[1:4](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii1] | ||
|
||
plot(p0..., p1..., p2...; layout=(3,5)) | ||
plot(p0..., p1..., p2...; layout=(3, 5)) | ||
|
||
|
||
p0 = [plot_image(X_train[:,:,:,i:i][:,:,1,1]) for i in ii2] | ||
p1 = [plot_image((m[1:2](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii2] | ||
p2 = [plot_image((m[1:4](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii2] | ||
p0 = [plot_image(X_train[:, :, :, i:i][:, :, 1, 1]) for i in ii2] | ||
p1 = [plot_image((model[1:2](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii2] | ||
p2 = [plot_image((model[1:4](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii2] | ||
|
||
plot(p0..., p1..., p2...; layout=(3,5)) | ||
plot(p0..., p1..., p2...; layout=(3, 5)) | ||
|
||
for i in 1:length(m) | ||
println(size(m[1:i](X_train[:,:,:,1:1]))) | ||
for i in 1:length(model) | ||
println(size(model[1:i](X_train[:, :, :, 1:1]))) | ||
end |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,69 +1,23 @@ | ||
using MLDatasets | ||
using Flux | ||
using BSON | ||
using Random | ||
using Statistics | ||
using Base.Iterators: partition | ||
using Flux: crossentropy, onehotbatch, onecold | ||
|
||
|
||
accuracy(x, y) = mean(onecold(cpu(m(x))) .== onecold(cpu(y))) | ||
|
||
function reshape_data(X::AbstractArray{T, 3}, y::AbstractVector) where T | ||
s = size(X) | ||
return reshape(X, s[1], s[2], 1, s[3]), reshape(y, 1, :) | ||
end | ||
|
||
function train_model!(m, X, y; | ||
opt=ADAM(0.001), | ||
batch_size=128, | ||
n_epochs=10, | ||
file_name="") | ||
|
||
loss(x, y) = crossentropy(m(x), y) | ||
|
||
batches_train = map(partition(randperm(size(y, 2)), batch_size)) do inds | ||
return (gpu(X[:, :, :, inds]), gpu(y[:, inds])) | ||
end | ||
|
||
for i in 1:n_epochs | ||
println("Iteration " * string(i)) | ||
Flux.train!(loss, params(m), batches_train, opt) | ||
end | ||
|
||
!isempty(file_name) && BSON.bson(file_name, m=m|>cpu) | ||
|
||
return | ||
end | ||
|
||
function load_data(dataset; onehot=false, T=Float32) | ||
classes = 0:9 | ||
X_train, y_train = reshape_data(dataset(T, :train)[:]...) | ||
X_test, y_test = reshape_data(dataset(T, :test)[:]...) | ||
y_train = T.(y_train) | ||
y_test = T.(y_test) | ||
|
||
if onehot | ||
y_train = onehotbatch(y_train[:], classes) | ||
y_test = onehotbatch(y_test[:], classes) | ||
end | ||
|
||
return X_train, y_train, X_test, y_test | ||
end | ||
include(joinpath(dirname(@__FILE__), "utilities.jl")) | ||
|
||
dataset = MLDatasets.MNIST | ||
T = Float32 | ||
X_train, y_train, X_test, y_test = load_data(dataset; T=T, onehot=true) | ||
|
||
m = Chain( | ||
Conv((2,2), 1=>16, sigmoid), | ||
MaxPool((2,2)), | ||
Conv((2,2), 16=>8, sigmoid), | ||
MaxPool((2,2)), | ||
flatten, | ||
Dense(288, size(y_train,1)), softmax) |> gpu | ||
model = Chain( | ||
Conv((2, 2), 1 => 16, sigmoid), | ||
MaxPool((2, 2)), | ||
Conv((2, 2), 16 => 8, sigmoid), | ||
MaxPool((2, 2)), | ||
Flux.flatten, | ||
Dense(288, size(y_train, 1)), | ||
softmax, | ||
) |> gpu | ||
|
||
file_name = joinpath("data", "mnist_sigmoid.bson") | ||
train_model!(m, X_train, y_train; file_name=file_name, n_epochs=100) | ||
file_name = joinpath("data", "mnist_sigmoid.jld2") | ||
train_model!(model, X_train, y_train; file_name=file_name, n_epochs=100) | ||
|
||
accuracy(X_test |> gpu, y_test |> gpu) | ||
accuracy(model, X_test, y_test) |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
using MLDatasets | ||
using Flux | ||
using JLD2 | ||
using Random | ||
using Statistics | ||
using Base.Iterators: partition | ||
using Flux: crossentropy, onehotbatch, onecold | ||
using Plots | ||
using Pkg | ||
|
||
if haskey(Pkg.project().dependencies, "CUDA") | ||
using CUDA | ||
else | ||
gpu(x) = x | ||
end | ||
|
||
accuracy(model, x, y) = mean(onecold(cpu(model(x))) .== onecold(cpu(y))) | ||
|
||
function reshape_data(X::AbstractArray{T,3}, y::AbstractVector) where {T} | ||
s = size(X) | ||
return reshape(X, s[1], s[2], 1, s[3]), reshape(y, 1, :) | ||
end | ||
|
||
function load_data(dataset; onehot=false, T=Float32) | ||
classes = 0:9 | ||
X_train, y_train = reshape_data(dataset(T, :train)[:]...) | ||
X_test, y_test = reshape_data(dataset(T, :test)[:]...) | ||
y_train = T.(y_train) | ||
y_test = T.(y_test) | ||
|
||
if onehot | ||
y_train = onehotbatch(y_train[:], classes) | ||
y_test = onehotbatch(y_test[:], classes) | ||
end | ||
|
||
return X_train, y_train, X_test, y_test | ||
end | ||
|
||
function train_model!( | ||
model, | ||
X, | ||
y; | ||
opt=Adam(0.001), | ||
batch_size=128, | ||
n_epochs=10, | ||
file_name="", | ||
) | ||
|
||
loss(x, y) = crossentropy(model(x), y) | ||
|
||
batches_train = map(partition(randperm(size(y, 2)), batch_size)) do inds | ||
return (gpu(X[:, :, :, inds]), gpu(y[:, inds])) | ||
end | ||
|
||
for epoch in 1:n_epochs | ||
@show epoch | ||
Flux.train!(loss, Flux.params(model), batches_train, opt) | ||
end | ||
|
||
!isempty(file_name) && jldsave(file_name; model_state=Flux.state(model) |> cpu) | ||
|
||
return | ||
end | ||
|
||
function train_or_load!(file_name, model, args...; force=false, kwargs...) | ||
|
||
!isdir(dirname(file_name)) && mkpath(dirname(file_name)) | ||
|
||
if force || !isfile(file_name) | ||
train_model!(model, args...; file_name=file_name, kwargs...) | ||
else | ||
model_state = JLD2.load(file_name, "model_state") | ||
Flux.loadmodel!(model, model_state) | ||
end | ||
end | ||
|
||
plot_image(x::AbstractArray{T,2}) where {T} = plot(Gray.(x'), axis=nothing) | ||
|
||
function plot_image(x::AbstractArray{T,4}) where {T} | ||
@assert size(x, 4) == 1 | ||
plot_image(x[:, :, :, 1]) | ||
end | ||
|
||
function plot_image(x::AbstractArray{T,3}) where {T} | ||
@assert size(x, 3) == 1 | ||
plot_image(x[:, :, 1]) | ||
end |
Oops, something went wrong.