Skip to content

Commit

Permalink
add test for weight functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mirka-henninger committed Feb 2, 2021
1 parent bf55da2 commit fff827b
Showing 1 changed file with 34 additions and 7 deletions.
41 changes: 34 additions & 7 deletions tests/testthat/test-LocalModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,26 @@ test_that("LocalModel works for single output and single feature", {
p <- plot(LocalModel1)
expect_s3_class(p, c("gg", "ggplot"))
p

x.interest2 <- X[4, ]
LocalModel1$explain(x.interest2)
dat <- LocalModel1$results
expect_equal(colnames(dat), expected.colnames)
expect_lte(nrow(dat), k)

pred <- predict(LocalModel1, newdata = X[3:4, ])
expect_data_frame(pred, nrows = 2)
expect_equal(colnames(pred), "prediction")

LocalModel1 <- LocalModel$new(predictor1,
x.interest = x.interest, k = k,
dist.fun = "euclidean", kernel.width = 1
x.interest = x.interest, k = k,
dist.fun = "euclidean", kernel.width = 1
)
LocalModel1$explain(x.interest2)
dat <- LocalModel1$results
expect_equal(colnames(dat), expected.colnames)
expect_lte(nrow(dat), k)

pred <- predict(LocalModel1, newdata = X[3:4, ])
expect_data_frame(pred, nrows = 2)
expect_equal(colnames(pred), "prediction")
Expand All @@ -57,7 +57,7 @@ test_that("LocalModel works for multiple output", {
expect_class(dat, "data.frame")
expect_data_frame(pred2, nrows = 2)
expect_equal(colnames(pred2), c("setosa", "versicolor", "virginica"))

p <- plot(LocalModel1)
expect_s3_class(p, c("gg", "ggplot"))
p
Expand All @@ -74,3 +74,30 @@ test_that("LocalModel prediction expects same cols as training dat", {
expect_warning(LocalModel1$predict(cbind(x.interest, data.frame(blabla = 1))))
expect_error(LocalModel1$predict(x.interest[-2]), "Missing")
})




test_that("LocalModel distance functions work as expected", {
kernel.width <- 1
distance.functions <- c(
"gower", "euclidean", "maximum",
"manhattan", "canberra", "binary", "minkowski")
x.interest <- X[2, ]
k <- 1
set.seed(42)
LocalModel1 <- LocalModel$new(predictor1, x.interest = x.interest, k = k,
dist = "euclidean", kernel.width = kernel.width)
# recode to avoid warning for categorical variables (NAs introduced by coercion)
X.recode <- recode(LocalModel1$.__enclos_env__$private$dataDesign, x.interest)
x.recoded <- recode(x.interest, x.interest)
# first test the function that was used for fitting
weights <- LocalModel1$.__enclos_env__$private$weight.fun(X.recode, x.recoded)
expect_equal(object = weights[2], expected = 1)
# test all distance functions by explicitly constructing them via get.weight.fun()
for(fun in distance.functions){
weight_fun <- LocalModel1$.__enclos_env__$private$get.weight.fun(
dist.fun = fun, kernel.width = kernel.width)
expect_equal(object = weight_fun(X.recode, x.recoded)[2], expected = 1)
}
})

0 comments on commit fff827b

Please sign in to comment.