diff --git a/auton_survival/datasets.py b/auton_survival/datasets.py index e39b5e0..c4e0923 100644 --- a/auton_survival/datasets.py +++ b/auton_survival/datasets.py @@ -38,7 +38,9 @@ import torchvision -def increase_censoring(e, t, p): +def increase_censoring(e, t, p, random_seed=0): + + np.random.seed(random_seed) uncens = np.where(e == 1)[0] mask = np.random.choice([False, True], len(uncens), p=[1-p, p])