From f2e2d30a6a3bf324e8199c19f74512c9c1bdb74f Mon Sep 17 00:00:00 2001 From: pfistfl Date: Wed, 28 Oct 2020 10:08:37 +0100 Subject: [PATCH 1/3] minor update --- DESCRIPTION | 15 +++++++-------- NAMESPACE | 1 + R/keras_callbacks.R | 2 +- R/zzz.R | 3 ++- man/KerasArchitecture.Rd | 2 +- man/callbacks.Rd | 14 +++++++++++++- tests/testthat/test_classif_kerasff.R | 1 + tests/testthat/test_entity_embedding.R | 10 ++++++++-- 8 files changed, 34 insertions(+), 14 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index dd0a634..68be41c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -50,16 +50,15 @@ VignetteBuilder: knitr RdMacros: mlr3misc -Remotes: - mlr-org/bbotk, - mlr-org/mlr3misc, - mlr-org/mlr3, - mlr-org/mlr3pipelines, - mlr-org/mlr3tuning, - mlr-org/paradox Encoding: UTF-8 LazyData: true NeedsCompilation: no Roxygen: list(markdown = TRUE, r6 = FALSE) RoxygenNote: 7.1.1 -SystemRequirements: Keras >= 2.0 (https://keras.io) +Config/reticulate: + list( + packages = list( + list(package = "tensorflow", pip = TRUE), + list(package = "keras", pip = TRUE) + ) + ) diff --git a/NAMESPACE b/NAMESPACE index b90181d..3e9c908 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -38,3 +38,4 @@ import(mlr3misc) import(paradox) importFrom(R6,R6Class) importFrom(stats,setNames) +importFrom(tensorflow,tf) diff --git a/R/keras_callbacks.R b/R/keras_callbacks.R index a827165..86daa37 100644 --- a/R/keras_callbacks.R +++ b/R/keras_callbacks.R @@ -35,7 +35,7 @@ cb_es = function(monitor = 'val_loss', patience = 3L) { #' @export cb_lr_scheduler_cosine_anneal = function(eta_max = 0.01, T_max = 10, T_mult = 2, M_mult = 1, eta_min = 0) { callback_learning_rate_scheduler( - tf$keras$experimental$CosineDecayRestarts(eta_max, T_max, t_mul = T_mult, m_mul = M_mult, alpha = eta_min) + tensorflow::tf$keras$experimental$CosineDecayRestarts(eta_max, T_max, t_mul = T_mult, m_mul = M_mult, alpha = eta_min) ) } diff --git a/R/zzz.R b/R/zzz.R index c1f0864..de2d26c 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -6,11 +6,11 @@ #' @import checkmate #' @importFrom R6 R6Class #' @importFrom stats setNames +#' @importFrom tensorflow tf #' @description #' A package that connects mlr3 to keras. "_PACKAGE" - #' @title Reflections mechanism for keras #' #' @details @@ -45,6 +45,7 @@ register_mlr3 = function() { # nocov start } .onLoad = function(libname, pkgname) { + reticulate::configure_environment(pkgname) register_mlr3() setHook(packageEvent("mlr3", "onLoad"), function(...) register_mlr3(), action = "append") } diff --git a/man/KerasArchitecture.Rd b/man/KerasArchitecture.Rd index 3f39e43..a22161d 100644 --- a/man/KerasArchitecture.Rd +++ b/man/KerasArchitecture.Rd @@ -37,7 +37,7 @@ It can be used to more easily and flexibly add architectures. Initialize architecture -Obtain the model. Called by Learner during \code{train_internal}. +Obtain the model. Called by Learner during \code{train()}. Setter method for 'x_transform' and 'y_transform'. diff --git a/man/callbacks.Rd b/man/callbacks.Rd index 178b6d6..58c647b 100644 --- a/man/callbacks.Rd +++ b/man/callbacks.Rd @@ -12,7 +12,13 @@ \usage{ cb_es(monitor = "val_loss", patience = 3L) -cb_lr_scheduler_cosine_anneal(T_max = 10, T_mult = 2, eta_min = 0) +cb_lr_scheduler_cosine_anneal( + eta_max = 0.01, + T_max = 10, + T_mult = 2, + M_mult = 1, + eta_min = 0 +) cb_lr_scheduler_exponential_decay() @@ -27,12 +33,18 @@ Quantity to be monitored.} \item{patience}{\code{\link{integer}}\cr Number of iterations without improvement to wait before stopping.} +\item{eta_max}{\code{\link{numeric}}\cr +Max learning rate.} + \item{T_max}{\code{\link{integer}}\cr Reset learning rate every T_max epochs. Default 10.} \item{T_mult}{\code{\link{integer}}\cr Multiply T_max by T_mult every T_max iterations. Default 2.} +\item{M_mult}{\code{\link{numeric}}\cr +Decay learning rate by factor 'M_mult' after each learning rate reset.} + \item{eta_min}{\code{\link{numeric}}\cr Minimal learning rate.} } diff --git a/tests/testthat/test_classif_kerasff.R b/tests/testthat/test_classif_kerasff.R index 08ef593..23effa9 100644 --- a/tests/testthat/test_classif_kerasff.R +++ b/tests/testthat/test_classif_kerasff.R @@ -66,6 +66,7 @@ test_that("can fit with binary_crossentropy", { po_lrn$param_set$values$epochs = 10L po_lrn$param_set$values$layer_units = c(12L, 12L) po_lrn$param_set$values$loss = "binary_crossentropy" + po_lrn$param_set$values$output_activation = "sigmoid" pipe = po_imp %>>% po_lrn pipe$train(mlr_tasks$get("pima")) diff --git a/tests/testthat/test_entity_embedding.R b/tests/testthat/test_entity_embedding.R index b454f0a..b9d9827 100644 --- a/tests/testthat/test_entity_embedding.R +++ b/tests/testthat/test_entity_embedding.R @@ -7,7 +7,8 @@ test_that("entity embedding works for all tasks", { embds = make_embedding(task) expect_list(embds, len = 2L, names = "named") dt = task$feature_types[type %in% c("character", "factor", "ordered"), ] - expect_true(length(embds$inputs) == nrow(dt) + 1) + + expect_true(length(embds$inputs) == nrow(dt) + (nrow(setdiff(task$feature_types, dt)) > 0)) expect_class(embds$layers, "tensorflow.tensor") map(embds$inputs, expect_class, "tensorflow.tensor") } @@ -21,6 +22,11 @@ test_that("entity embedding works for all tasks", { expect_list(embds, len = 2L, names = "named") dt = task$feature_types[type %in% c("character", "factor", "ordered"), ] expect_true(length(embds$fct_levels) == nrow(dt)) - expect_true(length(embds$fct_levels) == length(embds$data) - 1) + # Is either n categ - 1 (all numerics) or 0 (no categ) + expect_true( + (length(embds$fct_levels) == length(embds$data) - (nrow(setdiff(task$feature_types, dt)) > 0)) || + (length(embds$fct_levels) == 0 && length(embds$data) == 1L) + ) } + k_clear_session() }) From 076094a5cdb7e458e4bd5ef18780bd13fdd19539 Mon Sep 17 00:00:00 2001 From: pfistfl Date: Tue, 3 Nov 2020 10:10:55 +0100 Subject: [PATCH 2/3] update --- DESCRIPTION | 3 +++ README.md | 2 +- vignettes/mlr3keras.Rmd | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 68be41c..538f760 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -46,6 +46,9 @@ Suggests: mlr3tuning (>= 0.1.2), rmarkdown, testthat (>= 2.1.0) +Remotes: + mlr-org/mlr3@dtype_image + mlr-org/paradox VignetteBuilder: knitr RdMacros: diff --git a/README.md b/README.md index 38a15f2..9aa0423 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ If you encounter problems using the correct python versions, see [here](https:// | [regr/classif.smlp]() | Shaped MLP as described in Configuration Space 1* | Zimmer, L. et al. (2020): Auto PyTorch Tabular | | [regr/classif.smlp2]() | Shaped MLP as described in Configuration Space 2* | Zimmer, L. et al. (2020): Auto PyTorch Tabular | -* with some slight changes, namely no Shake-Shake, Shake-Drop, Mixup Training. +* with some changes, namely no Shake-Shake, Shake-Drop, Mixup Training. and added Entity Embeddings for categorical variables. Learners can be used for `training` and `prediction` as follows: diff --git a/vignettes/mlr3keras.Rmd b/vignettes/mlr3keras.Rmd index deccf78..8ca246a 100644 --- a/vignettes/mlr3keras.Rmd +++ b/vignettes/mlr3keras.Rmd @@ -214,4 +214,4 @@ kr = import("keras_radam") radam = kr$training$RAdamOptimizer() lrn = lrn("classif.kerasff", predict_type = "prob", epochs = 3L, optimizer = radam) lrn$train(mlr_tasks$get("iris")) -``` +``` \ No newline at end of file From 2a906133be75efdf28bed26e65bfda08732a5c8a Mon Sep 17 00:00:00 2001 From: pfistfl Date: Tue, 3 Nov 2020 14:37:37 +0100 Subject: [PATCH 3/3] first draft --- R/AutoEncoder.R | 158 ++++++++++++++++++++++++++++++ tests/testthat/test_autoencoder.R | 10 ++ 2 files changed, 168 insertions(+) create mode 100644 R/AutoEncoder.R create mode 100644 tests/testthat/test_autoencoder.R diff --git a/R/AutoEncoder.R b/R/AutoEncoder.R new file mode 100644 index 0000000..1e84788 --- /dev/null +++ b/R/AutoEncoder.R @@ -0,0 +1,158 @@ +PipeOpAutoencoder = R6::R6Class( + inherit = mlr3pipelines::PipeOpTaskPreproc, + public = list( + initialize = function(id = "autoencode", param_vals = list()) { + ps = ParamSet$new(list( + ParamInt$new("epochs", default = 100L, lower = 0L, tags = "train"), + ParamDbl$new("validation_split", lower = 0, upper = 1, default = 1/3, tags = "train"), + ParamInt$new("batch_size", default = 128L, lower = 1L, tags = c("train", "predict", "predict_fun")), + ParamUty$new("callbacks", default = list(), tags = "train"), + ParamInt$new("verbose", lower = 0L, upper = 1L, tags = c("train", "predict", "predict_fun")), + ParamInt$new("n_max", default = 128L, tags = "train", lower = 1, upper = Inf), + ParamInt$new("n_layers", default = 2L, tags = "train", lower = 1, upper = Inf), + ParamInt$new("bottleneck_size", default = 10L, tags = "train", lower = 1, upper = Inf), + ParamUty$new("initializer", default = "initializer_glorot_uniform()", tags = "train"), + ParamUty$new("regularizer", default = "regularizer_l1_l2()", tags = "train"), + ParamUty$new("optimizer", default = "optimizer_sgd()", tags = "train"), + ParamFct$new("activation", default = "relu", tags = "train", + levels = c("elu", "relu", "selu", "tanh", "sigmoid","PRelU", "LeakyReLu", "linear")), + ParamLgl$new("use_batchnorm", default = TRUE, tags = "train"), + ParamLgl$new("use_dropout", default = TRUE, tags = "train"), + ParamDbl$new("dropout", lower = 0, upper = 1, tags = "train"), + ParamFct$new("loss", default = "mean_squared_error", tags = "train", levels = keras_reflections$loss$regr), + ParamUty$new("metrics", tags = "train") + )) + ps$values = list( + epochs = 100L, + callbacks = list(), + validation_split = 1/3, + batch_size = 128L, + activation = "relu", + n_max = 128L, + n_layers = 2L, + bottleneck_size = 10L, + initializer = initializer_glorot_uniform(), + optimizer = optimizer_sgd(lr = 3*10^-4, momentum = 0.9), + regularizer = regularizer_l1_l2(), + use_batchnorm = FALSE, + use_dropout = TRUE, + dropout = 0, + loss = "mean_squared_error", + metrics = "mean_squared_error", + verbose = 0L + ) + super$initialize(id = id, param_set = ps, param_vals = param_vals, feature_types = c("numeric", "integer")) + } + ), + private = list( + .train_task = function(task) { + pars = self$param_set$values + + # Get columns from data + dt_columns = private$.select_cols(task) + cols = dt_columns + if (!length(cols)) { + self$state = list(dt_columns = dt_columns) + return(task) + } + x = data.matrix(task$data(cols = task$feature_names)) + + # Train model + aenc = build_autoencoder(task, self$param_set$values) + + history = invoke(keras::fit, + object = aenc$model, + x = x, + y = x, + epochs = as.integer(pars$epochs), + batch_size = as.integer(pars$batch_size), + validation_split = pars$validation_split, + verbose = as.integer(pars$verbose), + callbacks = pars$callbacks + ) + self$state = list(model = aenc$encoder, history = history) + + # Pass on encoded training data + dt = data.table(aenc$encoder %>% predict(x)) + self$state$dt_columns = dt_columns + task$select(setdiff(task$feature_names, cols))$cbind(dt) + }, + + .predict_dt = function(dt, levels) { + x = data.matrix(dt) + self$state$model %>% predict(x) + } + ) +) + + +# Feed-Forward Autoencoder +build_autoencoder = function(task, pars) { + + if ("factor" %in% task$feature_types$type && !pars$use_embedding) + stop("Factor features are only available with use_embedding = TRUE!") + + # Get input and output shape for model + input_shape = task$ncol - 1L + bottleneck_size = pars$bottleneck_size + + model = keras_model_sequential() + + # Build hidden layers + n_neurons_layer = integer(pars$n_layers) + n_neurons_layer[1] = pars$n_max + + # Encoder + enc_input = layer_input(shape = input_shape) + enc_output = enc_input + for (i in seq_len(pars$n_layers)) { + enc_output = enc_output %>% + layer_dense( + units = n_neurons_layer[i], + kernel_regularizer = pars$regularizer, + kernel_initializer = pars$initializer, + bias_regularizer = pars$regularizer, + bias_initializer = pars$initializer + ) %>% + layer_activation(pars$activation) + if (pars$use_batchnorm) enc_output = enc_output %>% layer_batch_normalization() + if (pars$use_dropout) enc_output = enc_output %>% layer_dropout(pars$dropout) + if(i < pars$n_layers) + n_neurons_layer[i+1] = ceiling(n_neurons_layer[i] - (pars$n_max - bottleneck_size) / (pars$n_layers - 1L)) + } + encoder = keras_model(enc_input, enc_output) + + # Decoder + n_neurons_layer_rev = c(rev(n_neurons_layer)[-1], input_shape) + dec_input = layer_input(shape = bottleneck_size) + dec_output = dec_input + for (i in seq_len(pars$n_layers)) { + dec_output = dec_output %>% + layer_dense( + units = n_neurons_layer_rev[i], + kernel_regularizer = pars$regularizer, + kernel_initializer = pars$initializer, + bias_regularizer = pars$regularizer, + bias_initializer = pars$initializer + ) %>% + layer_activation(pars$activation) + if (i != 1) { + if (pars$use_batchnorm) dec_output = dec_output %>% layer_batch_normalization() + if (pars$use_dropout) dec_output = dec_output %>% layer_dropout(pars$dropout) + } + } + decoder = keras_model(dec_input, dec_output) + + # AutoEncoder + a_input = layer_input(shape = input_shape) + a_output = a_input %>% + encoder() %>% + decoder() + model = keras_model(a_input, a_output) + model %>% compile( + optimizer = pars$optimizer, + loss = pars$loss, + metrics = pars$metrics + ) + list(model = model, encoder = encoder, decoder = decoder) +} diff --git a/tests/testthat/test_autoencoder.R b/tests/testthat/test_autoencoder.R new file mode 100644 index 0000000..41de3df --- /dev/null +++ b/tests/testthat/test_autoencoder.R @@ -0,0 +1,10 @@ +context("Resampling works for keras models") + +test_that("can be trained with cv3", { + skip_on_os("solaris") + # Build model + t = tsk("iris") + po = PipeOpAutoencoder$new() + po$train(list(t)) + po$predict(list(t)) +}) \ No newline at end of file