Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autoencoder #38

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,22 @@ Suggests:
mlr3tuning (>= 0.1.2),
rmarkdown,
testthat (>= 2.1.0)
Remotes:
mlr-org/mlr3@dtype_image
mlr-org/paradox
VignetteBuilder:
knitr
RdMacros:
mlr3misc
Remotes:
mlr-org/bbotk,
mlr-org/mlr3misc,
mlr-org/mlr3,
mlr-org/mlr3pipelines,
mlr-org/mlr3tuning,
mlr-org/paradox
Encoding: UTF-8
LazyData: true
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = FALSE)
RoxygenNote: 7.1.1
SystemRequirements: Keras >= 2.0 (https://keras.io)
Config/reticulate:
list(
packages = list(
list(package = "tensorflow", pip = TRUE),
list(package = "keras", pip = TRUE)
)
)
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ import(mlr3misc)
import(paradox)
importFrom(R6,R6Class)
importFrom(stats,setNames)
importFrom(tensorflow,tf)
158 changes: 158 additions & 0 deletions R/AutoEncoder.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
PipeOpAutoencoder = R6::R6Class(
inherit = mlr3pipelines::PipeOpTaskPreproc,
public = list(
initialize = function(id = "autoencode", param_vals = list()) {
ps = ParamSet$new(list(
ParamInt$new("epochs", default = 100L, lower = 0L, tags = "train"),
ParamDbl$new("validation_split", lower = 0, upper = 1, default = 1/3, tags = "train"),
ParamInt$new("batch_size", default = 128L, lower = 1L, tags = c("train", "predict", "predict_fun")),
ParamUty$new("callbacks", default = list(), tags = "train"),
ParamInt$new("verbose", lower = 0L, upper = 1L, tags = c("train", "predict", "predict_fun")),
ParamInt$new("n_max", default = 128L, tags = "train", lower = 1, upper = Inf),
ParamInt$new("n_layers", default = 2L, tags = "train", lower = 1, upper = Inf),
ParamInt$new("bottleneck_size", default = 10L, tags = "train", lower = 1, upper = Inf),
ParamUty$new("initializer", default = "initializer_glorot_uniform()", tags = "train"),
ParamUty$new("regularizer", default = "regularizer_l1_l2()", tags = "train"),
ParamUty$new("optimizer", default = "optimizer_sgd()", tags = "train"),
ParamFct$new("activation", default = "relu", tags = "train",
levels = c("elu", "relu", "selu", "tanh", "sigmoid","PRelU", "LeakyReLu", "linear")),
ParamLgl$new("use_batchnorm", default = TRUE, tags = "train"),
ParamLgl$new("use_dropout", default = TRUE, tags = "train"),
ParamDbl$new("dropout", lower = 0, upper = 1, tags = "train"),
ParamFct$new("loss", default = "mean_squared_error", tags = "train", levels = keras_reflections$loss$regr),
ParamUty$new("metrics", tags = "train")
))
ps$values = list(
epochs = 100L,
callbacks = list(),
validation_split = 1/3,
batch_size = 128L,
activation = "relu",
n_max = 128L,
n_layers = 2L,
bottleneck_size = 10L,
initializer = initializer_glorot_uniform(),
optimizer = optimizer_sgd(lr = 3*10^-4, momentum = 0.9),
regularizer = regularizer_l1_l2(),
use_batchnorm = FALSE,
use_dropout = TRUE,
dropout = 0,
loss = "mean_squared_error",
metrics = "mean_squared_error",
verbose = 0L
)
super$initialize(id = id, param_set = ps, param_vals = param_vals, feature_types = c("numeric", "integer"))
}
),
private = list(
.train_task = function(task) {
pars = self$param_set$values

# Get columns from data
dt_columns = private$.select_cols(task)
cols = dt_columns
if (!length(cols)) {
self$state = list(dt_columns = dt_columns)
return(task)
}
x = data.matrix(task$data(cols = task$feature_names))

# Train model
aenc = build_autoencoder(task, self$param_set$values)

history = invoke(keras::fit,
object = aenc$model,
x = x,
y = x,
epochs = as.integer(pars$epochs),
batch_size = as.integer(pars$batch_size),
validation_split = pars$validation_split,
verbose = as.integer(pars$verbose),
callbacks = pars$callbacks
)
self$state = list(model = aenc$encoder, history = history)

# Pass on encoded training data
dt = data.table(aenc$encoder %>% predict(x))
self$state$dt_columns = dt_columns
task$select(setdiff(task$feature_names, cols))$cbind(dt)
},

.predict_dt = function(dt, levels) {
x = data.matrix(dt)
self$state$model %>% predict(x)
}
)
)


# Feed-Forward Autoencoder
build_autoencoder = function(task, pars) {

if ("factor" %in% task$feature_types$type && !pars$use_embedding)
stop("Factor features are only available with use_embedding = TRUE!")

# Get input and output shape for model
input_shape = task$ncol - 1L
bottleneck_size = pars$bottleneck_size

model = keras_model_sequential()

# Build hidden layers
n_neurons_layer = integer(pars$n_layers)
n_neurons_layer[1] = pars$n_max

# Encoder
enc_input = layer_input(shape = input_shape)
enc_output = enc_input
for (i in seq_len(pars$n_layers)) {
enc_output = enc_output %>%
layer_dense(
units = n_neurons_layer[i],
kernel_regularizer = pars$regularizer,
kernel_initializer = pars$initializer,
bias_regularizer = pars$regularizer,
bias_initializer = pars$initializer
) %>%
layer_activation(pars$activation)
if (pars$use_batchnorm) enc_output = enc_output %>% layer_batch_normalization()
if (pars$use_dropout) enc_output = enc_output %>% layer_dropout(pars$dropout)
if(i < pars$n_layers)
n_neurons_layer[i+1] = ceiling(n_neurons_layer[i] - (pars$n_max - bottleneck_size) / (pars$n_layers - 1L))
}
encoder = keras_model(enc_input, enc_output)

# Decoder
n_neurons_layer_rev = c(rev(n_neurons_layer)[-1], input_shape)
dec_input = layer_input(shape = bottleneck_size)
dec_output = dec_input
for (i in seq_len(pars$n_layers)) {
dec_output = dec_output %>%
layer_dense(
units = n_neurons_layer_rev[i],
kernel_regularizer = pars$regularizer,
kernel_initializer = pars$initializer,
bias_regularizer = pars$regularizer,
bias_initializer = pars$initializer
) %>%
layer_activation(pars$activation)
if (i != 1) {
if (pars$use_batchnorm) dec_output = dec_output %>% layer_batch_normalization()
if (pars$use_dropout) dec_output = dec_output %>% layer_dropout(pars$dropout)
}
}
decoder = keras_model(dec_input, dec_output)

# AutoEncoder
a_input = layer_input(shape = input_shape)
a_output = a_input %>%
encoder() %>%
decoder()
model = keras_model(a_input, a_output)
model %>% compile(
optimizer = pars$optimizer,
loss = pars$loss,
metrics = pars$metrics
)
list(model = model, encoder = encoder, decoder = decoder)
}
2 changes: 1 addition & 1 deletion R/keras_callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ cb_es = function(monitor = 'val_loss', patience = 3L) {
#' @export
cb_lr_scheduler_cosine_anneal = function(eta_max = 0.01, T_max = 10, T_mult = 2, M_mult = 1, eta_min = 0) {
callback_learning_rate_scheduler(
tf$keras$experimental$CosineDecayRestarts(eta_max, T_max, t_mul = T_mult, m_mul = M_mult, alpha = eta_min)
tensorflow::tf$keras$experimental$CosineDecayRestarts(eta_max, T_max, t_mul = T_mult, m_mul = M_mult, alpha = eta_min)
)
}

Expand Down
3 changes: 2 additions & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
#' @import checkmate
#' @importFrom R6 R6Class
#' @importFrom stats setNames
#' @importFrom tensorflow tf
#' @description
#' A package that connects mlr3 to keras.
"_PACKAGE"


#' @title Reflections mechanism for keras
#'
#' @details
Expand Down Expand Up @@ -45,6 +45,7 @@ register_mlr3 = function() { # nocov start
}

.onLoad = function(libname, pkgname) {
reticulate::configure_environment(pkgname)
register_mlr3()
setHook(packageEvent("mlr3", "onLoad"), function(...) register_mlr3(), action = "append")
}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ If you encounter problems using the correct python versions, see [here](https://
| [regr/classif.smlp]() | Shaped MLP as described in Configuration Space 1* | Zimmer, L. et al. (2020): Auto PyTorch Tabular |
| [regr/classif.smlp2]() | Shaped MLP as described in Configuration Space 2* | Zimmer, L. et al. (2020): Auto PyTorch Tabular |

* with some slight changes, namely no Shake-Shake, Shake-Drop, Mixup Training.
* with some changes, namely no Shake-Shake, Shake-Drop, Mixup Training.
and added Entity Embeddings for categorical variables.

Learners can be used for `training` and `prediction` as follows:
Expand Down
2 changes: 1 addition & 1 deletion man/KerasArchitecture.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 13 additions & 1 deletion man/callbacks.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions tests/testthat/test_autoencoder.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
context("Resampling works for keras models")

test_that("can be trained with cv3", {
skip_on_os("solaris")
# Build model
t = tsk("iris")
po = PipeOpAutoencoder$new()
po$train(list(t))
po$predict(list(t))
})
1 change: 1 addition & 0 deletions tests/testthat/test_classif_kerasff.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ test_that("can fit with binary_crossentropy", {
po_lrn$param_set$values$epochs = 10L
po_lrn$param_set$values$layer_units = c(12L, 12L)
po_lrn$param_set$values$loss = "binary_crossentropy"
po_lrn$param_set$values$output_activation = "sigmoid"
pipe = po_imp %>>% po_lrn
pipe$train(mlr_tasks$get("pima"))

Expand Down
10 changes: 8 additions & 2 deletions tests/testthat/test_entity_embedding.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ test_that("entity embedding works for all tasks", {
embds = make_embedding(task)
expect_list(embds, len = 2L, names = "named")
dt = task$feature_types[type %in% c("character", "factor", "ordered"), ]
expect_true(length(embds$inputs) == nrow(dt) + 1)

expect_true(length(embds$inputs) == nrow(dt) + (nrow(setdiff(task$feature_types, dt)) > 0))
expect_class(embds$layers, "tensorflow.tensor")
map(embds$inputs, expect_class, "tensorflow.tensor")
}
Expand All @@ -21,6 +22,11 @@ test_that("entity embedding works for all tasks", {
expect_list(embds, len = 2L, names = "named")
dt = task$feature_types[type %in% c("character", "factor", "ordered"), ]
expect_true(length(embds$fct_levels) == nrow(dt))
expect_true(length(embds$fct_levels) == length(embds$data) - 1)
# Is either n categ - 1 (all numerics) or 0 (no categ)
expect_true(
(length(embds$fct_levels) == length(embds$data) - (nrow(setdiff(task$feature_types, dt)) > 0)) ||
(length(embds$fct_levels) == 0 && length(embds$data) == 1L)
)
}
k_clear_session()
})
2 changes: 1 addition & 1 deletion vignettes/mlr3keras.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,4 @@ kr = import("keras_radam")
radam = kr$training$RAdamOptimizer()
lrn = lrn("classif.kerasff", predict_type = "prob", epochs = 3L, optimizer = radam)
lrn$train(mlr_tasks$get("iris"))
```
```