diff --git a/.idea/vcs.xml b/.idea/vcs.xml index 02fe99be5..be6d774f0 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -3,5 +3,7 @@ + + \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 7842737d9..c4733a078 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,7 +34,7 @@ jobs that run in AzureML. -([#603](https://github.com/microsoft/InnerEye-DeepLearning/pull/603)) Add histopathology module -([#614](https://github.com/microsoft/InnerEye-DeepLearning/pull/614)) Checkpoint downloading falls back to looking into AzureML if no checkpoints on disk -([#613](https://github.com/microsoft/InnerEye-DeepLearning/pull/613)) Add additional tests for histopathology datasets - +-([#616](https://github.com/microsoft/InnerEye-DeepLearning/pull/616)) Add more histopathology configs and tests ### Changed - ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files. diff --git a/InnerEye/ML/Histopathology/datamodules/base_module.py b/InnerEye/ML/Histopathology/datamodules/base_module.py index 4bf4557e2..cb67c5506 100644 --- a/InnerEye/ML/Histopathology/datamodules/base_module.py +++ b/InnerEye/ML/Histopathology/datamodules/base_module.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + import pickle from enum import Enum from pathlib import Path diff --git a/InnerEye/ML/Histopathology/datamodules/panda_module.py b/InnerEye/ML/Histopathology/datamodules/panda_module.py index 0ce597377..c28073a25 100644 --- a/InnerEye/ML/Histopathology/datamodules/panda_module.py +++ b/InnerEye/ML/Histopathology/datamodules/panda_module.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from typing import Tuple from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule diff --git a/InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py b/InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py index 35a966b66..cbf238643 100644 --- a/InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py +++ b/InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from typing import Tuple, Any from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule diff --git a/InnerEye/ML/Histopathology/datasets/base_dataset.py b/InnerEye/ML/Histopathology/datasets/base_dataset.py index a7a31afe0..a03d3cb72 100644 --- a/InnerEye/ML/Histopathology/datasets/base_dataset.py +++ b/InnerEye/ML/Histopathology/datasets/base_dataset.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from pathlib import Path from typing import Any, Dict, Optional, Union diff --git a/InnerEye/ML/Histopathology/datasets/default_paths.py b/InnerEye/ML/Histopathology/datasets/default_paths.py index 0497a7ecd..a57bdff9b 100644 --- a/InnerEye/ML/Histopathology/datasets/default_paths.py +++ b/InnerEye/ML/Histopathology/datasets/default_paths.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + PANDA_TILES_DATASET_ID = "PANDA_tiles" TCGA_CRCK_DATASET_ID = "TCGA-CRCk" TCGA_PRAD_DATASET_ID = "TCGA-PRAD" diff --git a/InnerEye/ML/Histopathology/datasets/panda_dataset.py b/InnerEye/ML/Histopathology/datasets/panda_dataset.py index 77ad2f58d..ae13c993a 100644 --- a/InnerEye/ML/Histopathology/datasets/panda_dataset.py +++ b/InnerEye/ML/Histopathology/datasets/panda_dataset.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from pathlib import Path from typing import Any, Dict, Union, Optional diff --git a/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py b/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py index e382aae5e..43520a8c9 100644 --- a/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py +++ b/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from pathlib import Path from typing import Any, Callable, Optional, Tuple, Union diff --git a/InnerEye/ML/Histopathology/datasets/tcga_crck_tiles_dataset.py b/InnerEye/ML/Histopathology/datasets/tcga_crck_tiles_dataset.py index 783a0278a..f5614cfa8 100644 --- a/InnerEye/ML/Histopathology/datasets/tcga_crck_tiles_dataset.py +++ b/InnerEye/ML/Histopathology/datasets/tcga_crck_tiles_dataset.py @@ -1,7 +1,13 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from pathlib import Path from typing import Any, Callable, Optional, Tuple, Union import pandas as pd + from torchvision.datasets.vision import VisionDataset from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset diff --git a/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py b/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py index edb47d644..00da8af9a 100644 --- a/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py +++ b/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from pathlib import Path from typing import Any, Dict, Optional, Union diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 55663d074..c9af8c08d 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from pathlib import Path import pandas as pd import numpy as np diff --git a/InnerEye/ML/Histopathology/models/encoders.py b/InnerEye/ML/Histopathology/models/encoders.py index 43f85772d..04f454bba 100644 --- a/InnerEye/ML/Histopathology/models/encoders.py +++ b/InnerEye/ML/Histopathology/models/encoders.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from pathlib import Path from typing import Callable, Optional, Sequence, Tuple diff --git a/InnerEye/ML/Histopathology/models/transforms.py b/InnerEye/ML/Histopathology/models/transforms.py index 51cda8c4e..e50d088e0 100644 --- a/InnerEye/ML/Histopathology/models/transforms.py +++ b/InnerEye/ML/Histopathology/models/transforms.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from pathlib import Path from typing import Mapping, Sequence, Union diff --git a/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py b/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py index e1060860a..bbbf4f090 100644 --- a/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py +++ b/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + import functools import os import logging diff --git a/InnerEye/ML/Histopathology/preprocessing/tiling.py b/InnerEye/ML/Histopathology/preprocessing/tiling.py index b0f8b6c37..ed4a28404 100644 --- a/InnerEye/ML/Histopathology/preprocessing/tiling.py +++ b/InnerEye/ML/Histopathology/preprocessing/tiling.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + # These tiling implementations are adapted from PANDA Kaggle solutions, for example: # https://github.com/kentaroy47/Kaggle-PANDA-1st-place-solution/blob/master/src/data_process/a00_save_tiles.py from typing import Any, Optional, Tuple @@ -25,7 +30,7 @@ def pad_for_tiling_2d(array: np.ndarray, tile_size: int, channels_first: Optiona original array to obtain indices for the padded array. """ height, width = array.shape[1:] if channels_first else array.shape[:-1] - padding_h = get_1d_padding(height, tile_size) + padding_h = get_1d_padding(height, tile_size) padding_w = get_1d_padding(width, tile_size) padding = [padding_h, padding_w] channels_axis = 0 if channels_first else 2 diff --git a/InnerEye/ML/Histopathology/scripts/aggregate_metrics_crossvalidation.py b/InnerEye/ML/Histopathology/scripts/aggregate_metrics_crossvalidation.py index 6fe4551ba..7c96b0349 100644 --- a/InnerEye/ML/Histopathology/scripts/aggregate_metrics_crossvalidation.py +++ b/InnerEye/ML/Histopathology/scripts/aggregate_metrics_crossvalidation.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + """ Script to find mean and standard deviation of desired metrics from cross validation child runs. """ diff --git a/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py b/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py index c44f7ae6a..4eb3ef4ab 100644 --- a/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py +++ b/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + """ This script is an example of how to use the submit_to_azure_if_needed function from the hi-ml package to run the main pre-processing function that creates tiles from slides in the PANDA dataset. The advantage of using this script diff --git a/InnerEye/ML/Histopathology/scripts/mount_azure_dataset.py b/InnerEye/ML/Histopathology/scripts/mount_azure_dataset.py index 4f9afb064..a94cfa7e7 100644 --- a/InnerEye/ML/Histopathology/scripts/mount_azure_dataset.py +++ b/InnerEye/ML/Histopathology/scripts/mount_azure_dataset.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ + from health_azure import DatasetConfig from health_azure.utils import get_workspace diff --git a/InnerEye/ML/Histopathology/utils/analysis_plot_utils.py b/InnerEye/ML/Histopathology/utils/analysis_plot_utils.py index 0e9981357..c3b78d09f 100644 --- a/InnerEye/ML/Histopathology/utils/analysis_plot_utils.py +++ b/InnerEye/ML/Histopathology/utils/analysis_plot_utils.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + import numpy as np from typing import List, Any @@ -85,7 +90,7 @@ def plot_box_whisker(data_list: List[Any], column_names: List[str], show_outlier def plot_histogram(data: List[Any], title: str = "") -> None: """ Plot a histogram given some data - :param data: data to be plotted + :param data: data to be plotted :param title: plot title string """ plt.figure() diff --git a/InnerEye/ML/Histopathology/utils/download_utils.py b/InnerEye/ML/Histopathology/utils/download_utils.py index 10b80ebef..1addcbc55 100644 --- a/InnerEye/ML/Histopathology/utils/download_utils.py +++ b/InnerEye/ML/Histopathology/utils/download_utils.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + import os from pathlib import Path diff --git a/InnerEye/ML/Histopathology/utils/layer_utils.py b/InnerEye/ML/Histopathology/utils/layer_utils.py index a2847617d..d3b88d3c0 100644 --- a/InnerEye/ML/Histopathology/utils/layer_utils.py +++ b/InnerEye/ML/Histopathology/utils/layer_utils.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from typing import Callable, Tuple from torch import as_tensor, device, nn, prod, rand @@ -25,7 +30,7 @@ def setup_feature_extractor(pretrained_model: nn.Module, def load_weights_to_model(weights_url: str, model: nn.Module) -> nn.Module: """ Load weights to the histoSSL model from the given URL - https://github.com/ozanciga/self-supervised-histopathology + https://github.com/ozanciga/self-supervised-histopathology """ map_location = device('cpu') state = load_state_dict_from_url(weights_url, map_location=map_location) diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index cc42e9d42..834ac4182 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from typing import Tuple, List, Any, Dict import torch import matplotlib.pyplot as plt diff --git a/InnerEye/ML/Histopathology/utils/naming.py b/InnerEye/ML/Histopathology/utils/naming.py index b1731ae73..32d46d54d 100644 --- a/InnerEye/ML/Histopathology/utils/naming.py +++ b/InnerEye/ML/Histopathology/utils/naming.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from enum import Enum class ResultsKey(str, Enum): diff --git a/InnerEye/ML/Histopathology/utils/tcga_utils.py b/InnerEye/ML/Histopathology/utils/tcga_utils.py index c662169d9..64c43a1a0 100644 --- a/InnerEye/ML/Histopathology/utils/tcga_utils.py +++ b/InnerEye/ML/Histopathology/utils/tcga_utils.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + import pandas as pd diff --git a/InnerEye/ML/Histopathology/utils/viz_utils.py b/InnerEye/ML/Histopathology/utils/viz_utils.py index 379559e88..1c4bff791 100644 --- a/InnerEye/ML/Histopathology/utils/viz_utils.py +++ b/InnerEye/ML/Histopathology/utils/viz_utils.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + import math import matplotlib.pyplot as plt diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py new file mode 100644 index 000000000..29ce29da8 --- /dev/null +++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +"""BaseMIL is an abstract container defining basic functionality for running MIL experiments. +It is responsible for instantiating the encoder and full DeepMIL model. Subclasses should define +their datamodules and configure experiment-specific parameters. +""" +import os +from pathlib import Path +from typing import Type + +import param +from torch import nn +from torchvision.models.resnet import resnet18 + +from health_azure.utils import CheckpointDownloader, get_workspace +from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer +from InnerEye.Common import fixed_paths +from InnerEye.ML.lightning_container import LightningContainer +from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, TilesDataModule +from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule +from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, IdentityEncoder, + ImageNetEncoder, ImageNetSimCLREncoder, + InnerEyeSSLEncoder, TileEncoder) + + +class BaseMIL(LightningContainer): + # Model parameters: + pooling_type: str = param.String(doc="Name of the pooling layer class to use.") + # l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass + + # Encoder parameters: + encoder_type: str = param.String(doc="Name of the encoder class to use.") + tile_size: int = param.Integer(224, bounds=(1, None), doc="Tile width/height, in pixels.") + n_channels: int = param.Integer(3, bounds=(1, None), doc="Number of channels in the tile.") + + # Data module parameters: + batch_size: int = param.Integer(16, bounds=(1, None), doc="Number of slides to load per batch.") + max_bag_size: int = param.Integer(1000, bounds=(0, None), + doc="Upper bound on number of tiles in each loaded bag. " + "If 0 (default), will return all samples in each bag. " + "If > 0, bags larger than `max_bag_size` will yield " + "random subsets of instances.") + cache_mode: CacheMode = param.ClassSelector(default=CacheMode.MEMORY, class_=CacheMode, + doc="The type of caching to perform: " + "'memory' (default), 'disk', or 'none'.") + save_precache: bool = param.Boolean(True, doc="Whether to pre-cache the entire transformed " + "dataset upfront and save it to disk.") + # local_dataset (used as data module root_path) is declared in DatasetParams superclass + + @property + def cache_dir(self) -> Path: + raise NotImplementedError + + def setup(self) -> None: + if self.encoder_type == InnerEyeSSLEncoder.__name__: + self.downloader = CheckpointDownloader( + aml_workspace=get_workspace(), + run_id="updated_transforms:updated_transforms_1636471522_5473e3ff", + checkpoint_filename="best_checkpoint.ckpt", + download_dir='outputs/' + ) + os.chdir(fixed_paths.repository_root_directory()) + self.downloader.download_checkpoint_if_necessary() + + self.encoder = self.get_encoder() + self.encoder.cuda() + self.encoder.eval() + + def get_encoder(self) -> TileEncoder: + if self.encoder_type == ImageNetEncoder.__name__: + return ImageNetEncoder(feature_extraction_model=resnet18, + tile_size=self.tile_size, n_channels=self.n_channels) + + elif self.encoder_type == ImageNetSimCLREncoder.__name__: + return ImageNetSimCLREncoder(tile_size=self.tile_size, n_channels=self.n_channels) + + elif self.encoder_type == HistoSSLEncoder.__name__: + return HistoSSLEncoder(tile_size=self.tile_size, n_channels=self.n_channels) + + elif self.encoder_type == InnerEyeSSLEncoder.__name__: + return InnerEyeSSLEncoder(pl_checkpoint_path=self.downloader.local_checkpoint_path, + tile_size=self.tile_size, n_channels=self.n_channels) + + else: + raise ValueError(f"Unsupported encoder type: {self.encoder_type}") + + def get_pooling_layer(self) -> Type[nn.Module]: + if self.pooling_type == AttentionLayer.__name__: + return AttentionLayer + elif self.pooling_type == GatedAttentionLayer.__name__: + return GatedAttentionLayer + else: + raise ValueError(f"Unsupported pooling type: {self.pooling_type}") + + def create_model(self) -> DeepMILModule: + self.data_module = self.get_data_module() + # Encoding is done in the datamodule, so here we provide instead a dummy + # no-op IdentityEncoder to be used inside the model + return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)), + label_column=self.data_module.train_dataset.LABEL_COLUMN, + n_classes=self.data_module.train_dataset.N_CLASSES, + pooling_layer=self.get_pooling_layer(), + class_weights=self.data_module.class_weights, + l_rate=self.l_rate, + weight_decay=self.weight_decay, + adam_betas=self.adam_betas) + + def get_data_module(self) -> TilesDataModule: + raise NotImplementedError diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py new file mode 100644 index 000000000..a2c1900f0 --- /dev/null +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py @@ -0,0 +1,121 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +"""DeepSMILECrck is the container for experiments relating to DeepSMILE using the TCGA-CRCk dataset. +Run using `python InnerEyePrivate/ML/runner.py --model=DeepSMILECrck --encoder_type=` + +For convenience, this module also defines encoder-specific containers that can be invoked without +additional arguments, e.g. `python InnerEyePrivate/ML/runner.py --model=TcgaCrckImageNetMIL` + +Reference: +- Schirris (2021). DeepSMILE: Self-supervised heterogeneity-aware multiple instance learning for DNA +damage response defect classification directly from H&E whole-slide images. arXiv:2107.09405 +""" +from pathlib import Path +from typing import Any, Dict + +from monai.transforms import Compose +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + +from health_ml.networks.layers.attention_layers import GatedAttentionLayer +from InnerEye.Common import fixed_paths +from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL +from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule +from InnerEye.ML.Histopathology.datamodules.tcga_crck_module import TcgaCrckTilesDataModule +from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, ImageNetEncoder, + ImageNetSimCLREncoder, InnerEyeSSLEncoder) +from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTilesBatchd +from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset + + +class DeepSMILECrck(BaseMIL): + def __init__(self, **kwargs: Any) -> None: + # Define dictionary with default params that can be overriden from subclasses or CLI + default_kwargs = dict( + # declared in BaseMIL: + pooling_type=GatedAttentionLayer.__name__, + + # declared in DatasetParams: + local_dataset=Path("/tmp/datasets/TCGA-CRCk"), + azure_dataset_id="TCGA-CRCk", + # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI + + # declared in TrainerParams: + num_epochs=16, + recovery_checkpoint_save_interval=16, + recovery_checkpoints_save_last_k=-1, + + # declared in WorkflowParams: + number_of_cross_validation_splits=5, + cross_validation_split_index=0, + + # declared in OptimizerParams: + l_rate=5e-4, + weight_decay=1e-4, + adam_betas=(0.9, 0.99), + ) + default_kwargs.update(kwargs) + super().__init__(**default_kwargs) + + self.best_checkpoint_filename = "checkpoint_max_val_auroc" + self.best_checkpoint_filename_with_suffix = self.best_checkpoint_filename + ".ckpt" + self.checkpoint_folder_path = "outputs/checkpoints/" + + best_checkpoint_callback = ModelCheckpoint(dirpath=self.checkpoint_folder_path, + monitor='val/auroc', + filename=self.best_checkpoint_filename, + auto_insert_metric_name=False, + mode='max') + self.callbacks = best_checkpoint_callback + + @property + def cache_dir(self) -> Path: + return Path(f"/tmp/innereye_cache/{self.__class__.__name__}-{self.encoder_type}/") + + def get_data_module(self) -> TilesDataModule: + image_key = TcgaCrck_TilesDataset.IMAGE_COLUMN + transform = Compose([LoadTilesBatchd(image_key, progress=True), + EncodeTilesBatchd(image_key, self.encoder)]) + return TcgaCrckTilesDataModule(root_path=self.local_dataset, + max_bag_size=self.max_bag_size, + batch_size=self.batch_size, + transform=transform, + cache_mode=self.cache_mode, + save_precache=self.save_precache, + cache_dir=self.cache_dir, + number_of_cross_validation_splits=self.number_of_cross_validation_splits, + cross_validation_split_index=self.cross_validation_split_index) + + def get_trainer_arguments(self) -> Dict[str, Any]: + # These arguments will be passed through to the Lightning trainer. + return {"callbacks": self.callbacks} + + def get_path_to_best_checkpoint(self) -> Path: + """ + Returns the full path to a checkpoint file that was found to be best during training, whatever criterion + was applied there. + """ + # absolute path is required for registering the model. + return fixed_paths.repository_root_directory() / self.checkpoint_folder_path / self.best_checkpoint_filename_with_suffix + + +class TcgaCrckImageNetMIL(DeepSMILECrck): + def __init__(self, **kwargs: Any) -> None: + super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs) + + +class TcgaCrckImageNetSimCLRMIL(DeepSMILECrck): + def __init__(self, **kwargs: Any) -> None: + super().__init__(encoder_type=ImageNetSimCLREncoder.__name__, **kwargs) + + +class TcgaCrckInnerEyeSSLMIL(DeepSMILECrck): + def __init__(self, **kwargs: Any) -> None: + super().__init__(encoder_type=InnerEyeSSLEncoder.__name__, **kwargs) + + +class TcgaCrckHistoSSLMIL(DeepSMILECrck): + def __init__(self, **kwargs: Any) -> None: + super().__init__(encoder_type=HistoSSLEncoder.__name__, **kwargs) diff --git a/Tests/Azure/test_azure_util.py b/Tests/Azure/test_azure_util.py index 612da843f..45cb88465 100644 --- a/Tests/Azure/test_azure_util.py +++ b/Tests/Azure/test_azure_util.py @@ -8,6 +8,8 @@ from azureml.core import Run from azureml.core.workspace import Workspace +from health_azure.utils import is_run_and_child_runs_completed + from InnerEye.Azure.azure_config import AzureConfig from InnerEye.Azure.azure_runner import create_experiment_name from InnerEye.Azure.azure_util import DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, fetch_child_runs, fetch_run, \ @@ -16,9 +18,9 @@ from InnerEye.Common.common_util import logging_to_stdout from InnerEye.Common.fixed_paths import PRIVATE_SETTINGS_FILE, PROJECT_SECRETS_FILE, \ repository_root_directory + from Tests.AfterTraining.test_after_training import FALLBACK_ENSEMBLE_RUN, get_most_recent_run, get_most_recent_run_id from Tests.ML.util import get_default_workspace -from health_azure.utils import is_run_and_child_runs_completed def test_os_path_to_azure_friendly_container_path() -> None: diff --git a/Tests/ML/histopathology/datamodules/test_datamodule_caching.py b/Tests/ML/histopathology/datamodules/test_datamodule_caching.py index 71575bd68..41325f401 100644 --- a/Tests/ML/histopathology/datamodules/test_datamodule_caching.py +++ b/Tests/ML/histopathology/datamodules/test_datamodule_caching.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + import shutil from pathlib import Path from typing import Any, Tuple diff --git a/Tests/ML/histopathology/datasets/test_tcga_crck_dataset.py b/Tests/ML/histopathology/datasets/test_tcga_crck_dataset.py index 76509cbd7..c0b6018a7 100644 --- a/Tests/ML/histopathology/datasets/test_tcga_crck_dataset.py +++ b/Tests/ML/histopathology/datasets/test_tcga_crck_dataset.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + import os import pytest diff --git a/Tests/ML/histopathology/datasets/test_tcga_prad_dataset.py b/Tests/ML/histopathology/datasets/test_tcga_prad_dataset.py index 49c9f8ec9..e23f0e212 100644 --- a/Tests/ML/histopathology/datasets/test_tcga_prad_dataset.py +++ b/Tests/ML/histopathology/datasets/test_tcga_prad_dataset.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + import os import pytest diff --git a/Tests/ML/histopathology/models/test_deepmil.py b/Tests/ML/histopathology/models/test_deepmil.py new file mode 100644 index 000000000..60df2f759 --- /dev/null +++ b/Tests/ML/histopathology/models/test_deepmil.py @@ -0,0 +1,194 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +import os +from typing import Callable, Dict, List + +import pytest +from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose +from torchvision.models import resnet18 + +from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer + +from InnerEye.ML.configs.histo_configs.classification.DeepSMILECrck import DeepSMILECrck +from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule +from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR +from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule +from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder, TileEncoder +from InnerEye.ML.Histopathology.utils.naming import ResultsKey + + +def get_supervised_imagenet_encoder() -> TileEncoder: + return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224) + + +@pytest.mark.parametrize("n_classes", [1, 3]) +@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer]) +@pytest.mark.parametrize("batch_size", [1, 15]) +@pytest.mark.parametrize("max_bag_size", [1, 7]) +@pytest.mark.parametrize("pool_hidden_dim", [1, 5]) +@pytest.mark.parametrize("pool_out_dim", [1, 6]) +def test_lightningmodule(n_classes: int, + pooling_layer: Callable[[int, int, int], nn.Module], + batch_size: int, + max_bag_size: int, + pool_hidden_dim: int, + pool_out_dim: int) -> None: + + assert n_classes > 0 + + # hard-coded here to avoid test explosion; correctness of other encoders is tested elsewhere + encoder = get_supervised_imagenet_encoder() + module = DeepMILModule(encoder=encoder, + label_column='label', + n_classes=n_classes, + pooling_layer=pooling_layer, + pool_hidden_dim=pool_hidden_dim, + pool_out_dim=pool_out_dim) + + bag_images = rand([batch_size, max_bag_size, *module.encoder.input_dim]) + bag_labels_list = [] + bag_logits_list = [] + bag_attn_list = [] + for bag in bag_images: + if n_classes > 1: + labels = randint(n_classes, size=(max_bag_size,)) + else: + labels = randint(n_classes+1, size=(max_bag_size,)) + bag_labels_list.append(module.get_bag_label(labels)) + logit, attn = module(bag) + assert logit.shape == (1, n_classes) + assert attn.shape == (module.pool_out_dim, max_bag_size) + bag_logits_list.append(logit.view(-1)) + bag_attn_list.append(attn) + + bag_logits = stack(bag_logits_list) + bag_labels = stack(bag_labels_list).view(-1) + + assert bag_logits.shape[0] == (batch_size) + assert bag_labels.shape[0] == (batch_size) + + if module.n_classes > 1: + loss = module.loss_fn(bag_logits, bag_labels) + else: + loss = module.loss_fn(bag_logits.squeeze(1), bag_labels.float()) + + assert loss > 0 + assert loss.shape == () + + probs = module.activation_fn(bag_logits) + assert ((probs >= 0) & (probs <= 1)).all() + if n_classes > 1: + assert probs.shape == (batch_size, n_classes) + else: + assert probs.shape[0] == batch_size + + if n_classes > 1: + preds = argmax(probs, dim=1) + else: + preds = round(probs) + assert preds.shape[0] == batch_size + + for metric_name, metric_object in module.train_metrics.items(): + if (batch_size > 1) or (not metric_name == 'auroc'): + score = metric_object(preds.view(-1, 1), bag_labels.view(-1, 1)) + assert score >= 0 and score <= 1 + + +def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict: + device = 'cuda' if use_gpu else 'cpu' + return {key: [value.to(device) if isinstance(value, Tensor) else value for value in values] + for key, values in batch.items()} + + +@pytest.mark.parametrize("use_gpu", [True, False]) +def test_container(use_gpu: bool) -> None: + container_type = DeepSMILECrck + dataset_dir = TCGA_CRCK_DATASET_DIR + if not os.path.isdir(dataset_dir): + pytest.skip(f"Dataset for container {container_type.__name__} " + f"is unavailable: {dataset_dir}") + + container = DeepSMILECrck(encoder_type=ImageNetEncoder.__name__) + container.setup() + + data_module: TilesDataModule = container.get_data_module() + data_module.max_bag_size = 10 + module = container.create_model() + if use_gpu: + module.cuda() + + train_data_loader = data_module.train_dataloader() + for batch_idx, batch in enumerate(train_data_loader): + batch = move_batch_to_expected_device(batch, use_gpu) + loss = module.training_step(batch, batch_idx) + loss.retain_grad() + loss.backward() + assert loss.grad is not None + assert loss.shape == () + assert isinstance(loss, Tensor) + break + + val_data_loader = data_module.val_dataloader() + for batch_idx, batch in enumerate(val_data_loader): + batch = move_batch_to_expected_device(batch, use_gpu) + loss = module.validation_step(batch, batch_idx) + assert loss.shape == () + assert isinstance(loss, Tensor) + break + + test_data_loader = data_module.test_dataloader() + for batch_idx, batch in enumerate(test_data_loader): + batch = move_batch_to_expected_device(batch, use_gpu) + outputs_dict = module.test_step(batch, batch_idx) + loss = outputs_dict[ResultsKey.LOSS] + assert loss.shape == () + assert isinstance(loss, Tensor) + break + + +def test_class_weights_binary() -> None: + class_weights = Tensor([0.5, 3.5]) + n_classes = 1 + module = DeepMILModule(encoder=get_supervised_imagenet_encoder(), + label_column='label', + n_classes=n_classes, + pooling_layer=AttentionLayer, + pool_hidden_dim=5, + pool_out_dim=1, + class_weights=class_weights) + logits = Tensor(randn(1, n_classes)) + bag_label = randint(n_classes+1, size=(1,)) + + pos_weight = Tensor([class_weights[1]/(class_weights[0]+1e-5)]) + loss_weighted = module.loss_fn(logits.squeeze(1), bag_label.float()) + criterion_unweighted = nn.BCEWithLogitsLoss() + loss_unweighted = criterion_unweighted(logits.squeeze(1), bag_label.float()) + if bag_label.item() == 1: + assert allclose(loss_weighted, pos_weight*loss_unweighted) + else: + assert allclose(loss_weighted, loss_unweighted) + + +def test_class_weights_multiclass() -> None: + class_weights = Tensor([0.33, 0.33, 0.33]) + n_classes = 3 + module = DeepMILModule(encoder=get_supervised_imagenet_encoder(), + label_column='label', + n_classes=n_classes, + pooling_layer=AttentionLayer, + pool_hidden_dim=5, + pool_out_dim=1, + class_weights=class_weights) + logits = Tensor(randn(1, n_classes)) + bag_label = randint(n_classes, size=(1,)) + + loss_weighted = module.loss_fn(logits, bag_label) + criterion_unweighted = nn.CrossEntropyLoss() + loss_unweighted = criterion_unweighted(logits, bag_label) + # The weighted and unweighted loss functions give the same loss values for batch_size = 1. + # https://stackoverflow.com/questions/67639540/pytorch-cross-entropy-loss-weights-not-working + # TODO: the test should reflect actual weighted loss operation for the class weights after batch_size > 1 is implemented. + assert allclose(loss_weighted, loss_unweighted) diff --git a/Tests/ML/histopathology/models/test_encoders.py b/Tests/ML/histopathology/models/test_encoders.py index 7c102e7d9..a9ad82864 100644 --- a/Tests/ML/histopathology/models/test_encoders.py +++ b/Tests/ML/histopathology/models/test_encoders.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from typing import Callable import pytest diff --git a/Tests/ML/histopathology/models/test_transforms.py b/Tests/ML/histopathology/models/test_transforms.py new file mode 100644 index 000000000..ba938169a --- /dev/null +++ b/Tests/ML/histopathology/models/test_transforms.py @@ -0,0 +1,153 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +import os +from pathlib import Path +from typing import Callable, Sequence, Union + +import pytest +import torch +from monai.data.dataset import CacheDataset, Dataset, PersistentDataset +from monai.transforms import Compose +from torch.utils.data import Dataset as TorchDataset +from torch.utils.data import Subset +from torchvision.models import resnet18 + +from health_ml.utils.bag_utils import BagDataset +from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR +from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset +from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder +from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd +from Tests.ML.util import assert_dicts_equal + + +@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR), + reason="TCGA-CRCk tiles dataset is unavailable") +def test_load_tile() -> None: + tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR) + image_key = tiles_dataset.IMAGE_COLUMN + load_transform = LoadTiled(image_key) + index = 0 + + # Test that the transform affects only the image entry in the sample + input_sample = tiles_dataset[index] + loaded_sample = load_transform(input_sample) + assert_dicts_equal(loaded_sample, input_sample, exclude_keys=[image_key]) + + # Test that the MONAI Dataset applies the same transform + loaded_dataset = Dataset(tiles_dataset, transform=load_transform) # type:ignore + same_dataset_sample = loaded_dataset[index] + assert_dicts_equal(same_dataset_sample, loaded_sample) + + # Test that loading another sample gives different results + different_sample = loaded_dataset[index + 1] + assert not torch.allclose(different_sample[image_key], loaded_sample[image_key]) + + +@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR), + reason="TCGA-CRCk tiles dataset is unavailable") +def test_load_tiles_batch() -> None: + tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR) + image_key = tiles_dataset.IMAGE_COLUMN + max_bag_size = 5 + bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore + max_bag_size=max_bag_size) + load_batch_transform = LoadTilesBatchd(image_key) + loaded_dataset = Dataset(tiles_dataset, transform=LoadTiled(image_key)) # type:ignore + image_shape = loaded_dataset[0][image_key].shape + index = 0 + + # Test that the transform affects only the image entry in the batch, + # and that the loaded images have the expected shape + bagged_batch = bagged_dataset[index] + manually_loaded_batch = load_batch_transform(bagged_batch) + assert_dicts_equal(manually_loaded_batch, bagged_batch, exclude_keys=[image_key]) + assert manually_loaded_batch[image_key].shape == (max_bag_size, *image_shape) + + # Test that the MONAI Dataset applies the same transform + loaded_bagged_dataset = Dataset(bagged_dataset, transform=load_batch_transform) # type:ignore + loaded_bagged_batch = loaded_bagged_dataset[index] + assert_dicts_equal(loaded_bagged_batch, manually_loaded_batch) + + # Test that loading another batch gives different results + different_batch = loaded_bagged_dataset[index + 1] + assert not torch.allclose(different_batch[image_key], manually_loaded_batch[image_key]) + + # Test that loading and bagging commute + bagged_loaded_dataset = BagDataset(loaded_dataset, # type: ignore + bag_ids=tiles_dataset.slide_ids, + max_bag_size=max_bag_size) + bagged_loaded_batch = bagged_loaded_dataset[index] + assert_dicts_equal(bagged_loaded_batch, loaded_bagged_batch) + + +def _test_cache_and_persistent_datasets(tmp_path: Path, + base_dataset: TorchDataset, + transform: Union[Sequence[Callable], Callable], + cache_subdir: str) -> None: + default_dataset = Dataset(base_dataset, transform=transform) # type: ignore + cached_dataset = CacheDataset(base_dataset, transform=transform) # type: ignore + cache_dir = tmp_path / cache_subdir + cache_dir.mkdir(exist_ok=True) + persistent_dataset = PersistentDataset(base_dataset, transform=transform, # type: ignore + cache_dir=cache_dir) + + for default_sample, cached_sample, persistent_sample \ + in zip(default_dataset, cached_dataset, persistent_dataset): # type: ignore + assert_dicts_equal(cached_sample, default_sample) + assert_dicts_equal(persistent_sample, default_sample) + + +@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR), + reason="TCGA-CRCk tiles dataset is unavailable") +def test_cached_loading(tmp_path: Path) -> None: + tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR) + image_key = tiles_dataset.IMAGE_COLUMN + + max_num_tiles = 100 + tiles_subset = Subset(tiles_dataset, range(max_num_tiles)) + _test_cache_and_persistent_datasets(tmp_path, + tiles_subset, + transform=LoadTiled(image_key), + cache_subdir="TCGA-CRCk_tiles_cache") + + max_bag_size = 5 + max_num_bags = max_num_tiles // max_bag_size + bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore + max_bag_size=max_bag_size) + bagged_subset = Subset(bagged_dataset, range(max_num_bags)) + _test_cache_and_persistent_datasets(tmp_path, + bagged_subset, + transform=LoadTilesBatchd(image_key), + cache_subdir="TCGA-CRCk_load_cache") + + +@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR), + reason="TCGA-CRCk tiles dataset is unavailable") +@pytest.mark.parametrize('use_gpu', [False, True]) +def test_encode_tiles(tmp_path: Path, use_gpu: bool) -> None: + tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR) + image_key = tiles_dataset.IMAGE_COLUMN + max_bag_size = 5 + bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore + max_bag_size=max_bag_size) + + encoder = ImageNetEncoder(resnet18, tile_size=224, n_channels=3) + if use_gpu: + encoder.cuda() + + encode_transform = EncodeTilesBatchd(image_key, encoder) + transform = Compose([LoadTilesBatchd(image_key), encode_transform]) + dataset = Dataset(bagged_dataset, transform=transform) # type: ignore + sample = dataset[0] + assert sample[image_key].shape == (max_bag_size, encoder.num_encoding) + # TODO: Ensure it works in DDP + + max_num_bags = 20 + bagged_subset = Subset(bagged_dataset, range(max_num_bags)) + _test_cache_and_persistent_datasets(tmp_path, + bagged_subset, + transform=transform, + cache_subdir="TCGA-CRCk_embed_cache") diff --git a/Tests/ML/histopathology/preprocessing/test_tiling.py b/Tests/ML/histopathology/preprocessing/test_tiling.py index 891ac0029..15cddac68 100644 --- a/Tests/ML/histopathology/preprocessing/test_tiling.py +++ b/Tests/ML/histopathology/preprocessing/test_tiling.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + import numpy as np import pytest diff --git a/Tests/ML/histopathology/utils/test_metrics_utils.py b/Tests/ML/histopathology/utils/test_metrics_utils.py index aacaf2ab5..a63884477 100644 --- a/Tests/ML/histopathology/utils/test_metrics_utils.py +++ b/Tests/ML/histopathology/utils/test_metrics_utils.py @@ -1,3 +1,7 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ import math from typing import List diff --git a/Tests/ML/histopathology/utils/test_tcga_utils.py b/Tests/ML/histopathology/utils/test_tcga_utils.py index 03f6f3441..7d100e4c5 100644 --- a/Tests/ML/histopathology/utils/test_tcga_utils.py +++ b/Tests/ML/histopathology/utils/test_tcga_utils.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + import pandas as pd from InnerEye.ML.Histopathology.utils.tcga_utils import extract_fields diff --git a/Tests/ML/util.py b/Tests/ML/util.py index 40ebf7a61..2b7afb93a 100644 --- a/Tests/ML/util.py +++ b/Tests/ML/util.py @@ -4,7 +4,7 @@ # ------------------------------------------------------------------------------------------ import logging from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Collection, List, Mapping, Optional, Tuple, Union import numpy as np import pandas as pd @@ -120,6 +120,23 @@ def _assert_line(actual: str, expected: str) -> None: assert actual == expected, content_mismatch(actual, expected) +def assert_dicts_equal(d1: Mapping, d2: Mapping, exclude_keys: Collection[Any] = (), + rtol: float = 1e-5, atol: float = 1e-8) -> None: + assert isinstance(d1, Mapping) + assert isinstance(d2, Mapping) + keys1 = [key for key in d1 if key not in exclude_keys] + keys2 = [key for key in d2 if key not in exclude_keys] + assert keys1 == keys2 + for key in keys1: + msg = f"Dictionaries differ for key '{key}': {d1[key]} vs {d2[key]}" + if isinstance(d1[key], torch.Tensor): + assert torch.allclose(d1[key], d2[key], rtol=rtol, atol=atol, equal_nan=True), msg + elif isinstance(d1[key], np.ndarray): + assert np.allclose(d1[key], d2[key], rtol=rtol, atol=atol, equal_nan=True), msg + else: + assert d1[key] == d2[key], msg + + def assert_file_exists(file_path: Path) -> None: """ Checks if the given file exists.