Skip to content

Commit

Permalink
feat(resample): feature flag for sorting based order col
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Jan 1, 2025
1 parent 94f8e48 commit d1f3df2
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 41 deletions.
7 changes: 0 additions & 7 deletions R/ResamplingForecastCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,6 @@ ResamplingForecastCV = R6Class("ResamplingForecastCV",

private = list(
.sample = function(ids, task, ...) {
# if (length(task$col_roles$order) == 0L) {
# stopf(
# "Resampling '%s' requires an ordered task, but Task '%s' has no order.",
# self$id, task$id
# )
# }

pars = self$param_set$get_values()
ids = sort(ids)
train_end = ids[ids <= (max(ids) - pars$horizon) & ids >= pars$window_size]
Expand Down
26 changes: 17 additions & 9 deletions R/ResamplingForecastHoldout.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,6 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",

private = list(
.sample = function(ids, task, ...) {
# if (length(task$col_roles$order) == 0L) {
# stopf(
# "Resampling '%s' requires an ordered task, but Task '%s' has no order.",
# self$id, task$id
# )
# }

pars = self$param_set$get_values()
ratio = pars$ratio
n = pars$n
Expand All @@ -91,8 +84,23 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
} else {
nr = max(n_obs + n, 0L)
}
ii = ids[1:nr]
list(train = ii, test = ids[(nr + 1L):n_obs])

if (TRUE) {
ids = sort(ids)
ii = ids[1:nr]
list(train = ii, test = ids[(nr + 1L):n_obs])
} else {
# check when this is even needed
order = row_id = NULL
order_cols = private$.col_roles$order
tab = task$backend$data(rows = ids, cols = c(task$backend$primary_key, order_cols))
setnames(tab, c("row_id", "order"))
setorder(tab, order)
list(
train = tab[1:nr, row_id],
test = tab[(nr + 1L):n_obs, row_id]
)
}
},

.get_train = function(i) {
Expand Down
50 changes: 25 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,32 @@ prediction = ff$predict_newdata(newdata, task)
prediction
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 NA 446.9409
#> 2 NA 477.9439
#> 3 NA 480.5694
#> 1 NA 448.8710
#> 2 NA 475.2456
#> 3 NA 480.5179
prediction = ff$predict(task, 142:144)
prediction
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 461 459.4145
#> 2 390 411.2457
#> 3 432 400.4514
#> 1 461 456.4968
#> 2 390 411.1712
#> 3 432 393.9585
prediction$score(measure)
#> regr.rmse
#> 21.97883
#> 25.26957

ff = Forecaster$new(lrn("regr.ranger"), 1:3)
resampling = rsmp("forecast_holdout", ratio = 0.8)
rr = resample(task, ff, resampling)
rr$aggregate(measure)
#> regr.rmse
#> 105.0997
#> 105.8215

resampling = rsmp("forecast_cv")
rr = resample(task, ff, resampling)
rr$aggregate(measure)
#> regr.rmse
#> 54.93903
#> 54.28352
```

### Multivariate
Expand All @@ -90,34 +90,34 @@ ff = Forecaster$new(lrn("regr.ranger"), 1:3)$train(new_task)
prediction = ff$predict(new_task, 142:144)
prediction$score(measure)
#> regr.rmse
#> 17.55705
#> 17.0878

row_ids = new_task$nrow - 0:2
ff$predict_newdata(new_task$data(rows = row_ids), new_task)
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 432 405.2216
#> 2 390 388.3066
#> 3 461 385.6412
#> 1 432 405.5814
#> 2 390 388.3657
#> 3 461 390.9778
newdata = new_task$data(rows = row_ids, cols = new_task$feature_names)
ff$predict_newdata(newdata, new_task)
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 NA 405.2216
#> 2 NA 388.3066
#> 3 NA 385.6412
#> 1 NA 405.5814
#> 2 NA 388.3657
#> 3 NA 390.9778

resampling = rsmp("forecast_holdout", ratio = 0.8)
rr = resample(new_task, ff, resampling)
rr$aggregate(measure)
#> regr.rmse
#> 82.35283
#> 81.91252

resampling = rsmp("forecast_cv")
rr = resample(new_task, ff, resampling)
rr$aggregate(measure)
#> regr.rmse
#> 45.54337
#> 41.87113
```

### mlr3pipelines integration
Expand All @@ -128,7 +128,7 @@ glrn = as_learner(graph %>>% ff)$train(task)
prediction = glrn$predict(task, 142:144)
prediction$score(measure)
#> regr.rmse
#> 34.29322
#> 33.74039
```

### Example: Forecasting electricity demand
Expand Down Expand Up @@ -166,11 +166,11 @@ prediction = glrn$predict_newdata(newdata, task)
prediction
#> <PredictionRegr> for 14 observations:
#> row_ids truth response
#> 1 NA 187.9399
#> 2 NA 190.5695
#> 3 NA 184.2617
#> 1 NA 187.6208
#> 2 NA 191.8121
#> 3 NA 183.6753
#> --- --- ---
#> 12 NA 214.6350
#> 13 NA 218.8392
#> 14 NA 221.4170
#> 12 NA 213.8759
#> 13 NA 218.4198
#> 14 NA 218.8139
```

0 comments on commit d1f3df2

Please sign in to comment.