From 780790a0acc27dfe82a5edb1b9a69ad6707ba576 Mon Sep 17 00:00:00 2001 From: cossio Date: Thu, 3 Sep 2020 14:21:37 +0200 Subject: [PATCH] cd keyword and up version --- Project.toml | 2 +- src/train/cd.jl | 4 ++-- test/profile/profiling.jl | 2 +- test/rbm.jl | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index dcc70b50..1073a9c2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RestrictedBoltzmannMachines" uuid = "12e6b396-7db5-4506-8cb6-664a4fe1e50e" authors = ["Jorge Fernandez-de-Cossio-Diaz "] -version = "0.2.5" +version = "0.2.6" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/train/cd.jl b/src/train/cd.jl index 81a6c220..886e8d2d 100644 --- a/src/train/cd.jl +++ b/src/train/cd.jl @@ -16,13 +16,13 @@ end end """ - train!(rbm, data, [cd]) + train!(rbm, data) We measure training time in units of observations presented to the model. For example, if you want to train for 100 epochs, then set `iters = 100 * data.nobs`, where `data.nobs` is the number of observations in the dataset. """ -function train!(rbm::RBM, data::Data, cd::Union{CD,PCD} = PCD(); +function train!(rbm::RBM, data::Data; cd::Union{CD,PCD} = PCD(), iters::Int, opt = ADAM(), ps::Params = params(rbm), vm::AbstractArray = update_chains(rbm, cd, first(data).v), # Markov chains history = nothing, # stores training history diff --git a/test/profile/profiling.jl b/test/profile/profiling.jl index 9b4fe2fc..f1fd0fc7 100644 --- a/test/profile/profiling.jl +++ b/test/profile/profiling.jl @@ -31,7 +31,7 @@ Profile.init(n=10^7, delay=0.01) @time train!(rbm, train_loader) @profiler train!(rbm, train_loader) -train!(rbm, train_loader, PCD(1)) +train!(rbm, train_loader) #@descend train!(rbm, train_loader) first(train_loader) diff --git a/test/rbm.jl b/test/rbm.jl index 204543a6..e70448dd 100644 --- a/test/rbm.jl +++ b/test/rbm.jl @@ -152,7 +152,7 @@ end gauge!(student) @test norm(teacher.weights) ≈ 1 ps = params(student.weights) - train!(student, train_data, PCD(5); iters=10000 * 32, ps = ps, opt = Flux.ADAM()) + train!(student, train_data; cd=PCD(5), iters=10000 * 32, ps = ps, opt = Flux.ADAM()) @test norm(teacher.weights) ≈ 1 @show dot(teacher.weights, student.weights) @test abs(dot(teacher.weights, student.weights)) ≥ 0.8 @@ -174,7 +174,7 @@ end gauge!(student) @test norm(teacher.weights) ≈ 1 ps = params(student.weights) - train!(student, train_data, PCD(5); iters=10000 * 32, ps = ps, opt = Flux.ADAM()) + train!(student, train_data; cd=PCD(5), iters=10000 * 32, ps = ps, opt = Flux.ADAM()) @test norm(teacher.weights) ≈ 1 @show dot(teacher.weights, student.weights) @test abs(dot(teacher.weights, student.weights)) ≥ 0.8