Skip to content

Commit

Permalink
Merge pull request #504 from OHDSI/existing_split
Browse files Browse the repository at this point in the history
Add existing splitSettings and tests. Solves #487
  • Loading branch information
egillax authored Nov 21, 2024
2 parents 77cc0bf + 7bddc38 commit fb99638
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 31 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export(createDatabaseSchemaSettings)
export(createDefaultExecuteSettings)
export(createDefaultSplitSetting)
export(createExecuteSettings)
export(createExistingSplitSettings)
export(createFeatureEngineeringSettings)
export(createGlmModel)
export(createLearningCurve)
Expand Down
35 changes: 32 additions & 3 deletions R/DataSplitting.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,26 @@ createDefaultSplitSetting <- function(testFraction = 0.25,
return(splitSettings)
}

#' Create the settings for defining how the plpData are split into
#' test/validation/train sets using an existing split - good to use for
#' reproducing results from a different run
#' @param splitIds (data.frame) A data frame with rowId and index columns of
#' type integer/numeric. Index is -1 for test set, positive integer for train
#' set folds
#' @return An object of class \code{splitSettings}
#' @export
createExistingSplitSettings <- function(splitIds) {
checkIsClass(splitIds, "data.frame")
checkColumnNames(splitIds, c("rowId", "index"))
checkIsClass(splitIds$rowId, c("integer", "numeric"))
checkIsClass(splitIds$index, c("integer", "numeric"))
checkHigherEqual(splitIds$index, -1)

splitSettings <- list(splitIds = splitIds)
attr(splitSettings, "fun") <- "existingSplitter"
class(splitSettings) <- "splitSettings"
return(splitSettings)
}


#' Split the plpData into test/train sets using a splitting settings of class
Expand Down Expand Up @@ -561,7 +581,16 @@ checkInputsSplit <- function(test, train, nfold, seed) {
ParallelLogger::logDebug(paste0("nfold: ", nfold))
checkIsClass(nfold, c("numeric", "integer"))
checkHigher(nfold, 1)

ParallelLogger::logInfo(paste0('seed: ', seed))
checkIsClass(seed, c('numeric','integer'))

ParallelLogger::logInfo(paste0("seed: ", seed))
checkIsClass(seed, c("numeric", "integer"))
}

existingSplitter <- function(population, splitSettings) {
splitIds <- splitSettings$splitIds
# check all row Ids are in population
if (sum(!splitIds$rowId %in% population$rowId) > 0) {
stop("Not all rowIds in splitIds are in the population")
}
return(splitIds)
}
42 changes: 27 additions & 15 deletions man/createDefaultSplitSetting.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions man/createExistingSplitSettings.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 23 additions & 13 deletions man/splitData.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 35 additions & 0 deletions tests/testthat/test-dataSplitting.R
Original file line number Diff line number Diff line change
Expand Up @@ -420,5 +420,40 @@ test_that("Data splitting by subject", {
expect_equal(unique(table(test$subjectId[test$index == 3])), 4)
expect_equal(unique(table(test$subjectId[test$index == 1])), 4)

# test that no subject is not assigned a fold
expect_equal(sum(test$index == 0), 0)
})

test_that("Existing data splitter works", {
# split by age
age <- population$ageYear
# create empty index same lengths as age
index <- rep(0, length(age))
index[age > 43] <- -1 # test set
index[age <= 35] <- 1 # train fold 1
index[age > 35 & age <= 43] <- 2 # train fold 2
splitIds <- data.frame(rowId = population$rowId, index = index)
splitSettings <- createExistingSplitSettings(splitIds)
ageSplit <- splitData(
plpData = plpData,
population = population,
splitSettings = splitSettings
)

# test only old people in test
expect_equal(
length(ageSplit$Test$labels$rowId),
sum(age > 43)
)
# only young people in train
expect_equal(
length(ageSplit$Train$labels$rowId),
sum(age <= 43)
)
# no overlap
expect_equal(
length(intersect(ageSplit$Test$labels$rowId, ageSplit$Train$labels$rowId)),
0
)

})

0 comments on commit fb99638

Please sign in to comment.