Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Add class-wise metrics logging and confusion matrix to DeepMIL #647

Merged
merged 9 commits into from
Feb 1, 2022

Conversation

harshita-s
Copy link
Contributor

In this PR:

  • Absolute confusion matrix for test set is printed along with other metrics at the end of the run
  • Normalized confusion matrix for test set is saved as a heatmap figure in outputs/fig
  • Per-class accuracy is logged every epoch for train and validation stages (plots are visible along with other metrics)
  • Class names are added as an argument to DeepMIL (could be passed through the containers depending on task, default None)
  • Test for normalization and plotting the normalized confusion matrix

@@ -53,7 +55,8 @@ def __init__(self,
verbose: bool = False,
slide_dataset: SlidesDataset = None,
tile_size: int = 224,
level: int = 1) -> None:
level: int = 1,
class_names: List[str] = None) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type annotation has a small discrepancy - if you set the default to None, it should be "Optional[List[str]]".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, now changed

InnerEye/ML/Histopathology/models/deepmil.py Show resolved Hide resolved
'precision': Precision(),
'recall': Recall(),
'f1score': F1(),
'confusion_matrix': ConfusionMatrix(num_classes=self.n_classes+1)})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing this here: It would be good to add documentation for n_classes beyond what you have now. For two classes "0" and "1", n_classes should be set to 1, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, added to the docstring


def log_metrics(self,
stage: str) -> None:
valid_stages = ['train', 'test', 'val']
if stage not in valid_stages:
raise Exception(f"Invalid stage. Chose one of {valid_stages}")
for metric_name, metric_object in self.get_metrics_dict(stage).items():
self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True)
if not metric_name == "confusion_matrix":
self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you replace the self.log calls with hi-ml's "log_on_epoch" function? that gives you a simpler interface and handle sync_dist better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(also in the else branch)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if log_on_epoch will work with torch module metrics objects?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now changed, thanks


def log_metrics(self,
stage: str) -> None:
valid_stages = ['train', 'test', 'val']
if stage not in valid_stages:
raise Exception(f"Invalid stage. Chose one of {valid_stages}")
for metric_name, metric_object in self.get_metrics_dict(stage).items():
self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True)
if not metric_name == "confusion_matrix":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you have some if statements that are "if not A then something() else otherthing()". If you have anyway handling both cases, it is easier to read the code if you do not negate the condition and swap if/else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I swapped if/else

@@ -338,6 +357,18 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore
fig = plot_scores_hist(results)
self.save_figure(fig=fig, figpath=outputs_fig_path / 'hist_scores.png')

print("Computing and saving confusion matrix...")
metrics_dict = self.get_metrics_dict('test')
cf_matrix = metrics_dict["confusion_matrix"].compute()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This literal "confusion_matrix" needs to be in sync with your other uses of "confusion_matrix". This is an extremely common source of errors - you change the constant somewhere, and forget to change is somewhere else (as trivial/benign as this may sound). Your code will be a lot safer (and require fewer tests) if you define those literals as constants, CONF_MATRIX_METRIC = "confusion_matrix" and re-use it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, now defined metrics names as constants

Comment on lines 144 to 148
absolute_checkpoint_path_parent = Path(fixed_paths.repository_parent_directory(),
self.checkpoint_folder_path,
self.best_checkpoint_filename_with_suffix)
if absolute_checkpoint_path_parent.is_file():
return absolute_checkpoint_path_parent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confused. This and the variable above are exactly the same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have taken this from our other config.

It was only to enable this config to work and may not be related to this PR. Will remove for now for another PR

Comment on lines 138 to 142
absolute_checkpoint_path = Path(fixed_paths.repository_root_directory(),
self.checkpoint_folder_path,
self.best_checkpoint_filename_with_suffix)
if absolute_checkpoint_path.is_file():
return absolute_checkpoint_path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm puzzled. Checkpoints are normally stored in the "outputs" folder, so that they are also available at the end of an AzureML run. Why are the checkpoints here accessed as part of repository root?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have taken this from our other config


It was only to enable the Crck config to work and may not be related to this PR. Will remove for now for another PR

assert file.exists()
expected = full_ml_test_data_path("histo_heatmaps") / f"confusion_matrix_{n_classes}.png"
# To update the stored results, uncomment this line:
expected.write_bytes(file.read_bytes())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be a checked in - your test will always pass now!

Suggested change
expected.write_bytes(file.read_bytes())
# expected.write_bytes(file.read_bytes())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for spotting this, changed now

Copy link
Member

@dccastro dccastro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! No other comments 👍

@ant0nsc ant0nsc merged commit 710bc36 into main Feb 1, 2022
@ant0nsc ant0nsc deleted the hsharma/perclassmetrics branch February 1, 2022 13:13
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants