From 94f8e48fbbd86605dfcc70b39404cdeb27a5b23e Mon Sep 17 00:00:00 2001 From: Maximilian Muecke Date: Wed, 1 Jan 2025 19:06:03 +0100 Subject: [PATCH] refactor: adjust date col reflection --- R/ResamplingForecastCV.R | 13 ++++----- R/ResamplingForecastHoldout.R | 13 ++++----- R/TaskFcstAirpassengers.R | 2 +- R/zzz.R | 6 ++--- README.md | 50 +++++++++++++++++------------------ man/TaskFcst.Rd | 20 +------------- 6 files changed, 44 insertions(+), 60 deletions(-) diff --git a/R/ResamplingForecastCV.R b/R/ResamplingForecastCV.R index cd57bf2..6f30447 100644 --- a/R/ResamplingForecastCV.R +++ b/R/ResamplingForecastCV.R @@ -82,12 +82,13 @@ ResamplingForecastCV = R6Class("ResamplingForecastCV", private = list( .sample = function(ids, task, ...) { - if (length(task$col_roles$order) == 0L) { - stopf( - "Resampling '%s' requires an ordered task, but Task '%s' has no order.", - self$id, task$id - ) - } + # if (length(task$col_roles$order) == 0L) { + # stopf( + # "Resampling '%s' requires an ordered task, but Task '%s' has no order.", + # self$id, task$id + # ) + # } + pars = self$param_set$get_values() ids = sort(ids) train_end = ids[ids <= (max(ids) - pars$horizon) & ids >= pars$window_size] diff --git a/R/ResamplingForecastHoldout.R b/R/ResamplingForecastHoldout.R index 8168e98..e53dee7 100644 --- a/R/ResamplingForecastHoldout.R +++ b/R/ResamplingForecastHoldout.R @@ -68,12 +68,13 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout", private = list( .sample = function(ids, task, ...) { - if (length(task$col_roles$order) == 0L) { - stopf( - "Resampling '%s' requires an ordered task, but Task '%s' has no order.", - self$id, task$id - ) - } + # if (length(task$col_roles$order) == 0L) { + # stopf( + # "Resampling '%s' requires an ordered task, but Task '%s' has no order.", + # self$id, task$id + # ) + # } + pars = self$param_set$get_values() ratio = pars$ratio n = pars$n diff --git a/R/TaskFcstAirpassengers.R b/R/TaskFcstAirpassengers.R index b6c65cd..5da6a48 100644 --- a/R/TaskFcstAirpassengers.R +++ b/R/TaskFcstAirpassengers.R @@ -18,7 +18,7 @@ load_task_airpassengers = function(id = "airpassengers") { if (!requireNamespace("tsbox", quietly = TRUE)) { stopf("Package 'tsbox' is required to load the 'AirPassengers' dataset.") } - dt = tsbox::ts_dt(load_dataset("AirPassengers", package = "datasets")) + dt = tsbox::ts_dt(load_dataset("AirPassengers", "datasets")) setnames(dt, c("date", "passengers")) b = as_data_backend(dt) diff --git a/R/zzz.R b/R/zzz.R index f0cc383..cfe2a98 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -10,7 +10,8 @@ mlr3forecast_resamplings = new.env() mlr3forecast_tasks = new.env() mlr3forecast_learners = new.env() mlr3forecast_measures = new.env() -mlr3forecast_feature_types = c(date = "Date") +# TODO: check if this can be moved to mlr3, copy components from @mllg PR +mlr3forecast_feature_types = c(dte = "Date") named_union = function(x, y) set_names(union(x, y), union(names(x), names(y))) @@ -35,7 +36,6 @@ register_mlr3 = function() { "fcst", "mlr3forecast", "TaskFcst", "LearnerFcst", "PredictionFcst", "PredictionDataFcst", "MeasureFcst" # nolint ), fill = TRUE), "type") mlr_reflections$learner_predict_types$fcst = mlr_reflections$learner_predict_types$regr - mlr_reflections$task_col_roles$fcst = union(mlr_reflections$task_col_roles$regr, "index") mlr_reflections$task_feature_types = named_union( mlr_reflections$task_feature_types, mlr3forecast_feature_types ) @@ -79,7 +79,7 @@ register_mlr3 = function() { mlr_reflections$task_types = mlr_reflections$task_types[!"fcst"] mlr_reflections$task_feature_types = mlr_reflections$task_feature_types[mlr_reflections$task_feature_types %nin% mlr3forecast_feature_types] # nolint - reflections = c("learner_predict_types", "task_col_roles", "task_properties") + reflections = c("learner_predict_types", "task_properties") walk(reflections, function(x) mlr_reflections[[x]] = remove_named(mlr_reflections[[x]], "fcst")) } diff --git a/README.md b/README.md index 0aec75b..0df8f6e 100644 --- a/README.md +++ b/README.md @@ -45,32 +45,32 @@ prediction = ff$predict_newdata(newdata, task) prediction #> for 3 observations: #> row_ids truth response -#> 1 NA 452.7575 -#> 2 NA 474.8485 -#> 3 NA 481.4720 +#> 1 NA 446.9409 +#> 2 NA 477.9439 +#> 3 NA 480.5694 prediction = ff$predict(task, 142:144) prediction #> for 3 observations: #> row_ids truth response -#> 1 461 458.1226 -#> 2 390 412.7669 -#> 3 432 396.2460 +#> 1 461 459.4145 +#> 2 390 411.2457 +#> 3 432 400.4514 prediction$score(measure) #> regr.rmse -#> 24.52863 +#> 21.97883 ff = Forecaster$new(lrn("regr.ranger"), 1:3) resampling = rsmp("forecast_holdout", ratio = 0.8) rr = resample(task, ff, resampling) rr$aggregate(measure) #> regr.rmse -#> 112.7031 +#> 105.0997 resampling = rsmp("forecast_cv") rr = resample(task, ff, resampling) rr$aggregate(measure) #> regr.rmse -#> 48.80621 +#> 54.93903 ``` ### Multivariate @@ -90,34 +90,34 @@ ff = Forecaster$new(lrn("regr.ranger"), 1:3)$train(new_task) prediction = ff$predict(new_task, 142:144) prediction$score(measure) #> regr.rmse -#> 19.26131 +#> 17.55705 row_ids = new_task$nrow - 0:2 ff$predict_newdata(new_task$data(rows = row_ids), new_task) #> for 3 observations: #> row_ids truth response -#> 1 432 407.1069 -#> 2 390 391.4622 -#> 3 461 393.6115 +#> 1 432 405.2216 +#> 2 390 388.3066 +#> 3 461 385.6412 newdata = new_task$data(rows = row_ids, cols = new_task$feature_names) ff$predict_newdata(newdata, new_task) #> for 3 observations: #> row_ids truth response -#> 1 NA 407.1069 -#> 2 NA 391.4622 -#> 3 NA 393.6115 +#> 1 NA 405.2216 +#> 2 NA 388.3066 +#> 3 NA 385.6412 resampling = rsmp("forecast_holdout", ratio = 0.8) rr = resample(new_task, ff, resampling) rr$aggregate(measure) #> regr.rmse -#> 81.82989 +#> 82.35283 resampling = rsmp("forecast_cv") rr = resample(new_task, ff, resampling) rr$aggregate(measure) #> regr.rmse -#> 45.80208 +#> 45.54337 ``` ### mlr3pipelines integration @@ -128,7 +128,7 @@ glrn = as_learner(graph %>>% ff)$train(task) prediction = glrn$predict(task, 142:144) prediction$score(measure) #> regr.rmse -#> 34.39579 +#> 34.29322 ``` ### Example: Forecasting electricity demand @@ -166,11 +166,11 @@ prediction = glrn$predict_newdata(newdata, task) prediction #> for 14 observations: #> row_ids truth response -#> 1 NA 187.1619 -#> 2 NA 191.8612 -#> 3 NA 184.2280 +#> 1 NA 187.9399 +#> 2 NA 190.5695 +#> 3 NA 184.2617 #> --- --- --- -#> 12 NA 214.1141 -#> 13 NA 216.5287 -#> 14 NA 217.9717 +#> 12 NA 214.6350 +#> 13 NA 218.8392 +#> 14 NA 221.4170 ``` diff --git a/man/TaskFcst.Rd b/man/TaskFcst.Rd index 5051324..61dcccd 100644 --- a/man/TaskFcst.Rd +++ b/man/TaskFcst.Rd @@ -46,14 +46,6 @@ Other Task: \section{Super classes}{ \code{\link[mlr3:Task]{mlr3::Task}} -> \code{\link[mlr3:TaskSupervised]{mlr3::TaskSupervised}} -> \code{\link[mlr3:TaskRegr]{mlr3::TaskRegr}} -> \code{TaskFcst} } -\section{Public fields}{ -\if{html}{\out{
}} -\describe{ -\item{\code{index}}{(\code{character(1)})\cr -Column name of the index variable.} -} -\if{html}{\out{
}} -} \section{Methods}{ \subsection{Public methods}{ \itemize{ @@ -94,14 +86,7 @@ Column name of the index variable.} Creates a new instance of this \link[R6:R6Class]{R6} class. The function \code{\link[=as_task_fcst]{as_task_fcst()}} provides an alternative way to construct forecast tasks. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{TaskFcst$new( - id, - backend, - target, - index, - label = NA_character_, - extra_args = list() -)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{TaskFcst$new(id, backend, target, label = NA_character_, extra_args = list())}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -117,9 +102,6 @@ E.g., a \code{data.frame()} will be converted to a \link{DataBackendDataTable}.} \item{\code{target}}{(\code{character(1)})\cr Name of the target column.} -\item{\code{index}}{(\code{character(1)})\cr -Column name of the index variable.} - \item{\code{label}}{(\code{character(1)})\cr Label for the new instance.}