Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Survival stuff #1833

Merged
merged 19 commits into from
Jul 11, 2017
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Suggests:
smoof,
sparseLDA,
stepPlr,
survAUC,
SwarmSVM,
svglite,
testthat,
Expand Down
6 changes: 3 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ S3method(makeRLearner,surv.gamboost)
S3method(makeRLearner,surv.gbm)
S3method(makeRLearner,surv.glmboost)
S3method(makeRLearner,surv.glmnet)
S3method(makeRLearner,surv.penalized)
S3method(makeRLearner,surv.randomForestSRC)
S3method(makeRLearner,surv.ranger)
S3method(makeRLearner,surv.rpart)
Expand Down Expand Up @@ -484,7 +483,6 @@ S3method(predictLearner,surv.gamboost)
S3method(predictLearner,surv.gbm)
S3method(predictLearner,surv.glmboost)
S3method(predictLearner,surv.glmnet)
S3method(predictLearner,surv.penalized)
S3method(predictLearner,surv.randomForestSRC)
S3method(predictLearner,surv.ranger)
S3method(predictLearner,surv.rpart)
Expand Down Expand Up @@ -735,7 +733,6 @@ S3method(trainLearner,surv.gamboost)
S3method(trainLearner,surv.gbm)
S3method(trainLearner,surv.glmboost)
S3method(trainLearner,surv.glmnet)
S3method(trainLearner,surv.penalized)
S3method(trainLearner,surv.randomForestSRC)
S3method(trainLearner,surv.ranger)
S3method(trainLearner,surv.rpart)
Expand All @@ -762,6 +759,7 @@ export(calculateConfusionMatrix)
export(calculateROCMeasures)
export(capLargeValues)
export(cindex)
export(cindex.uno)
export(configureMlr)
export(convertBMRToRankMatrix)
export(convertMLBenchObjToTask)
Expand Down Expand Up @@ -877,6 +875,7 @@ export(helpLearner)
export(helpLearnerParam)
export(holdout)
export(hout)
export(iauc.uno)
export(impute)
export(imputeConstant)
export(imputeHist)
Expand Down Expand Up @@ -1090,6 +1089,7 @@ export(setHyperPars)
export(setHyperPars2)
export(setId)
export(setLearnerId)
export(setMeasurePars)
export(setPredictThreshold)
export(setPredictType)
export(setThreshold)
Expand Down
22 changes: 3 additions & 19 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
#' \item{req.task}{Is task object required in calculation? Usually not the case}
#' \item{req.model}{Is model object required in calculation? Usually not the case.}
#' \item{req.feats}{Are feature values required in calculation? Usually not the case.}
#' \item{req.prob}{Are predicted probabilites required in calculation? Usually not the case, example would be AUC.}
#' \item{req.prob}{Are predicted probabilities required in calculation? Usually not the case, example would be AUC.}
#' }
#' Default is \code{character(0)}.
#' @param fun [\code{function(task, model, pred, feats, extra.args)}]\cr
Expand All @@ -63,6 +63,7 @@
#' }
#' @param extra.args [\code{list}]\cr
#' List of extra arguments which will always be passed to \code{fun}.
#' Can be changed after construction via \code{\link{setMeasurePars}}<`3`>.
#' Default is empty list.
#' @param aggr [\code{\link{Aggregation}}]\cr
#' Aggregation funtion, which is used to aggregate the values measured
Expand Down Expand Up @@ -156,24 +157,6 @@ getDefaultMeasure = function(x) {
)
}


#' Set aggregation function of measure.
#'
#' Set how this measure will be aggregated after resampling.
#' To see possible aggregation functions: \code{\link{aggregations}}.
#'
#' @param measure [\code{\link{Measure}}]\cr
#' Performance measure.
#' @template arg_aggr
#' @return [\code{\link{Measure}}] with changed aggregation behaviour.
#' @export
setAggregation = function(measure, aggr) {
assertClass(measure, classes = "Measure")
assertClass(aggr, classes = "Aggregation")
measure$aggr = aggr
return(measure)
}

#' @export
print.Measure = function(x, ...) {
catf("Name: %s", x$name)
Expand All @@ -182,5 +165,6 @@ print.Measure = function(x, ...) {
catf("Minimize: %s", x$minimize)
catf("Best: %g; Worst: %g", x$best, x$worst)
catf("Aggregated by: %s", x$aggr$id)
catf("Arguments: %s", listToShortString(x$extra.args))
catf("Note: %s", x$note)
}
43 changes: 43 additions & 0 deletions R/Measure_operators.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#' @title Set parameters of performance measures
#'
#' @description
#' Sets hyperparameters of measures.
#'
#' @param measure [\code{\link{Measure}}]\cr
#' Performance measure.
#' @param ... [any]\cr
#' Named (hyper)parameters with new settings. Alternatively these can be passed
#' using the \code{par.vals} argument.
#' @param par.vals [\code{list}]\cr
#' Optional list of named (hyper)parameter settings. The arguments in
#' \code{...} take precedence over values in this list.
#' @template ret_measure
#' @family performance
#' @export
setMeasurePars = function(measure, ..., par.vals = list()) {
args = list(...)
assertClass(measure, classes = "Measure")
assertList(args, names = "unique", .var.name = "parameter settings")
assertList(par.vals, names = "unique", .var.name = "parameter settings")
measure$extra.args = insert(measure$extra.args, insert(par.vals, args))
measure
}

#' @title Set aggregation function of measure.
#'
#' @description
#' Set how this measure will be aggregated after resampling.
#' To see possible aggregation functions: \code{\link{aggregations}}.
#'
#' @param measure [\code{\link{Measure}}]\cr
#' Performance measure.
#' @template arg_aggr
#' @return [\code{\link{Measure}}] with changed aggregation behaviour.
#' @family performance
#' @export
setAggregation = function(measure, aggr) {
assertClass(measure, classes = "Measure")
assertClass(aggr, classes = "Aggregation")
measure$aggr = aggr
return(measure)
}
3 changes: 2 additions & 1 deletion R/RLearner_surv_cforest.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ trainLearner.surv.cforest = function(.learner, .task, .subset,

#' @export
predictLearner.surv.cforest = function(.learner, .model, .newdata, ...) {
predict(.model$learner.model, newdata = .newdata, ...)
# cforest returns median survival times; multiply by -1 so that high values correspond to high risk
-1 * predict(.model$learner.model, newdata = .newdata, type = "response", ...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test for this please?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a test in this PR which detects if the predictions are reversed/inverted.

}

#' @export
Expand Down
16 changes: 3 additions & 13 deletions R/RLearner_surv_coxph.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,13 @@ trainLearner.surv.coxph = function(.learner, .task, .subset, .weights = NULL, .
f = getTaskFormula(.task)
data = getTaskData(.task, subset = .subset)
if (is.null(.weights)) {
mod = survival::coxph(formula = f, data = data, ...)
survival::coxph(formula = f, data = data, ...)
} else {
mod = survival::coxph(formula = f, data = data, weights = .weights, ...)
survival::coxph(formula = f, data = data, weights = .weights, ...)
}
#if (.learner$predict.type == "prob")
# mod = attachTrainingInfo(mod, list(surv.range = range(getTaskTargets(.task)[, 1L])))
mod
}

#' @export
predictLearner.surv.coxph = function(.learner, .model, .newdata, ...) {
if (.learner$predict.type == "response") {
predict(.model$learner.model, newdata = .newdata, type = "lp", ...)
}
# else if (.learner$predict.type == "prob") {
# surv.range = getTrainingInfo(.model$learner.model)$surv.range
# times = seq(from = surv.range[1L], to = surv.range[2L], length.out = 1000)
# t(summary(survival::survfit(.model$learner.model, newdata = .newdata, se.fit = FALSE, conf.int = FALSE), times = times)$surv)
# }
predict(.model$learner.model, newdata = .newdata, type = "lp", ...)
}
5 changes: 1 addition & 4 deletions R/RLearner_surv_gamboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,5 @@ trainLearner.surv.gamboost = function(.learner, .task, .subset, .weights = NULL,

#' @export
predictLearner.surv.gamboost = function(.learner, .model, .newdata, ...) {
if (.learner$predict.type == "response")
predict(.model$learner.model, newdata = .newdata, type = "link")
else
stop("Unknown predict type")
predict(.model$learner.model, newdata = .newdata, type = "link")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the if no longer necessary here (and below)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Survival learners do not support multiple predict types currently. There was an attempt to support survival probabilities, but this is not implemented. The calling function checks for predict type and matches against properties, so this is dead code.

}
5 changes: 1 addition & 4 deletions R/RLearner_surv_glmboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,5 @@ predictLearner.surv.glmboost = function(.learner, .model, .newdata, use.formula,
info = getTrainingInfo(.model)
.newdata = as.matrix(fixDataForLearner(.newdata, info))
}
if (.learner$predict.type == "response")
predict(.model$learner.model, newdata = .newdata, type = "link")
else
stop("Unknown predict type")
predict(.model$learner.model, newdata = .newdata, type = "link")
}
40 changes: 0 additions & 40 deletions R/RLearner_surv_penalized.R

This file was deleted.

6 changes: 1 addition & 5 deletions R/RLearner_surv_rpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,7 @@ trainLearner.surv.rpart = function(.learner, .task, .subset, .weights = NULL, ..

#' @export
predictLearner.surv.rpart = function(.learner, .model, .newdata, ...) {
if (.learner$predict.type == "response") {
predict(.model$learner.model, newdata = .newdata, type = "vector", ...)
} else {
stop("Unsupported predict type")
}
predict(.model$learner.model, newdata = .newdata, type = "vector", ...)
}

#' @export
Expand Down
63 changes: 56 additions & 7 deletions R/measures.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#' For clustering measures, we compact the predicted cluster IDs such that they form a continuous series
#' starting with 1. If this is not the case, some of the measures will generate warnings.
#'
#' Some measure have parameters. Their defaults are set in the constructor \code{\link{makeMeasure}} and can be
#' overwritten using \code{\link{setMeasurePars}}.
#'
#' @param truth [\code{factor}]\cr
#' Vector of the true class.
#' @param response [\code{factor}]\cr
Expand Down Expand Up @@ -1337,19 +1340,65 @@ measureMultilabelTPR = function(truth, response) {
#' @format none
cindex = makeMeasure(id = "cindex", minimize = FALSE, best = 1, worst = 0,
properties = c("surv", "req.pred", "req.truth"),
name = "Concordance index",
name = "Harrell's Concordance index",
note = "Fraction of all pairs of subjects whose predicted survival times are correctly ordered among all subjects that can actually be ordered. In other words, it is the probability of concordance between the predicted and the observed survival.",
fun = function(task, model, pred, feats, extra.args) {
requirePackages("Hmisc", default.method = "load")
resp = pred$data$response
if (anyMissing(resp))
requirePackages("_Hmisc")
y = getPredictionResponse(pred)
if (anyMissing(y))
return(NA_real_)
# FIXME: we need to convert to he correct survival type
s = Surv(pred$data$truth.time, pred$data$truth.event)
Hmisc::rcorr.cens(-1 * resp, s)[["C Index"]]
s = getPredictionTruth(pred)
Hmisc::rcorr.cens(-1 * y, s)[["C Index"]]
}
)

#' @export cindex.uno
#' @rdname measures
#' @format none
#' @references
#' H. Uno et al.
#' \emph{On the C-statistics for Evaluating Overall Adequacy of Risk Prediction Procedures with Censored Survival Data}
#' Statistics in medicine. 2011;30(10):1105-1117. \url{http://dx.doi.org/10.1002/sim.4154}.
cindex.uno = makeMeasure(id = "cindex.uno", minimize = FALSE, best = 1, worst = 0,
properties = c("surv", "req.pred", "req.truth", "req.model"),
name = "Uno's Concordance index",
note = "Fraction of all pairs of subjects whose predicted survival times are correctly ordered among all subjects that can actually be ordered. In other words, it is the probability of concordance between the predicted and the observed survival. Corrected by weighting with IPCW as suggested by Uno. Implemented in survAUC::UnoC.",
fun = function(task, model, pred, feats, extra.args) {
requirePackages("_survAUC")
y = getPredictionResponse(pred)
if (anyMissing(y))
return(NA_real_)
surv.train = getTaskTargets(task, recode.target = "rcens")[model$subset]
max.time = assertNumber(extra.args$max.time, null.ok = TRUE) %??% max(getTaskTargets(task)[, 1L])
survAUC::UnoC(Surv.rsp = surv.train, Surv.rsp.new = getPredictionTruth(pred), time = max.time, lpnew = y)
},
extra.args = list(max.time = NULL)
)

#' @export iauc.uno
#' @rdname measures
#' @format none
#' @references
#' H. Uno et al.
#' \emph{Evaluating Prediction Rules for T-Year Survivors with Censored Regression Models}
#' Journal of the American Statistical Association 102, no. 478 (2007): 527-37. \url{http://www.jstor.org/stable/27639883}.
iauc.uno = makeMeasure(id = "iauc.uno", minimize = FALSE, best = 1, worst = 0,
properties = c("surv", "req.pred", "req.truth", "req.model", "req.task"),
name = "Uno's estimator of cumulative AUC for right censored time-to-event data",
note = "To set an upper time limit, set argument max.time (defaults to max time in complete task). Implemented in survAUC::AUC.uno.",
fun = function(task, model, pred, feats, extra.args) {
requirePackages("_survAUC")
max.time = assertNumber(extra.args$max.time, null.ok = TRUE) %??% max(getTaskTargets(task)[, 1L])
times = seq(from = 0, to = max.time, length.out = extra.args$resolution)
surv.train = getTaskTargets(task, recode.target = "rcens")[model$subset]
y = getPredictionResponse(pred)
if (anyMissing(y))
return(NA_real_)
survAUC::AUC.uno(Surv.rsp = surv.train, Surv.rsp.new = getPredictionTruth(pred), times = times, lpnew = y)$iauc
},
extra.args = list(max.time = NULL, resolution = 1000)
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add hand-constructed tests for the new measures please?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by hand-constructed? Calculating these measures without a package would require a few hundred LOC.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For most of the other measure tests along the lines of incorrect predictions 5, correct predictions 10, therefore error rate 33%. Check that implemented measure gets that number. The point is to check that the number is correct for specific cases (and these can be constructed, i.e. you know what the answer should be).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's complicated. I've added a small test to check if perfect predictions lead to (nearly) perfect performance if there is no censoring. For all other cases, I'd need an external package because you cannot compute this by hand (in a reasonable time frame). I guess we have to rely on the package authors of survAUC for correctness.

@PhilippPro Do you have any ideas how to test those measures?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, not really. Of course one can construct simple cases without censoring that can be calculated by hand, but with censoring we have to use the complicated formulas from Uno's Paper here (https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3079915/), which do not look very simple at first glance.

###############################################################################
### cost-sensitive ###
###############################################################################
Expand Down
3 changes: 2 additions & 1 deletion man/ConfusionMatrix.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/calculateConfusionMatrix.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/calculateROCMeasures.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/estimateRelativeOverfitting.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/makeCostMeasure.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/makeCustomResampledMeasure.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading