diff --git a/.idea/vcs.xml b/.idea/vcs.xml
index 02fe99be5..be6d774f0 100644
--- a/.idea/vcs.xml
+++ b/.idea/vcs.xml
@@ -3,5 +3,7 @@
+
+
\ No newline at end of file
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 7842737d9..c4733a078 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -34,7 +34,7 @@ jobs that run in AzureML.
-([#603](https://github.com/microsoft/InnerEye-DeepLearning/pull/603)) Add histopathology module
-([#614](https://github.com/microsoft/InnerEye-DeepLearning/pull/614)) Checkpoint downloading falls back to looking into AzureML if no checkpoints on disk
-([#613](https://github.com/microsoft/InnerEye-DeepLearning/pull/613)) Add additional tests for histopathology datasets
-
+-([#616](https://github.com/microsoft/InnerEye-DeepLearning/pull/616)) Add more histopathology configs and tests
### Changed
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
diff --git a/InnerEye/ML/Histopathology/datamodules/base_module.py b/InnerEye/ML/Histopathology/datamodules/base_module.py
index 4bf4557e2..cb67c5506 100644
--- a/InnerEye/ML/Histopathology/datamodules/base_module.py
+++ b/InnerEye/ML/Histopathology/datamodules/base_module.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
import pickle
from enum import Enum
from pathlib import Path
diff --git a/InnerEye/ML/Histopathology/datamodules/panda_module.py b/InnerEye/ML/Histopathology/datamodules/panda_module.py
index 0ce597377..c28073a25 100644
--- a/InnerEye/ML/Histopathology/datamodules/panda_module.py
+++ b/InnerEye/ML/Histopathology/datamodules/panda_module.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from typing import Tuple
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
diff --git a/InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py b/InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py
index 35a966b66..cbf238643 100644
--- a/InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py
+++ b/InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from typing import Tuple, Any
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
diff --git a/InnerEye/ML/Histopathology/datasets/base_dataset.py b/InnerEye/ML/Histopathology/datasets/base_dataset.py
index a7a31afe0..a03d3cb72 100644
--- a/InnerEye/ML/Histopathology/datasets/base_dataset.py
+++ b/InnerEye/ML/Histopathology/datasets/base_dataset.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from pathlib import Path
from typing import Any, Dict, Optional, Union
diff --git a/InnerEye/ML/Histopathology/datasets/default_paths.py b/InnerEye/ML/Histopathology/datasets/default_paths.py
index 0497a7ecd..a57bdff9b 100644
--- a/InnerEye/ML/Histopathology/datasets/default_paths.py
+++ b/InnerEye/ML/Histopathology/datasets/default_paths.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
PANDA_TILES_DATASET_ID = "PANDA_tiles"
TCGA_CRCK_DATASET_ID = "TCGA-CRCk"
TCGA_PRAD_DATASET_ID = "TCGA-PRAD"
diff --git a/InnerEye/ML/Histopathology/datasets/panda_dataset.py b/InnerEye/ML/Histopathology/datasets/panda_dataset.py
index 77ad2f58d..ae13c993a 100644
--- a/InnerEye/ML/Histopathology/datasets/panda_dataset.py
+++ b/InnerEye/ML/Histopathology/datasets/panda_dataset.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from pathlib import Path
from typing import Any, Dict, Union, Optional
diff --git a/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py b/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py
index e382aae5e..43520a8c9 100644
--- a/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py
+++ b/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
diff --git a/InnerEye/ML/Histopathology/datasets/tcga_crck_tiles_dataset.py b/InnerEye/ML/Histopathology/datasets/tcga_crck_tiles_dataset.py
index 783a0278a..f5614cfa8 100644
--- a/InnerEye/ML/Histopathology/datasets/tcga_crck_tiles_dataset.py
+++ b/InnerEye/ML/Histopathology/datasets/tcga_crck_tiles_dataset.py
@@ -1,7 +1,13 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
import pandas as pd
+
from torchvision.datasets.vision import VisionDataset
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
diff --git a/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py b/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py
index edb47d644..00da8af9a 100644
--- a/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py
+++ b/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from pathlib import Path
from typing import Any, Dict, Optional, Union
diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py
index 55663d074..c9af8c08d 100644
--- a/InnerEye/ML/Histopathology/models/deepmil.py
+++ b/InnerEye/ML/Histopathology/models/deepmil.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from pathlib import Path
import pandas as pd
import numpy as np
diff --git a/InnerEye/ML/Histopathology/models/encoders.py b/InnerEye/ML/Histopathology/models/encoders.py
index 43f85772d..04f454bba 100644
--- a/InnerEye/ML/Histopathology/models/encoders.py
+++ b/InnerEye/ML/Histopathology/models/encoders.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from pathlib import Path
from typing import Callable, Optional, Sequence, Tuple
diff --git a/InnerEye/ML/Histopathology/models/transforms.py b/InnerEye/ML/Histopathology/models/transforms.py
index 51cda8c4e..e50d088e0 100644
--- a/InnerEye/ML/Histopathology/models/transforms.py
+++ b/InnerEye/ML/Histopathology/models/transforms.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from pathlib import Path
from typing import Mapping, Sequence, Union
diff --git a/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py b/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py
index e1060860a..bbbf4f090 100644
--- a/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py
+++ b/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
import functools
import os
import logging
diff --git a/InnerEye/ML/Histopathology/preprocessing/tiling.py b/InnerEye/ML/Histopathology/preprocessing/tiling.py
index b0f8b6c37..ed4a28404 100644
--- a/InnerEye/ML/Histopathology/preprocessing/tiling.py
+++ b/InnerEye/ML/Histopathology/preprocessing/tiling.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
# These tiling implementations are adapted from PANDA Kaggle solutions, for example:
# https://github.com/kentaroy47/Kaggle-PANDA-1st-place-solution/blob/master/src/data_process/a00_save_tiles.py
from typing import Any, Optional, Tuple
@@ -25,7 +30,7 @@ def pad_for_tiling_2d(array: np.ndarray, tile_size: int, channels_first: Optiona
original array to obtain indices for the padded array.
"""
height, width = array.shape[1:] if channels_first else array.shape[:-1]
- padding_h = get_1d_padding(height, tile_size)
+ padding_h = get_1d_padding(height, tile_size)
padding_w = get_1d_padding(width, tile_size)
padding = [padding_h, padding_w]
channels_axis = 0 if channels_first else 2
diff --git a/InnerEye/ML/Histopathology/scripts/aggregate_metrics_crossvalidation.py b/InnerEye/ML/Histopathology/scripts/aggregate_metrics_crossvalidation.py
index 6fe4551ba..7c96b0349 100644
--- a/InnerEye/ML/Histopathology/scripts/aggregate_metrics_crossvalidation.py
+++ b/InnerEye/ML/Histopathology/scripts/aggregate_metrics_crossvalidation.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
"""
Script to find mean and standard deviation of desired metrics from cross validation child runs.
"""
diff --git a/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py b/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py
index c44f7ae6a..4eb3ef4ab 100644
--- a/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py
+++ b/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
"""
This script is an example of how to use the submit_to_azure_if_needed function from the hi-ml package to run the
main pre-processing function that creates tiles from slides in the PANDA dataset. The advantage of using this script
diff --git a/InnerEye/ML/Histopathology/scripts/mount_azure_dataset.py b/InnerEye/ML/Histopathology/scripts/mount_azure_dataset.py
index 4f9afb064..a94cfa7e7 100644
--- a/InnerEye/ML/Histopathology/scripts/mount_azure_dataset.py
+++ b/InnerEye/ML/Histopathology/scripts/mount_azure_dataset.py
@@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
+
from health_azure import DatasetConfig
from health_azure.utils import get_workspace
diff --git a/InnerEye/ML/Histopathology/utils/analysis_plot_utils.py b/InnerEye/ML/Histopathology/utils/analysis_plot_utils.py
index 0e9981357..c3b78d09f 100644
--- a/InnerEye/ML/Histopathology/utils/analysis_plot_utils.py
+++ b/InnerEye/ML/Histopathology/utils/analysis_plot_utils.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
import numpy as np
from typing import List, Any
@@ -85,7 +90,7 @@ def plot_box_whisker(data_list: List[Any], column_names: List[str], show_outlier
def plot_histogram(data: List[Any], title: str = "") -> None:
"""
Plot a histogram given some data
- :param data: data to be plotted
+ :param data: data to be plotted
:param title: plot title string
"""
plt.figure()
diff --git a/InnerEye/ML/Histopathology/utils/download_utils.py b/InnerEye/ML/Histopathology/utils/download_utils.py
index 10b80ebef..1addcbc55 100644
--- a/InnerEye/ML/Histopathology/utils/download_utils.py
+++ b/InnerEye/ML/Histopathology/utils/download_utils.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
import os
from pathlib import Path
diff --git a/InnerEye/ML/Histopathology/utils/layer_utils.py b/InnerEye/ML/Histopathology/utils/layer_utils.py
index a2847617d..d3b88d3c0 100644
--- a/InnerEye/ML/Histopathology/utils/layer_utils.py
+++ b/InnerEye/ML/Histopathology/utils/layer_utils.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from typing import Callable, Tuple
from torch import as_tensor, device, nn, prod, rand
@@ -25,7 +30,7 @@ def setup_feature_extractor(pretrained_model: nn.Module,
def load_weights_to_model(weights_url: str, model: nn.Module) -> nn.Module:
"""
Load weights to the histoSSL model from the given URL
- https://github.com/ozanciga/self-supervised-histopathology
+ https://github.com/ozanciga/self-supervised-histopathology
"""
map_location = device('cpu')
state = load_state_dict_from_url(weights_url, map_location=map_location)
diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py
index cc42e9d42..834ac4182 100644
--- a/InnerEye/ML/Histopathology/utils/metrics_utils.py
+++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from typing import Tuple, List, Any, Dict
import torch
import matplotlib.pyplot as plt
diff --git a/InnerEye/ML/Histopathology/utils/naming.py b/InnerEye/ML/Histopathology/utils/naming.py
index b1731ae73..32d46d54d 100644
--- a/InnerEye/ML/Histopathology/utils/naming.py
+++ b/InnerEye/ML/Histopathology/utils/naming.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from enum import Enum
class ResultsKey(str, Enum):
diff --git a/InnerEye/ML/Histopathology/utils/tcga_utils.py b/InnerEye/ML/Histopathology/utils/tcga_utils.py
index c662169d9..64c43a1a0 100644
--- a/InnerEye/ML/Histopathology/utils/tcga_utils.py
+++ b/InnerEye/ML/Histopathology/utils/tcga_utils.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
import pandas as pd
diff --git a/InnerEye/ML/Histopathology/utils/viz_utils.py b/InnerEye/ML/Histopathology/utils/viz_utils.py
index 379559e88..1c4bff791 100644
--- a/InnerEye/ML/Histopathology/utils/viz_utils.py
+++ b/InnerEye/ML/Histopathology/utils/viz_utils.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
import math
import matplotlib.pyplot as plt
diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py
new file mode 100644
index 000000000..29ce29da8
--- /dev/null
+++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py
@@ -0,0 +1,112 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
+"""BaseMIL is an abstract container defining basic functionality for running MIL experiments.
+It is responsible for instantiating the encoder and full DeepMIL model. Subclasses should define
+their datamodules and configure experiment-specific parameters.
+"""
+import os
+from pathlib import Path
+from typing import Type
+
+import param
+from torch import nn
+from torchvision.models.resnet import resnet18
+
+from health_azure.utils import CheckpointDownloader, get_workspace
+from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer
+from InnerEye.Common import fixed_paths
+from InnerEye.ML.lightning_container import LightningContainer
+from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, TilesDataModule
+from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
+from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, IdentityEncoder,
+ ImageNetEncoder, ImageNetSimCLREncoder,
+ InnerEyeSSLEncoder, TileEncoder)
+
+
+class BaseMIL(LightningContainer):
+ # Model parameters:
+ pooling_type: str = param.String(doc="Name of the pooling layer class to use.")
+ # l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass
+
+ # Encoder parameters:
+ encoder_type: str = param.String(doc="Name of the encoder class to use.")
+ tile_size: int = param.Integer(224, bounds=(1, None), doc="Tile width/height, in pixels.")
+ n_channels: int = param.Integer(3, bounds=(1, None), doc="Number of channels in the tile.")
+
+ # Data module parameters:
+ batch_size: int = param.Integer(16, bounds=(1, None), doc="Number of slides to load per batch.")
+ max_bag_size: int = param.Integer(1000, bounds=(0, None),
+ doc="Upper bound on number of tiles in each loaded bag. "
+ "If 0 (default), will return all samples in each bag. "
+ "If > 0, bags larger than `max_bag_size` will yield "
+ "random subsets of instances.")
+ cache_mode: CacheMode = param.ClassSelector(default=CacheMode.MEMORY, class_=CacheMode,
+ doc="The type of caching to perform: "
+ "'memory' (default), 'disk', or 'none'.")
+ save_precache: bool = param.Boolean(True, doc="Whether to pre-cache the entire transformed "
+ "dataset upfront and save it to disk.")
+ # local_dataset (used as data module root_path) is declared in DatasetParams superclass
+
+ @property
+ def cache_dir(self) -> Path:
+ raise NotImplementedError
+
+ def setup(self) -> None:
+ if self.encoder_type == InnerEyeSSLEncoder.__name__:
+ self.downloader = CheckpointDownloader(
+ aml_workspace=get_workspace(),
+ run_id="updated_transforms:updated_transforms_1636471522_5473e3ff",
+ checkpoint_filename="best_checkpoint.ckpt",
+ download_dir='outputs/'
+ )
+ os.chdir(fixed_paths.repository_root_directory())
+ self.downloader.download_checkpoint_if_necessary()
+
+ self.encoder = self.get_encoder()
+ self.encoder.cuda()
+ self.encoder.eval()
+
+ def get_encoder(self) -> TileEncoder:
+ if self.encoder_type == ImageNetEncoder.__name__:
+ return ImageNetEncoder(feature_extraction_model=resnet18,
+ tile_size=self.tile_size, n_channels=self.n_channels)
+
+ elif self.encoder_type == ImageNetSimCLREncoder.__name__:
+ return ImageNetSimCLREncoder(tile_size=self.tile_size, n_channels=self.n_channels)
+
+ elif self.encoder_type == HistoSSLEncoder.__name__:
+ return HistoSSLEncoder(tile_size=self.tile_size, n_channels=self.n_channels)
+
+ elif self.encoder_type == InnerEyeSSLEncoder.__name__:
+ return InnerEyeSSLEncoder(pl_checkpoint_path=self.downloader.local_checkpoint_path,
+ tile_size=self.tile_size, n_channels=self.n_channels)
+
+ else:
+ raise ValueError(f"Unsupported encoder type: {self.encoder_type}")
+
+ def get_pooling_layer(self) -> Type[nn.Module]:
+ if self.pooling_type == AttentionLayer.__name__:
+ return AttentionLayer
+ elif self.pooling_type == GatedAttentionLayer.__name__:
+ return GatedAttentionLayer
+ else:
+ raise ValueError(f"Unsupported pooling type: {self.pooling_type}")
+
+ def create_model(self) -> DeepMILModule:
+ self.data_module = self.get_data_module()
+ # Encoding is done in the datamodule, so here we provide instead a dummy
+ # no-op IdentityEncoder to be used inside the model
+ return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)),
+ label_column=self.data_module.train_dataset.LABEL_COLUMN,
+ n_classes=self.data_module.train_dataset.N_CLASSES,
+ pooling_layer=self.get_pooling_layer(),
+ class_weights=self.data_module.class_weights,
+ l_rate=self.l_rate,
+ weight_decay=self.weight_decay,
+ adam_betas=self.adam_betas)
+
+ def get_data_module(self) -> TilesDataModule:
+ raise NotImplementedError
diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py
new file mode 100644
index 000000000..a2c1900f0
--- /dev/null
+++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py
@@ -0,0 +1,121 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
+"""DeepSMILECrck is the container for experiments relating to DeepSMILE using the TCGA-CRCk dataset.
+Run using `python InnerEyePrivate/ML/runner.py --model=DeepSMILECrck --encoder_type=`
+
+For convenience, this module also defines encoder-specific containers that can be invoked without
+additional arguments, e.g. `python InnerEyePrivate/ML/runner.py --model=TcgaCrckImageNetMIL`
+
+Reference:
+- Schirris (2021). DeepSMILE: Self-supervised heterogeneity-aware multiple instance learning for DNA
+damage response defect classification directly from H&E whole-slide images. arXiv:2107.09405
+"""
+from pathlib import Path
+from typing import Any, Dict
+
+from monai.transforms import Compose
+from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
+
+from health_ml.networks.layers.attention_layers import GatedAttentionLayer
+from InnerEye.Common import fixed_paths
+from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL
+from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
+from InnerEye.ML.Histopathology.datamodules.tcga_crck_module import TcgaCrckTilesDataModule
+from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, ImageNetEncoder,
+ ImageNetSimCLREncoder, InnerEyeSSLEncoder)
+from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTilesBatchd
+from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
+
+
+class DeepSMILECrck(BaseMIL):
+ def __init__(self, **kwargs: Any) -> None:
+ # Define dictionary with default params that can be overriden from subclasses or CLI
+ default_kwargs = dict(
+ # declared in BaseMIL:
+ pooling_type=GatedAttentionLayer.__name__,
+
+ # declared in DatasetParams:
+ local_dataset=Path("/tmp/datasets/TCGA-CRCk"),
+ azure_dataset_id="TCGA-CRCk",
+ # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI
+
+ # declared in TrainerParams:
+ num_epochs=16,
+ recovery_checkpoint_save_interval=16,
+ recovery_checkpoints_save_last_k=-1,
+
+ # declared in WorkflowParams:
+ number_of_cross_validation_splits=5,
+ cross_validation_split_index=0,
+
+ # declared in OptimizerParams:
+ l_rate=5e-4,
+ weight_decay=1e-4,
+ adam_betas=(0.9, 0.99),
+ )
+ default_kwargs.update(kwargs)
+ super().__init__(**default_kwargs)
+
+ self.best_checkpoint_filename = "checkpoint_max_val_auroc"
+ self.best_checkpoint_filename_with_suffix = self.best_checkpoint_filename + ".ckpt"
+ self.checkpoint_folder_path = "outputs/checkpoints/"
+
+ best_checkpoint_callback = ModelCheckpoint(dirpath=self.checkpoint_folder_path,
+ monitor='val/auroc',
+ filename=self.best_checkpoint_filename,
+ auto_insert_metric_name=False,
+ mode='max')
+ self.callbacks = best_checkpoint_callback
+
+ @property
+ def cache_dir(self) -> Path:
+ return Path(f"/tmp/innereye_cache/{self.__class__.__name__}-{self.encoder_type}/")
+
+ def get_data_module(self) -> TilesDataModule:
+ image_key = TcgaCrck_TilesDataset.IMAGE_COLUMN
+ transform = Compose([LoadTilesBatchd(image_key, progress=True),
+ EncodeTilesBatchd(image_key, self.encoder)])
+ return TcgaCrckTilesDataModule(root_path=self.local_dataset,
+ max_bag_size=self.max_bag_size,
+ batch_size=self.batch_size,
+ transform=transform,
+ cache_mode=self.cache_mode,
+ save_precache=self.save_precache,
+ cache_dir=self.cache_dir,
+ number_of_cross_validation_splits=self.number_of_cross_validation_splits,
+ cross_validation_split_index=self.cross_validation_split_index)
+
+ def get_trainer_arguments(self) -> Dict[str, Any]:
+ # These arguments will be passed through to the Lightning trainer.
+ return {"callbacks": self.callbacks}
+
+ def get_path_to_best_checkpoint(self) -> Path:
+ """
+ Returns the full path to a checkpoint file that was found to be best during training, whatever criterion
+ was applied there.
+ """
+ # absolute path is required for registering the model.
+ return fixed_paths.repository_root_directory() / self.checkpoint_folder_path / self.best_checkpoint_filename_with_suffix
+
+
+class TcgaCrckImageNetMIL(DeepSMILECrck):
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs)
+
+
+class TcgaCrckImageNetSimCLRMIL(DeepSMILECrck):
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(encoder_type=ImageNetSimCLREncoder.__name__, **kwargs)
+
+
+class TcgaCrckInnerEyeSSLMIL(DeepSMILECrck):
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(encoder_type=InnerEyeSSLEncoder.__name__, **kwargs)
+
+
+class TcgaCrckHistoSSLMIL(DeepSMILECrck):
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(encoder_type=HistoSSLEncoder.__name__, **kwargs)
diff --git a/Tests/Azure/test_azure_util.py b/Tests/Azure/test_azure_util.py
index 612da843f..45cb88465 100644
--- a/Tests/Azure/test_azure_util.py
+++ b/Tests/Azure/test_azure_util.py
@@ -8,6 +8,8 @@
from azureml.core import Run
from azureml.core.workspace import Workspace
+from health_azure.utils import is_run_and_child_runs_completed
+
from InnerEye.Azure.azure_config import AzureConfig
from InnerEye.Azure.azure_runner import create_experiment_name
from InnerEye.Azure.azure_util import DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, fetch_child_runs, fetch_run, \
@@ -16,9 +18,9 @@
from InnerEye.Common.common_util import logging_to_stdout
from InnerEye.Common.fixed_paths import PRIVATE_SETTINGS_FILE, PROJECT_SECRETS_FILE, \
repository_root_directory
+
from Tests.AfterTraining.test_after_training import FALLBACK_ENSEMBLE_RUN, get_most_recent_run, get_most_recent_run_id
from Tests.ML.util import get_default_workspace
-from health_azure.utils import is_run_and_child_runs_completed
def test_os_path_to_azure_friendly_container_path() -> None:
diff --git a/Tests/ML/histopathology/datamodules/test_datamodule_caching.py b/Tests/ML/histopathology/datamodules/test_datamodule_caching.py
index 71575bd68..41325f401 100644
--- a/Tests/ML/histopathology/datamodules/test_datamodule_caching.py
+++ b/Tests/ML/histopathology/datamodules/test_datamodule_caching.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
import shutil
from pathlib import Path
from typing import Any, Tuple
diff --git a/Tests/ML/histopathology/datasets/test_tcga_crck_dataset.py b/Tests/ML/histopathology/datasets/test_tcga_crck_dataset.py
index 76509cbd7..c0b6018a7 100644
--- a/Tests/ML/histopathology/datasets/test_tcga_crck_dataset.py
+++ b/Tests/ML/histopathology/datasets/test_tcga_crck_dataset.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
import os
import pytest
diff --git a/Tests/ML/histopathology/datasets/test_tcga_prad_dataset.py b/Tests/ML/histopathology/datasets/test_tcga_prad_dataset.py
index 49c9f8ec9..e23f0e212 100644
--- a/Tests/ML/histopathology/datasets/test_tcga_prad_dataset.py
+++ b/Tests/ML/histopathology/datasets/test_tcga_prad_dataset.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
import os
import pytest
diff --git a/Tests/ML/histopathology/models/test_deepmil.py b/Tests/ML/histopathology/models/test_deepmil.py
new file mode 100644
index 000000000..60df2f759
--- /dev/null
+++ b/Tests/ML/histopathology/models/test_deepmil.py
@@ -0,0 +1,194 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
+import os
+from typing import Callable, Dict, List
+
+import pytest
+from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
+from torchvision.models import resnet18
+
+from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer
+
+from InnerEye.ML.configs.histo_configs.classification.DeepSMILECrck import DeepSMILECrck
+from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
+from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
+from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
+from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder, TileEncoder
+from InnerEye.ML.Histopathology.utils.naming import ResultsKey
+
+
+def get_supervised_imagenet_encoder() -> TileEncoder:
+ return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)
+
+
+@pytest.mark.parametrize("n_classes", [1, 3])
+@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
+@pytest.mark.parametrize("batch_size", [1, 15])
+@pytest.mark.parametrize("max_bag_size", [1, 7])
+@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
+@pytest.mark.parametrize("pool_out_dim", [1, 6])
+def test_lightningmodule(n_classes: int,
+ pooling_layer: Callable[[int, int, int], nn.Module],
+ batch_size: int,
+ max_bag_size: int,
+ pool_hidden_dim: int,
+ pool_out_dim: int) -> None:
+
+ assert n_classes > 0
+
+ # hard-coded here to avoid test explosion; correctness of other encoders is tested elsewhere
+ encoder = get_supervised_imagenet_encoder()
+ module = DeepMILModule(encoder=encoder,
+ label_column='label',
+ n_classes=n_classes,
+ pooling_layer=pooling_layer,
+ pool_hidden_dim=pool_hidden_dim,
+ pool_out_dim=pool_out_dim)
+
+ bag_images = rand([batch_size, max_bag_size, *module.encoder.input_dim])
+ bag_labels_list = []
+ bag_logits_list = []
+ bag_attn_list = []
+ for bag in bag_images:
+ if n_classes > 1:
+ labels = randint(n_classes, size=(max_bag_size,))
+ else:
+ labels = randint(n_classes+1, size=(max_bag_size,))
+ bag_labels_list.append(module.get_bag_label(labels))
+ logit, attn = module(bag)
+ assert logit.shape == (1, n_classes)
+ assert attn.shape == (module.pool_out_dim, max_bag_size)
+ bag_logits_list.append(logit.view(-1))
+ bag_attn_list.append(attn)
+
+ bag_logits = stack(bag_logits_list)
+ bag_labels = stack(bag_labels_list).view(-1)
+
+ assert bag_logits.shape[0] == (batch_size)
+ assert bag_labels.shape[0] == (batch_size)
+
+ if module.n_classes > 1:
+ loss = module.loss_fn(bag_logits, bag_labels)
+ else:
+ loss = module.loss_fn(bag_logits.squeeze(1), bag_labels.float())
+
+ assert loss > 0
+ assert loss.shape == ()
+
+ probs = module.activation_fn(bag_logits)
+ assert ((probs >= 0) & (probs <= 1)).all()
+ if n_classes > 1:
+ assert probs.shape == (batch_size, n_classes)
+ else:
+ assert probs.shape[0] == batch_size
+
+ if n_classes > 1:
+ preds = argmax(probs, dim=1)
+ else:
+ preds = round(probs)
+ assert preds.shape[0] == batch_size
+
+ for metric_name, metric_object in module.train_metrics.items():
+ if (batch_size > 1) or (not metric_name == 'auroc'):
+ score = metric_object(preds.view(-1, 1), bag_labels.view(-1, 1))
+ assert score >= 0 and score <= 1
+
+
+def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
+ device = 'cuda' if use_gpu else 'cpu'
+ return {key: [value.to(device) if isinstance(value, Tensor) else value for value in values]
+ for key, values in batch.items()}
+
+
+@pytest.mark.parametrize("use_gpu", [True, False])
+def test_container(use_gpu: bool) -> None:
+ container_type = DeepSMILECrck
+ dataset_dir = TCGA_CRCK_DATASET_DIR
+ if not os.path.isdir(dataset_dir):
+ pytest.skip(f"Dataset for container {container_type.__name__} "
+ f"is unavailable: {dataset_dir}")
+
+ container = DeepSMILECrck(encoder_type=ImageNetEncoder.__name__)
+ container.setup()
+
+ data_module: TilesDataModule = container.get_data_module()
+ data_module.max_bag_size = 10
+ module = container.create_model()
+ if use_gpu:
+ module.cuda()
+
+ train_data_loader = data_module.train_dataloader()
+ for batch_idx, batch in enumerate(train_data_loader):
+ batch = move_batch_to_expected_device(batch, use_gpu)
+ loss = module.training_step(batch, batch_idx)
+ loss.retain_grad()
+ loss.backward()
+ assert loss.grad is not None
+ assert loss.shape == ()
+ assert isinstance(loss, Tensor)
+ break
+
+ val_data_loader = data_module.val_dataloader()
+ for batch_idx, batch in enumerate(val_data_loader):
+ batch = move_batch_to_expected_device(batch, use_gpu)
+ loss = module.validation_step(batch, batch_idx)
+ assert loss.shape == ()
+ assert isinstance(loss, Tensor)
+ break
+
+ test_data_loader = data_module.test_dataloader()
+ for batch_idx, batch in enumerate(test_data_loader):
+ batch = move_batch_to_expected_device(batch, use_gpu)
+ outputs_dict = module.test_step(batch, batch_idx)
+ loss = outputs_dict[ResultsKey.LOSS]
+ assert loss.shape == ()
+ assert isinstance(loss, Tensor)
+ break
+
+
+def test_class_weights_binary() -> None:
+ class_weights = Tensor([0.5, 3.5])
+ n_classes = 1
+ module = DeepMILModule(encoder=get_supervised_imagenet_encoder(),
+ label_column='label',
+ n_classes=n_classes,
+ pooling_layer=AttentionLayer,
+ pool_hidden_dim=5,
+ pool_out_dim=1,
+ class_weights=class_weights)
+ logits = Tensor(randn(1, n_classes))
+ bag_label = randint(n_classes+1, size=(1,))
+
+ pos_weight = Tensor([class_weights[1]/(class_weights[0]+1e-5)])
+ loss_weighted = module.loss_fn(logits.squeeze(1), bag_label.float())
+ criterion_unweighted = nn.BCEWithLogitsLoss()
+ loss_unweighted = criterion_unweighted(logits.squeeze(1), bag_label.float())
+ if bag_label.item() == 1:
+ assert allclose(loss_weighted, pos_weight*loss_unweighted)
+ else:
+ assert allclose(loss_weighted, loss_unweighted)
+
+
+def test_class_weights_multiclass() -> None:
+ class_weights = Tensor([0.33, 0.33, 0.33])
+ n_classes = 3
+ module = DeepMILModule(encoder=get_supervised_imagenet_encoder(),
+ label_column='label',
+ n_classes=n_classes,
+ pooling_layer=AttentionLayer,
+ pool_hidden_dim=5,
+ pool_out_dim=1,
+ class_weights=class_weights)
+ logits = Tensor(randn(1, n_classes))
+ bag_label = randint(n_classes, size=(1,))
+
+ loss_weighted = module.loss_fn(logits, bag_label)
+ criterion_unweighted = nn.CrossEntropyLoss()
+ loss_unweighted = criterion_unweighted(logits, bag_label)
+ # The weighted and unweighted loss functions give the same loss values for batch_size = 1.
+ # https://stackoverflow.com/questions/67639540/pytorch-cross-entropy-loss-weights-not-working
+ # TODO: the test should reflect actual weighted loss operation for the class weights after batch_size > 1 is implemented.
+ assert allclose(loss_weighted, loss_unweighted)
diff --git a/Tests/ML/histopathology/models/test_encoders.py b/Tests/ML/histopathology/models/test_encoders.py
index 7c102e7d9..a9ad82864 100644
--- a/Tests/ML/histopathology/models/test_encoders.py
+++ b/Tests/ML/histopathology/models/test_encoders.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
from typing import Callable
import pytest
diff --git a/Tests/ML/histopathology/models/test_transforms.py b/Tests/ML/histopathology/models/test_transforms.py
new file mode 100644
index 000000000..ba938169a
--- /dev/null
+++ b/Tests/ML/histopathology/models/test_transforms.py
@@ -0,0 +1,153 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
+import os
+from pathlib import Path
+from typing import Callable, Sequence, Union
+
+import pytest
+import torch
+from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
+from monai.transforms import Compose
+from torch.utils.data import Dataset as TorchDataset
+from torch.utils.data import Subset
+from torchvision.models import resnet18
+
+from health_ml.utils.bag_utils import BagDataset
+from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
+from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
+from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder
+from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd
+from Tests.ML.util import assert_dicts_equal
+
+
+@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
+ reason="TCGA-CRCk tiles dataset is unavailable")
+def test_load_tile() -> None:
+ tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR)
+ image_key = tiles_dataset.IMAGE_COLUMN
+ load_transform = LoadTiled(image_key)
+ index = 0
+
+ # Test that the transform affects only the image entry in the sample
+ input_sample = tiles_dataset[index]
+ loaded_sample = load_transform(input_sample)
+ assert_dicts_equal(loaded_sample, input_sample, exclude_keys=[image_key])
+
+ # Test that the MONAI Dataset applies the same transform
+ loaded_dataset = Dataset(tiles_dataset, transform=load_transform) # type:ignore
+ same_dataset_sample = loaded_dataset[index]
+ assert_dicts_equal(same_dataset_sample, loaded_sample)
+
+ # Test that loading another sample gives different results
+ different_sample = loaded_dataset[index + 1]
+ assert not torch.allclose(different_sample[image_key], loaded_sample[image_key])
+
+
+@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
+ reason="TCGA-CRCk tiles dataset is unavailable")
+def test_load_tiles_batch() -> None:
+ tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR)
+ image_key = tiles_dataset.IMAGE_COLUMN
+ max_bag_size = 5
+ bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore
+ max_bag_size=max_bag_size)
+ load_batch_transform = LoadTilesBatchd(image_key)
+ loaded_dataset = Dataset(tiles_dataset, transform=LoadTiled(image_key)) # type:ignore
+ image_shape = loaded_dataset[0][image_key].shape
+ index = 0
+
+ # Test that the transform affects only the image entry in the batch,
+ # and that the loaded images have the expected shape
+ bagged_batch = bagged_dataset[index]
+ manually_loaded_batch = load_batch_transform(bagged_batch)
+ assert_dicts_equal(manually_loaded_batch, bagged_batch, exclude_keys=[image_key])
+ assert manually_loaded_batch[image_key].shape == (max_bag_size, *image_shape)
+
+ # Test that the MONAI Dataset applies the same transform
+ loaded_bagged_dataset = Dataset(bagged_dataset, transform=load_batch_transform) # type:ignore
+ loaded_bagged_batch = loaded_bagged_dataset[index]
+ assert_dicts_equal(loaded_bagged_batch, manually_loaded_batch)
+
+ # Test that loading another batch gives different results
+ different_batch = loaded_bagged_dataset[index + 1]
+ assert not torch.allclose(different_batch[image_key], manually_loaded_batch[image_key])
+
+ # Test that loading and bagging commute
+ bagged_loaded_dataset = BagDataset(loaded_dataset, # type: ignore
+ bag_ids=tiles_dataset.slide_ids,
+ max_bag_size=max_bag_size)
+ bagged_loaded_batch = bagged_loaded_dataset[index]
+ assert_dicts_equal(bagged_loaded_batch, loaded_bagged_batch)
+
+
+def _test_cache_and_persistent_datasets(tmp_path: Path,
+ base_dataset: TorchDataset,
+ transform: Union[Sequence[Callable], Callable],
+ cache_subdir: str) -> None:
+ default_dataset = Dataset(base_dataset, transform=transform) # type: ignore
+ cached_dataset = CacheDataset(base_dataset, transform=transform) # type: ignore
+ cache_dir = tmp_path / cache_subdir
+ cache_dir.mkdir(exist_ok=True)
+ persistent_dataset = PersistentDataset(base_dataset, transform=transform, # type: ignore
+ cache_dir=cache_dir)
+
+ for default_sample, cached_sample, persistent_sample \
+ in zip(default_dataset, cached_dataset, persistent_dataset): # type: ignore
+ assert_dicts_equal(cached_sample, default_sample)
+ assert_dicts_equal(persistent_sample, default_sample)
+
+
+@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
+ reason="TCGA-CRCk tiles dataset is unavailable")
+def test_cached_loading(tmp_path: Path) -> None:
+ tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR)
+ image_key = tiles_dataset.IMAGE_COLUMN
+
+ max_num_tiles = 100
+ tiles_subset = Subset(tiles_dataset, range(max_num_tiles))
+ _test_cache_and_persistent_datasets(tmp_path,
+ tiles_subset,
+ transform=LoadTiled(image_key),
+ cache_subdir="TCGA-CRCk_tiles_cache")
+
+ max_bag_size = 5
+ max_num_bags = max_num_tiles // max_bag_size
+ bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore
+ max_bag_size=max_bag_size)
+ bagged_subset = Subset(bagged_dataset, range(max_num_bags))
+ _test_cache_and_persistent_datasets(tmp_path,
+ bagged_subset,
+ transform=LoadTilesBatchd(image_key),
+ cache_subdir="TCGA-CRCk_load_cache")
+
+
+@pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR),
+ reason="TCGA-CRCk tiles dataset is unavailable")
+@pytest.mark.parametrize('use_gpu', [False, True])
+def test_encode_tiles(tmp_path: Path, use_gpu: bool) -> None:
+ tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR)
+ image_key = tiles_dataset.IMAGE_COLUMN
+ max_bag_size = 5
+ bagged_dataset = BagDataset(tiles_dataset, bag_ids=tiles_dataset.slide_ids, # type: ignore
+ max_bag_size=max_bag_size)
+
+ encoder = ImageNetEncoder(resnet18, tile_size=224, n_channels=3)
+ if use_gpu:
+ encoder.cuda()
+
+ encode_transform = EncodeTilesBatchd(image_key, encoder)
+ transform = Compose([LoadTilesBatchd(image_key), encode_transform])
+ dataset = Dataset(bagged_dataset, transform=transform) # type: ignore
+ sample = dataset[0]
+ assert sample[image_key].shape == (max_bag_size, encoder.num_encoding)
+ # TODO: Ensure it works in DDP
+
+ max_num_bags = 20
+ bagged_subset = Subset(bagged_dataset, range(max_num_bags))
+ _test_cache_and_persistent_datasets(tmp_path,
+ bagged_subset,
+ transform=transform,
+ cache_subdir="TCGA-CRCk_embed_cache")
diff --git a/Tests/ML/histopathology/preprocessing/test_tiling.py b/Tests/ML/histopathology/preprocessing/test_tiling.py
index 891ac0029..15cddac68 100644
--- a/Tests/ML/histopathology/preprocessing/test_tiling.py
+++ b/Tests/ML/histopathology/preprocessing/test_tiling.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
import numpy as np
import pytest
diff --git a/Tests/ML/histopathology/utils/test_metrics_utils.py b/Tests/ML/histopathology/utils/test_metrics_utils.py
index aacaf2ab5..a63884477 100644
--- a/Tests/ML/histopathology/utils/test_metrics_utils.py
+++ b/Tests/ML/histopathology/utils/test_metrics_utils.py
@@ -1,3 +1,7 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
import math
from typing import List
diff --git a/Tests/ML/histopathology/utils/test_tcga_utils.py b/Tests/ML/histopathology/utils/test_tcga_utils.py
index 03f6f3441..7d100e4c5 100644
--- a/Tests/ML/histopathology/utils/test_tcga_utils.py
+++ b/Tests/ML/histopathology/utils/test_tcga_utils.py
@@ -1,3 +1,8 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+
import pandas as pd
from InnerEye.ML.Histopathology.utils.tcga_utils import extract_fields
diff --git a/Tests/ML/util.py b/Tests/ML/util.py
index 40ebf7a61..2b7afb93a 100644
--- a/Tests/ML/util.py
+++ b/Tests/ML/util.py
@@ -4,7 +4,7 @@
# ------------------------------------------------------------------------------------------
import logging
from pathlib import Path
-from typing import Any, List, Optional, Tuple, Union
+from typing import Any, Collection, List, Mapping, Optional, Tuple, Union
import numpy as np
import pandas as pd
@@ -120,6 +120,23 @@ def _assert_line(actual: str, expected: str) -> None:
assert actual == expected, content_mismatch(actual, expected)
+def assert_dicts_equal(d1: Mapping, d2: Mapping, exclude_keys: Collection[Any] = (),
+ rtol: float = 1e-5, atol: float = 1e-8) -> None:
+ assert isinstance(d1, Mapping)
+ assert isinstance(d2, Mapping)
+ keys1 = [key for key in d1 if key not in exclude_keys]
+ keys2 = [key for key in d2 if key not in exclude_keys]
+ assert keys1 == keys2
+ for key in keys1:
+ msg = f"Dictionaries differ for key '{key}': {d1[key]} vs {d2[key]}"
+ if isinstance(d1[key], torch.Tensor):
+ assert torch.allclose(d1[key], d2[key], rtol=rtol, atol=atol, equal_nan=True), msg
+ elif isinstance(d1[key], np.ndarray):
+ assert np.allclose(d1[key], d2[key], rtol=rtol, atol=atol, equal_nan=True), msg
+ else:
+ assert d1[key] == d2[key], msg
+
+
def assert_file_exists(file_path: Path) -> None:
"""
Checks if the given file exists.