Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Enable custom loss, prediction targets, and reporting in scalar configs #442

Merged
merged 70 commits into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
dffa9c6
Add first version of cross-validation report
dccastro Feb 9, 2021
f9cd893
Update docstring in notebook_report.py
dccastro Feb 23, 2021
62e3842
Update CHANGELOG
dccastro Feb 23, 2021
db0d91c
Fix mypy
dccastro Feb 23, 2021
a70ff14
Add custom scalar loss type
dccastro Mar 30, 2021
1230d78
Allow target names to differ from class_names
dccastro Mar 30, 2021
53b8f88
Allow control over whether to generate multilabel report
dccastro Mar 30, 2021
5a97033
Enable custom post-hoc label transformation
dccastro Mar 30, 2021
ec9955c
Add checks for classification reports with empty classes
dccastro Mar 30, 2021
02445ec
Reduce wasteful operations in clf report
dccastro Apr 6, 2021
4f82938
Merge remote-tracking branch 'origin/main' into dacoelh/crossval-report
dccastro Apr 7, 2021
97c0238
Add print_table()
dccastro Apr 8, 2021
62252a5
Report classification metrics as table
dccastro Apr 8, 2021
37b56f0
Extract get_labels_and_predictions_from_dataframe()
dccastro Apr 8, 2021
7aa176e
Optionally filter metrics CSV by crossval split
dccastro Apr 8, 2021
845d01b
Extract get_all_metrics() from print_metrics()
dccastro Apr 8, 2021
8ca740d
Optionally filter CSV by Train/Val/Test
dccastro Apr 12, 2021
498a522
Streamline formatting of PR/ROC plot axes
dccastro Apr 12, 2021
040c849
Implement crossval PR/ROC plots
dccastro Apr 12, 2021
23e4e33
Enable generation of crossval metrics table
dccastro Apr 12, 2021
fd78140
Fix positional arg bug in crossval ROC plot
dccastro Apr 12, 2021
4b200dc
Add data_split arg to plot curves & print metrics
dccastro Apr 12, 2021
f0a4459
Add explicit is_crossval_report arg
dccastro Apr 12, 2021
49817ee
Move crossval report trigger into MLRunner
dccastro Apr 12, 2021
d0bb961
Fix data_split in val metrics test file
dccastro Apr 12, 2021
679143c
Add basic test for crossval report generation
dccastro Apr 12, 2021
ef2b12a
Fix mypy and flake8
dccastro Apr 12, 2021
a968a92
Remove obsolete plot_auc()
dccastro Apr 12, 2021
2433941
Add & update classification report docs
dccastro Apr 12, 2021
330ebe3
Update CHANGELOG
dccastro Apr 13, 2021
c0a882b
Merge remote-tracking branch 'origin/main' into dacoelh/crossval-report
dccastro Apr 13, 2021
2c296c7
Add tests for quantiles and metrics CSV filtering
dccastro Apr 13, 2021
379c4c4
Use np.quantile() instead of own implementation
dccastro Apr 15, 2021
fac9cbb
Use more meaningful names than val_* and test_*
dccastro Apr 15, 2021
e4fa7dc
Improve readability of print_metrics_for_all_prediction_targets()
dccastro Apr 15, 2021
4d06e89
Add explicit specificity/sensitivity metrics
dccastro Apr 15, 2021
2d3e07d
Avoid repetition in computing metrics to report
dccastro Apr 15, 2021
f99f35f
Fix flake8 and mypy
dccastro Apr 15, 2021
b633714
Rename get_metrics_for_crossval_split()
dccastro Apr 15, 2021
5a26778
Use pandas.DataFrame to format HTML tables
dccastro Apr 15, 2021
d2b3977
Rename 'interval' to avoid confusion
dccastro Apr 16, 2021
1631f6f
Merge remote-tracking branch 'origin/main' into dacoelh/crossval-report
dccastro Apr 16, 2021
4d2160f
Add crossval report integration test
dccastro Apr 16, 2021
0033a59
Fix mypy
dccastro Apr 16, 2021
71a862a
Add epoch option to read metrics, default last
dccastro Apr 16, 2021
ce01374
Fix mypy and flake8
dccastro Apr 16, 2021
2fe2e7d
Merge remote-tracking branch 'origin/main' into dacoelh/custom-labels
dccastro Apr 16, 2021
4a4a636
Merge remote-tracking branch 'origin/dacoelh/crossval-report' into da…
dccastro Apr 16, 2021
b8d70f3
Fix & add new tests for get_labels_and_predictions()
dccastro Apr 16, 2021
7185c0e
Merge remote-tracking branch 'origin/dacoelh/crossval-report' into da…
dccastro Apr 16, 2021
f14c172
Refactor & add tests for metrics table generation
dccastro Apr 19, 2021
c5c9ae1
Merge branch 'main' into dacoelh/crossval-report
melanibe Apr 19, 2021
6b285ba
Fix mypy
dccastro Apr 19, 2021
0cdf490
Merge remote-tracking branch 'origin/dacoelh/crossval-report' into da…
dccastro Apr 19, 2021
9128358
Merge remote-tracking branch 'origin/main' into dacoelh/custom-labels
dccastro Apr 19, 2021
6e70e02
Merge remote-tracking branch 'origin/main' into dacoelh/custom-labels
dccastro Apr 19, 2021
b53b75b
Add generate_custom_report() (default: no-op)
dccastro Apr 19, 2021
9050333
Use target_names for best/worst samples in report
dccastro Apr 21, 2021
0b74a08
Fix flake8
dccastro Apr 21, 2021
6857be4
Update CHANGELOG
dccastro Apr 21, 2021
142e74c
Merge remote-tracking branch 'origin/main' into dacoelh/custom-labels
dccastro Apr 21, 2021
4c7ed6f
Fix mypy
dccastro Apr 21, 2021
51703b6
Add missing docs in scalar config
dccastro Apr 21, 2021
9761915
Add docstring for generate_custom_report()
dccastro Apr 21, 2021
3f5beaf
Fix mypy
dccastro Apr 21, 2021
d13ea17
Fix mypy
dccastro Apr 21, 2021
4aa29d6
Fix test_train_classification_model()
dccastro Apr 21, 2021
a293f58
Fix flake
dccastro Apr 21, 2021
2ef4f0b
Merge remote-tracking branch 'origin/main' into dacoelh/custom-labels
dccastro Apr 21, 2021
f787a0a
Update docs and CHANGELOG
dccastro Apr 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ with only minimum code changes required. See [the MD documentation](docs/bring_y
- ([#439](https://github.com/microsoft/InnerEye-DeepLearning/pull/439)) Enable automatic job recovery from last recovery
checkpoint in case of job pre-emption on AML. Give the possibility to the user to keep more than one recovery
checkpoint.
- ([#442](https://github.com/microsoft/InnerEye-DeepLearning/pull/442)) Enable custom scalar loss, prediction targets,
dccastro marked this conversation as resolved.
Show resolved Hide resolved
and reporting in scalar configs, providing more flexibility for defining model configs with custom behaviour while
leveraging the existing InnerEye workflows.

### Changed

Expand Down
4 changes: 3 additions & 1 deletion InnerEye/ML/lightning_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(self, config: ScalarModelBase, *args: Any, **kwargs: Any) -> None:
super().__init__(config, *args, **kwargs)
self.model = config.create_model()
raw_loss = model_util.create_scalar_loss_function(config)
self.posthoc_label_transform = config.get_posthoc_label_transform()
if isinstance(config, SequenceModelBase):
self.loss_fn = lambda model_output, loss: apply_sequence_model_loss(raw_loss, model_output, loss)
self.target_indices = config.get_target_indices()
Expand All @@ -186,7 +187,7 @@ def __init__(self, config: ScalarModelBase, *args: Any, **kwargs: Any) -> None:
else:
self.loss_fn = raw_loss
self.target_indices = []
self.target_names = config.class_names
self.target_names = config.target_names
self.is_classification_model = config.is_classification_model
self.use_mean_teacher_model = config.compute_mean_teacher_model
self.is_binary_classification_or_regression = True if len(config.class_names) == 1 else False
Expand Down Expand Up @@ -269,6 +270,7 @@ def training_or_validation_step(self,
"""
model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model, self.target_indices, sample)
labels = model_inputs_and_labels.labels
labels = self.posthoc_label_transform(labels)
if is_training:
logits = self.model(*model_inputs_and_labels.model_inputs)
else:
Expand Down
14 changes: 14 additions & 0 deletions InnerEye/ML/model_config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,20 @@ def set_derived_model_properties(self, model: Any) -> None:
"""
pass

def generate_custom_report(self, report_dir: Path, train_metrics: Path, val_metrics: Path,
test_metrics: Path) -> Path:
"""
Enables creating a custom results report, given the metrics files written during model training and inference.
By default, this method is a no-op.

:param report_dir: The output directory where the generated report should be saved.
:param train_metrics: The CSV file with training metrics.
:param val_metrics: The CSV file with validation metrics.
:param test_metrics: The CSV file with test metrics.
:return: The path to the generated report file.
"""
pass


class ModelTransformsPerExecutionMode:
"""
Expand Down
4 changes: 3 additions & 1 deletion InnerEye/ML/model_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def create_metrics_dict_for_scalar_models(config: ScalarModelBase) -> \
return SequenceMetricsDict.create(is_classification_model=config.is_classification_model,
sequence_target_positions=config.sequence_target_positions)
else:
return ScalarMetricsDict(hues=config.class_names,
return ScalarMetricsDict(hues=config.target_names,
is_classification_metrics=config.is_classification_model)


Expand All @@ -407,6 +407,7 @@ def classification_model_test(config: ScalarModelBase,
:param model_proc: whether we are testing an ensemble or single model
:return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs.
"""
posthoc_label_transform = config.get_posthoc_label_transform()

def test_epoch(checkpoint_paths: List[Path]) -> Optional[MetricsDict]:
pipeline = create_inference_pipeline(config=config,
Expand All @@ -431,6 +432,7 @@ def test_epoch(checkpoint_paths: List[Path]) -> Optional[MetricsDict]:
result = pipeline.predict(sample)
model_output = result.posteriors
label = result.labels.to(device=model_output.device)
label = posthoc_label_transform(label)
sample_id = result.subject_ids[0]
compute_scalar_metrics(metrics_dict,
subject_ids=[sample_id],
Expand Down
4 changes: 2 additions & 2 deletions InnerEye/ML/reports/classification_crossval_report.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@
"outputs": [],
"source": [
"if not is_crossval_report and val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" for prediction_target in config.class_names:\n",
" for prediction_target in config.target_names:\n",
" print_header(f\"Class: {prediction_target}\", level=3)\n",
" print_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
" k=number_best_and_worst_performing,\n",
Expand All @@ -242,7 +242,7 @@
"outputs": [],
"source": [
"if not is_crossval_report and val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" for prediction_target in config.class_names:\n",
" for prediction_target in config.target_names:\n",
" print_header(f\"Class: {prediction_target}\", level=3)\n",
" plot_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
" k=number_best_and_worst_performing, prediction_target=prediction_target,\n",
Expand Down
4 changes: 2 additions & 2 deletions InnerEye/ML/reports/classification_report.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@
"outputs": [],
"source": [
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" for prediction_target in config.class_names:\n",
" for prediction_target in config.target_names:\n",
" print_header(f\"Class {prediction_target}\", level=3)\n",
" print_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
" k=number_best_and_worst_performing,\n",
Expand All @@ -255,7 +255,7 @@
"outputs": [],
"source": [
"if val_metrics_csv.is_file() and test_metrics_csv.is_file():\n",
" for prediction_target in config.class_names:\n",
" for prediction_target in config.target_names:\n",
" print_header(f\"Class {prediction_target}\", level=3)\n",
" plot_k_best_and_worst_performing(val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv,\n",
" k=number_best_and_worst_performing, prediction_target=prediction_target, config=config)"
Expand Down
23 changes: 19 additions & 4 deletions InnerEye/ML/reports/classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def check_column_present(dataframe: pd.DataFrame, column: LoggingColumns) -> Non

df = pd.read_csv(csv)
df = df[df[LoggingColumns.Hue.value] == prediction_target] # Filter by prediction target
df = df[~df[LoggingColumns.Label.value].isna()] # Filter missing labels

# Filter by crossval split index
if crossval_split_index is not None:
Expand Down Expand Up @@ -279,7 +280,7 @@ def plot_pr_and_roc_curves_from_csv(metrics_csv: Path, config: ScalarModelBase,
:param is_crossval_report: If True, assumes CSV contains results for multiple cross-validation runs and plots the
curves with median and confidence intervals. Otherwise, plots curves for a single run.
"""
for prediction_target in config.class_names:
for prediction_target in config.target_names:
print_header(f"Class: {prediction_target}", level=3)
if is_crossval_report:
all_metrics = [get_labels_and_predictions(metrics_csv, prediction_target,
Expand Down Expand Up @@ -469,7 +470,7 @@ def print_metrics_for_all_prediction_targets(csv_to_set_optimal_threshold: Path,
:param is_crossval_report: If True, assumes CSVs contain results for multiple cross-validation runs and prints the
metrics along with means and standard deviations. Otherwise, prints metrics for a single run.
"""
for prediction_target in config.class_names:
for prediction_target in config.target_names:
print_header(f"Class: {prediction_target}", level=3)
rows, header = get_metrics_table_for_prediction_target(
csv_to_set_optimal_threshold=csv_to_set_optimal_threshold,
Expand All @@ -484,7 +485,7 @@ def print_metrics_for_all_prediction_targets(csv_to_set_optimal_threshold: Path,


def get_correct_and_misclassified_examples(val_metrics_csv: Path, test_metrics_csv: Path,
prediction_target: str = MetricsDict.DEFAULT_HUE_KEY) -> Results:
prediction_target: str = MetricsDict.DEFAULT_HUE_KEY) -> Optional[Results]:
"""
Given the paths to the metrics files for the validation and test sets, get a list of true positives,
false positives, false negatives and true negatives.
Expand All @@ -495,12 +496,18 @@ def get_correct_and_misclassified_examples(val_metrics_csv: Path, test_metrics_c
"""
df_val = read_csv_and_filter_prediction_target(val_metrics_csv, prediction_target)

if len(df_val) == 0:
return None

fpr, tpr, thresholds = roc_curve(df_val[LoggingColumns.Label.value], df_val[LoggingColumns.ModelOutput.value])
optimal_idx = MetricsDict.get_optimal_idx(fpr=fpr, tpr=tpr)
optimal_threshold = thresholds[optimal_idx]

df_test = read_csv_and_filter_prediction_target(test_metrics_csv, prediction_target)

if len(df_test) == 0:
return None

df_test["predicted"] = df_test.apply(lambda x: int(x[LoggingColumns.ModelOutput.value] >= optimal_threshold),
axis=1)

Expand All @@ -516,14 +523,16 @@ def get_correct_and_misclassified_examples(val_metrics_csv: Path, test_metrics_c


def get_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int,
prediction_target: str = MetricsDict.DEFAULT_HUE_KEY) -> Results:
prediction_target: str = MetricsDict.DEFAULT_HUE_KEY) -> Optional[Results]:
"""
Get the top "k" best predictions (i.e. correct classifications where the model was the most certain) and the
top "k" worst predictions (i.e. misclassifications where the model was the most confident).
"""
results = get_correct_and_misclassified_examples(val_metrics_csv=val_metrics_csv,
test_metrics_csv=test_metrics_csv,
prediction_target=prediction_target)
if results is None:
return None

# sort by model_output
sorted = Results(true_positives=results.true_positives.sort_values(by=LoggingColumns.ModelOutput.value,
Expand Down Expand Up @@ -553,6 +562,9 @@ def print_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: P
test_metrics_csv=test_metrics_csv,
k=k,
prediction_target=prediction_target)
if results is None:
print_header("Empty validation or test set", level=2)
return

print_header(f"Top {k} false positives", level=2)
for index, (subject, model_output) in enumerate(zip(results.false_positives[LoggingColumns.Patient.value],
Expand Down Expand Up @@ -729,6 +741,9 @@ def plot_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Pa
test_metrics_csv=test_metrics_csv,
k=k,
prediction_target=prediction_target)
if results is None:
print_header("Empty validation or test set", level=4)
return

test_metrics = pd.read_csv(test_metrics_csv, dtype=str)

Expand Down
7 changes: 6 additions & 1 deletion InnerEye/ML/run_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def get_epoch_path(mode: ModelExecutionMode) -> Path:
val_metrics=path_to_best_epoch_val,
test_metrics=path_to_best_epoch_test)

if len(config.class_names) > 1:
if config.should_generate_multilabel_report():
generate_classification_multilabel_notebook(
result_notebook=reports_dir / get_ipynb_report_name(
f"{config.model_category.value}_multilabel"),
Expand All @@ -883,6 +883,11 @@ def get_epoch_path(mode: ModelExecutionMode) -> Path:
test_metrics=path_to_best_epoch_test)
else:
logging.info(f"Cannot create report for config of type {type(config)}.")

config.generate_custom_report(report_dir=reports_dir,
train_metrics=path_to_best_epoch_train,
val_metrics=path_to_best_epoch_val,
test_metrics=path_to_best_epoch_test)
except Exception as ex:
print_exception(ex, "Failed to generated reporting notebook.")
raise
30 changes: 28 additions & 2 deletions InnerEye/ML/scalar_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,15 @@ class ScalarLoss(Enum):
BinaryCrossEntropyWithLogits = "BinaryCrossEntropyWithLogits"
WeightedCrossEntropyWithLogits = "WeightedCrossEntropyWithLogits"
MeanSquaredError = "MeanSquaredError"
CustomClassification = "CustomClassification"
CustomRegression = "CustomRegression"

def is_classification_loss(self) -> bool:
return self == self.BinaryCrossEntropyWithLogits or self == self.WeightedCrossEntropyWithLogits
return self in {self.BinaryCrossEntropyWithLogits, self.WeightedCrossEntropyWithLogits,
self.CustomClassification}

def is_regression_loss(self) -> bool:
return self == self.MeanSquaredError
return self in {self.MeanSquaredError, self.CustomRegression}


@unique
Expand Down Expand Up @@ -112,6 +115,11 @@ class ScalarModelBase(ModelConfigBase):
"For binary classification, this field must be a list of size 1, and "
"is by default ['Default'], but can optionally be set to a more descriptive "
"name for the positive class.")
target_names: List[str] = param.List(class_=str,
default=None,
bounds=(1, None),
doc="The label names for each output target, used for reporting results. "
dccastro marked this conversation as resolved.
Show resolved Hide resolved
"By default this matches class_names.")
aggregation_type: AggregationType = param.ClassSelector(default=AggregationType.Average, class_=AggregationType,
doc="The type of global pooling aggregation to use between"
" the encoder and the classifier.")
Expand Down Expand Up @@ -214,6 +222,8 @@ def __init__(self, num_dataset_reader_workers: int = 0, **params: Any) -> None:
"num_dataset_reader_workers to 0 as this is an AML run.")
else:
self.num_dataset_reader_workers = num_dataset_reader_workers
if self.target_names is None:
self.target_names = self.class_names

def validate(self) -> None:
if len(self.class_names) > 1 and not self.is_classification_model:
Expand All @@ -240,6 +250,10 @@ def is_non_imaging_model(self) -> bool:
"""
return len(self.image_channels) == 0

def should_generate_multilabel_report(self) -> bool:
"""Determines whether to produce a multilabel report. Override this to implement custom behaviour."""
return len(self.class_names) > 1

def get_total_number_of_non_imaging_features(self) -> int:
"""Returns the total number of non imaging features expected in the input"""
return self.get_total_number_of_numerical_non_imaging_features() + \
Expand Down Expand Up @@ -338,6 +352,12 @@ def get_label_transform(self) -> Union[Callable, List[Callable]]:
"""
return LabelTransformation.identity

def get_posthoc_label_transform(self) -> Callable:
"""Return a transformation or list of transformation to apply to the labels after they are
loaded, for computing losses, metrics, and reports.
"""
return lambda x: x # no-op by default

def read_dataset_into_dataframe_and_pre_process(self) -> None:
assert self.local_dataset is not None
file_path = self.local_dataset / self.dataset_csv
Expand Down Expand Up @@ -408,6 +428,12 @@ def get_total_number_of_training_samples(self) -> int:
def create_model(self) -> Any:
pass

def get_loss_function(self) -> Callable:
"""Returns a custom loss function to be used with ScalarLoss.CustomClassification or CustomRegression."""
assert self.loss_type in {ScalarLoss.CustomClassification, ScalarLoss.CustomRegression}, \
f"get_loss_function() should be called only for custom loss types (received {self.loss_type})"
raise NotImplementedError(f"get_loss_function() must be implemented for loss type {self.loss_type}")

def get_post_loss_logits_normalization_function(self) -> Callable:
"""
Post loss normalization function to apply to the logits produced by the model.
Expand Down
2 changes: 2 additions & 0 deletions InnerEye/ML/utils/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def create_scalar_loss_function(config: ScalarModelBase) -> torch.nn.Module:
num_train_samples=config.get_total_number_of_training_samples())
elif config.loss_type == ScalarLoss.MeanSquaredError:
return MSELoss()
elif config.loss_type == ScalarLoss.CustomClassification or config.loss_type == ScalarLoss.CustomRegression:
return config.get_loss_function() # type: ignore
else:
raise NotImplementedError(f"Loss type {config.loss_type} is not implemented")

Expand Down
2 changes: 1 addition & 1 deletion Tests/ML/models/test_scalar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_train_classification_model(class_name: str, test_output_dirs: OutputFol
"""
logging_to_stdout(logging.DEBUG)
config = ClassificationModelForTesting()
config.class_names = [class_name]
config.class_names = config.target_names = [class_name]
config.set_output_to(test_output_dirs.root_dir)
# Train for 4 epochs, checkpoints at epochs 2 and 4
config.num_epochs = 4
Expand Down
4 changes: 4 additions & 0 deletions Tests/ML/reports/test_classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def test_get_correct_and_misclassified_examples() -> None:
results = get_correct_and_misclassified_examples(val_metrics_csv=val_metrics_file,
test_metrics_csv=test_metrics_file)

assert results is not None # for mypy

true_positives = [item[LoggingColumns.Patient.value] for _, item in results.true_positives.iterrows()]
assert all([i in true_positives for i in [3, 4, 5]])

Expand All @@ -323,6 +325,8 @@ def test_get_k_best_and_worst_performing() -> None:
test_metrics_csv=test_metrics_file,
k=2)

assert results is not None # for mypy

best_true_positives = [item[LoggingColumns.Patient.value] for _, item in results.true_positives.iterrows()]
assert best_true_positives == [5, 4]

Expand Down