diff --git a/NEWS.md b/NEWS.md index 772f858b1..78f412ac8 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # parsnip (development version) +* parsnip now lets the engines for [mlp()] check for acceptable values of the activation function (#1019) + * Tightened logic for outcome checking. This resolves issues—some errors and some silent failures—when atomic outcome variables have an attribute (#1060, #1061). * `rpart_train()` has been deprecated in favor of using `decision_tree()` with the `"rpart"` engine or `rpart::rpart()` directly (#1044). diff --git a/R/mlp.R b/R/mlp.R index d607bae2a..8b5ba1b59 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -21,8 +21,8 @@ #' @param activation A single character string denoting the type of relationship #' between the original predictors and the hidden unit layer. The activation #' function between the hidden and output layers is automatically set to either -#' "linear" or "softmax" depending on the type of outcome. Possible values are: -#' "linear", "softmax", "relu", and "elu" +#' "linear" or "softmax" depending on the type of outcome. Possible values +#' depend on the engine being used. #' #' @templateVar modeltype mlp #' @template spec-details @@ -142,24 +142,6 @@ check_args.mlp <- function(object) { if (args$dropout > 0 & args$penalty > 0) rlang::abort("Both weight decay and dropout should not be specified.") - - if (object$engine == "brulee") { - act_funs <- c("linear", "relu", "elu", "tanh") - } else if (object$engine == "keras") { - act_funs <- c("linear", "softmax", "relu", "elu") - } else if (object$engine == "h2o") { - act_funs <- c("relu", "tanh") - } - - if (is.character(args$activation)) { - if (!any(args$activation %in% c(act_funs))) { - rlang::abort( - glue::glue("`activation` should be one of: ", - glue::glue_collapse(glue::glue("'{act_funs}'"), sep = ", ")) - ) - } - } - invisible(object) } @@ -210,6 +192,9 @@ keras_mlp <- seeds = sample.int(10^5, size = 3), ...) { + act_funs <- c("linear", "softmax", "relu", "elu") + rlang::arg_match(activation, act_funs,) + if (penalty > 0 & dropout > 0) { rlang::abort("Please use either dropoput or weight decay.", call. = FALSE) } diff --git a/tests/testthat/test_mlp.R b/tests/testthat/test_mlp.R index da4e639d7..ec9dafad9 100644 --- a/tests/testthat/test_mlp.R +++ b/tests/testthat/test_mlp.R @@ -23,3 +23,28 @@ test_that("nnet_softmax", { expect_equal(res$b, 1 - res$a) }) +test_that("more activations for brulee", { + skip_if_not_installed("brulee", minimum_version = "0.3.0") + skip_on_cran() + + data(ames, package = "modeldata") + + ames$Sale_Price <- log10(ames$Sale_Price) + + set.seed(122) + in_train <- sample(1:nrow(ames), 2000) + ames_train <- ames[ in_train,] + ames_test <- ames[-in_train,] + + set.seed(1) + fit <- + try( + mlp(penalty = 0.10, activation = "softplus") %>% + set_mode("regression") %>% + set_engine("brulee") %>% + fit_xy(x = as.matrix(ames_train[, c("Longitude", "Latitude")]), + y = ames_train$Sale_Price), + silent = TRUE) + expect_true(inherits(fit$fit, "brulee_mlp")) +}) +