-
Notifications
You must be signed in to change notification settings - Fork 143
Implement grouped dataset splits and cross-validation #363
Changes from 14 commits
4581725
ef86970
af4cbfd
5644714
b25ffc0
425de87
d966ba1
a98b822
0b39867
fb38e4c
a0a3989
0ed76ba
fd1ea81
7786464
2d4b857
655e4fa
f3af1f5
fe81f02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ | |
|
||
import numpy as np | ||
import pandas as pd | ||
from sklearn.model_selection import KFold | ||
from sklearn.model_selection import GroupKFold, KFold | ||
|
||
from InnerEye.Common import common_util | ||
from InnerEye.ML.common import ModelExecutionMode | ||
|
@@ -26,17 +26,36 @@ class DatasetSplits: | |
val: pd.DataFrame | ||
test: pd.DataFrame | ||
subject_column: str = CSV_SUBJECT_HEADER | ||
group_column: Optional[str] = None | ||
allow_empty: bool = False | ||
|
||
def __post_init__(self) -> None: | ||
common_util.check_properties_are_not_none(self) | ||
|
||
def pairwise_intersection(*collections: Iterable) -> Set: | ||
"""Returns any element that appears in more than one collection.""" | ||
from itertools import combinations | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a specific reason this import is local? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really; well spotted! I'll move it to the top. |
||
intersection = set() | ||
for col1, col2 in combinations(map(set, collections), 2): | ||
intersection |= col1 & col2 | ||
return intersection | ||
|
||
# perform dataset split validity assertions | ||
unique_train, unique_test, unique_val = self.unique_subjects() | ||
intersection = set.intersection(set(unique_train), set(unique_test), set(unique_val)) | ||
intersection = pairwise_intersection(unique_train, unique_test, unique_val) | ||
|
||
if len(intersection) != 0: | ||
raise ValueError("Train, Test, and Val splits must have no intersection, found: {}".format(intersection)) | ||
|
||
if self.group_column is not None: | ||
groups_train = self.train[self.group_column].unique() | ||
groups_test = self.test[self.group_column].unique() | ||
groups_val = self.val[self.group_column].unique() | ||
group_intersection = pairwise_intersection(groups_train, groups_test, groups_val) | ||
if len(group_intersection) != 0: | ||
raise ValueError("Train, Test, and Val splits must have no intersecting groups, found: {}" | ||
.format(group_intersection)) | ||
|
||
if (not self.allow_empty) and any([len(x) == 0 for x in [unique_train, unique_val]]): | ||
raise ValueError("train_ids({}), val_ids({}) must have at least one value" | ||
.format(len(unique_train), len(unique_val))) | ||
|
@@ -69,12 +88,16 @@ def restrict_subjects(self, restriction_pattern: str) -> DatasetSplits: | |
""" | ||
Creates a new dataset split that has at most the specified numbers of subjects in train, validation and test | ||
sets respectively. | ||
|
||
If `group_column` was specified, this operation may violate the grouping constraints and the resulting splits | ||
object will have `group_column == None`. | ||
|
||
:param restriction_pattern: a string containing zero or two commas, and otherwise digits or "+". An empty | ||
substring will result in no restriction for the corresponding dataset. Thus "20,,3" means "restrict to 20 | ||
training images and 3 test images, with no restriction on validation". A "+" value means "reassign all | ||
images from the set(s) with a numeric count (there must be at least one) to this set". Thus ",0,+" means "leave | ||
the training set alone, but move all validation images to the test set", and "0,2,+" means "move | ||
all training images and all but 2 validation images to the test set". | ||
substring will result in no restriction for the corresponding dataset. Thus "20,,3" means "restrict to 20 | ||
training images and 3 test images, with no restriction on validation". A "+" value means "reassign all | ||
images from the set(s) with a numeric count (there must be at least one) to this set". Thus ",0,+" means | ||
"leave the training set alone, but move all validation images to the test set", and "0,2,+" means "move | ||
all training images and all but 2 validation images to the test set". | ||
:return: A new dataset split object with (at most) the numbers of subjects specified by restrict_pattern | ||
""" | ||
|
||
|
@@ -183,48 +206,87 @@ def get_subject_ranges_for_splits(population: Sequence[str], | |
} | ||
return result | ||
|
||
@staticmethod | ||
def _from_split_keys(df: pd.DataFrame, | ||
train_keys: Sequence[str], | ||
test_keys: Sequence[str], | ||
val_keys: Sequence[str], | ||
*, # make column names keyword-only arguments to avoid mistakes when providing both | ||
key_column: str, | ||
subject_column: str, | ||
group_column: Optional[str]) -> DatasetSplits: | ||
Comment on lines
+215
to
+217
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. both column names should have "" as the default value, and group_column=None |
||
""" | ||
Takes a slice of values from each data split train/test/val for the provided keys. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's a slice of values? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just adapted this docstring from |
||
|
||
:param df: the input DataFrame | ||
:param train_keys: keys for training. | ||
:param test_keys: keys for testing. | ||
:param val_keys: keys for validation. | ||
:param key_column: name of the column the provided keys belong to | ||
:param subject_column: subject id column name | ||
:param group_column: grouping column name; if given, samples from each group will always be | ||
in the same subset (train, val, or test) and cross-validation fold. | ||
:return: Data splits with respected dataset split ids. | ||
""" | ||
train_df = DatasetSplits.get_df_from_ids(df, train_keys, key_column) | ||
test_df = DatasetSplits.get_df_from_ids(df, test_keys, key_column) | ||
val_df = DatasetSplits.get_df_from_ids(df, val_keys, key_column) | ||
|
||
return DatasetSplits(train=train_df, test=test_df, val=val_df, | ||
subject_column=subject_column, group_column=group_column) | ||
|
||
@staticmethod | ||
def from_proportions(df: pd.DataFrame, | ||
proportion_train: float, | ||
proportion_test: float, | ||
proportion_val: float, | ||
*, # make column names keyword-only arguments to avoid mistakes when providing both | ||
subject_column: str = CSV_SUBJECT_HEADER, | ||
group_column: Optional[str] = None, | ||
shuffle: bool = True, | ||
random_seed: int = 0) -> DatasetSplits: | ||
""" | ||
Creates a split of a dataset into train, test, and validation set, according to fixed proportions using | ||
the "subject" column in the dataframe. | ||
the "subject" column in the dataframe, or the group column, if given. | ||
|
||
:param df: The dataframe containing all subjects. | ||
:param proportion_train: proportion for the train set. | ||
:param proportion_test: proportion for the test set. | ||
:param subject_column: Subject id column name | ||
:param group_column: grouping column name; if given, samples from each group will always be | ||
in the same subset (train, val, or test) and cross-validation fold. | ||
:param proportion_val: proportion for the validation set. | ||
:param shuffle: If True the subjects in the dataframe will be shuffle before performing splits. | ||
:param random_seed: Random seed to be used for shuffle 0 is default. | ||
:return: | ||
""" | ||
subjects = df[subject_column].unique() | ||
key_column: str = subject_column if group_column is None else group_column | ||
split_keys = df[key_column].unique() | ||
if shuffle: | ||
# fix the random seed so we can guarantee reproducibility when working with shuffle | ||
random.Random(random_seed).shuffle(subjects) | ||
random.Random(random_seed).shuffle(split_keys) | ||
ranges = DatasetSplits.get_subject_ranges_for_splits( | ||
subjects, | ||
split_keys, | ||
proportion_train=proportion_train, | ||
proportion_val=proportion_val, | ||
proportion_test=proportion_test | ||
) | ||
return DatasetSplits.from_subject_ids(df, | ||
return DatasetSplits._from_split_keys(df, | ||
list(ranges[ModelExecutionMode.TRAIN]), | ||
list(ranges[ModelExecutionMode.TEST]), | ||
list(ranges[ModelExecutionMode.VAL]), | ||
subject_column) | ||
key_column=key_column, | ||
subject_column=subject_column, | ||
group_column=group_column) | ||
|
||
@staticmethod | ||
def from_subject_ids(df: pd.DataFrame, | ||
train_ids: Sequence[str], | ||
test_ids: Sequence[str], | ||
val_ids: Sequence[str], | ||
subject_column: str = CSV_SUBJECT_HEADER) -> DatasetSplits: | ||
*, # make column names keyword-only arguments to avoid mistakes when providing both | ||
subject_column: str = CSV_SUBJECT_HEADER, | ||
group_column: Optional[str] = None) -> DatasetSplits: | ||
""" | ||
Assuming a DataFrame with columns subject | ||
Takes a slice of values from each data split train/test/val for the provided ids. | ||
|
@@ -234,14 +296,36 @@ def from_subject_ids(df: pd.DataFrame, | |
:param test_ids: ids for testing. | ||
:param val_ids: ids for validation. | ||
:param subject_column: subject id column name | ||
:param group_column: grouping column name; if given, samples from each group will always be | ||
in the same subset (train, val, or test) and cross-validation fold. | ||
:return: Data splits with respected dataset split ids. | ||
""" | ||
return DatasetSplits( | ||
train=DatasetSplits.get_df_from_ids(df, train_ids, subject_column), | ||
test=DatasetSplits.get_df_from_ids(df, test_ids, subject_column), | ||
val=DatasetSplits.get_df_from_ids(df, val_ids, subject_column), | ||
subject_column=subject_column | ||
) | ||
return DatasetSplits._from_split_keys(df, train_ids, test_ids, val_ids, key_column=subject_column, | ||
subject_column=subject_column, group_column=group_column) | ||
|
||
@staticmethod | ||
def from_groups(df: pd.DataFrame, | ||
train_groups: Sequence[str], | ||
test_groups: Sequence[str], | ||
val_groups: Sequence[str], | ||
*, # make column names keyword-only arguments to avoid mistakes when providing both | ||
group_column: str, | ||
subject_column: str = CSV_SUBJECT_HEADER) -> DatasetSplits: | ||
""" | ||
Assuming a DataFrame with columns subject | ||
Takes a slice of values from each data split train/test/val for the provided groups. | ||
|
||
:param df: the input DataFrame | ||
:param train_groups: groups for training. | ||
:param test_groups: groups for testing. | ||
:param val_groups: groups for validation. | ||
:param subject_column: subject id column name | ||
:param group_column: grouping column name; if given, samples from each group will always be | ||
in the same subset (train, val, or test) and cross-validation fold. | ||
:return: Data splits with respected dataset split ids. | ||
""" | ||
return DatasetSplits._from_split_keys(df, train_groups, test_groups, val_groups, key_column=group_column, | ||
subject_column=subject_column, group_column=group_column) | ||
|
||
@staticmethod | ||
def from_institutions(df: pd.DataFrame, | ||
|
@@ -256,9 +340,12 @@ def from_institutions(df: pd.DataFrame, | |
subject_ids_for_test_only: Optional[Iterable[str]] = None) -> DatasetSplits: | ||
""" | ||
Assuming a DataFrame with columns subject and institutionId | ||
|
||
Takes a slice of values from each institution based on the train/test/val proportions provided, | ||
such that for each institution there is at least one subject in each of the train/test/val splits. | ||
|
||
This method for creating `DatasetSplits` does not currently support grouping. | ||
|
||
:param df: the input DataFrame | ||
:param proportion_train: Proportion of images per institution to be used for training. | ||
:param proportion_val: Proportion of images per institution to be used for validation. | ||
|
@@ -267,11 +354,11 @@ def from_institutions(df: pd.DataFrame, | |
:param shuffle: If True the subjects in the dataframe will be shuffle before performing splits. | ||
:param random_seed: Random seed to be used for shuffle 0 is default. | ||
:param exclude_institutions: If given, all subjects where institutionId has the given value will be | ||
excluded from train, test, and validation set. | ||
excluded from train, test, and validation set. | ||
:param institutions_for_test_only: If given, all subjects where institutionId has the given value will be | ||
placed only in the test set. | ||
placed only in the test set. | ||
:param subject_ids_for_test_only: If given, all images with the provided subject Ids will be placed in the | ||
test set. | ||
test set. | ||
:return: Data splits with respected dataset split proportions per institution. | ||
""" | ||
results: Dict[ModelExecutionMode, pd.DataFrame] = {} | ||
|
@@ -347,25 +434,43 @@ def get_df_from_ids(df: pd.DataFrame, ids: Sequence[str], | |
|
||
def get_k_fold_cross_validation_splits(self, n_splits: int, random_seed: int = 0) -> List[DatasetSplits]: | ||
""" | ||
Creates K folds from the Train + Val splits | ||
Creates K folds from the Train + Val splits. | ||
|
||
If a group_column has been specified, the folds will be split such that | ||
subjects in a group will not be separated. In this case, the splits are | ||
fully deterministic, and random_seed is ignored. | ||
|
||
:param n_splits: number of folds to perform. | ||
:param random_seed: random seed to be used for shuffle 0 is default. | ||
:return: List of K dataset splits | ||
""" | ||
if n_splits <= 0: | ||
raise ValueError("n_splits must be >= 0 found {}".format(n_splits)) | ||
|
||
# calculate the random split indices | ||
k_folds = KFold(n_splits=n_splits, shuffle=True, random_state=random_seed) | ||
# concatenate train and val, as training set = train + val | ||
cv_dataset = pd.concat([self.train, self.val]) | ||
# unique subjects | ||
subject_ids = cv_dataset[self.subject_column].unique() | ||
|
||
if self.group_column is None: # perform standard subject-based k-fold cross-validation | ||
# unique subjects | ||
subject_ids = cv_dataset[self.subject_column].unique() | ||
# calculate the random split indices | ||
k_folds = KFold(n_splits=n_splits, shuffle=True, random_state=random_seed) | ||
folds_gen = k_folds.split(subject_ids) | ||
else: # perform grouped k-fold cross-validation | ||
# Here we take all entries, rather than unique, to keep subjects | ||
# matched to groups in the resulting arrays. This works, but could | ||
# perhaps be improved with group-by logic...? | ||
subject_ids = cv_dataset[self.subject_column].values | ||
groups = cv_dataset[self.group_column].values | ||
# scikit-learn uses a deterministic algorithm for grouped splits | ||
# that tries to balance the group sizes in all folds | ||
k_folds = GroupKFold(n_splits=n_splits) | ||
folds_gen = k_folds.split(subject_ids, groups=groups) | ||
|
||
ids_from_indices = lambda indices: [subject_ids[x] for x in indices] | ||
# create the number of requested splits of the dataset | ||
return [ | ||
DatasetSplits(train=self.get_df_from_ids(cv_dataset, ids_from_indices(train_indices), self.subject_column), | ||
val=self.get_df_from_ids(cv_dataset, ids_from_indices(val_indices), self.subject_column), | ||
test=self.test, | ||
subject_column=self.subject_column) for train_indices, val_indices in | ||
k_folds.split(subject_ids)] | ||
test=self.test, subject_column=self.subject_column, group_column=self.group_column) | ||
for train_indices, val_indices in folds_gen] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add here the specific change users will need to make to use this feature in their configs?