Skip to content

Commit

Permalink
Merge 705be28 into 7beb0fe
Browse files Browse the repository at this point in the history
  • Loading branch information
SvenVw authored Jun 6, 2023
2 parents 7beb0fe + 705be28 commit 7bd9c4f
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 13 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ Authors@R: c(
person(family = "RStudio", role = c("cph")),
person(given = "Christophe", family = "Regouby", role = c("cre", "ctb"), email = "[email protected]"),
person(given = "Egill", family = "Fridgeirsson", role = c("ctb")),
person(given = "Philipp", family = "Haarmeyer", role = c("ctb"))
person(given = "Philipp", family = "Haarmeyer", role = c("ctb")),
person(given = "Sven", family = "Verweij", role = c("ctb"), comment = c(ORCID = "0000-0002-5573-3952"))
)
Description: Implements the 'TabNet' model by Sercan O. Arik et al. (2019) <arXiv:1908.07442>
and provides a consistent interface for fitting and creating predictions. It's
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# tabnet (development version)

## New features
* Add `reduce_on_plateau` as option for `lr_scheduler` at `tabnet_config()` (@SvenVw, #120)

# tabnet 0.4.0

## New features
Expand Down
18 changes: 13 additions & 5 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ resolve_data <- function(x, y) {
#' range from 1 to 5
#' @param verbose (logical) Whether to print progress and loss values during
#' training.
#' @param lr_scheduler if `NULL`, no learning rate decay is used. if "step"
#' decays the learning rate by `lr_decay` every `step_size` epochs. It can
#' also be a [torch::lr_scheduler] function that only takes the optimizer
#' @param lr_scheduler if `NULL`, no learning rate decay is used. If "step"
#' decays the learning rate by `lr_decay` every `step_size` epochs. If "reduce_on_plateau"
#' decays the learning rate by `lr_decay` when no improvement after `step_size` epochs.
#' It can #' also be a [torch::lr_scheduler] function that only takes the optimizer
#' as parameter. The `step` method is called once per epoch.
#' @param lr_decay multiplies the initial learning rate by `lr_decay` every
#' `step_size` epochs. Unused if `lr_scheduler` is a `torch::lr_scheduler`
Expand Down Expand Up @@ -457,8 +458,12 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
scheduler <- list(step = function() {})
} else if (rlang::is_function(config$lr_scheduler)) {
scheduler <- config$lr_scheduler(optimizer)
} else if (config$lr_scheduler == "reduce_on_plateau") {
scheduler <- torch::lr_reduce_on_plateau(optimizer, factor = config$lr_decay, patience = config$step_size)
} else if (config$lr_scheduler == "step") {
scheduler <- torch::lr_step(optimizer, config$step_size, config$lr_decay)
} else {
rlang::abort("Currently only the 'step' and 'reduce_on_plateau' scheduler are supported.")
}

# restore previous metrics & checkpoints
Expand Down Expand Up @@ -537,8 +542,11 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
best_metric <- current_loss
}


scheduler$step()
if ("metrics" %in% names(formals(scheduler$step))) {
scheduler$step(current_loss)
} else {
scheduler$step()
}
}

network$to(device = "cpu")
Expand Down
11 changes: 9 additions & 2 deletions R/pretraining.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,12 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =
scheduler <- list(step = function() {})
} else if (rlang::is_function(config$lr_scheduler)) {
scheduler <- config$lr_scheduler(optimizer)
} else if (config$lr_scheduler == "reduce_on_plateau") {
scheduler <- torch::lr_reduce_on_plateau(optimizer, factor = config$lr_decay, patience = config$step_size)
} else if (config$lr_scheduler == "step") {
scheduler <- torch::lr_step(optimizer, config$step_size, config$lr_decay)
} else {
rlang::abort("Currently only the 'step' and 'reduce_on_plateau' scheduler are supported.")
}

# initialize metrics & checkpoints
Expand Down Expand Up @@ -223,8 +227,11 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =
best_metric <- current_loss
}


scheduler$step()
if ("metrics" %in% names(formals(scheduler$step))) {
scheduler$step(current_loss)
} else {
scheduler$step()
}
}

network$to(device = "cpu")
Expand Down
7 changes: 4 additions & 3 deletions man/tabnet_config.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 20 additions & 1 deletion tests/testthat/test-hardhat_parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ test_that("explicit error message when categorical embedding dimension vector ha
)
})

test_that("scheduler works", {
test_that("step scheduler works", {

expect_error(
fit <- tabnet_fit(x, y, epochs = 3, lr_scheduler = "step",
Expand All @@ -102,6 +102,25 @@ test_that("scheduler works", {

})

test_that("reduce_on_plateau scheduler works", {

expect_error(
fit <- tabnet_fit(x, y, epochs = 3, lr_scheduler = "reduce_on_plateau",
lr_decay = 0.1, step_size = 1),
regexp = NA
)

sc_fn <- function(optimizer) {
torch::lr_step(optimizer, step_size = 1, gamma = 0.1)
}

expect_error(
fit <- tabnet_fit(x, y, epochs = 3, lr_scheduler = sc_fn,
lr_decay = 0.1, step_size = 1),
regexp = NA
)

})

test_that("fit uses config parameters mix from config= and ...", {

Expand Down
22 changes: 21 additions & 1 deletion tests/testthat/test-pretraining.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ test_that("can train from a recipe", {

})

test_that("lr scheduler works", {
test_that("lr scheduler step works", {

expect_error(
fit <- tabnet_pretrain(x, y, epochs = 3, lr_scheduler = "step",
Expand All @@ -128,6 +128,26 @@ test_that("lr scheduler works", {

})

test_that("lr scheduler reduce_on_plateau works", {

expect_error(
fit <- tabnet_pretrain(x, y, epochs = 3, lr_scheduler = "reduce_on_plateau",
lr_decay = 0.1, step_size = 1),
regexp = NA
)

sc_fn <- function(optimizer) {
torch::lr_reduce_on_plateau(optimizer, factor = 0.1, patience = 10)
}

expect_error(
fit <- tabnet_pretrain(x, y, epochs = 3, lr_scheduler = sc_fn,
lr_decay = 0.1, step_size = 1),
regexp = NA
)

})

test_that("checkpoints works", {

expect_error(
Expand Down

0 comments on commit 7bd9c4f

Please sign in to comment.