Skip to content

Commit

Permalink
Support All keras activations functions (#1244)
Browse files Browse the repository at this point in the history
* allow all keras activations

* update news

* add tests

---------

Co-authored-by: ‘topepo’ <[email protected]>
  • Loading branch information
EmilHvitfeldt and topepo authored Jan 29, 2025
1 parent e9354e7 commit 556c732
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 5 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ export(glm_grouped)
export(has_multi_predict)
export(importance_weights)
export(is_varying)
export(keras_activations)
export(keras_mlp)
export(keras_predict_classes)
export(knit_engine_docs)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).

* `mlp()` with `keras` engine now work for all activation functions currently supported by `keras` (#1127).

## Other Changes

* Transitioned package errors and warnings to use cli (#1147 and #1148 by @shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160, #1161, #1081).
Expand Down
22 changes: 20 additions & 2 deletions R/mlp.R
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,13 @@ keras_mlp <-
seeds = sample.int(10^5, size = 3),
...) {

act_funs <- c("linear", "softmax", "relu", "elu", "tanh")
rlang::arg_match(activation, act_funs)
allowed_keras_activation <- keras_activations()
good_activation <- activation %in% allowed_keras_activation
if (!all(good_activation)) {
cli::cli_abort(
"{.arg activation} should be one of: {allowed_activation}."
)
}

if (penalty > 0 & dropout > 0) {
cli::cli_abort("Please use either dropout or weight decay.", call = NULL)
Expand Down Expand Up @@ -344,6 +349,19 @@ mlp_num_weights <- function(p, hidden_units, classes) {
((p + 1) * hidden_units) + ((hidden_units+1) * classes)
}

allowed_keras_activation <-
c("elu", "exponential", "gelu", "hard_sigmoid", "linear", "relu", "selu",
"sigmoid", "softmax", "softplus", "softsign", "swish", "tanh")

#' Activation functions for neural networks in keras
#'
#' @keywords internal
#' @return A character vector of values.
#' @export
keras_activations <- function() {
allowed_keras_activation
}

## -----------------------------------------------------------------------------

#' @importFrom purrr map
Expand Down
15 changes: 15 additions & 0 deletions man/keras_activations.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 32 additions & 3 deletions tests/testthat/test-mlp_keras.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,10 @@ car_basic <- mlp(mode = "regression", epochs = 10) %>%

bad_keras_reg <-
mlp(mode = "regression") %>%
set_engine("keras", min.node.size = -10)
set_engine("keras", min.node.size = -10, verbose = 0)

# ------------------------------------------------------------------------------


test_that('keras execution, regression', {
skip_on_cran()
skip_if_not_installed("keras")
Expand Down Expand Up @@ -211,7 +210,6 @@ test_that('keras regression prediction', {
keras::backend()$clear_session()
})


# ------------------------------------------------------------------------------

test_that('multivariate nnet formula', {
Expand Down Expand Up @@ -247,3 +245,34 @@ test_that('multivariate nnet formula', {

keras::backend()$clear_session()
})

# ------------------------------------------------------------------------------

test_that('all keras activation functions', {
skip_on_cran()
skip_if_not_installed("keras")
skip_if_not_installed("modeldata")
skip_if(!is_tf_ok())

act <- parsnip:::keras_activations()

test_act <- function(fn) {
set.seed(1)
try(
mlp(mode = "classification", hidden_units = 2, penalty = 0.01, epochs = 2,
activation = !!fn) %>%
set_engine("keras", verbose = 0) %>%
parsnip::fit(Class ~ A + B, data = modeldata::two_class_dat),
silent = TRUE)

}
test_act_sshhh <- purrr::quietly(test_act)

for (i in act) {
keras::backend()$clear_session()
act_res <- test_act_sshhh(i)
expect_s3_class(act_res$result, "model_fit")
keras::backend()$clear_session()
}

})

0 comments on commit 556c732

Please sign in to comment.