Skip to content

Commit

Permalink
Merge pull request #261 from tidymodels/sparsevctrs-predict
Browse files Browse the repository at this point in the history
Make `predict()` work with sparse data
  • Loading branch information
EmilHvitfeldt authored Sep 26, 2024
2 parents 602f37e + 3c11174 commit 35131ce
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 0 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Suggests:
tailor (>= 0.0.0.9001),
covr,
dials (>= 1.0.0),
glmnet,
knitr,
magrittr,
Matrix,
Expand All @@ -54,6 +55,7 @@ Config/Needs/website:
yardstick
Remotes:
tidymodels/rsample,
tidymodels/recipes,
tidymodels/parsnip,
tidymodels/tailor,
r-lib/sparsevctrs
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

* New `extract_fit_time()` method has been added that return the time it took to train the workflow (#191).

* `fit()` can now take dgCMatrix and sparse tibbles as data values when `add_recipe()` or `add_variables()` is used (#245, #258).

* `predict()` can now take dgCMatrix and sparse tibble input for `new_data` argument (#261).

# workflows 1.1.4

* While `augment.workflow()` previously never returned a `.resid` column, the
Expand Down
4 changes: 4 additions & 0 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ predict.workflow <- function(object, new_data, type = NULL, opts = list(), ...)
))
}

if (is_sparse_matrix(new_data)) {
new_data <- sparsevctrs::coerce_to_sparse_tibble(new_data)
}

fit <- extract_fit_parsnip(workflow)
new_data <- forge_predictors(new_data, workflow)

Expand Down
49 changes: 49 additions & 0 deletions tests/testthat/test-sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,52 @@ test_that("sparse matrices can be passed to `fit() - xy", {
# We expect 1 materialization - the outcome
expect_snapshot(wf_fit <- fit(wf_spec, hotel_data))
})

test_that("sparse tibble can be passed to `predict()`", {
skip_if_not_installed("glmnet")
# Make materialization of sparse vectors throw an error
# https://r-lib.github.io/sparsevctrs/dev/reference/sparsevctrs_options.html
withr::local_options("sparsevctrs.verbose_materialize" = 3)

hotel_data <- sparse_hotel_rates(tibble = TRUE)

spec <- parsnip::linear_reg(penalty = 0) %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("glmnet")

rec <- recipes::recipe(avg_price_per_room ~ ., data = hotel_data)

wf_spec <- workflow() %>%
add_recipe(rec) %>%
add_model(spec)

wf_fit <- fit(wf_spec, hotel_data)

expect_no_error(predict(wf_fit, hotel_data))
})

test_that("sparse matrix can be passed to `predict()`", {
skip_if_not_installed("glmnet")
# Make materialization of sparse vectors throw a warning
# https://r-lib.github.io/sparsevctrs/dev/reference/sparsevctrs_options.html
withr::local_options("sparsevctrs.verbose_materialize" = 2)

hotel_data <- sparse_hotel_rates()

spec <- parsnip::linear_reg(penalty = 0) %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("glmnet")

rec <- recipes::recipe(avg_price_per_room ~ ., data = hotel_data)

wf_spec <- workflow() %>%
add_recipe(rec) %>%
add_model(spec)

# We know that this will cause 1 warning due to the outcome
suppressWarnings(
wf_fit <- fit(wf_spec, hotel_data)
)

expect_no_warning(predict(wf_fit, hotel_data))
})

0 comments on commit 35131ce

Please sign in to comment.