Skip to content

Commit

Permalink
Merge pull request #194 from AlbertoAlmuinha/master
Browse files Browse the repository at this point in the history
Recursive Fix #187 & #174
  • Loading branch information
mdancho84 authored Aug 10, 2022
2 parents 51d25d7 + afc9a81 commit ff32b26
Show file tree
Hide file tree
Showing 21 changed files with 161 additions and 29 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ export(damping)
export(damping_smooth)
export(default_forecast_accuracy_metric_set)
export(distribution)
export(drop_modeltime_model)
export(enquo)
export(enquos)
export(error)
Expand Down Expand Up @@ -306,7 +307,6 @@ export(window_function_predict_impl)
export(window_reg)
export(xgboost_impl)
export(xgboost_predict)
import(StanHeaders)
importFrom(magrittr,"%>%")
importFrom(parsnip,fit)
importFrom(parsnip,fit_xy)
Expand Down
46 changes: 46 additions & 0 deletions R/helpers-modeltime_table.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#' @seealso
#' - [combine_modeltime_tables()]: Combine 2 or more Modeltime Tables together
#' - [add_modeltime_model()]: Adds a new row with a new model to a Modeltime Table
#' - [drop_modeltime_model()]: Drop one or more models from a Modeltime Table
#' - [update_modeltime_description()]: Updates a description for a model inside a Modeltime Table
#' - [update_modeltime_model()]: Updates a model inside a Modeltime Table
#' - [pull_modeltime_model()]: Extracts a model from a Modeltime Table
Expand Down Expand Up @@ -112,6 +113,7 @@ combine_modeltime_tables <- function(...) {
#' @seealso
#' - [combine_modeltime_tables()]: Combine 2 or more Modeltime Tables together
#' - [add_modeltime_model()]: Adds a new row with a new model to a Modeltime Table
#' - [drop_modeltime_model()]: Drop one or more models from a Modeltime Table
#' - [update_modeltime_description()]: Updates a description for a model inside a Modeltime Table
#' - [update_modeltime_model()]: Updates a model inside a Modeltime Table
#' - [pull_modeltime_model()]: Extracts a model from a Modeltime Table
Expand Down Expand Up @@ -144,6 +146,47 @@ add_modeltime_model <- function(object, model, location = "bottom") {

}

# DROP MODEL -----

#' Drop a Model from a Modeltime Table
#'
#' @param object A Modeltime Table (class `mdl_time_tbl`)
#' @param .model_id A numeric value matching the .model_id that you want to drop
#'
#' @seealso
#' - [combine_modeltime_tables()]: Combine 2 or more Modeltime Tables together
#' - [add_modeltime_model()]: Adds a new row with a new model to a Modeltime Table
#' - [drop_modeltime_model()]: Drop one or more models from a Modeltime Table
#' - [update_modeltime_description()]: Updates a description for a model inside a Modeltime Table
#' - [update_modeltime_model()]: Updates a model inside a Modeltime Table
#' - [pull_modeltime_model()]: Extracts a model from a Modeltime Table
#'
#' @examples
#' \donttest{
#' library(tidymodels)
#'
#'
#' m750_models %>%
#' drop_modeltime_model(.model_id = c(2,3))
#' }
#'
#' @export
drop_modeltime_model <- function(object, .model_id) {

if (!rlang::is_bare_numeric(.model_id)){
rlang::abort(".model_id must be numeric")
}

if (!is_modeltime_table(object)){
rlang::abort("object must be a 'modeltime_table'")
}

ret <- object %>%
dplyr::filter(!(.model_id %in% !!.model_id))

return(ret)

}

# UPDATE MODEL ----

Expand All @@ -157,6 +200,7 @@ add_modeltime_model <- function(object, model, location = "bottom") {
#' @seealso
#' - [combine_modeltime_tables()]: Combine 2 or more Modeltime Tables together
#' - [add_modeltime_model()]: Adds a new row with a new model to a Modeltime Table
#' - [drop_modeltime_model()]: Drop one or more models from a Modeltime Table
#' - [update_modeltime_description()]: Updates a description for a model inside a Modeltime Table
#' - [update_modeltime_model()]: Updates a model inside a Modeltime Table
#' - [pull_modeltime_model()]: Extracts a model from a Modeltime Table
Expand Down Expand Up @@ -213,6 +257,7 @@ update_modeltime_model.mdl_time_tbl <- function(object, .model_id, .new_model) {
#' @seealso
#' - [combine_modeltime_tables()]: Combine 2 or more Modeltime Tables together
#' - [add_modeltime_model()]: Adds a new row with a new model to a Modeltime Table
#' - [drop_modeltime_model()]: Drop one or more models from a Modeltime Table
#' - [update_modeltime_description()]: Updates a description for a model inside a Modeltime Table
#' - [update_modeltime_model()]: Updates a model inside a Modeltime Table
#' - [pull_modeltime_model()]: Extracts a model from a Modeltime Table
Expand Down Expand Up @@ -256,6 +301,7 @@ update_modeltime_description <- update_model_description
#' @seealso
#' - [combine_modeltime_tables()]: Combine 2 or more Modeltime Tables together
#' - [add_modeltime_model()]: Adds a new row with a new model to a Modeltime Table
#' - [drop_modeltime_model()]: Drop one or more models from a Modeltime Table
#' - [update_modeltime_description()]: Updates a description for a model inside a Modeltime Table
#' - [update_modeltime_model()]: Updates a model inside a Modeltime Table
#' - [pull_modeltime_model()]: Extracts a model from a Modeltime Table
Expand Down
16 changes: 12 additions & 4 deletions R/modeltime-forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -627,9 +627,9 @@ mdl_time_forecast.model_fit <- function(object, calibration_data, new_data = NUL

modeltime_forecast <- tryCatch({

if (inherits(object, "_elnet") && inherits(object, "recursive")) {
if (detect_net(object) && inherits(object, "recursive")) {
predictions_tbl <- object %>% predict.recursive(new_data = new_data)
} else if (inherits(object, "_elnet") && inherits(object, "recursive_panel")) {
} else if (detect_net(object) && inherits(object, "recursive_panel")) {
predictions_tbl <- object %>% predict.recursive_panel(new_data = new_data)
} else {
predictions_tbl <- object %>% stats::predict(new_data = new_data)
Expand Down Expand Up @@ -914,9 +914,9 @@ mdl_time_forecast.workflow <- function(object, calibration_data, new_data = NULL
}

# PREDICT
if (inherits(fit, "_elnet") && inherits(fit, "recursive")) {
if (detect_net(fit) && inherits(fit, "recursive")) {
data_formatted <- fit %>% predict.recursive(new_data = df)
} else if (inherits(fit, "_elnet") && inherits(fit, "recursive_panel")) {
} else if (detect_net(fit) && inherits(fit, "recursive_panel")) {
data_formatted <- fit %>% predict.recursive_panel(new_data = df)
} else {
data_formatted <- fit %>% stats::predict(new_data = df)
Expand Down Expand Up @@ -1040,3 +1040,11 @@ mdl_time_forecast.workflow <- function(object, calibration_data, new_data = NULL

}


detect_net <- function(object){
res <- class(object) %>%
stringr::str_detect(., "net") %>%
sum()

if (res >= 1) {TRUE} else {FALSE}
}
30 changes: 17 additions & 13 deletions R/modeltime-recursive.R
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ print.recursive <- function(x, ...) {
} else if (inherits(x, "workflow")) {
cat("Recursive [workflow]\n\n")
} else {
cat("Recursive [modeltime ensemble]\n\n")
cat("Recursive [modeltime ensemble]\n\n")
}

y <- x
Expand Down Expand Up @@ -440,6 +440,14 @@ predict_recursive_panel_model_fit <- function(object, new_data, type = NULL, opt

.id <- dplyr::ensym(id)

unique_id_new_data <- new_data %>% dplyr::select(!! .id) %>% unique() %>% dplyr::pull()

unique_id_train_tail <- train_tail %>% dplyr::select(!! .id) %>% unique() %>% dplyr::pull()

if (length(dplyr::setdiff(unique_id_train_tail, unique_id_new_data)) >= 1){
train_tail <- train_tail %>% dplyr::filter(!! .id %in% unique_id_new_data)
}

# # Comment this out ----
# print("here")
# obj <<- object
Expand Down Expand Up @@ -483,10 +491,10 @@ predict_recursive_panel_model_fit <- function(object, new_data, type = NULL, opt
}

.preds[.preds$rowid.. == 1, 2] <- new_data[new_data$rowid.. == 1, y_var] <- pred_fun(object,
new_data = .first_slice,
type = type,
opts = opts,
...)
new_data = .first_slice,
type = type,
opts = opts,
...)

.groups <- new_data %>%
dplyr::group_by(!! .id) %>%
Expand Down Expand Up @@ -518,10 +526,10 @@ predict_recursive_panel_model_fit <- function(object, new_data, type = NULL, opt


.preds[.preds$rowid.. == i, 2] <- new_data[new_data$rowid.. == i, y_var] <- pred_fun(object,
new_data = .nth_slice,
type = type,
opts = opts,
...)
new_data = .nth_slice,
type = type,
opts = opts,
...)
}

return(.preds[,2])
Expand Down Expand Up @@ -689,7 +697,3 @@ is_prepped_recipe <- function(recipe) {
}
return(is_prepped)
}




5 changes: 4 additions & 1 deletion R/parsnip-adam.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ update.adam_reg <- function(object, parameters = NULL,
seasonal_period = NULL, select_order = NULL,
fresh = FALSE, ...) {

parsnip::update_dot_check(...)
eng_args <- parsnip::update_engine_parameters(object$eng_args, fresh, ...)

if (!is.null(parameters)) {
parameters <- parsnip::check_final_param(parameters)
Expand Down Expand Up @@ -308,12 +308,15 @@ update.adam_reg <- function(object, parameters = NULL,

if (fresh) {
object$args <- args
object$eng_args <- eng_args
} else {
null_args <- purrr::map_lgl(args, parsnip::null_value)
if (any(null_args))
args <- args[!null_args]
if (length(args) > 0)
object$args[names(args)] <- args
if (length(eng_args) > 0)
object$eng_args[names(eng_args)] <- eng_args
}

parsnip::new_model_spec(
Expand Down
5 changes: 4 additions & 1 deletion R/parsnip-arima_boost.R
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ update.arima_boost <- function(object,
sample_size = NULL, stop_iter = NULL,
fresh = FALSE, ...) {

parsnip::update_dot_check(...)
eng_args <- parsnip::update_engine_parameters(object$eng_args, fresh, ...)

if (!is.null(parameters)) {
parameters <- parsnip::check_final_param(parameters)
Expand Down Expand Up @@ -342,12 +342,15 @@ update.arima_boost <- function(object,

if (fresh) {
object$args <- args
object$eng_args <- eng_args
} else {
null_args <- purrr::map_lgl(args, parsnip::null_value)
if (any(null_args))
args <- args[!null_args]
if (length(args) > 0)
object$args[names(args)] <- args
if (length(eng_args) > 0)
object$eng_args[names(eng_args)] <- eng_args
}

parsnip::new_model_spec(
Expand Down
5 changes: 4 additions & 1 deletion R/parsnip-arima_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ update.arima_reg <- function(object, parameters = NULL,
seasonal_ar = NULL, seasonal_differences = NULL, seasonal_ma = NULL,
fresh = FALSE, ...) {

parsnip::update_dot_check(...)
eng_args <- parsnip::update_engine_parameters(object$eng_args, fresh, ...)

if (!is.null(parameters)) {
parameters <- parsnip::check_final_param(parameters)
Expand All @@ -275,12 +275,15 @@ update.arima_reg <- function(object, parameters = NULL,

if (fresh) {
object$args <- args
object$eng_args <- eng_args
} else {
null_args <- purrr::map_lgl(args, parsnip::null_value)
if (any(null_args))
args <- args[!null_args]
if (length(args) > 0)
object$args[names(args)] <- args
if (length(eng_args) > 0)
object$eng_args[names(eng_args)] <- eng_args
}

parsnip::new_model_spec(
Expand Down
5 changes: 4 additions & 1 deletion R/parsnip-exp_smoothing.R
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ update.exp_smoothing <- function(object, parameters = NULL,
smooth_level = NULL, smooth_trend = NULL, smooth_seasonal = NULL,
fresh = FALSE, ...) {

parsnip::update_dot_check(...)
eng_args <- parsnip::update_engine_parameters(object$eng_args, fresh, ...)

if (!is.null(parameters)) {
parameters <- parsnip::check_final_param(parameters)
Expand All @@ -351,12 +351,15 @@ update.exp_smoothing <- function(object, parameters = NULL,

if (fresh) {
object$args <- args
object$eng_args <- eng_args
} else {
null_args <- purrr::map_lgl(args, parsnip::null_value)
if (any(null_args))
args <- args[!null_args]
if (length(args) > 0)
object$args[names(args)] <- args
if (length(eng_args) > 0)
object$eng_args[names(eng_args)] <- eng_args
}

parsnip::new_model_spec(
Expand Down
5 changes: 4 additions & 1 deletion R/parsnip-naive_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ update.naive_reg <- function(object, parameters = NULL,
id = NULL, seasonal_period = NULL,
fresh = FALSE, ...) {

parsnip::update_dot_check(...)
eng_args <- parsnip::update_engine_parameters(object$eng_args, fresh, ...)

if (!is.null(parameters)) {
parameters <- parsnip::check_final_param(parameters)
Expand All @@ -191,12 +191,15 @@ update.naive_reg <- function(object, parameters = NULL,

if (fresh) {
object$args <- args
object$eng_args <- eng_args
} else {
null_args <- purrr::map_lgl(args, parsnip::null_value)
if (any(null_args))
args <- args[!null_args]
if (length(args) > 0)
object$args[names(args)] <- args
if (length(eng_args) > 0)
object$eng_args[names(eng_args)] <- eng_args
}

parsnip::new_model_spec(
Expand Down
5 changes: 4 additions & 1 deletion R/parsnip-nnetar_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ update.nnetar_reg <- function(object, parameters = NULL,
penalty = NULL, epochs = NULL,
fresh = FALSE, ...) {

parsnip::update_dot_check(...)
eng_args <- parsnip::update_engine_parameters(object$eng_args, fresh, ...)

if (!is.null(parameters)) {
parameters <- parsnip::check_final_param(parameters)
Expand All @@ -225,12 +225,15 @@ update.nnetar_reg <- function(object, parameters = NULL,

if (fresh) {
object$args <- args
object$eng_args <- eng_args
} else {
null_args <- purrr::map_lgl(args, parsnip::null_value)
if (any(null_args))
args <- args[!null_args]
if (length(args) > 0)
object$args[names(args)] <- args
if (length(eng_args) > 0)
object$eng_args[names(eng_args)] <- eng_args
}

parsnip::new_model_spec(
Expand Down
5 changes: 4 additions & 1 deletion R/parsnip-prophet_boost.R
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ update.prophet_boost <- function(object, parameters = NULL,
sample_size = NULL, stop_iter = NULL,
fresh = FALSE, ...) {

parsnip::update_dot_check(...)
args <- parsnip::update_main_parameters(args, parameters)

if (!is.null(parameters)) {
parameters <- parsnip::check_final_param(parameters)
Expand Down Expand Up @@ -342,12 +342,15 @@ update.prophet_boost <- function(object, parameters = NULL,

if (fresh) {
object$args <- args
object$eng_args <- eng_args
} else {
null_args <- purrr::map_lgl(args, parsnip::null_value)
if (any(null_args))
args <- args[!null_args]
if (length(args) > 0)
object$args[names(args)] <- args
if (length(eng_args) > 0)
object$eng_args[names(eng_args)] <- eng_args
}

parsnip::new_model_spec(
Expand Down
Loading

0 comments on commit ff32b26

Please sign in to comment.