Skip to content

Commit

Permalink
refactor: try to only implement the private methods
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Jan 6, 2025
1 parent 35a705b commit dbbd151
Show file tree
Hide file tree
Showing 16 changed files with 506 additions and 244 deletions.
6 changes: 5 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Authors@R:
Description: Extends the 'mlr3' ecosystem to time series forecasting.
License: LGPL-3
Depends:
mlr3 (>= 0.22.1),
mlr3 (>= 0.22.1.9000),
R (>= 3.3.0)
Imports:
backports,
Expand All @@ -21,6 +21,8 @@ Suggests:
testthat (>= 3.2.0),
tsbox,
withr (>= 3.0.0)
Remotes:
mlr-org/mlr3
Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Expand All @@ -30,6 +32,8 @@ Collate:
'DataBackendTimeSeries.R'
'ForecastLearner.R'
'zzz.R'
'LearnerARIMA.R'
'LearnerFcst.R'
'MeasureDirectional.R'
'ResamplingForecastCV.R'
'ResamplingForecastHoldout.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ S3method(as_task_fcst,TaskFcst)
S3method(as_task_fcst,data.frame)
export(DataBackendTimeSeries)
export(ForecastLearner)
export(LearnerFcstARIMA)
export(ResamplingForecastCV)
export(ResamplingForecastHoldout)
export(TaskFcst)
Expand Down
133 changes: 34 additions & 99 deletions R/ForecastLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,118 +30,49 @@ ForecastLearner = R6::R6Class("ForecastLearner",

super$initialize(
id = learner$id,
task_type = learner$task_type,
task_type = "regr",
param_set = learner$param_set,
predict_types = learner$predict_types,
feature_types = learner$feature_types,
properties = learner$properties,
packages = c("mlr3forecast", learner$packages),
man = learner$man
)
},

#' @description
#' Uses the information stored during `$train()` in `$state` to create a new [Prediction]
#' for a set of observations of the provided `task`.
#'
#' @param task ([Task]).
#'
#' @param row_ids (`integer()`)\cr
#' Vector of test indices as subset of `task$row_ids`. For a simple split
#' into training and test set, see [partition()].
#'
#' @return [Prediction].
predict = function(task, row_ids = NULL) {
task = assert_task(as_task(task))
row_ids = assert_integerish(row_ids,
lower = 1L, any.missing = FALSE, coerce = TRUE, null.ok = TRUE
)

# 1. direct learner$predict(): entire task + row_ids or `NULL` for entire task prediction
# 2. resampling: test task and `NULL` row_ids, task$row_ids are from entire task
# 3. glrn$predict(): test task and `NULL` row_ids, task$row_ids are from train task
# 4. glrn$predict_newdata(): test task and `NULL` row_ids, task$row_ids are 1:n, i.e. not from entire task
# NB: this will need some special handling, how do I know if its called by glrn?
# check for glrn$predict_newdata() case
has_row_ids = !is.null(row_ids)
row_ids = row_ids %??% task$row_ids
row_ids = sort(row_ids)
if (!has_row_ids &&
nrow(fintersect(task$data(), private$.task$data())) == 0 &&
all(task$row_ids %in% private$.task$row_ids)) {
row_ids = seq_along(row_ids) + tail(private$.task$row_ids, 1L)
}
if (is.null(task$key) && !all(diff(row_ids) == 1L)) {
stopf("Row ids must be consecutive")
}
private$.predict_recursive(task, row_ids)
},

#' @description
#' Uses the model fitted during `$train()` to create a new [Prediction] based on the forecast horizon `n`.
#'
#' @param task ([Task]).
#' @param n (`integer(1)`).
#' @param newdata (any object supported by [as_data_backend()])\cr
#' New data to predict on.
#' All data formats convertible by [as_data_backend()] are supported, e.g.
#' `data.frame()` or [DataBackend].
#' If a [DataBackend] is provided as `newdata`, the row ids are preserved,
#' otherwise they are set to to the sequence `1:nrow(newdata)`.
#'
#' @return [Prediction].
predict_newdata = function(newdata, task) {
task = assert_task(as_task(task))
assert_learnable(task, self)
private$.predict_newdata_recursive(task, newdata)
}
),

private = list(
.task = NULL,
.max_index = NULL,

.train = function(task) {
private$.max_index = max(task$data(cols = task$col_roles$order)[[1L]])
private$.task = task$clone()
target = task$target_names
dt = private$.lag_transform(task$data(), target)
new_task = as_task_regr(dt, target = target)

learner = self$learner$clone(deep = TRUE)$train(new_task)
structure(list(learner = learner), class = c("forecast_learner_model", "list"))
},

.predict = function(task) {
self$predict(task)
private$.predict_recursive(task)
},

.lag_transform = function(dt, target) {
lag = self$lag
nms = sprintf("%s_lag_%s", target, lag)
dt = copy(dt)
key = private$.task$key
if (is.null(key)) {
dt[, (nms) := shift(.SD, n = lag, type = "lag"), .SDcols = target]
} else {
setorderv(dt, c(key))
dt[, (nms) := shift(.SD, n = lag, type = "lag"), by = key, .SDcols = target]
}
dt
},

.predict_recursive = function(task, row_ids) {
# join the training task with the prediction task for lag transformation
# in normal predict we get the entire task, in resampling we only get the subset
# TODO: check why `Task$data_formats` warning is thrown
if (suppressWarnings(isTRUE(all.equal(private$.task, task)))) {
dt = task$data()
.predict_recursive = function(task) {
target = private$.task$target_names
if (private$.is_newdata(task)) {
row_ids = private$.task$nrow + seq_len(task$nrow)
dt = rbind(private$.task$data(), task$data(), fill = TRUE)
} else {
dt = rbind(private$.task$data(), task$data())
row_ids = task$row_ids
dt = private$.task$data()
}
target = private$.task$target_names
# one model for all steps
preds = map(row_ids, function(i) {
new_x = private$.lag_transform(dt, target)[i]
pred = self$model$learner$predict_newdata(new_x)
# set is faster with DT
dt[i, (target) := pred$response]
pred
})
Expand All @@ -150,27 +81,31 @@ ForecastLearner = R6::R6Class("ForecastLearner",
preds
},

.predict_newdata_recursive = function(task, newdata) {
dt = task$data()
target = task$target_names
# create a new rows for the new prediction
dt = rbind(dt, newdata, fill = TRUE)
row_ids = task$nrow + seq_len(nrow(newdata))
# one model for all steps
preds = map(row_ids, function(i) {
new_x = private$.lag_transform(dt, target)[i]
pred = self$model$learner$predict_newdata(new_x)
dt[i, (target) := pred$response]
pred
})
preds = do.call(c, preds)
preds$data$row_ids = seq_len(nrow(newdata))
preds
},

.predict_direct = function(dt, n) {
# one model for each step
.NotYetImplemented()
},

.lag_transform = function(dt, target) {
lag = self$lag
nms = sprintf("%s_lag_%s", target, lag)
dt = copy(dt)
key = private$.task$key
if (is.null(key)) {
dt[, (nms) := shift(.SD, n = lag, type = "lag"), .SDcols = target]
} else {
dt[, (nms) := shift(.SD, n = lag, type = "lag"), by = key, .SDcols = target]
}
dt
},

.is_newdata = function(task) {
order_cols = task$col_roles$order
tab = task$backend$data(rows = task$row_ids, cols = order_cols)
if (nrow(tab) == 0L) {
return(TRUE)
}
!any(private$.max_index %in% tab[[1L]])
}
)
)
95 changes: 95 additions & 0 deletions R/LearnerARIMA.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#' @title ARIMA
#'
#' @name mlr_learners_fcst.arima
#'
#' @description
#' ...
#'
#' @templateVar id fcst.arima
#' @template learner
#'
#' @references
#' ...
#'
#' @export
#' @template seealso_learner
LearnerFcstARIMA = R6Class("LearnerFcstARIMA",
inherit = LearnerRegr,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {

ps = ps(
order = p_uty(default = c(0, 0, 0), tags = "train"),
seasonal = p_uty(default = c(0, 0, 0), tags = "train"),
include.mean = p_lgl(default = TRUE, tags = "train"),
include.drift = p_lgl(default = FALSE, tags = "train"),
biasadj = p_lgl(default = FALSE, tags = "train"),
method = p_fct(c("CSS-ML", "ML", "CSS"), default = "CSS-ML", tags = "train")
)

super$initialize(
id = "fcst.arima",
param_set = ps,
feature_types = c("logical", "integer", "numeric"),
packages = c("mlr3learners", "forecast"),
label = "ARIMA",
man = "mlr3learners::mlr_learners_arima.arima"
)
}
),

private = list(
.max_index = NULL,

.train = function(task) {
if (length(task$col_roles$order) == 0L) {
stopf("%s learner requires an ordered task.", self$id)
}
private$.max_index = max(task$data(cols = task$col_roles$order)[[1L]])
pv = self$param_set$get_values(tags = "train")
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}
if (length(task$feature_names) > 0) {
xreg = as.matrix(task$data(cols = task$feature_names))
invoke(forecast::Arima,
y = task$data(rows = task$row_ids, cols = task$target_names),
xreg = xreg,
.args = pv
)
} else {
invoke(forecast::Arima,
y = task$data(rows = task$row_ids, cols = task$target_names),
.args = pv)
}
},

.predict = function(task) {
pv = self$param_set$get_values(tags = "predict")
if (private$.is_newdata(task)) {
if (length(task$feature_names) > 0) {
newdata = as.matrix(task$data(cols = task$feature_names))
prediction = invoke(forecast::forecast, self$model, xreg = newdata)
} else {
prediction = invoke(forecast::forecast, self$model, h = length(task$row_ids))
browser()
}
list(response = prediction$mean)
} else {
prediction = stats::fitted(self$model[task$row_ids])
list(response = prediction)
}
},

.is_newdata = function(task) {
order_cols = task$col_roles$order
idx = task$backend$data(rows = task$row_ids, cols = order_cols)[[1L]]
!any(private$.max_index %in% idx)
}
)
)

#' @include zzz.R
register_learner("fcst.arima", LearnerFcstARIMA)
23 changes: 23 additions & 0 deletions R/LearnerFcst.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#' @title Forecast Learner
#'
LearnerFcst = R6Class("LearnerFcst",
inherit = Learner,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, param_set = ps(), predict_types = "response", feature_types = character(), properties = character(), data_formats, packages = character(), label = NA_character_, man = NA_character_) {
super$initialize(
id = id,
task_type = "fcst",
param_set = param_set,
feature_types = feature_types,
predict_types = predict_types,
properties = properties,
data_formats,
packages = packages,
label = label,
man = man
)
}
)
)
39 changes: 36 additions & 3 deletions R/ResamplingForecastCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,11 @@ ResamplingForecastCV = R6Class("ResamplingForecastCV",
private = list(
.sample = function(ids, ...) {
pars = self$param_set$get_values()
window_size = pars$window_size
horizon = pars$horizon

ids = sort(ids)
train_end = ids[ids <= (max(ids) - pars$horizon) & ids >= pars$window_size]
train_end = ids[ids <= (max(ids) - horizon) & ids >= window_size]
train_end = seq.int(
from = train_end[length(train_end)],
by = -pars$step_size,
Expand All @@ -93,14 +96,44 @@ ResamplingForecastCV = R6Class("ResamplingForecastCV",
if (!pars$fixed_window) {
train_ids = map(train_end, function(x) ids[1L]:x)
} else {
train_ids = map(train_end, function(x) (x - pars$window_size + 1L):x)
train_ids = map(train_end, function(x) (x - window_size + 1L):x)
}
test_ids = map(train_ids, function(x) (x[length(x)] + 1L):(x[length(x)] + pars$horizon))
test_ids = map(train_ids, function(x) (x[length(x)] + 1L):(x[length(x)] + horizon))
list(train = train_ids, test = test_ids)
},

.sample_new = function(ids, task, ...) {
.NotYetImplemented()

pars = self$param_set$get_values()
horizon = pars$horizon
window_size = pars$window_size
step_size = pars$step_size
folds = pars$folds
fixed_window = pars$fixed_window

order_cols = task$col_roles$order
key_cols = task$key
has_key = length(key_cols) > 0L

tab = task$backend$data(
rows = ids,
cols = c(task$backend$primary_key, order_cols, key_cols)
)

if (has_key) {
setnames(tab, c("row_id", "order", "key"))
setorderv(tab, c("key", "order"))
} else {
setnames(tab, c("row_id", "order"))
setorderv(tab, "order")
}

if (!has_key) {
} else {

}

},

.get_train = function(i) {
Expand Down
Loading

0 comments on commit dbbd151

Please sign in to comment.