Skip to content

Commit

Permalink
modified: auton_survival/__init__.py
Browse files Browse the repository at this point in the history
	modified:   auton_survival/datasets.py
	modified:   auton_survival/metrics.py
  • Loading branch information
chiragnagpal committed Feb 16, 2022
1 parent 4d191ce commit ddedb91
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 109 deletions.
35 changes: 26 additions & 9 deletions auton_survival/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
'''
r'''
[![Build Status](https://travis-ci.org/autonlab/DeepSurvivalMachines.svg?branch=master)](https://travis-ci.org/autonlab/DeepSurvivalMachines)
   
[![codecov](https://codecov.io/gh/autonlab/DeepSurvivalMachines/branch/master/graph/badge.svg?token=FU1HB5O92D)](https://codecov.io/gh/autonlab/DeepSurvivalMachines)
   
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
   
[![GitHub Repo stars](https://img.shields.io/github/stars/autonlab/DeepSurvivalMachines?style=social)](https://github.com/autonlab/DeepSurvivalMachines)
[![GitHub Repo stars](https://img.shields.io/github/stars/autonlab/auton-survival?style=social)](https://github.com/autonlab/auton-survival)
Python package `auton_survival` provides a flexible API for various problems
in survival analysis, including regression, counterfactual estimation,
and phenotyping.
What is Survival Analysis?
------------------------
--------------------------
**Survival Analysis** involves estimating when an event of interest, \( T \)
would take places given some features or covariates \( X \). In statistics
Expand All @@ -27,11 +27,13 @@
* There is presence of censoring ie. a large number of instances of data are
lost to follow up.
# Auton Survival
Auton Survival
----------------
Repository of reusable code utilities for Survival Analysis projects.
## `auton_survival.datasets`
Dataset Loading and Preprocessing
---------------------------------
Helper functions to load various trial data like `TOPCAT`, `BARI2D` and `ALLHAT`.
Expand All @@ -41,7 +43,7 @@
features, outcomes = datasets.load_topcat()
```
## `auton_survival.preprocessing`
### `auton_survival.preprocessing`
This module provides a flexible API to perform imputation and data
normalization for downstream machine learning models. The module has
3 distinct classes, `Scaler`, `Imputer` and `Preprocessor`. The `Preprocessor`
Expand All @@ -58,9 +60,14 @@ class is a composite transform that does both Imputing ***and*** Scaling.
num_feats=['height', 'weight'])
# The `cat_feats` and `num_feats` lists would contain all the categorical and numerical features in the dataset.
```
## `auton_survival.estimators`
Survival Regression
-------------------
### `auton_survival.estimators`
This module provids a wrapper to model BioLINNC datasets with standard
survival (time-to-event) analysis methods.
Expand All @@ -87,7 +94,7 @@ class is a composite transform that does both Imputing ***and*** Scaling.
```
## `auton_survival.experiments`
### `auton_survival.experiments`
Modules to perform standard survival analysis experiments. This module
provides a top-level interface to run `auton_survival` Style experiments
Expand Down Expand Up @@ -118,7 +125,17 @@ class is a composite transform that does both Imputing ***and*** Scaling.
print(scores)
```
## `auton_survival.reporting`
Phenotyping and Knowledge Discovery
-----------------------------------
### `auton_survival.phenotyping`
Reporting
----------
### `auton_survival.reporting`
Helper functions to generate standard reports for popular Survival Analysis problems.
Expand Down
5 changes: 3 additions & 2 deletions auton_survival/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,9 @@ def load_dataset(dataset='SUPPORT', **kwargs):
Returns
----------
tuple: (np.ndarray, np.ndarray, np.ndarray)
A tuple of the form of (x, t, e) where x, t, e are the input covariates,
event times and the censoring indicators respectively.
A tuple of the form of \( (x, t, e) \) where \( x \)
are the input covariates, \( t \) the event times and
\( e \) the censoring indicators.
"""
sequential = kwargs.get('sequential', False)

Expand Down
197 changes: 99 additions & 98 deletions auton_survival/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,105 @@

from tqdm import tqdm

def survival_diff_metric(metric, outcomes, treatment_indicator,
weights=None, horizon=None, interpolate=True,
weights_clip=1e-2,
n_bootstrap=None, size_bootstrap=1.0, random_seed=0):

r"""Metrics for comparing population level survival outcomes across treatment arms.
Parameters
----------
metric : str
The metric to evalute. One of:
- **`hazard_ratio`**
- **`restricted_mean`**
- **`survival_at`**
outcomes : pd.DataFrame
The outcomes to compare. A pd.Dataframe with columns 'time' and 'event'.
treatment_indicator : np.array
Boolean numpy array of treatment indicators. True means individual was
assigned treatment.
weights : pd.Series
Treatment assignment propensity scores, \( \widehat{\mathbb{P}}(A|X=x) \).
If None, all weights are set to 0.5. Default is None.
horizon : float
The time horizon at which to compare the survival curves.
Must be specified for metric 'restricted_mean' and 'survival_at'.
For 'hazard_ratio' this is ignored.
interpolate : bool
Whether to interpolate the survival curves. Default is True.
weights_clip : float
Weights below this value are clamped. This is to ensure IPTW estimation
is numerically stable. Large weights can result in estimator with high
variance.
n_bootstrap : int
The number of bootstrap samples to use. Default is None.
If None, no bootrapping is performed.
size_bootstrap : float
The fraction of the population to sample for each bootstrap sample.
Default is 1.0.
random_seed : int
The random seed to use for bootstrapping. Default is 0.
Returns:
float or list: The metric value(s) for the specified metric.
"""

assert metric in ['median', 'hazard_ratio', 'restricted_mean', 'survival_at', 'time_to']

if metric in ['restricted_mean', 'survival_at', 'time_to']:
assert horizon is not None, "Please specify Event Horizon"

if metric == 'hazard_ratio':
raise Warning("WARNING: You are computing Hazard Ratios.\n Make sure you have tested the PH Assumptions.")
if (n_bootstrap is None) and (weights is not None):
raise Warning("Treatment Propensity weights would be ignored, Since no boostrapping is performed."+
"In order to incorporate IPTW weights please specify number of bootstrap iterations n_bootstrap>=1")
# Bootstrapping ...
if n_bootstrap is not None:
assert isinstance(n_bootstrap, int), '`bootstrap` must be None or int'

if isinstance(n_bootstrap, int):
print('Bootstrapping... ', n_bootstrap,
' number of times. This may take a while. Please be Patient...')

is_treated = treatment_indicator.astype(float)
if weights is None:
weights = 0.5*np.ones(len(outcomes))

weights[weights>weights_clip] = 1-weights_clip
weights[weights<weights_clip] = weights_clip

iptw_weights = 1./((is_treated*weights)+((1-is_treated)*(1-weights)))

treated_outcomes = outcomes[treatment_indicator]
control_outcomes = outcomes[~treatment_indicator]

if metric == 'survival_at': _metric = _survival_at_diff
elif metric == 'time_to': _metric = _time_to_diff
elif metric == 'restricted_mean': _metric = _restricted_mean_diff
elif metric == 'median': _metric = _time_to_diff
elif metric == 'hazard_ratio': _metric = _hazard_ratio
else: raise NotImplementedError()

if n_bootstrap is None:
return _metric(treated_outcomes,
control_outcomes,
horizon=horizon,
interpolate=interpolate,
treated_weights=iptw_weights[treatment_indicator],
control_weights=iptw_weights[~treatment_indicator])
else:
return [_metric(treated_outcomes,
control_outcomes,
horizon=horizon,
interpolate=interpolate,
treated_weights=iptw_weights[treatment_indicator],
control_weights=iptw_weights[~treatment_indicator],
size_bootstrap=size_bootstrap,
seed=random_seed*i) for i in range(n_bootstrap)]


def survival_regression_metric(metric, predictions, outcomes, times,
folds=None, fold=None):
Expand Down Expand Up @@ -211,101 +310,3 @@ def _hazard_ratio(treated_outcomes, control_outcomes,
return CoxPHFitter().fit(outcomes,
duration_col='time',
event_col='event').hazard_ratios_['treated']


def survival_diff_metric(metric, outcomes, treatment_indicator,
weights=None, horizon=None, interpolate=True,
weights_clip=1e-2,
n_bootstrap=None, size_bootstrap=1.0, random_seed=0):

"""Metrics for comparing population level survival outcomes across treatment arms.
Parameters
----------
metric : str
The metric to evalute. One of:
- **`hazard_ratio`**
- **`restricted_mean`**
- **`survival_at`**
outcomes : pd.DataFrame
The outcomes to compare. A pd.Daraframe with columns 'time' and 'event'.
treatment_indicator : np.array
Boolean numpy array of treatment indicators. True means individual was
assigned treatment.
weights : pd.Series
Treatment assignment propensity scores.
If None, all weights are set to 0.5. Default is None.
horizon : float
The time horizon at which to compare the survival curves.
Must be specified for metric 'restricted_mean' and 'survival_at'.
For 'hazard_ratio' this is ignored.
interpolate : bool
Whether to interpolate the survival curves. Default is True.
weights_clip : float
Weights below this value are clamped. This is to ensure IPTW estimation
is numerically stable. Large weights can result in estimator with high
variance.
n_bootstrap : int
The number of bootstrap samples to use. Default is None.
If None, no bootrapping is performed.
size_bootstrap : float
The fraction of the population to sample for each bootstrap sample.
Default is 1.0.
random_seed : int
The random seed to use for bootstrapping. Default is 0.
Returns:
float or list: The metric value(s) for the specified metric.
"""

assert metric in ['median', 'hazard_ratio', 'restricted_mean', 'survival_at', 'time_to']

if metric in ['restricted_mean', 'survival_at', 'time_to']:
assert horizon is not None, "Please specify Event Horizon"

if metric == 'hazard_ratio':
raise Warning("WARNING: You are computing Hazard Ratios.\n Make sure you have tested the PH Assumptions.")
if (n_bootstrap is None) and (weights is not None):
raise Warning("Treatment Propensity weights would be ignored, Since no boostrapping is performed."+
"In order to incorporate IPTW weights please specify number of bootstrap iterations n_bootstrap>=1")
# Bootstrapping ...
if n_bootstrap is not None:
assert isinstance(n_bootstrap, int), '`bootstrap` must be None or int'

if isinstance(n_bootstrap, int):
print('Bootstrapping... ', n_bootstrap,
' number of times. This may take a while. Please be Patient...')

is_treated = treatment_indicator.astype(float)
if weights is None:
weights = 0.5*np.ones(len(outcomes))

weights[weights>weights_clip] = 1-weights_clip
weights[weights<weights_clip] = weights_clip

iptw_weights = 1./((is_treated*weights)+((1-is_treated)*(1-weights)))

treated_outcomes = outcomes[treatment_indicator]
control_outcomes = outcomes[~treatment_indicator]

if metric == 'survival_at': _metric = _survival_at_diff
elif metric == 'time_to': _metric = _time_to_diff
elif metric == 'restricted_mean': _metric = _restricted_mean_diff
elif metric == 'median': _metric = _time_to_diff
elif metric == 'hazard_ratio': _metric = _hazard_ratio
else: raise NotImplementedError()

if n_bootstrap is None:
return _metric(treated_outcomes,
control_outcomes,
horizon=horizon,
interpolate=interpolate,
treated_weights=iptw_weights[treatment_indicator],
control_weights=iptw_weights[~treatment_indicator])
else:
return [_metric(treated_outcomes,
control_outcomes,
horizon=horizon,
interpolate=interpolate,
treated_weights=iptw_weights[treatment_indicator],
control_weights=iptw_weights[~treatment_indicator],
size_bootstrap=size_bootstrap,
seed=random_seed*i) for i in range(n_bootstrap)]

0 comments on commit ddedb91

Please sign in to comment.