Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shapley does not allow data.table inputs #217

Closed
dandls opened this issue Oct 16, 2024 · 0 comments
Closed

Shapley does not allow data.table inputs #217

dandls opened this issue Oct 16, 2024 · 0 comments

Comments

@dandls
Copy link

dandls commented Oct 16, 2024

When I use data.table(x.interest) instead of x.interest in the example code of iml::Shapley, I receive an error message, i.e.:

library("iml")
library("data.table")
library("rpart")
# First we fit a machine learning model on the Boston housing data
data("Boston", package = "MASS")
rf <- rpart(medv ~ ., data = Boston)
X <- Boston[-which(names(Boston) == "medv")]
mod <- Predictor$new(rf, data = X)

# Then we explain the first instance of the dataset with the Shapley method:
x.interest <- X[1, ]
shapley <- Shapley$new(mod, x.interest = data.table(x.interest))

The error message is:

Error in `[.data.table`(x.interest, setdiff(colnames(x.interest), predictor$data$y.names)) : 
  When i is a data.table (or character vector), the columns to join by must be specified using 'on=' argument (see ?data.table), by keying x (i.e. sorted, and, marked as sorted, see ?setkey), or by sharing column names between x and i (i.e., a natural join). Keyed joins might have further speed benefits on very large data due to x being sorted in RAM.

I used the latest installations of packages on CRAN.
This is my sessionInfo() output

> sessionInfo()
R version 4.4.0 (2024-04-24)
Platform: x86_64-pc-linux-gnu
Running under: Debian GNU/Linux 12 (bookworm)

Matrix products: default
BLAS:   /usr/local/lib/R/lib/libRblas.so 
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.11.0

locale:
 [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_GB.UTF-8        LC_COLLATE=en_GB.UTF-8    
 [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
 [7] LC_PAPER=en_GB.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       

time zone: Europe/Berlin
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] data.table_1.16.2 rpart_4.1.23      iml_0.11.3       

loaded via a namespace (and not attached):
 [1] vctrs_0.6.5       cli_3.6.3         rlang_1.1.4       Formula_1.2-5    
 [5] generics_0.1.3    glue_1.7.0        colorspace_2.1-0  listenv_0.9.1    
 [9] backports_1.5.0   Metrics_0.1.4     scales_1.3.0      fansi_1.0.6      
[13] grid_4.4.0        munsell_0.5.1     tibble_3.2.1      lifecycle_1.0.4  
[17] compiler_4.4.0    dplyr_1.1.4       codetools_0.2-20  pkgconfig_2.0.3  
[21] future_1.34.0     digest_0.6.37     R6_2.5.1          tidyselect_1.2.1 
[25] utf8_1.2.4        pillar_1.9.0      parallelly_1.38.0 parallel_4.4.0   
[29] magrittr_2.0.3    checkmate_2.3.2   tools_4.4.0       gtable_0.3.5     
[33] globals_0.16.3    ggplot2_3.5.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant