-
Notifications
You must be signed in to change notification settings - Fork 143
Implement grouped dataset splits and cross-validation #363
Conversation
The expression 'all([len(x[mode]) >= 1] for mode in x.keys())' will always evaluate to True, because 'bool([False]) == True'.
Previously, it erroneously checked for empty three-way intersection of train, test, and val, whereas the correct check is for pairwise intersections: train-test, train-val, and test-val.
This employs scikit-learn's GroupKFold class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency across configs, we should consider having a documented default for the group column name in the csv. For example, we have a default subject column name for segmentation dataset csv files, and we have a parameter in the config which is used to specify the subject column name in scalar datasets.
InnerEye/ML/utils/split_dataset.py
Outdated
|
||
def pairwise_intersection(*collections: Iterable) -> Set: | ||
"""Returns any element that appears in more than one collection.""" | ||
from itertools import combinations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a specific reason this import is local?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really; well spotted! I'll move it to the top.
CHANGELOG.md
Outdated
@@ -12,6 +12,7 @@ created. | |||
|
|||
### Added | |||
- New extensions of SegmentationModelBases `HeadAndNeckBase` and `ProstateBase`. Use these classes to build your own Head&Neck or Prostate models, by just providing a list of foreground classes. | |||
- Grouped dataset splits and k-fold cross-validation. This allows, for example, training on datasets with multiple images per subject without leaking data from the same subject across train/test/validation sets or cross-validation folds. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add here the specific change users will need to make to use this feature in their configs?
I thought about adding a default group column name, but couldn't immediately come up with a good catch-all solution... I believe in most cases that would be |
key_column: str, | ||
subject_column: str, | ||
group_column: Optional[str]) -> DatasetSplits: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
both column names should have "" as the default value, and group_column=None
subject_column: str, | ||
group_column: Optional[str]) -> DatasetSplits: | ||
""" | ||
Takes a slice of values from each data split train/test/val for the provided keys. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's a slice of values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just adapted this docstring from from_subject_ids()
.
This PR adds the ability to specify a
group_column
in addition to the primarysubject_column
when creatingDatasetSplits
. If given, this ensures that subjects within each group cannot be in separate training/test/validation sets or cross-validation folds.