Skip to content

Commit

Permalink
feat: more params for {forecast} learners
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Jan 8, 2025
1 parent 6c1981d commit 19c50da
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 20 deletions.
23 changes: 20 additions & 3 deletions R/LearnerRegrArfima.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,25 @@ LearnerFcstArfima = R6Class("LearnerFcstArfima",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps()
param_set = ps(
drange = p_uty(default = c(0, 0.5), tags = "train"),
estim = p_fct(default = "mle", levels = c("mle", "ls"), tags = "train"),
lambda = p_uty(default = NULL, tags = "train"),
order = p_uty(
default = c(0L, 0L, 0L),
tags = "train",
custom_check = crate(function(x) check_integerish(x, lower = 0L, len = 3L))
),
seasonal = p_uty(
default = c(0L, 0L, 0L),
tags = "train",
custom_check = crate(function(x) check_integerish(x, lower = 0L, len = 3L))
),
include.mean = p_lgl(default = TRUE, tags = "train"),
include.drift = p_lgl(default = FALSE, tags = "train"),
biasadj = p_lgl(default = FALSE, tags = "train"),
method = p_fct(c("CSS-ML", "ML", "CSS"), default = "CSS-ML", tags = "train")
)

super$initialize(
id = "fcst.arfima",
Expand All @@ -46,9 +64,8 @@ LearnerFcstArfima = R6Class("LearnerFcstArfima",
pv = insert_named(pv, list(weights = task$weights$weight))
}

xreg = NULL
if (is_task_featureless(task)) {
xreg = NULL
} else {
xreg = as.matrix(task$data(cols = fcst_feature_names(task)))
}
invoke(forecast::arfima,
Expand Down
3 changes: 1 addition & 2 deletions R/LearnerRegrArima.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ LearnerFcstArima = R6Class("LearnerFcstArima",
pv = insert_named(pv, list(weights = task$weights$weight))
}

xreg = NULL
if (is_task_featureless(task)) {
xreg = NULL
} else {
xreg = as.matrix(task$data(cols = fcst_feature_names(task)))
}
invoke(forecast::Arima,
Expand Down
3 changes: 1 addition & 2 deletions R/LearnerRegrAutoArima.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ LearnerFcstAutoArima = R6Class("LearnerFcstAutoArima",
pv = insert_named(pv, list(weights = task$weights$weight))
}

xreg = NULL
if (is_task_featureless(task)) {
xreg = NULL
} else {
xreg = as.matrix(task$data(cols = fcst_feature_names(task)))
}
invoke(forecast::auto.arima,
Expand Down
21 changes: 17 additions & 4 deletions R/LearnerRegrEts.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ LearnerFcstEts = R6Class("LearnerFcstEts",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(
model = p_uty(default = "ZZZ", tags = "train", custom_check = crate(function(x) check_string(x, n.chars = 3L))),
model = p_uty(
default = "ZZZ",
tags = "train",
custom_check = crate(function(x) check_string(x, n.chars = 3L))
),
damped = p_lgl(default = NULL, special_vals = list(NULL), tags = "train"),
alpha = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"),
beta = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"),
Expand All @@ -32,13 +36,22 @@ LearnerFcstEts = R6Class("LearnerFcstEts",
biasadj = p_lgl(default = FALSE, tags = "train"),
lower = p_uty(default = c(rep(1e-04, 3), 0.8), tags = "train"),
upper = p_uty(default = c(rep(0.9999, 3), 0.98), tags = "train"),
opt.crit = p_fct(default = "lik", levels = c("lik", "amse", "mse", "sigma", "mae"), tags = "train"),
opt.crit = p_fct(
default = "lik",
levels = c("lik", "amse", "mse", "sigma", "mae"),
tags = "train"
),
nmse = p_int(0L, 30L, default = 3, tags = "train"),
bounds = p_fct(default = "both", levels = c("both", "usual", "admissible"), tags = "train"),
bounds = p_fct(
default = "both", levels = c("both", "usual", "admissible"), tags = "train"
),
ic = p_fct(default = "aicc", levels = c("aicc", "aic", "bic"), tags = "train"),
restrict = p_lgl(default = TRUE, tags = "train"),
allow.multiplicative.trend = p_lgl(default = FALSE, tags = "train"),
na.action = p_fct(default = "na.contiguous", levels = c("na.contiguous", "na.interp", "na.fail"))
na.action = p_fct(
default = "na.contiguous",
levels = c("na.contiguous", "na.interp", "na.fail")
)
)

super$initialize(
Expand Down
1 change: 0 additions & 1 deletion R/LearnerRegrForecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ LearnerRegrForecast = R6Class("LearnerRegrForecast",
return(list(response = as.numeric(pred$mean)))
}

# might not be robust enough with position instead of name
pred$lower = pred$lower[, rev(seq_len(ncol(pred$lower)))]
quantiles = cbind(
pred$lower,
Expand Down
16 changes: 13 additions & 3 deletions man/mlr_learners_fcst.arfima.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_learners_fcst.arima.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_learners_fcst.auto_arima.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_learners_fcst.bats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_learners_fcst.ets.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_learners_fcst.tbats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 19c50da

Please sign in to comment.