Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Implement grouped dataset splits and cross-validation #363

Merged
merged 18 commits into from
Jan 22, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ created.

### Added
- New extensions of SegmentationModelBases `HeadAndNeckBase` and `ProstateBase`. Use these classes to build your own Head&Neck or Prostate models, by just providing a list of foreground classes.
- Grouped dataset splits and k-fold cross-validation. This allows, for example, training on datasets with multiple images per subject without leaking data from the same subject across train/test/validation sets or cross-validation folds.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add here the specific change users will need to make to use this feature in their configs?


### Changed
- The arguments of the `score.py` script changed: `data_root` -> `data_folder`, it no longer assumes a fixed
Expand Down
167 changes: 136 additions & 31 deletions InnerEye/ML/utils/split_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.model_selection import GroupKFold, KFold

from InnerEye.Common import common_util
from InnerEye.ML.common import ModelExecutionMode
Expand All @@ -26,17 +26,36 @@ class DatasetSplits:
val: pd.DataFrame
test: pd.DataFrame
subject_column: str = CSV_SUBJECT_HEADER
group_column: Optional[str] = None
allow_empty: bool = False

def __post_init__(self) -> None:
common_util.check_properties_are_not_none(self)

def pairwise_intersection(*collections: Iterable) -> Set:
"""Returns any element that appears in more than one collection."""
from itertools import combinations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason this import is local?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really; well spotted! I'll move it to the top.

intersection = set()
for col1, col2 in combinations(map(set, collections), 2):
intersection |= col1 & col2
return intersection

# perform dataset split validity assertions
unique_train, unique_test, unique_val = self.unique_subjects()
intersection = set.intersection(set(unique_train), set(unique_test), set(unique_val))
intersection = pairwise_intersection(unique_train, unique_test, unique_val)

if len(intersection) != 0:
raise ValueError("Train, Test, and Val splits must have no intersection, found: {}".format(intersection))

if self.group_column is not None:
groups_train = self.train[self.group_column].unique()
groups_test = self.test[self.group_column].unique()
groups_val = self.val[self.group_column].unique()
group_intersection = pairwise_intersection(groups_train, groups_test, groups_val)
if len(group_intersection) != 0:
raise ValueError("Train, Test, and Val splits must have no intersecting groups, found: {}"
.format(group_intersection))

if (not self.allow_empty) and any([len(x) == 0 for x in [unique_train, unique_val]]):
raise ValueError("train_ids({}), val_ids({}) must have at least one value"
.format(len(unique_train), len(unique_val)))
Expand Down Expand Up @@ -69,12 +88,16 @@ def restrict_subjects(self, restriction_pattern: str) -> DatasetSplits:
"""
Creates a new dataset split that has at most the specified numbers of subjects in train, validation and test
sets respectively.

If `group_column` was specified, this operation may violate the grouping constraints and the resulting splits
object will have `group_column == None`.

:param restriction_pattern: a string containing zero or two commas, and otherwise digits or "+". An empty
substring will result in no restriction for the corresponding dataset. Thus "20,,3" means "restrict to 20
training images and 3 test images, with no restriction on validation". A "+" value means "reassign all
images from the set(s) with a numeric count (there must be at least one) to this set". Thus ",0,+" means "leave
the training set alone, but move all validation images to the test set", and "0,2,+" means "move
all training images and all but 2 validation images to the test set".
substring will result in no restriction for the corresponding dataset. Thus "20,,3" means "restrict to 20
training images and 3 test images, with no restriction on validation". A "+" value means "reassign all
images from the set(s) with a numeric count (there must be at least one) to this set". Thus ",0,+" means
"leave the training set alone, but move all validation images to the test set", and "0,2,+" means "move
all training images and all but 2 validation images to the test set".
:return: A new dataset split object with (at most) the numbers of subjects specified by restrict_pattern
"""

Expand Down Expand Up @@ -183,48 +206,87 @@ def get_subject_ranges_for_splits(population: Sequence[str],
}
return result

@staticmethod
def _from_split_keys(df: pd.DataFrame,
train_keys: Sequence[str],
test_keys: Sequence[str],
val_keys: Sequence[str],
*, # make column names keyword-only arguments to avoid mistakes when providing both
key_column: str,
subject_column: str,
group_column: Optional[str]) -> DatasetSplits:
Comment on lines +215 to +217
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both column names should have "" as the default value, and group_column=None

"""
Takes a slice of values from each data split train/test/val for the provided keys.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's a slice of values?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just adapted this docstring from from_subject_ids().


:param df: the input DataFrame
:param train_keys: keys for training.
:param test_keys: keys for testing.
:param val_keys: keys for validation.
:param key_column: name of the column the provided keys belong to
:param subject_column: subject id column name
:param group_column: grouping column name; if given, samples from each group will always be
in the same subset (train, val, or test) and cross-validation fold.
:return: Data splits with respected dataset split ids.
"""
train_df = DatasetSplits.get_df_from_ids(df, train_keys, key_column)
test_df = DatasetSplits.get_df_from_ids(df, test_keys, key_column)
val_df = DatasetSplits.get_df_from_ids(df, val_keys, key_column)

return DatasetSplits(train=train_df, test=test_df, val=val_df,
subject_column=subject_column, group_column=group_column)

@staticmethod
def from_proportions(df: pd.DataFrame,
proportion_train: float,
proportion_test: float,
proportion_val: float,
*, # make column names keyword-only arguments to avoid mistakes when providing both
subject_column: str = CSV_SUBJECT_HEADER,
group_column: Optional[str] = None,
shuffle: bool = True,
random_seed: int = 0) -> DatasetSplits:
"""
Creates a split of a dataset into train, test, and validation set, according to fixed proportions using
the "subject" column in the dataframe.
the "subject" column in the dataframe, or the group column, if given.

:param df: The dataframe containing all subjects.
:param proportion_train: proportion for the train set.
:param proportion_test: proportion for the test set.
:param subject_column: Subject id column name
:param group_column: grouping column name; if given, samples from each group will always be
in the same subset (train, val, or test) and cross-validation fold.
:param proportion_val: proportion for the validation set.
:param shuffle: If True the subjects in the dataframe will be shuffle before performing splits.
:param random_seed: Random seed to be used for shuffle 0 is default.
:return:
"""
subjects = df[subject_column].unique()
key_column: str = subject_column if group_column is None else group_column
split_keys = df[key_column].unique()
if shuffle:
# fix the random seed so we can guarantee reproducibility when working with shuffle
random.Random(random_seed).shuffle(subjects)
random.Random(random_seed).shuffle(split_keys)
ranges = DatasetSplits.get_subject_ranges_for_splits(
subjects,
split_keys,
proportion_train=proportion_train,
proportion_val=proportion_val,
proportion_test=proportion_test
)
return DatasetSplits.from_subject_ids(df,
return DatasetSplits._from_split_keys(df,
list(ranges[ModelExecutionMode.TRAIN]),
list(ranges[ModelExecutionMode.TEST]),
list(ranges[ModelExecutionMode.VAL]),
subject_column)
key_column=key_column,
subject_column=subject_column,
group_column=group_column)

@staticmethod
def from_subject_ids(df: pd.DataFrame,
train_ids: Sequence[str],
test_ids: Sequence[str],
val_ids: Sequence[str],
subject_column: str = CSV_SUBJECT_HEADER) -> DatasetSplits:
*, # make column names keyword-only arguments to avoid mistakes when providing both
subject_column: str = CSV_SUBJECT_HEADER,
group_column: Optional[str] = None) -> DatasetSplits:
"""
Assuming a DataFrame with columns subject
Takes a slice of values from each data split train/test/val for the provided ids.
Expand All @@ -234,14 +296,36 @@ def from_subject_ids(df: pd.DataFrame,
:param test_ids: ids for testing.
:param val_ids: ids for validation.
:param subject_column: subject id column name
:param group_column: grouping column name; if given, samples from each group will always be
in the same subset (train, val, or test) and cross-validation fold.
:return: Data splits with respected dataset split ids.
"""
return DatasetSplits(
train=DatasetSplits.get_df_from_ids(df, train_ids, subject_column),
test=DatasetSplits.get_df_from_ids(df, test_ids, subject_column),
val=DatasetSplits.get_df_from_ids(df, val_ids, subject_column),
subject_column=subject_column
)
return DatasetSplits._from_split_keys(df, train_ids, test_ids, val_ids, key_column=subject_column,
subject_column=subject_column, group_column=group_column)

@staticmethod
def from_groups(df: pd.DataFrame,
train_groups: Sequence[str],
test_groups: Sequence[str],
val_groups: Sequence[str],
*, # make column names keyword-only arguments to avoid mistakes when providing both
group_column: str,
subject_column: str = CSV_SUBJECT_HEADER) -> DatasetSplits:
"""
Assuming a DataFrame with columns subject
Takes a slice of values from each data split train/test/val for the provided groups.

:param df: the input DataFrame
:param train_groups: groups for training.
:param test_groups: groups for testing.
:param val_groups: groups for validation.
:param subject_column: subject id column name
:param group_column: grouping column name; if given, samples from each group will always be
in the same subset (train, val, or test) and cross-validation fold.
:return: Data splits with respected dataset split ids.
"""
return DatasetSplits._from_split_keys(df, train_groups, test_groups, val_groups, key_column=group_column,
subject_column=subject_column, group_column=group_column)

@staticmethod
def from_institutions(df: pd.DataFrame,
Expand All @@ -256,9 +340,12 @@ def from_institutions(df: pd.DataFrame,
subject_ids_for_test_only: Optional[Iterable[str]] = None) -> DatasetSplits:
"""
Assuming a DataFrame with columns subject and institutionId

Takes a slice of values from each institution based on the train/test/val proportions provided,
such that for each institution there is at least one subject in each of the train/test/val splits.

This method for creating `DatasetSplits` does not currently support grouping.

:param df: the input DataFrame
:param proportion_train: Proportion of images per institution to be used for training.
:param proportion_val: Proportion of images per institution to be used for validation.
Expand All @@ -267,11 +354,11 @@ def from_institutions(df: pd.DataFrame,
:param shuffle: If True the subjects in the dataframe will be shuffle before performing splits.
:param random_seed: Random seed to be used for shuffle 0 is default.
:param exclude_institutions: If given, all subjects where institutionId has the given value will be
excluded from train, test, and validation set.
excluded from train, test, and validation set.
:param institutions_for_test_only: If given, all subjects where institutionId has the given value will be
placed only in the test set.
placed only in the test set.
:param subject_ids_for_test_only: If given, all images with the provided subject Ids will be placed in the
test set.
test set.
:return: Data splits with respected dataset split proportions per institution.
"""
results: Dict[ModelExecutionMode, pd.DataFrame] = {}
Expand Down Expand Up @@ -347,25 +434,43 @@ def get_df_from_ids(df: pd.DataFrame, ids: Sequence[str],

def get_k_fold_cross_validation_splits(self, n_splits: int, random_seed: int = 0) -> List[DatasetSplits]:
"""
Creates K folds from the Train + Val splits
Creates K folds from the Train + Val splits.

If a group_column has been specified, the folds will be split such that
subjects in a group will not be separated. In this case, the splits are
fully deterministic, and random_seed is ignored.

:param n_splits: number of folds to perform.
:param random_seed: random seed to be used for shuffle 0 is default.
:return: List of K dataset splits
"""
if n_splits <= 0:
raise ValueError("n_splits must be >= 0 found {}".format(n_splits))

# calculate the random split indices
k_folds = KFold(n_splits=n_splits, shuffle=True, random_state=random_seed)
# concatenate train and val, as training set = train + val
cv_dataset = pd.concat([self.train, self.val])
# unique subjects
subject_ids = cv_dataset[self.subject_column].unique()

if self.group_column is None: # perform standard subject-based k-fold cross-validation
# unique subjects
subject_ids = cv_dataset[self.subject_column].unique()
# calculate the random split indices
k_folds = KFold(n_splits=n_splits, shuffle=True, random_state=random_seed)
folds_gen = k_folds.split(subject_ids)
else: # perform grouped k-fold cross-validation
# Here we take all entries, rather than unique, to keep subjects
# matched to groups in the resulting arrays. This works, but could
# perhaps be improved with group-by logic...?
subject_ids = cv_dataset[self.subject_column].values
groups = cv_dataset[self.group_column].values
# scikit-learn uses a deterministic algorithm for grouped splits
# that tries to balance the group sizes in all folds
k_folds = GroupKFold(n_splits=n_splits)
folds_gen = k_folds.split(subject_ids, groups=groups)

ids_from_indices = lambda indices: [subject_ids[x] for x in indices]
# create the number of requested splits of the dataset
return [
DatasetSplits(train=self.get_df_from_ids(cv_dataset, ids_from_indices(train_indices), self.subject_column),
val=self.get_df_from_ids(cv_dataset, ids_from_indices(val_indices), self.subject_column),
test=self.test,
subject_column=self.subject_column) for train_indices, val_indices in
k_folds.split(subject_ids)]
test=self.test, subject_column=self.subject_column, group_column=self.group_column)
for train_indices, val_indices in folds_gen]
Loading