Skip to content

Commit

Permalink
More activation functions (#74)
Browse files Browse the repository at this point in the history
* add more activation functions for #69

* udpate error checking and tests

* fix unit test

* unit tests

* make a function to get possible values

* small updates

* update snapshot

* redoc with function link

* add skips; will re-write tests in next PR
  • Loading branch information
topepo authored Nov 5, 2023
1 parent ec4756c commit 88d6002
Show file tree
Hide file tree
Showing 11 changed files with 424 additions and 89 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ S3method(tunable,brulee_mlp)
S3method(tunable,brulee_multinomial_reg)
export("%>%")
export(autoplot)
export(brulee_activations)
export(brulee_linear_reg)
export(brulee_logistic_reg)
export(brulee_mlp)
Expand Down
25 changes: 25 additions & 0 deletions R/activation.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
allowed_activation <-
c("celu", "elu", "gelu", "hardshrink", "hardsigmoid",
"hardtanh", "leaky_relu", "linear", "log_sigmoid", "relu", "relu6",
"rrelu", "selu", "sigmoid", "silu", "softplus", "softshrink",
"softsign", "tanh", "tanhshrink")

#' Activation functions for neural networks in brulee
#'
#' @return A character vector of values.
#' @export
brulee_activations <- function() {
allowed_activation
}

get_activation_fn <- function(arg, ...) {

if (arg == "linear") {
res <- identity
} else {
cl <- rlang::call2(paste0("nn_", arg), .ns = "torch")
res <- rlang::eval_bare(cl)
}

res
}
38 changes: 16 additions & 22 deletions R/mlp-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@
#' @param hidden_units An integer for the number of hidden units, or a vector
#' of integers. If a vector of integers, the model will have `length(hidden_units)`
#' layers each with `hidden_units[i]` hidden units.
#' @param activation A string for the activation function. Possible values are
#' "relu", "elu", "tanh", and "linear". If `hidden_units` is a vector, `activation`
#' can be a character vector with length equals to `length(hidden_units)` specifying
#' the activation for each hidden layer.
#' @param activation A character vector for the activation function )such as
#' "relu", "tanh", "sigmoid", and so on). See [brulee_activations()] for
#' a list of possible values. If `hidden_units` is a vector, `activation`
#' can be a character vector with length equals to `length(hidden_units)`
#' specifying the activation for each hidden layer.
#' @param optimizer The method used in the optimization procedure. Possible choices
#' are 'LBFGS' and 'SGD'. Default is 'LBFGS'.
#' @param learn_rate A positive number that controls the initial rapidity that
Expand Down Expand Up @@ -435,18 +436,26 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
if (length(hidden_units) != length(activation)) {
rlang::abort("'activation' must be a single value or a vector with the same length as 'hidden_units'")
}

allowed_activation <- brulee_activations()
good_activation <- activation %in% allowed_activation
if (!all(good_activation)) {
rlang::abort(paste("'activation' should be one of: ", paste0(allowed_activation, collapse = ", ")))
}

if (optimizer == "LBFGS" & !is.null(batch_size)) {
rlang::warn("'batch_size' is only used for the SGD optimizer.")
batch_size <- NULL
}

check_integer(epochs, single = TRUE, 1, fn = f_nm)
if (!is.null(batch_size)) {
if (!is.null(batch_size) & optimizer == "SGD") {
if (is.numeric(batch_size) & !is.integer(batch_size)) {
batch_size <- as.integer(batch_size)
}
check_integer(batch_size, single = TRUE, 1, fn = f_nm)
}

check_integer(epochs, single = TRUE, 1, fn = f_nm)
check_integer(hidden_units, single = FALSE, 1, fn = f_nm)
check_double(penalty, single = TRUE, 0, incl = c(TRUE, TRUE), fn = f_nm)
check_double(mixture, single = TRUE, 0, 1, incl = c(TRUE, TRUE), fn = f_nm)
Expand All @@ -457,8 +466,6 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
check_logical(verbose, single = TRUE, fn = f_nm)
check_character(activation, single = FALSE, fn = f_nm)



## -----------------------------------------------------------------------------

predictors <- processed$predictors
Expand Down Expand Up @@ -635,7 +642,7 @@ mlp_fit_imp <-
loss_label <- "\tLoss:"
}

if (is.null(batch_size)) {
if (is.null(batch_size) & optimizer == "SGD") {
batch_size <- nrow(x)
} else {
batch_size <- min(batch_size, nrow(x))
Expand Down Expand Up @@ -854,19 +861,6 @@ print.brulee_mlp <- function(x, ...) {

## -----------------------------------------------------------------------------

get_activation_fn <- function(arg, ...) {
if (arg == "relu") {
res <- torch::nn_relu(...)
} else if (arg == "elu") {
res <- torch::nn_elu(...)
} else if (arg == "tanh") {
res <- torch::nn_tanh(...)
} else {
res <- identity
}
res
}

set_optimizer <- function(optimizer, model, learn_rate, momentum) {
if (optimizer == "LBFGS") {
res <- torch::optim_lbfgs(model$parameters, lr = learn_rate, history_size = 5)
Expand Down
1 change: 1 addition & 0 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ mlp
multilayer
perceptrons
relu
sigmoid
tanh
tibble
14 changes: 14 additions & 0 deletions man/brulee_activations.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions man/brulee_mlp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 88d6002

Please sign in to comment.