Skip to content

Commit

Permalink
Merge pull request #24 from JuliaTeachingCTU/macha/refactor-lectures-6-7
Browse files Browse the repository at this point in the history
Macha/fixes-neural-networks
  • Loading branch information
VaclavMacha authored Oct 27, 2024
2 parents 9392514 + c1b102b commit 220278c
Show file tree
Hide file tree
Showing 12 changed files with 174 additions and 172 deletions.
8 changes: 5 additions & 3 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Expand All @@ -13,9 +12,11 @@ GLPK = "60bf3e95-4087-53dc-ae20-288a0d20c6a6"
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Query = "1a8c2f83-1ff3-5112-b086-8aa67b057ba1"
Expand All @@ -29,8 +30,7 @@ StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
BSON = "0.3"
BenchmarkTools = "1.5"
CSV = "0.10"
DataFrames = "1.6"
DifferentialEquations = "7.14"
DataFrames = "1.7"
Distributions = "0.25"
Documenter = "1.7"
Flux = "0.14"
Expand All @@ -39,8 +39,10 @@ GLPK = "1.2"
GR = "0.73"
HypothesisTests = "0.11"
Ipopt = "1.6"
JLD2 = "0.5"
JuMP = "1.23"
MLDatasets = "0.7"
MLUtils = "0.4"
Plots = "1.40"
ProgressMeter = "1.10"
Query = "1.0"
Expand Down
Binary file removed docs/src/lecture_11/data/mnist.bson
Binary file not shown.
108 changes: 27 additions & 81 deletions docs/src/lecture_11/data/mnist.jl
Original file line number Diff line number Diff line change
@@ -1,108 +1,54 @@
using BSON
using Flux
using Flux: onehotbatch, onecold
using Flux: onecold
using MLDatasets

Core.eval(Main, :(using Flux)) # hide

function reshape_data(X::AbstractArray{T, 3}, y::AbstractVector) where T
s = size(X)
return reshape(X, s[1], s[2], 1, s[3]), reshape(y, 1, :)
end

function train_or_load!(file_name, m, X, y; force=false, kwargs...)

!isdir(dirname(file_name)) && mkpath(dirname(file_name))

if force || !isfile(file_name)
train_model!(m, X, y; file_name=file_name, kwargs...)
else
m_loaded = BSON.load(file_name)[:m]
Flux.loadparams!(m, params(m_loaded))
end
end

function load_data(dataset; onehot=false, T=Float32)
classes = 0:9
X_train, y_train = reshape_data(dataset(T, :train)[:]...)
X_test, y_test = reshape_data(dataset(T, :test)[:]...)
y_train = T.(y_train)
y_test = T.(y_test)

if onehot
y_train = onehotbatch(y_train[:], classes)
y_test = onehotbatch(y_test[:], classes)
end

return X_train, y_train, X_test, y_test
end

using Plots

plot_image(x::AbstractArray{T, 2}) where T = plot(Gray.(x'), axis=nothing)

function plot_image(x::AbstractArray{T, 4}) where T
@assert size(x,4) == 1
plot_image(x[:,:,:,1])
end

function plot_image(x::AbstractArray{T, 3}) where T
@assert size(x,3) == 1
plot_image(x[:,:,1])
end

include(joinpath(dirname(@__FILE__), "utilities.jl"))

T = Float32
dataset = MLDatasets.MNIST

X_train, y_train, X_test, y_test = load_data(dataset; T=T, onehot=true)

model = Chain(
Conv((2, 2), 1 => 16, sigmoid),
MaxPool((2, 2)),
Conv((2, 2), 16 => 8, sigmoid),
MaxPool((2, 2)),
Flux.flatten,
Dense(288, size(y_train, 1)),
softmax,
)




m = Chain(
Conv((2,2), 1=>16, sigmoid),
MaxPool((2,2)),
Conv((2,2), 16=>8, sigmoid),
MaxPool((2,2)),
flatten,
Dense(288, size(y_train,1)), softmax)

file_name = joinpath("data", "mnist_sigmoid.bson")
train_or_load!(file_name, m, X_train, y_train)



file_name = joinpath("data", "mnist_sigmoid.jld2")
train_or_load!(file_name, model, X_train, y_train)

ii1 = findall(onecold(y_train, 0:9) .== 1)[1:5]
ii2 = findall(onecold(y_train, 0:9) .== 9)[1:5]


for qwe = 0:9
ii0 = findall(onecold(y_train, 0:9) .== qwe)[1:5]

p0 = [plot_image(X_train[:,:,:,i:i][:,:,1,1]) for i in ii0]
p1 = [plot_image((m[1:2](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii0]
p2 = [plot_image((m[1:4](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii0]
p0 = [plot_image(X_train[:, :, :, i:i][:, :, 1, 1]) for i in ii0]
p1 = [plot_image((model[1:2](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii0]
p2 = [plot_image((model[1:4](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii0]

p = plot(p0..., p1..., p2...; layout=(3,5))
p = plot(p0..., p1..., p2...; layout=(3, 5))
display(p)
end

p0 = [plot_image(X_train[:,:,:,i:i][:,:,1,1]) for i in ii1]
p1 = [plot_image((m[1:2](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii1]
p2 = [plot_image((m[1:4](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii1]
p0 = [plot_image(X_train[:, :, :, i:i][:, :, 1, 1]) for i in ii1]
p1 = [plot_image((model[1:2](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii1]
p2 = [plot_image((model[1:4](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii1]

plot(p0..., p1..., p2...; layout=(3,5))
plot(p0..., p1..., p2...; layout=(3, 5))


p0 = [plot_image(X_train[:,:,:,i:i][:,:,1,1]) for i in ii2]
p1 = [plot_image((m[1:2](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii2]
p2 = [plot_image((m[1:4](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii2]
p0 = [plot_image(X_train[:, :, :, i:i][:, :, 1, 1]) for i in ii2]
p1 = [plot_image((model[1:2](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii2]
p2 = [plot_image((model[1:4](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii2]

plot(p0..., p1..., p2...; layout=(3,5))
plot(p0..., p1..., p2...; layout=(3, 5))

for i in 1:length(m)
println(size(m[1:i](X_train[:,:,:,1:1])))
for i in 1:length(model)
println(size(model[1:i](X_train[:, :, :, 1:1])))
end
Binary file added docs/src/lecture_11/data/mnist.jld2
Binary file not shown.
72 changes: 13 additions & 59 deletions docs/src/lecture_11/data/mnist_gpu.jl
Original file line number Diff line number Diff line change
@@ -1,69 +1,23 @@
using MLDatasets
using Flux
using BSON
using Random
using Statistics
using Base.Iterators: partition
using Flux: crossentropy, onehotbatch, onecold


accuracy(x, y) = mean(onecold(cpu(m(x))) .== onecold(cpu(y)))

function reshape_data(X::AbstractArray{T, 3}, y::AbstractVector) where T
s = size(X)
return reshape(X, s[1], s[2], 1, s[3]), reshape(y, 1, :)
end

function train_model!(m, X, y;
opt=ADAM(0.001),
batch_size=128,
n_epochs=10,
file_name="")

loss(x, y) = crossentropy(m(x), y)

batches_train = map(partition(randperm(size(y, 2)), batch_size)) do inds
return (gpu(X[:, :, :, inds]), gpu(y[:, inds]))
end

for i in 1:n_epochs
println("Iteration " * string(i))
Flux.train!(loss, params(m), batches_train, opt)
end

!isempty(file_name) && BSON.bson(file_name, m=m|>cpu)

return
end

function load_data(dataset; onehot=false, T=Float32)
classes = 0:9
X_train, y_train = reshape_data(dataset(T, :train)[:]...)
X_test, y_test = reshape_data(dataset(T, :test)[:]...)
y_train = T.(y_train)
y_test = T.(y_test)

if onehot
y_train = onehotbatch(y_train[:], classes)
y_test = onehotbatch(y_test[:], classes)
end

return X_train, y_train, X_test, y_test
end
include(joinpath(dirname(@__FILE__), "utilities.jl"))

dataset = MLDatasets.MNIST
T = Float32
X_train, y_train, X_test, y_test = load_data(dataset; T=T, onehot=true)

m = Chain(
Conv((2,2), 1=>16, sigmoid),
MaxPool((2,2)),
Conv((2,2), 16=>8, sigmoid),
MaxPool((2,2)),
flatten,
Dense(288, size(y_train,1)), softmax) |> gpu
model = Chain(
Conv((2, 2), 1 => 16, sigmoid),
MaxPool((2, 2)),
Conv((2, 2), 16 => 8, sigmoid),
MaxPool((2, 2)),
Flux.flatten,
Dense(288, size(y_train, 1)),
softmax,
) |> gpu

file_name = joinpath("data", "mnist_sigmoid.bson")
train_model!(m, X_train, y_train; file_name=file_name, n_epochs=100)
file_name = joinpath("data", "mnist_sigmoid.jld2")
train_model!(model, X_train, y_train; file_name=file_name, n_epochs=100)

accuracy(X_test |> gpu, y_test |> gpu)
accuracy(model, X_test, y_test)
Binary file removed docs/src/lecture_11/data/mnist_sigmoid.bson
Binary file not shown.
Binary file added docs/src/lecture_11/data/mnist_sigmoid.jld2
Binary file not shown.
87 changes: 87 additions & 0 deletions docs/src/lecture_11/data/utilities.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
using MLDatasets
using Flux
using JLD2
using Random
using Statistics
using Base.Iterators: partition
using Flux: crossentropy, onehotbatch, onecold
using Plots
using Pkg

if haskey(Pkg.project().dependencies, "CUDA")
using CUDA
else
gpu(x) = x
end

accuracy(model, x, y) = mean(onecold(cpu(model(x))) .== onecold(cpu(y)))

function reshape_data(X::AbstractArray{T,3}, y::AbstractVector) where {T}
s = size(X)
return reshape(X, s[1], s[2], 1, s[3]), reshape(y, 1, :)
end

function load_data(dataset; onehot=false, T=Float32)
classes = 0:9
X_train, y_train = reshape_data(dataset(T, :train)[:]...)
X_test, y_test = reshape_data(dataset(T, :test)[:]...)
y_train = T.(y_train)
y_test = T.(y_test)

if onehot
y_train = onehotbatch(y_train[:], classes)
y_test = onehotbatch(y_test[:], classes)
end

return X_train, y_train, X_test, y_test
end

function train_model!(
model,
X,
y;
opt=Adam(0.001),
batch_size=128,
n_epochs=10,
file_name="",
)

loss(x, y) = crossentropy(model(x), y)

batches_train = map(partition(randperm(size(y, 2)), batch_size)) do inds
return (gpu(X[:, :, :, inds]), gpu(y[:, inds]))
end

for epoch in 1:n_epochs
@show epoch
Flux.train!(loss, Flux.params(model), batches_train, opt)
end

!isempty(file_name) && jldsave(file_name; model_state=Flux.state(model) |> cpu)

return
end

function train_or_load!(file_name, model, args...; force=false, kwargs...)

!isdir(dirname(file_name)) && mkpath(dirname(file_name))

if force || !isfile(file_name)
train_model!(model, args...; file_name=file_name, kwargs...)
else
model_state = JLD2.load(file_name, "model_state")
Flux.loadmodel!(model, model_state)
end
end

plot_image(x::AbstractArray{T,2}) where {T} = plot(Gray.(x'), axis=nothing)

function plot_image(x::AbstractArray{T,4}) where {T}
@assert size(x, 4) == 1
plot_image(x[:, :, :, 1])
end

function plot_image(x::AbstractArray{T,3}) where {T}
@assert size(x, 3) == 1
plot_image(x[:, :, 1])
end
Loading

0 comments on commit 220278c

Please sign in to comment.