-
Notifications
You must be signed in to change notification settings - Fork 143
Add class-wise metrics logging and confusion matrix to DeepMIL #647
Conversation
@@ -53,7 +55,8 @@ def __init__(self, | |||
verbose: bool = False, | |||
slide_dataset: SlidesDataset = None, | |||
tile_size: int = 224, | |||
level: int = 1) -> None: | |||
level: int = 1, | |||
class_names: List[str] = None) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type annotation has a small discrepancy - if you set the default to None, it should be "Optional[List[str]]".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, now changed
'precision': Precision(), | ||
'recall': Recall(), | ||
'f1score': F1(), | ||
'confusion_matrix': ConfusionMatrix(num_classes=self.n_classes+1)}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seeing this here: It would be good to add documentation for n_classes beyond what you have now. For two classes "0" and "1", n_classes should be set to 1, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, added to the docstring
|
||
def log_metrics(self, | ||
stage: str) -> None: | ||
valid_stages = ['train', 'test', 'val'] | ||
if stage not in valid_stages: | ||
raise Exception(f"Invalid stage. Chose one of {valid_stages}") | ||
for metric_name, metric_object in self.get_metrics_dict(stage).items(): | ||
self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) | ||
if not metric_name == "confusion_matrix": | ||
self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you replace the self.log calls with hi-ml's "log_on_epoch" function? that gives you a simpler interface and handle sync_dist better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(also in the else branch)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering if log_on_epoch will work with torch module metrics objects?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now changed, thanks
|
||
def log_metrics(self, | ||
stage: str) -> None: | ||
valid_stages = ['train', 'test', 'val'] | ||
if stage not in valid_stages: | ||
raise Exception(f"Invalid stage. Chose one of {valid_stages}") | ||
for metric_name, metric_object in self.get_metrics_dict(stage).items(): | ||
self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) | ||
if not metric_name == "confusion_matrix": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you have some if
statements that are "if not A then something() else otherthing()". If you have anyway handling both cases, it is easier to read the code if you do not negate the condition and swap if/else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I swapped if/else
@@ -338,6 +357,18 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore | |||
fig = plot_scores_hist(results) | |||
self.save_figure(fig=fig, figpath=outputs_fig_path / 'hist_scores.png') | |||
|
|||
print("Computing and saving confusion matrix...") | |||
metrics_dict = self.get_metrics_dict('test') | |||
cf_matrix = metrics_dict["confusion_matrix"].compute() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This literal "confusion_matrix" needs to be in sync with your other uses of "confusion_matrix". This is an extremely common source of errors - you change the constant somewhere, and forget to change is somewhere else (as trivial/benign as this may sound). Your code will be a lot safer (and require fewer tests) if you define those literals as constants, CONF_MATRIX_METRIC = "confusion_matrix"
and re-use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion, now defined metrics names as constants
absolute_checkpoint_path_parent = Path(fixed_paths.repository_parent_directory(), | ||
self.checkpoint_folder_path, | ||
self.best_checkpoint_filename_with_suffix) | ||
if absolute_checkpoint_path_parent.is_file(): | ||
return absolute_checkpoint_path_parent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confused. This and the variable above are exactly the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have taken this from our other config.
InnerEye-DeepLearning/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py
Line 148 in fb258d5
def get_path_to_best_checkpoint(self) -> Path: |
absolute_checkpoint_path = Path(fixed_paths.repository_root_directory(), | ||
self.checkpoint_folder_path, | ||
self.best_checkpoint_filename_with_suffix) | ||
if absolute_checkpoint_path.is_file(): | ||
return absolute_checkpoint_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm puzzled. Checkpoints are normally stored in the "outputs" folder, so that they are also available at the end of an AzureML run. Why are the checkpoints here accessed as part of repository root?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have taken this from our other config
InnerEye-DeepLearning/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py
Line 148 in fb258d5
def get_path_to_best_checkpoint(self) -> Path: |
It was only to enable the Crck config to work and may not be related to this PR. Will remove for now for another PR
assert file.exists() | ||
expected = full_ml_test_data_path("histo_heatmaps") / f"confusion_matrix_{n_classes}.png" | ||
# To update the stored results, uncomment this line: | ||
expected.write_bytes(file.read_bytes()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be a checked in - your test will always pass now!
expected.write_bytes(file.read_bytes()) | |
# expected.write_bytes(file.read_bytes()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for spotting this, changed now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job! No other comments 👍
In this PR: