diff --git a/NAMESPACE b/NAMESPACE index 67ce6a2b4..b7c6b6b24 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -253,6 +253,7 @@ export(glm_grouped) export(has_multi_predict) export(importance_weights) export(is_varying) +export(keras_activations) export(keras_mlp) export(keras_predict_classes) export(knit_engine_docs) diff --git a/NEWS.md b/NEWS.md index 54e0c0ace..b7f4641b7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -15,6 +15,8 @@ * New `extract_fit_time()` method has been added that returns the time it took to train the model (#853). +* `mlp()` with `keras` engine now work for all activation functions currently supported by `keras` (#1127). + ## Other Changes * Transitioned package errors and warnings to use cli (#1147 and #1148 by @shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160, #1161, #1081). diff --git a/R/mlp.R b/R/mlp.R index 4ea677e09..ee40edbd9 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -192,8 +192,13 @@ keras_mlp <- seeds = sample.int(10^5, size = 3), ...) { - act_funs <- c("linear", "softmax", "relu", "elu", "tanh") - rlang::arg_match(activation, act_funs) + allowed_keras_activation <- keras_activations() + good_activation <- activation %in% allowed_keras_activation + if (!all(good_activation)) { + cli::cli_abort( + "{.arg activation} should be one of: {allowed_activation}." + ) + } if (penalty > 0 & dropout > 0) { cli::cli_abort("Please use either dropout or weight decay.", call = NULL) @@ -344,6 +349,19 @@ mlp_num_weights <- function(p, hidden_units, classes) { ((p + 1) * hidden_units) + ((hidden_units+1) * classes) } +allowed_keras_activation <- + c("elu", "exponential", "gelu", "hard_sigmoid", "linear", "relu", "selu", + "sigmoid", "softmax", "softplus", "softsign", "swish", "tanh") + +#' Activation functions for neural networks in keras +#' +#' @keywords internal +#' @return A character vector of values. +#' @export +keras_activations <- function() { + allowed_keras_activation +} + ## ----------------------------------------------------------------------------- #' @importFrom purrr map diff --git a/man/keras_activations.Rd b/man/keras_activations.Rd new file mode 100644 index 000000000..c5fb4e610 --- /dev/null +++ b/man/keras_activations.Rd @@ -0,0 +1,15 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/mlp.R +\name{keras_activations} +\alias{keras_activations} +\title{Activation functions for neural networks in keras} +\usage{ +keras_activations() +} +\value{ +A character vector of values. +} +\description{ +Activation functions for neural networks in keras +} +\keyword{internal} diff --git a/tests/testthat/test-mlp_keras.R b/tests/testthat/test-mlp_keras.R index b56300bfc..640ac2098 100644 --- a/tests/testthat/test-mlp_keras.R +++ b/tests/testthat/test-mlp_keras.R @@ -149,11 +149,10 @@ car_basic <- mlp(mode = "regression", epochs = 10) %>% bad_keras_reg <- mlp(mode = "regression") %>% - set_engine("keras", min.node.size = -10) + set_engine("keras", min.node.size = -10, verbose = 0) # ------------------------------------------------------------------------------ - test_that('keras execution, regression', { skip_on_cran() skip_if_not_installed("keras") @@ -211,7 +210,6 @@ test_that('keras regression prediction', { keras::backend()$clear_session() }) - # ------------------------------------------------------------------------------ test_that('multivariate nnet formula', { @@ -247,3 +245,34 @@ test_that('multivariate nnet formula', { keras::backend()$clear_session() }) + +# ------------------------------------------------------------------------------ + +test_that('all keras activation functions', { + skip_on_cran() + skip_if_not_installed("keras") + skip_if_not_installed("modeldata") + skip_if(!is_tf_ok()) + + act <- parsnip:::keras_activations() + + test_act <- function(fn) { + set.seed(1) + try( + mlp(mode = "classification", hidden_units = 2, penalty = 0.01, epochs = 2, + activation = !!fn) %>% + set_engine("keras", verbose = 0) %>% + parsnip::fit(Class ~ A + B, data = modeldata::two_class_dat), + silent = TRUE) + + } + test_act_sshhh <- purrr::quietly(test_act) + + for (i in act) { + keras::backend()$clear_session() + act_res <- test_act_sshhh(i) + expect_s3_class(act_res$result, "model_fit") + keras::backend()$clear_session() + } + +})