diff --git a/.gitattributes b/.gitattributes index 0be747b84..136bd36a2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -15,3 +15,4 @@ *.dcm filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.jpg filter=lfs diff=lfs merge=lfs -text +*.tiff filter=lfs diff=lfs merge=lfs -text diff --git a/CHANGELOG.md b/CHANGELOG.md index 4541ed468..ef596be14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ jobs that run in AzureML. - ([#614](https://github.com/microsoft/InnerEye-DeepLearning/pull/614)) Checkpoint downloading falls back to looking into AzureML if no checkpoints on disk - ([#613](https://github.com/microsoft/InnerEye-DeepLearning/pull/613)) Add additional tests for histopathology datasets - ([#616](https://github.com/microsoft/InnerEye-DeepLearning/pull/616)) Add more histopathology configs and tests +- ([#621](https://github.com/microsoft/InnerEye-DeepLearning/pull/621)) Add WSI preprocessing functions and enable tiling more generic slide datasets ### Changed - ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files. diff --git a/InnerEye/ML/Histopathology/datasets/base_dataset.py b/InnerEye/ML/Histopathology/datasets/base_dataset.py index a03d3cb72..8f4fdc31c 100644 --- a/InnerEye/ML/Histopathology/datasets/base_dataset.py +++ b/InnerEye/ML/Histopathology/datasets/base_dataset.py @@ -4,7 +4,7 @@ # ------------------------------------------------------------------------------------------ from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import pandas as pd @@ -12,6 +12,8 @@ from sklearn.utils.class_weight import compute_class_weight from torch.utils.data import Dataset +from InnerEye.ML.Histopathology.utils.naming import SlideKey + class TilesDataset(Dataset): """Base class for datasets of WSI tiles, iterating dictionaries of image paths and metadata. @@ -71,7 +73,7 @@ def __init__(self, self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME dataset_df = pd.read_csv(self.dataset_csv) - columns = [self.SLIDE_ID_COLUMN, self.IMAGE_COLUMN, self.LABEL_COLUMN, self.LABEL_COLUMN, + columns = [self.SLIDE_ID_COLUMN, self.IMAGE_COLUMN, self.LABEL_COLUMN, self.SPLIT_COLUMN, self.TILE_X_COLUMN, self.TILE_Y_COLUMN] for column in columns: if column is not None and column not in dataset_df.columns: @@ -110,3 +112,109 @@ def get_class_weights(self) -> torch.Tensor: classes = np.unique(slide_labels) class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=slide_labels) return torch.as_tensor(class_weights) + + +class SlidesDataset(Dataset): + """Base class for datasets of WSIs, iterating dictionaries of image paths and metadata. + + The output dictionaries are indexed by `..utils.naming.SlideKey`. + + :param SLIDE_ID_COLUMN: CSV column name for slide ID. + :param IMAGE_COLUMN: CSV column name for relative path to image file. + :param LABEL_COLUMN: CSV column name for tile label. + :param SPLIT_COLUMN: CSV column name for train/test split (optional). + :param TRAIN_SPLIT_LABEL: Value used to indicate the training split in `SPLIT_COLUMN`. + :param TEST_SPLIT_LABEL: Value used to indicate the test split in `SPLIT_COLUMN`. + :param DEFAULT_CSV_FILENAME: Default name of the dataset CSV at the dataset rood directory. + :param N_CLASSES: Number of classes indexed in `LABEL_COLUMN`. + """ + SLIDE_ID_COLUMN: str = 'slide_id' + IMAGE_COLUMN: str = 'image' + LABEL_COLUMN: str = 'label' + MASK_COLUMN: Optional[str] = None + SPLIT_COLUMN: Optional[str] = None + + TRAIN_SPLIT_LABEL: str = 'train' + TEST_SPLIT_LABEL: str = 'test' + + METADATA_COLUMNS: Tuple[str, ...] = () + + DEFAULT_CSV_FILENAME: str = "dataset.csv" + + N_CLASSES: int = 1 # binary classification by default + + def __init__(self, + root: Union[str, Path], + dataset_csv: Optional[Union[str, Path]] = None, + dataset_df: Optional[pd.DataFrame] = None, + train: Optional[bool] = None, + validate_columns: bool = True) -> None: + """ + :param root: Root directory of the dataset. + :param dataset_csv: Full path to a dataset CSV file, containing at least + `TILE_ID_COLUMN`, `SLIDE_ID_COLUMN`, and `IMAGE_COLUMN`. If omitted, the CSV will be read + from `"{root}/{DEFAULT_CSV_FILENAME}"`. + :param dataset_df: A potentially pre-processed dataframe in the same format as would be read + from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`. + :param train: If `True`, loads only the training split (resp. `False` for test split). By + default (`None`), loads the entire dataset as-is. + :param validate_columns: Whether to call `validate_columns()` at the end of `__init__()`. + """ + if self.SPLIT_COLUMN is None and train is not None: + raise ValueError("Train/test split was specified but dataset has no split column") + + self.root_dir = Path(root) + + if dataset_df is not None: + self.dataset_csv = None + else: + self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME + dataset_df = pd.read_csv(self.dataset_csv) + + dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN) + if train is None: + self.dataset_df = dataset_df + else: + split = self.TRAIN_SPLIT_LABEL if train else self.TEST_SPLIT_LABEL + self.dataset_df = dataset_df[dataset_df[self.SPLIT_COLUMN] == split] + + if validate_columns: + self.validate_columns() + + def validate_columns(self) -> None: + """Check that loaded dataframe contains expected columns, raises `ValueError` otherwise. + + If the constructor is overloaded in a subclass, you can pass `validate_columns=False` and + call `validate_columns()` after creating derived columns, for example. + """ + columns = [self.IMAGE_COLUMN, self.LABEL_COLUMN, self.MASK_COLUMN, + self.SPLIT_COLUMN] + list(self.METADATA_COLUMNS) + for column in columns: + if column is not None and column not in self.dataset_df.columns: + raise ValueError(f"Expected column '{column}' not found in the dataframe") + + def __len__(self) -> int: + return self.dataset_df.shape[0] + + def __getitem__(self, index: int) -> Dict[SlideKey, Any]: + slide_id = self.dataset_df.index[index] + slide_row = self.dataset_df.loc[slide_id] + sample = {SlideKey.SLIDE_ID: slide_id} + + rel_image_path = slide_row[self.IMAGE_COLUMN] + sample[SlideKey.IMAGE] = str(self.root_dir / rel_image_path) + # we're replicating this column because we want to propagate the path to the batch + sample[SlideKey.IMAGE_PATH] = sample[SlideKey.IMAGE] + + if self.MASK_COLUMN: + rel_mask_path = slide_row[self.MASK_COLUMN] + sample[SlideKey.MASK] = str(self.root_dir / rel_mask_path) + sample[SlideKey.MASK_PATH] = sample[SlideKey.MASK] + + sample[SlideKey.LABEL] = slide_row[self.LABEL_COLUMN] + sample[SlideKey.METADATA] = {col: slide_row[col] for col in self.METADATA_COLUMNS} + return sample + + @classmethod + def has_mask(cls) -> bool: + return cls.MASK_COLUMN is not None diff --git a/InnerEye/ML/Histopathology/datasets/default_paths.py b/InnerEye/ML/Histopathology/datasets/default_paths.py index a57bdff9b..a731daf48 100644 --- a/InnerEye/ML/Histopathology/datasets/default_paths.py +++ b/InnerEye/ML/Histopathology/datasets/default_paths.py @@ -3,11 +3,13 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +PANDA_DATASET_ID = "PANDA" PANDA_TILES_DATASET_ID = "PANDA_tiles" TCGA_CRCK_DATASET_ID = "TCGA-CRCk" TCGA_PRAD_DATASET_ID = "TCGA-PRAD" DEFAULT_DATASET_LOCATION = "/tmp/datasets/" +PANDA_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_DATASET_ID PANDA_TILES_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_TILES_DATASET_ID TCGA_CRCK_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_CRCK_DATASET_ID TCGA_PRAD_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_PRAD_DATASET_ID diff --git a/InnerEye/ML/Histopathology/datasets/panda_dataset.py b/InnerEye/ML/Histopathology/datasets/panda_dataset.py index ae13c993a..b84571257 100644 --- a/InnerEye/ML/Histopathology/datasets/panda_dataset.py +++ b/InnerEye/ML/Histopathology/datasets/panda_dataset.py @@ -7,50 +7,42 @@ from typing import Any, Dict, Union, Optional import pandas as pd +from cucim import CuImage +from health_ml.utils import box_utils from monai.config import KeysCollection from monai.data.image_reader import ImageReader, WSIReader from monai.transforms import MapTransform -from openslide import OpenSlide -from torch.utils.data import Dataset -from health_ml.utils import box_utils +from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset -class PandaDataset(Dataset): +class PandaDataset(SlidesDataset): """Dataset class for loading files from the PANDA challenge dataset. - Iterating over this dataset returns a dictionary containing the `'image_id'`, paths to the `'image'` - and `'mask'` files, and the remaining meta-data from the original dataset (`'data_provider'`, - `'isup_grade'`, and `'gleason_score'`). + Iterating over this dataset returns a dictionary following the `SlideKey` schema plus meta-data + from the original dataset (`'data_provider'`, `'isup_grade'`, and `'gleason_score'`). Ref.: https://www.kaggle.com/c/prostate-cancer-grade-assessment/overview """ - def __init__(self, root_dir: Union[str, Path], n_slides: Optional[int] = None, - frac_slides: Optional[float] = None) -> None: - super().__init__() - self.root_dir = Path(root_dir) - self.train_df = pd.read_csv(self.root_dir / "train.csv", index_col='image_id') - if n_slides or frac_slides: - self.train_df = self.train_df.sample(n=n_slides, frac=frac_slides, replace=False, - random_state=1234) - - def __len__(self) -> int: - return self.train_df.shape[0] - - def _get_image_path(self, image_id: str) -> Path: - return self.root_dir / "train_images" / f"{image_id}.tiff" - - def _get_mask_path(self, image_id: str) -> Path: - return self.root_dir / "train_label_masks" / f"{image_id}_mask.tiff" - - def __getitem__(self, index: int) -> Dict: - image_id = self.train_df.index[index] - return { - 'image_id': image_id, - 'image': str(self._get_image_path(image_id).absolute()), - 'mask': str(self._get_mask_path(image_id).absolute()), - **self.train_df.loc[image_id].to_dict() - } + SLIDE_ID_COLUMN = 'image_id' + IMAGE_COLUMN = 'image' + MASK_COLUMN = 'mask' + LABEL_COLUMN = 'isup_grade' + + METADATA_COLUMNS = ('data_provider', 'isup_grade', 'gleason_score') + + DEFAULT_CSV_FILENAME = "train.csv" + + def __init__(self, + root: Union[str, Path], + dataset_csv: Optional[Union[str, Path]] = None, + dataset_df: Optional[pd.DataFrame] = None) -> None: + super().__init__(root, dataset_csv, dataset_df, validate_columns=False) + # PANDA CSV does not come with paths for image and mask files + slide_ids = self.dataset_df.index + self.dataset_df[self.IMAGE_COLUMN] = "train_images/" + slide_ids + ".tiff" + self.dataset_df[self.MASK_COLUMN] = "train_label_masks/" + slide_ids + "_mask.tiff" + self.validate_columns() # MONAI's convention is that dictionary transforms have a 'd' suffix in the class name @@ -96,10 +88,10 @@ def __init__(self, reader: WSIReader, image_key: str = 'image', mask_key: str = self.margin = margin self.kwargs = kwargs - def _get_bounding_box(self, mask_obj: OpenSlide) -> box_utils.Box: + def _get_bounding_box(self, mask_obj: CuImage) -> box_utils.Box: # Estimate bounding box at the lowest resolution (i.e. highest level) - highest_level = mask_obj.level_count - 1 - scale = mask_obj.level_downsamples[highest_level] + highest_level = mask_obj.resolutions['level_count'] - 1 + scale = mask_obj.resolutions['level_downsamples'][highest_level] mask, _ = self.reader.get_data(mask_obj, level=highest_level) # loaded as RGB PIL image foreground_mask = mask[0] > 0 # PANDA segmentation mask is in 'R' channel @@ -107,14 +99,14 @@ def _get_bounding_box(self, mask_obj: OpenSlide) -> box_utils.Box: return bbox def __call__(self, data: Dict) -> Dict: - mask_obj: OpenSlide = self.reader.read(data[self.mask_key]) - image_obj: OpenSlide = self.reader.read(data[self.image_key]) + mask_obj: CuImage = self.reader.read(data[self.mask_key]) + image_obj: CuImage = self.reader.read(data[self.image_key]) level0_bbox = self._get_bounding_box(mask_obj) - # OpenSlide takes absolute location coordinates in the level 0 reference frame, + # cuCIM/OpenSlide take absolute location coordinates in the level 0 reference frame, # but relative region size in pixels at the chosen level - scale = mask_obj.level_downsamples[self.level] + scale = mask_obj.resolutions['level_downsamples'][self.level] scaled_bbox = level0_bbox / scale get_data_kwargs = dict(location=(level0_bbox.x, level0_bbox.y), size=(scaled_bbox.w, scaled_bbox.h), diff --git a/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py b/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py index 00da8af9a..4bd1590a4 100644 --- a/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py +++ b/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py @@ -4,13 +4,14 @@ # ------------------------------------------------------------------------------------------ from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Optional, Union import pandas as pd -from torch.utils.data import Dataset +from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset -class TcgaPradDataset(Dataset): + +class TcgaPradDataset(SlidesDataset): """Dataset class for loading TCGA-PRAD slides. Iterating over this dataset returns a dictionary containing: @@ -19,16 +20,14 @@ class TcgaPradDataset(Dataset): - `'image_path'` (str): absolute slide image path - `'label'` (int, 0 or 1): label for predicting positive or negative """ - SLIDE_ID_COLUMN: str = 'slide_id' - CASE_ID_COLUMN: str = 'case_id' IMAGE_COLUMN: str = 'image_path' LABEL_COLUMN: str = 'label' DEFAULT_CSV_FILENAME: str = "dataset.csv" - def __init__(self, root_dir: Union[str, Path], + def __init__(self, root: Union[str, Path], dataset_csv: Optional[Union[str, Path]] = None, - dataset_df: Optional[pd.DataFrame] = None,) -> None: + dataset_df: Optional[pd.DataFrame] = None) -> None: """ :param root: Root directory of the dataset. :param dataset_csv: Full path to a dataset CSV file. If omitted, the CSV will be read from @@ -36,27 +35,8 @@ def __init__(self, root_dir: Union[str, Path], :param dataset_df: A potentially pre-processed dataframe in the same format as would be read from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`. """ - self.root_dir = Path(root_dir) - - if dataset_df is not None: - self.dataset_csv = None - else: - self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME - dataset_df = pd.read_csv(self.dataset_csv) - - dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN) - dataset_df[self.LABEL_COLUMN] = (dataset_df['label1_mutation'] - | dataset_df['label2_mutation']).astype(int) - self.dataset_df = dataset_df - - def __len__(self) -> int: - return self.dataset_df.shape[0] - - def __getitem__(self, index: int) -> Dict[str, Any]: - slide_id = self.dataset_df.index[index] - sample = { - self.SLIDE_ID_COLUMN: slide_id, - **self.dataset_df.loc[slide_id].to_dict() - } - sample[self.IMAGE_COLUMN] = str(self.root_dir / sample.pop(self.IMAGE_COLUMN)) - return sample + super().__init__(root, dataset_csv, dataset_df, validate_columns=False) + # Example of how to define a custom label column from existing columns: + self.dataset_df[self.LABEL_COLUMN] = (self.dataset_df['label1'] + | self.dataset_df['label2']).astype(int) + self.validate_columns() diff --git a/InnerEye/ML/Histopathology/preprocessing/create_panda_tiles_dataset.py b/InnerEye/ML/Histopathology/preprocessing/create_panda_tiles_dataset.py new file mode 100644 index 000000000..f1dc3d65a --- /dev/null +++ b/InnerEye/ML/Histopathology/preprocessing/create_panda_tiles_dataset.py @@ -0,0 +1,230 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +"""This script is specific to PANDA and is kept only for retrocompatibility. +`create_tiles_dataset.py` is the new supported way to process slide datasets. +""" +import functools +import os +import logging +import shutil +import traceback +import warnings +from pathlib import Path +from typing import Sequence, Tuple, Union + +import numpy as np +import PIL +from monai.data import Dataset +from monai.data.image_reader import WSIReader +from tqdm import tqdm + +from InnerEye.ML.Histopathology.preprocessing import tiling +from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId + + +CSV_COLUMNS = ['slide_id', 'tile_id', 'image', 'mask', 'tile_x', 'tile_y', 'occupancy', + 'data_provider', 'slide_isup_grade', 'slide_gleason_score'] +TMP_SUFFIX = "_tmp" + +logging.basicConfig(format='%(asctime)s %(message)s', filemode='w') +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + + +def select_tile(mask_tile: np.ndarray, occupancy_threshold: float) \ + -> Union[Tuple[bool, float], Tuple[np.ndarray, np.ndarray]]: + if occupancy_threshold < 0. or occupancy_threshold > 1.: + raise ValueError("Tile occupancy threshold must be between 0 and 1") + foreground_mask = mask_tile > 0 + occupancy = foreground_mask.mean(axis=(-2, -1)) + return (occupancy > occupancy_threshold).squeeze(), occupancy.squeeze() + + +def get_tile_descriptor(tile_location: Sequence[int]) -> str: + return f"{tile_location[0]:05d}x_{tile_location[1]:05d}y" + + +def get_tile_id(slide_id: str, tile_location: Sequence[int]) -> str: + return f"{slide_id}.{get_tile_descriptor(tile_location)}" + + +def save_image(array_chw: np.ndarray, path: Path) -> PIL.Image: + path.parent.mkdir(parents=True, exist_ok=True) + array_hwc = np.moveaxis(array_chw, 0, -1).astype(np.uint8).squeeze() + pil_image = PIL.Image.fromarray(array_hwc) + pil_image.convert('RGB').save(path) + return pil_image + + +def generate_tiles(sample: dict, tile_size: int, occupancy_threshold: float) \ + -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]: + image_tiles, tile_locations = tiling.tile_array_2d(sample['image'], tile_size=tile_size, + constant_values=255) + mask_tiles, _ = tiling.tile_array_2d(sample['mask'], tile_size=tile_size, constant_values=0) + + selected: np.ndarray + occupancies: np.ndarray + selected, occupancies = select_tile(mask_tiles, occupancy_threshold) + n_discarded = (~selected).sum() + logging.info(f"Percentage tiles discarded: {round(selected.sum() / n_discarded * 100, 2)}") + + image_tiles = image_tiles[selected] + mask_tiles = mask_tiles[selected] + tile_locations = tile_locations[selected] + occupancies = occupancies[selected] + + abs_tile_locations = (sample['scale'] * tile_locations + sample['location']).astype(int) + + return image_tiles, mask_tiles, abs_tile_locations, occupancies, n_discarded + + +# TODO refactor this to separate metadata identification from saving. We might want the metadata +# even if the saving fails +def save_tile(sample: dict, image_tile: np.ndarray, mask_tile: np.ndarray, + tile_location: Sequence[int], output_dir: Path) -> dict: + slide_id = sample['image_id'] + descriptor = get_tile_descriptor(tile_location) + image_tile_filename = f"train_images/{descriptor}.png" + mask_tile_filename = f"train_label_masks/{descriptor}_mask.png" + + save_image(image_tile, output_dir / image_tile_filename) + save_image(mask_tile, output_dir / mask_tile_filename) + + tile_metadata = { + 'slide_id': slide_id, + 'tile_id': get_tile_id(slide_id, tile_location), + 'image': image_tile_filename, + 'mask': mask_tile_filename, + 'tile_x': tile_location[0], + 'tile_y': tile_location[1], + 'data_provider': sample['data_provider'], + 'slide_isup_grade': sample['isup_grade'], + 'slide_gleason_score': sample['gleason_score'], + } + + return tile_metadata + + +def process_slide(sample: dict, level: int, margin: int, tile_size: int, occupancy_threshold: int, + output_dir: Path, tile_progress: bool = False) -> None: + slide_id = sample['image_id'] + slide_dir: Path = output_dir / (slide_id + "/") + logging.info(f">>> Slide dir {slide_dir}") + if slide_dir.exists(): # already processed slide - skip + logging.info(f">>> Skipping {slide_dir} - already processed") + return + else: + try: + slide_dir.mkdir(parents=True) + + dataset_csv_path = slide_dir / "dataset.csv" + dataset_csv_file = dataset_csv_path.open('w') + dataset_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header + + tiles_failure = 0 + failed_tiles_csv_path = slide_dir / "failed_tiles.csv" + failed_tiles_file = failed_tiles_csv_path.open('w') + failed_tiles_file.write('tile_id' + '\n') + + logging.info(f"Loading slide {slide_id} ...") + loader = LoadPandaROId(WSIReader(), level=level, margin=margin) + sample = loader(sample) # load 'image' and 'mask' from disk + + logging.info(f"Tiling slide {slide_id} ...") + image_tiles, mask_tiles, tile_locations, occupancies, _ = \ + generate_tiles(sample, tile_size, occupancy_threshold) + n_tiles = image_tiles.shape[0] + + for i in tqdm(range(n_tiles), f"Tiles ({slide_id[:6]}…)", unit="img", disable=not tile_progress): + try: + tile_metadata = save_tile(sample, image_tiles[i], mask_tiles[i], tile_locations[i], + slide_dir) + tile_metadata['occupancy'] = occupancies[i] + tile_metadata['image'] = os.path.join(slide_dir.name, tile_metadata['image']) + tile_metadata['mask'] = os.path.join(slide_dir.name, tile_metadata['mask']) + dataset_row = ','.join(str(tile_metadata[column]) for column in CSV_COLUMNS) + dataset_csv_file.write(dataset_row + '\n') + except Exception as e: + tiles_failure += 1 + descriptor = get_tile_descriptor(tile_locations[i]) + '\n' + failed_tiles_file.write(descriptor) + traceback.print_exc() + warnings.warn(f"An error occurred while saving tile " + f"{get_tile_id(slide_id, tile_locations[i])}: {e}") + + dataset_csv_file.close() + failed_tiles_file.close() + if tiles_failure > 0: + # TODO what we want to do with slides that have some failed tiles? + logging.warning(f"{slide_id} is incomplete. {tiles_failure} tiles failed.") + except Exception as e: + traceback.print_exc() + warnings.warn(f"An error occurred while processing slide {slide_id}: {e}") + + +def merge_dataset_csv_files(dataset_dir: Path) -> Path: + full_csv = dataset_dir / "dataset.csv" + # TODO change how we retrieve these filenames, probably because mounted, the operation is slow + # and it seems to find many more files + # print("List of files") + # print([str(file) + '\n' for file in dataset_dir.glob("*/dataset.csv")]) + with full_csv.open('w') as full_csv_file: + # full_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header + first_file = True + for slide_csv in tqdm(dataset_dir.glob("*/dataset.csv"), desc="Merging dataset.csv", unit='file'): + logging.info(f"Merging slide {slide_csv}") + content = slide_csv.read_text() + if not first_file: + content = content[content.index('\n') + 1:] # discard header row for all but the first file + full_csv_file.write(content) + first_file = False + return full_csv + + +def main(panda_dir: Union[str, Path], root_output_dir: Union[str, Path], level: int, tile_size: int, + margin: int, occupancy_threshold: float, parallel: bool = False, overwrite: bool = False) -> None: + + # Ignoring some types here because mypy is getting confused with the MONAI Dataset class + # to select a subsample use keyword n_slides + dataset = Dataset(PandaDataset(panda_dir)) # type: ignore + + output_dir = Path(root_output_dir) / f"panda_tiles_level{level}_{tile_size}" + logging.info(f"Creating dataset of level-{level} {tile_size}x{tile_size} PANDA tiles at: {output_dir}") + + if overwrite and output_dir.exists(): + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=not overwrite) + + func = functools.partial(process_slide, level=level, margin=margin, tile_size=tile_size, + occupancy_threshold=occupancy_threshold, output_dir=output_dir, + tile_progress=not parallel) + + if parallel: + import multiprocessing + + pool = multiprocessing.Pool() + map_func = pool.imap_unordered # type: ignore + else: + map_func = map # type: ignore + + list(tqdm(map_func(func, dataset), desc="Slides", unit="img", total=len(dataset))) # type: ignore + + if parallel: + pool.close() + + logging.info("Merging slide files in a single file") + merge_dataset_csv_files(output_dir) + + +if __name__ == '__main__': + main(panda_dir="/tmp/datasets/PANDA", + root_output_dir="/datadrive", + level=1, + tile_size=224, + margin=64, + occupancy_threshold=0.05, + parallel=True, + overwrite=False) diff --git a/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py b/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py index bbbf4f090..87e670fc8 100644 --- a/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py +++ b/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py @@ -4,13 +4,12 @@ # ------------------------------------------------------------------------------------------ import functools -import os import logging import shutil import traceback import warnings from pathlib import Path -from typing import Sequence, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union import numpy as np import PIL @@ -18,37 +17,43 @@ from monai.data.image_reader import WSIReader from tqdm import tqdm +from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset from InnerEye.ML.Histopathology.preprocessing import tiling -from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId - - -CSV_COLUMNS = ['slide_id', 'tile_id', 'image', 'mask', 'tile_x', 'tile_y', 'occupancy', - 'data_provider', 'slide_isup_grade', 'slide_gleason_score'] -TMP_SUFFIX = "_tmp" +from InnerEye.ML.Histopathology.preprocessing.loading import LoadROId, segment_foreground +from InnerEye.ML.Histopathology.utils.naming import SlideKey, TileKey logging.basicConfig(format='%(asctime)s %(message)s', filemode='w') logger = logging.getLogger() logger.setLevel(logging.DEBUG) -def select_tile(mask_tile: np.ndarray, occupancy_threshold: float) \ - -> Union[Tuple[bool, float], Tuple[np.ndarray, np.ndarray]]: +def select_tiles(foreground_mask: np.ndarray, occupancy_threshold: float) \ + -> Tuple[np.ndarray, np.ndarray]: + """Exclude tiles that are mostly background based on estimated occupancy. + + :param foreground_mask: Boolean array of shape (*, H, W). + :param occupancy_threshold: Tiles with lower occupancy (between 0 and 1) will be discarded. + :return: A tuple containing which tiles were selected and the estimated occupancies. These will + be boolean and float arrays of shape (*,), or scalars if `foreground_mask` is a single tile. + """ if occupancy_threshold < 0. or occupancy_threshold > 1.: raise ValueError("Tile occupancy threshold must be between 0 and 1") - foreground_mask = mask_tile > 0 occupancy = foreground_mask.mean(axis=(-2, -1)) - return (occupancy > occupancy_threshold).squeeze(), occupancy.squeeze() + return (occupancy > occupancy_threshold).squeeze(), occupancy.squeeze() # type: ignore def get_tile_descriptor(tile_location: Sequence[int]) -> str: + """Format the XY tile coordinates into a tile descriptor.""" return f"{tile_location[0]:05d}x_{tile_location[1]:05d}y" def get_tile_id(slide_id: str, tile_location: Sequence[int]) -> str: + """Format the slide ID and XY tile coordinates into a unique tile ID.""" return f"{slide_id}.{get_tile_descriptor(tile_location)}" def save_image(array_chw: np.ndarray, path: Path) -> PIL.Image: + """Save an image array in (C, H, W) format to disk.""" path.parent.mkdir(parents=True, exist_ok=True) array_hwc = np.moveaxis(array_chw, 0, -1).astype(np.uint8).squeeze() pil_image = PIL.Image.fromarray(array_hwc) @@ -56,59 +61,102 @@ def save_image(array_chw: np.ndarray, path: Path) -> PIL.Image: return pil_image -def generate_tiles(sample: dict, tile_size: int, occupancy_threshold: float) \ - -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]: - image_tiles, tile_locations = tiling.tile_array_2d(sample['image'], tile_size=tile_size, +def generate_tiles(slide_image: np.ndarray, tile_size: int, foreground_threshold: float, + occupancy_threshold: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]: + """Split the foreground of an input slide image into tiles. + + :param slide_image: The RGB image array in (C, H, W) format. + :param tile_size: Lateral dimensions of each tile, in pixels. + :param foreground_threshold: Luminance threshold (0 to 255) to determine tile occupancy. + :param occupancy_threshold: Threshold (between 0 and 1) to determine empty tiles to discard. + :return: A tuple containing the image tiles (N, C, H, W), tile coordinates (N, 2), occupancies + (N,), and total number of discarded empty tiles. + """ + image_tiles, tile_locations = tiling.tile_array_2d(slide_image, tile_size=tile_size, constant_values=255) - mask_tiles, _ = tiling.tile_array_2d(sample['mask'], tile_size=tile_size, constant_values=0) + foreground_mask, _ = segment_foreground(image_tiles, foreground_threshold) - selected: np.ndarray - occupancies: np.ndarray - selected, occupancies = select_tile(mask_tiles, occupancy_threshold) + selected, occupancies = select_tiles(foreground_mask, occupancy_threshold) n_discarded = (~selected).sum() - logging.info(f"Percentage tiles discarded: {round(selected.sum() / n_discarded * 100, 2)}") + logging.info(f"Percentage tiles discarded: {n_discarded / len(selected) * 100:.2f}") image_tiles = image_tiles[selected] - mask_tiles = mask_tiles[selected] tile_locations = tile_locations[selected] occupancies = occupancies[selected] - abs_tile_locations = (sample['scale'] * tile_locations + sample['location']).astype(int) + return image_tiles, tile_locations, occupancies, n_discarded - return image_tiles, mask_tiles, abs_tile_locations, occupancies, n_discarded +def get_tile_info(sample: Dict[SlideKey, Any], occupancy: float, tile_location: Sequence[int], + rel_slide_dir: Path) -> Dict[TileKey, Any]: + """Map slide information and tiling outputs into tile-specific information dictionary. -# TODO refactor this to separate metadata identification from saving. We might want the metadata -# even if the saving fails -def save_tile(sample: dict, image_tile: np.ndarray, mask_tile: np.ndarray, - tile_location: Sequence[int], output_dir: Path) -> dict: - slide_id = sample['image_id'] + :param sample: Slide dictionary. + :param occupancy: Estimated tile foreground occuppancy. + :param tile_location: Tile XY coordinates. + :param rel_slide_dir: Directory where tiles are saved, relative to dataset root. + :return: Tile information dictionary. + """ + slide_id = sample[SlideKey.SLIDE_ID] descriptor = get_tile_descriptor(tile_location) - image_tile_filename = f"train_images/{descriptor}.png" - mask_tile_filename = f"train_label_masks/{descriptor}_mask.png" - - save_image(image_tile, output_dir / image_tile_filename) - save_image(mask_tile, output_dir / mask_tile_filename) - - tile_metadata = { - 'slide_id': slide_id, - 'tile_id': get_tile_id(slide_id, tile_location), - 'image': image_tile_filename, - 'mask': mask_tile_filename, - 'tile_x': tile_location[0], - 'tile_y': tile_location[1], - 'data_provider': sample['data_provider'], - 'slide_isup_grade': sample['isup_grade'], - 'slide_gleason_score': sample['gleason_score'], + rel_image_path = f"{rel_slide_dir}/{descriptor}.png" + + tile_info = { + TileKey.SLIDE_ID: slide_id, + TileKey.TILE_ID: get_tile_id(slide_id, tile_location), + TileKey.IMAGE: rel_image_path, + TileKey.LABEL: sample[SlideKey.LABEL], + TileKey.TILE_X: tile_location[0], + TileKey.TILE_Y: tile_location[1], + TileKey.OCCUPANCY: occupancy, + TileKey.SLIDE_METADATA: {TileKey.from_slide_metadata_key(key): value + for key, value in sample[SlideKey.METADATA].items()} } - return tile_metadata - - -def process_slide(sample: dict, level: int, margin: int, tile_size: int, occupancy_threshold: int, - output_dir: Path, tile_progress: bool = False) -> None: - slide_id = sample['image_id'] - slide_dir: Path = output_dir / (slide_id + "/") + return tile_info + + +def format_csv_row(tile_info: Dict[TileKey, Any], keys_to_save: Iterable[TileKey], + metadata_keys: Iterable[str]) -> str: + """Format tile information dictionary as a row to write to a dataset CSV tile. + + :param tile_info: Tile information dictionary. + :param keys_to_save: Which main keys to include in the row, and in which order. + :param metadata_keys: Likewise for metadata keys. + :return: The formatted CSV row. + """ + tile_slide_metadata = tile_info.pop(TileKey.SLIDE_METADATA) + fields = [str(tile_info[key]) for key in keys_to_save] + fields.extend(str(tile_slide_metadata[key]) for key in metadata_keys) + dataset_row = ','.join(fields) + return dataset_row + + +def process_slide(sample: Dict[SlideKey, Any], level: int, margin: int, tile_size: int, + foreground_threshold: Optional[float], occupancy_threshold: float, output_dir: Path, + tile_progress: bool = False) -> None: + """Load and process a slide, saving tile images and information to a CSV file. + + :param sample: Slide information dictionary, returned by the input slide dataset. + :param level: Magnification level at which to process the slide. + :param margin: Margin around the foreground bounding box, in pixels at lowest resolution. + :param tile_size: Lateral dimensions of each tile, in pixels. + :param foreground_threshold: Luminance threshold (0 to 255) to determine tile occupancy. + If `None` (default), an optimal threshold will be estimated automatically. + :param occupancy_threshold: Threshold (between 0 and 1) to determine empty tiles to discard. + :param output_dir: Root directory for the output dataset; outputs for a single slide will be + saved inside `output_dir/slide_id/`. + :param tile_progress: Whether to display a progress bar in the terminal. + """ + slide_metadata: Dict[str, Any] = sample[SlideKey.METADATA] + keys_to_save = (TileKey.SLIDE_ID, TileKey.TILE_ID, TileKey.IMAGE, TileKey.LABEL, + TileKey.TILE_X, TileKey.TILE_Y, TileKey.OCCUPANCY) + metadata_keys = tuple(TileKey.from_slide_metadata_key(key) for key in slide_metadata) + csv_columns: Tuple[str, ...] = (*keys_to_save, *metadata_keys) + + slide_id: str = sample[SlideKey.SLIDE_ID] + rel_slide_dir = Path(slide_id) + slide_dir = output_dir / rel_slide_dir logging.info(f">>> Slide dir {slide_dir}") if slide_dir.exists(): # already processed slide - skip logging.info(f">>> Skipping {slide_dir} - already processed") @@ -119,50 +167,57 @@ def process_slide(sample: dict, level: int, margin: int, tile_size: int, occupan dataset_csv_path = slide_dir / "dataset.csv" dataset_csv_file = dataset_csv_path.open('w') - dataset_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header + dataset_csv_file.write(','.join(csv_columns) + '\n') # write CSV header - tiles_failure = 0 + n_failed_tiles = 0 failed_tiles_csv_path = slide_dir / "failed_tiles.csv" failed_tiles_file = failed_tiles_csv_path.open('w') failed_tiles_file.write('tile_id' + '\n') logging.info(f"Loading slide {slide_id} ...") - loader = LoadPandaROId(WSIReader(), level=level, margin=margin) - sample = loader(sample) # load 'image' and 'mask' from disk + loader = LoadROId(WSIReader('cuCIM'), level=level, margin=margin, + foreground_threshold=foreground_threshold) + sample = loader(sample) # load 'image' from disk logging.info(f"Tiling slide {slide_id} ...") - image_tiles, mask_tiles, tile_locations, occupancies, _ = \ - generate_tiles(sample, tile_size, occupancy_threshold) + image_tiles, rel_tile_locations, occupancies, _ = \ + generate_tiles(sample[SlideKey.IMAGE], tile_size, + sample[SlideKey.FOREGROUND_THRESHOLD], + occupancy_threshold) + + tile_locations = (sample[SlideKey.SCALE] * rel_tile_locations + + sample[SlideKey.ORIGIN]).astype(int) + n_tiles = image_tiles.shape[0] + logging.info(f"Saving tiles for slide {slide_id} ...") for i in tqdm(range(n_tiles), f"Tiles ({slide_id[:6]}…)", unit="img", disable=not tile_progress): try: - tile_metadata = save_tile(sample, image_tiles[i], mask_tiles[i], tile_locations[i], - slide_dir) - tile_metadata['occupancy'] = occupancies[i] - tile_metadata['image'] = os.path.join(slide_dir.name, tile_metadata['image']) - tile_metadata['mask'] = os.path.join(slide_dir.name, tile_metadata['mask']) - dataset_row = ','.join(str(tile_metadata[column]) for column in CSV_COLUMNS) + tile_info = get_tile_info(sample, occupancies[i], tile_locations[i], rel_slide_dir) + save_image(image_tiles[i], output_dir / tile_info[TileKey.IMAGE]) + dataset_row = format_csv_row(tile_info, keys_to_save, metadata_keys) dataset_csv_file.write(dataset_row + '\n') except Exception as e: - tiles_failure += 1 - descriptor = get_tile_descriptor(tile_locations[i]) + '\n' - failed_tiles_file.write(descriptor) + n_failed_tiles += 1 + descriptor = get_tile_descriptor(tile_locations[i]) + failed_tiles_file.write(descriptor + '\n') traceback.print_exc() warnings.warn(f"An error occurred while saving tile " f"{get_tile_id(slide_id, tile_locations[i])}: {e}") dataset_csv_file.close() failed_tiles_file.close() - if tiles_failure > 0: + if n_failed_tiles > 0: # TODO what we want to do with slides that have some failed tiles? - logging.warning(f"{slide_id} is incomplete. {tiles_failure} tiles failed.") + logging.warning(f"{slide_id} is incomplete. {n_failed_tiles} tiles failed.") + logging.info(f"Finished processing slide {slide_id}") except Exception as e: traceback.print_exc() warnings.warn(f"An error occurred while processing slide {slide_id}: {e}") def merge_dataset_csv_files(dataset_dir: Path) -> Path: + """Combines all "*/dataset.csv" files into a single "dataset.csv" file in the given directory.""" full_csv = dataset_dir / "dataset.csv" # TODO change how we retrieve these filenames, probably because mounted, the operation is slow # and it seems to find many more files @@ -181,21 +236,40 @@ def merge_dataset_csv_files(dataset_dir: Path) -> Path: return full_csv -def main(panda_dir: Union[str, Path], root_output_dir: Union[str, Path], level: int, tile_size: int, - margin: int, occupancy_threshold: float, parallel: bool = False, overwrite: bool = False) -> None: +def main(slides_dataset: SlidesDataset, root_output_dir: Union[str, Path], + level: int, tile_size: int, margin: int, foreground_threshold: Optional[float], + occupancy_threshold: float, parallel: bool = False, overwrite: bool = False, + n_slides: Optional[int] = None) -> None: + """Process a slides dataset to produce a tiles dataset. + + :param slides_dataset: Input tiles dataset object. + :param root_output_dir: The root directory of the output tiles dataset. + :param level: Magnification level at which to process the slide. + :param tile_size: Lateral dimensions of each tile, in pixels. + :param margin: Margin around the foreground bounding box, in pixels at lowest resolution. + :param foreground_threshold: Luminance threshold (0 to 255) to determine tile occupancy. + If `None` (default), an optimal threshold will be estimated automatically. + :param occupancy_threshold: Threshold (between 0 and 1) to determine empty tiles to discard. + :param parallel: Whether slides should be processed in parallel with multiprocessing. + :param overwrite: Whether to overwrite an existing output tiles dataset. If `True`, will delete + and recreate `root_output_dir`, otherwise will resume by skipping already processed slides. + :param n_slides: If given, limit the total number of slides for debugging. + """ # Ignoring some types here because mypy is getting confused with the MONAI Dataset class # to select a subsample use keyword n_slides - dataset = Dataset(PandaDataset(panda_dir)) # type: ignore + dataset = Dataset(slides_dataset)[:n_slides] # type: ignore - output_dir = Path(root_output_dir) / f"panda_tiles_level{level}_{tile_size}" - logging.info(f"Creating dataset of level-{level} {tile_size}x{tile_size} PANDA tiles at: {output_dir}") + output_dir = Path(root_output_dir) + logging.info(f"Creating dataset of level-{level} {tile_size}x{tile_size} " + f"{slides_dataset.__class__.__name__} tiles at: {output_dir}") if overwrite and output_dir.exists(): shutil.rmtree(output_dir) output_dir.mkdir(parents=True, exist_ok=not overwrite) func = functools.partial(process_slide, level=level, margin=margin, tile_size=tile_size, + foreground_threshold=foreground_threshold, occupancy_threshold=occupancy_threshold, output_dir=output_dir, tile_progress=not parallel) @@ -217,11 +291,16 @@ def main(panda_dir: Union[str, Path], root_output_dir: Union[str, Path], level: if __name__ == '__main__': - main(panda_dir="/tmp/datasets/PANDA", - root_output_dir="/datadrive", - level=1, + from InnerEye.ML.Histopathology.datasets.tcga_prad_dataset import TcgaPradDataset + + # Example set up for an existing slides dataset: + main(slides_dataset=TcgaPradDataset("/tmp/datasets/TCGA-PRAD"), + root_output_dir="/datadrive/TCGA-PRAD_tiles", + n_slides=5, + level=3, tile_size=224, margin=64, + foreground_threshold=None, occupancy_threshold=0.05, - parallel=True, - overwrite=False) + parallel=False, + overwrite=True) diff --git a/InnerEye/ML/Histopathology/preprocessing/loading.py b/InnerEye/ML/Histopathology/preprocessing/loading.py new file mode 100644 index 000000000..77942e555 --- /dev/null +++ b/InnerEye/ML/Histopathology/preprocessing/loading.py @@ -0,0 +1,108 @@ +from typing import Dict, Optional, Tuple + +import numpy as np +import skimage.filters +from cucim import CuImage +from health_ml.utils import box_utils +from monai.data.image_reader import WSIReader +from monai.transforms import MapTransform + +from InnerEye.ML.Histopathology.utils.naming import SlideKey + + +def get_luminance(slide: np.ndarray) -> np.ndarray: + """Compute a grayscale version of the input slide. + + :param slide: The RGB image array in (*, C, H, W) format. + :return: The single-channel luminance array as (*, H, W). + """ + # TODO: Consider more sophisticated luminance calculation if necessary + return slide.mean(axis=-3) # type: ignore + + +def segment_foreground(slide: np.ndarray, threshold: Optional[float] = None) \ + -> Tuple[np.ndarray, float]: + """Segment the given slide by thresholding its luminance. + + :param slide: The RGB image array in (*, C, H, W) format. + :param threshold: Pixels with luminance below this value will be considered foreground. + If `None` (default), an optimal threshold will be estimated automatically using Otsu's method. + :return: A tuple containing the boolean output array in (*, H, W) format and the threshold used. + """ + luminance = get_luminance(slide) + if threshold is None: + threshold = skimage.filters.threshold_otsu(luminance) + return luminance < threshold, threshold + + +def load_slide_at_level(reader: WSIReader, slide_obj: CuImage, level: int) -> np.ndarray: + """Load full slide array at the given magnification level. + + This is a manual workaround for a MONAI bug (https://github.com/Project-MONAI/MONAI/issues/3415) + fixed in a currently unreleased PR (https://github.com/Project-MONAI/MONAI/pull/3417). + + :param reader: A MONAI `WSIReader` using cuCIM backend. + :param slide_obj: The cuCIM image object returned by `reader.read()`. + :param level: Index of the desired magnification level as defined in the `slide_obj` headers. + :return: The loaded image array in (C, H, W) format. + """ + size = slide_obj.resolutions['level_dimensions'][level][::-1] + slide, _ = reader.get_data(slide_obj, size=size, level=level) # loaded as RGB PIL image + return slide + + +class LoadROId(MapTransform): + """Transform that loads a pathology slide, cropped to an estimated bounding box (ROI). + + Operates on dictionaries, replacing the file path in `image_key` with the loaded array in + (C, H, W) format. Also adds the following entries: + - `SlideKey.ORIGIN` (tuple): top-right coordinates of the bounding box + - `SlideKey.SCALE` (float): corresponding scale, loaded from the file + - `SlideKey.FOREGROUND_THRESHOLD` (float): threshold used to segment the foreground + """ + def __init__(self, reader: WSIReader, image_key: str = SlideKey.IMAGE, level: int = 0, + margin: int = 0, foreground_threshold: Optional[float] = None) -> None: + """ + :param reader: And instance of MONAI's `WSIReader`. + :param image_key: Image key in the input and output dictionaries. + :param level: Magnification level to load from the raw multi-scale file. + :param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping. + :param foreground_threshold: Pixels with luminance below this value will be considered foreground. + If `None` (default), an optimal threshold will be estimated automatically using Otsu's method. + """ + super().__init__([image_key], allow_missing_keys=False) + self.reader = reader + self.image_key = image_key + self.level = level + self.margin = margin + self.foreground_threshold = foreground_threshold + + def _get_bounding_box(self, slide_obj: CuImage) -> Tuple[box_utils.Box, float]: + # Estimate bounding box at the lowest resolution (i.e. highest level) + highest_level = slide_obj.resolutions['level_count'] - 1 + scale = slide_obj.resolutions['level_downsamples'][highest_level] + slide = load_slide_at_level(self.reader, slide_obj, level=highest_level) + + foreground_mask, threshold = segment_foreground(slide, self.foreground_threshold) + bbox = scale * box_utils.get_bounding_box(foreground_mask).add_margin(self.margin) + return bbox, threshold + + def __call__(self, data: Dict) -> Dict: + image_obj: CuImage = self.reader.read(data[self.image_key]) + + level0_bbox, threshold = self._get_bounding_box(image_obj) + + # cuCIM/OpenSlide takes absolute location coordinates in the level 0 reference frame, + # but relative region size in pixels at the chosen level + origin = (level0_bbox.x, level0_bbox.y) + scale = image_obj.resolutions['level_downsamples'][self.level] + scaled_bbox = level0_bbox / scale + + data[self.image_key], _ = self.reader.get_data(image_obj, location=origin, level=self.level, + size=(scaled_bbox.w, scaled_bbox.h)) + data[SlideKey.ORIGIN] = origin + data[SlideKey.SCALE] = scale + data[SlideKey.FOREGROUND_THRESHOLD] = threshold + + image_obj.close() + return data diff --git a/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py b/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py deleted file mode 100644 index 4eb3ef4ab..000000000 --- a/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py +++ /dev/null @@ -1,61 +0,0 @@ -# ------------------------------------------------------------------------------------------ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. -# ------------------------------------------------------------------------------------------ - -""" -This script is an example of how to use the submit_to_azure_if_needed function from the hi-ml package to run the -main pre-processing function that creates tiles from slides in the PANDA dataset. The advantage of using this script -is the ability to submit to a cluster on azureml and to have the output files directly saved as a registered dataset. - -To run execute, from inside the pre-processing folder, -python azure_tiles_creation.py --azureml - -A json configuration file containing the credentials to the Azure workspace and an environment.yml file are expected -in input. - -This has been tested on hi-mlv0.1.4. -""" - -from pathlib import Path -import sys -import time - -current_file = Path(__file__) -radiomics_root = current_file.absolute().parent.parent.parent.parent.parent -sys.path.append(str(radiomics_root)) -from health_azure.himl import submit_to_azure_if_needed, DatasetConfig # noqa -from InnerEye.ML.Histopathology.preprocessing.create_tiles_dataset import main # noqa - -# Pre-built environment file that contains all the requirements (RadiomicsNN + histo) -# Assuming ENV_NAME is a complete environment, `conda env export -n ENV_NAME -f ENV_NAME.yml` will create the desired file -ENVIRONMENT_FILE = radiomics_root.joinpath(Path("/envs/innereyeprivatetiles.yml")) -DATASET_NAME = "PANDA_tiles" -timestr = time.strftime("%Y%m%d-%H%M%S") -folder_name = DATASET_NAME + '_' + timestr - -if __name__ == '__main__': - print(f"Running {str(current_file)}") - input_dataset = DatasetConfig(name="PANDA", datastore="innereyedatasets", local_folder=Path("/tmp/datasets/PANDA"), use_mounting=True) - output_dataset = DatasetConfig(name=DATASET_NAME, datastore="innereyedatasets", local_folder=Path("/datadrive/"), use_mounting=True) - run_info = submit_to_azure_if_needed(entry_script=current_file, - snapshot_root_directory=radiomics_root, - workspace_config_file=Path("config.json"), - compute_cluster_name='training-pr-nc12', # training-nd24 - default_datastore="innereyedatasets", - conda_environment_file=Path(ENVIRONMENT_FILE), - input_datasets=[input_dataset], - output_datasets=[output_dataset], - ) - input_folder = run_info.input_datasets[0] - output_folder = Path(run_info.output_datasets[0], folder_name) - print(f'This will be the final ouput folder {str(output_folder)}') - - main(panda_dir=str(input_folder), - root_output_dir=str(output_folder), - level=1, - tile_size=224, - margin=64, - occupancy_threshold=0.05, - parallel=True, - overwrite=False) diff --git a/InnerEye/ML/Histopathology/utils/naming.py b/InnerEye/ML/Histopathology/utils/naming.py index 1af40015b..9f1b237df 100644 --- a/InnerEye/ML/Histopathology/utils/naming.py +++ b/InnerEye/ML/Histopathology/utils/naming.py @@ -5,6 +5,41 @@ from enum import Enum + +class SlideKey(str, Enum): + SLIDE_ID = 'slide_id' + IMAGE = 'image' + IMAGE_PATH = 'image_path' + MASK = 'mask' + MASK_PATH = 'mask_path' + LABEL = 'label' + SPLIT = 'split' + SCALE = 'scale' + ORIGIN = 'origin' + FOREGROUND_THRESHOLD = 'foreground_threshold' + METADATA = 'metadata' + + +class TileKey(str, Enum): + TILE_ID = 'tile_id' + SLIDE_ID = 'slide_id' + IMAGE = 'image' + IMAGE_PATH = 'image_path' + MASK = 'mask' + MASK_PATH = 'mask_path' + LABEL = 'label' + SPLIT = 'split' + TILE_X = 'tile_x' + TILE_Y = 'tile_y' + OCCUPANCY = 'occupancy' + FOREGROUND_THRESHOLD = 'foreground_threshold' + SLIDE_METADATA = 'slide_metadata' + + @staticmethod + def from_slide_metadata_key(slide_metadata_key: str) -> str: + return 'slide_' + slide_metadata_key + + class ResultsKey(str, Enum): SLIDE_ID = 'slide_id' TILE_ID = 'tile_id' diff --git a/InnerEye/ML/Histopathology/utils/viz_utils.py b/InnerEye/ML/Histopathology/utils/viz_utils.py index 1c4bff791..763e59368 100644 --- a/InnerEye/ML/Histopathology/utils/viz_utils.py +++ b/InnerEye/ML/Histopathology/utils/viz_utils.py @@ -4,29 +4,32 @@ # ------------------------------------------------------------------------------------------ import math -import matplotlib.pyplot as plt +from typing import Any, Dict +import matplotlib.pyplot as plt +from monai.data.dataset import Dataset from monai.data.image_reader import WSIReader from torch.utils.data import DataLoader from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId +from InnerEye.ML.Histopathology.utils.naming import SlideKey -def load_image_dict(sample: dict, level: int, margin: int) -> dict: +def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any]: """ Load image from metadata dictionary - param sample: dict describing image metadata. Example: + :param sample: dict describing image metadata. Example: {'image_id': ['1ca999adbbc948e69783686e5b5414e4'], 'image': ['/tmp/datasets/PANDA/train_images/1ca999adbbc948e69783686e5b5414e4.tiff'], 'mask': ['/tmp/datasets/PANDA/train_label_masks/1ca999adbbc948e69783686e5b5414e4_mask.tiff'], 'data_provider': ['karolinska'], 'isup_grade': tensor([0]), 'gleason_score': ['0+0']} - param level: level of resolution to be loaded - param margin: margin to be included - return: a dict containing the image data and metadata + :param level: level of resolution to be loaded + :param margin: margin to be included + :return: a dict containing the image data and metadata """ - loader = LoadPandaROId(WSIReader(), level=level, margin=margin) + loader = LoadPandaROId(WSIReader('cuCIM'), level=level, margin=margin) img = loader(sample) return img @@ -34,25 +37,25 @@ def load_image_dict(sample: dict, level: int, margin: int) -> dict: def plot_panda_data_sample(panda_dir: str, nsamples: int, ncols: int, level: int, margin: int, title_key: str = 'data_provider') -> None: """ - param panda_dir: path to the dataset, it's expected a file called "train.csv" exists at the path. + :param panda_dir: path to the dataset, it's expected a file called "train.csv" exists at the path. Look at the PandaDataset for more detail - param nsamples: number of random samples to be visualized - param ncols: number of columns in the figure grid. Nrows is automatically inferred - param level: level of resolution to be loaded - param margin: margin to be included - param title_key: key in image_dict used to label each subplot + :param nsamples: number of random samples to be visualized + :param ncols: number of columns in the figure grid. Nrows is automatically inferred + :param level: level of resolution to be loaded + :param margin: margin to be included + :param title_key: metadata key in image_dict used to label each subplot """ - panda_dataset = PandaDataset(root_dir=panda_dir, n_slides=nsamples) + panda_dataset = Dataset(PandaDataset(root=panda_dir))[:nsamples] # type: ignore loader = DataLoader(panda_dataset, batch_size=1) nrows = math.ceil(nsamples/ncols) fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(9, 9)) for dict_images, ax in zip(loader, axes.flat): - slide_id = dict_images['image_id'] - title = dict_images[title_key] + slide_id = dict_images[SlideKey.SLIDE_ID] + title = dict_images[SlideKey.METADATA][title_key] print(f">>> Slide {slide_id}") img = load_image_dict(dict_images, level=level, margin=margin) - ax.imshow(img['image'].transpose(1, 2, 0)) + ax.imshow(img[SlideKey.IMAGE].transpose(1, 2, 0)) ax.set_title(title) fig.tight_layout() diff --git a/Tests/ML/datasets/test_dataset.py b/Tests/ML/datasets/test_dataset.py index 6fc446f2d..7228c97ac 100644 --- a/Tests/ML/datasets/test_dataset.py +++ b/Tests/ML/datasets/test_dataset.py @@ -8,7 +8,7 @@ import pandas as pd import pytest import torch -from pytorch_lightning.core.step_result import Result +from pytorch_lightning.utilities.data import extract_batch_size from InnerEye.Common import common_util from InnerEye.ML.config import PaddingMode, SegmentationModelBase @@ -502,7 +502,7 @@ def test_sample_metadata_field() -> None: assert SAMPLE_METADATA_FIELD in fields # Lightning attempts to determine the batch size by trying to find a tensor field in the sample. # This only works if any field other than Metadata is first. - assert Result.unpack_batch_size(fields) == batch_size + assert extract_batch_size(fields) == batch_size def test_custom_collate() -> None: diff --git a/Tests/ML/histopathology/datasets/test_slides_dataset.py b/Tests/ML/histopathology/datasets/test_slides_dataset.py new file mode 100644 index 000000000..99273998b --- /dev/null +++ b/Tests/ML/histopathology/datasets/test_slides_dataset.py @@ -0,0 +1,40 @@ +import os + +import pandas as pd + +from InnerEye.Common.fixed_paths_for_tests import tests_root_directory +from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset +from InnerEye.ML.Histopathology.utils.naming import SlideKey + +HISTO_TEST_DATA_DIR = str(tests_root_directory("ML/histopathology/test_data")) + + +class MockSlidesDataset(SlidesDataset): + DEFAULT_CSV_FILENAME = "test_slides_dataset.csv" + METADATA_COLUMNS = ('meta1', 'meta2') + + def __init__(self) -> None: + super().__init__(root=HISTO_TEST_DATA_DIR) + + +def test_slides_dataset() -> None: + dataset = MockSlidesDataset() + assert isinstance(dataset.dataset_df, pd.DataFrame) + assert dataset.dataset_df.index.name == dataset.SLIDE_ID_COLUMN + assert len(dataset) == len(dataset.dataset_df) + + sample = dataset[0] + assert isinstance(sample, dict) + assert all(isinstance(key, SlideKey) for key in sample) + + expected_keys = [SlideKey.SLIDE_ID, SlideKey.IMAGE, SlideKey.IMAGE_PATH, SlideKey.LABEL, + SlideKey.METADATA] + assert all(key in sample for key in expected_keys) + + image_path = sample[SlideKey.IMAGE_PATH] + assert isinstance(image_path, str) + assert os.path.isfile(image_path) + + metadata = sample[SlideKey.METADATA] + assert isinstance(metadata, dict) + assert all(meta_col in metadata for meta_col in type(dataset).METADATA_COLUMNS) diff --git a/Tests/ML/histopathology/datasets/test_tcga_prad_dataset.py b/Tests/ML/histopathology/datasets/test_tcga_prad_dataset.py deleted file mode 100644 index e23f0e212..000000000 --- a/Tests/ML/histopathology/datasets/test_tcga_prad_dataset.py +++ /dev/null @@ -1,34 +0,0 @@ -# ------------------------------------------------------------------------------------------ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. -# ------------------------------------------------------------------------------------------ - -import os - -import pytest - -from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_PRAD_DATASET_DIR -from InnerEye.ML.Histopathology.datasets.tcga_prad_dataset import TcgaPradDataset - - -@pytest.mark.skipif(not os.path.isdir(TCGA_PRAD_DATASET_DIR), - reason="TCGA-PRAD dataset is unavailable") -def test_dataset() -> None: - dataset = TcgaPradDataset(TCGA_PRAD_DATASET_DIR) - - expected_length = 449 - assert len(dataset) == expected_length - - expected_num_positives = 10 - assert dataset.dataset_df[dataset.LABEL_COLUMN].sum() == expected_num_positives - - sample = dataset[0] - assert isinstance(sample, dict) - - expected_keys = [dataset.SLIDE_ID_COLUMN, dataset.CASE_ID_COLUMN, - dataset.IMAGE_COLUMN, dataset.LABEL_COLUMN] - assert all(key in sample for key in expected_keys) - - image_path = sample[dataset.IMAGE_COLUMN] - assert isinstance(image_path, str) - assert os.path.isfile(image_path) diff --git a/Tests/ML/histopathology/preprocessing/test_slide_loading.py b/Tests/ML/histopathology/preprocessing/test_slide_loading.py new file mode 100644 index 000000000..60d9717ed --- /dev/null +++ b/Tests/ML/histopathology/preprocessing/test_slide_loading.py @@ -0,0 +1,161 @@ +from typing import Optional + +import numpy as np +import pytest +from cucim import CuImage +from monai.data.image_reader import WSIReader + +from InnerEye.Common.fixed_paths_for_tests import tests_root_directory +from InnerEye.ML.Histopathology.preprocessing.tiling import tile_array_2d +from InnerEye.ML.Histopathology.preprocessing.loading import LoadROId, get_luminance, load_slide_at_level, segment_foreground +from InnerEye.ML.Histopathology.utils.naming import SlideKey +from Tests.ML.histopathology.datasets.test_slides_dataset import MockSlidesDataset + +TEST_IMAGE_PATH = str(tests_root_directory("ML/histopathology/test_data/panda_wsi_example.tiff")) + + +def test_load_slide() -> None: + level = 2 + reader = WSIReader('cuCIM') + slide_obj: CuImage = reader.read(TEST_IMAGE_PATH) + dims = slide_obj.resolutions['level_dimensions'][level][::-1] + + slide = load_slide_at_level(reader, slide_obj, level) + assert isinstance(slide, np.ndarray) + expected_shape = (3, *dims) + assert slide.shape == expected_shape + frac_empty = (slide == 0).mean() + assert frac_empty == 0.0 + + larger_dims = (2 * dims[0], 2 * dims[1]) + larger_slide, _ = reader.get_data(slide_obj, size=larger_dims, level=level) + assert isinstance(larger_slide, np.ndarray) + assert larger_slide.shape == (3, *larger_dims) + # Overlapping parts match exactly + assert np.array_equal(larger_slide[:, :dims[0], :dims[1]], slide) + # Non-overlapping parts are all empty + empty_fill_value = 0 # fill value seems to depend on the image + assert np.array_equiv(larger_slide[:, dims[0]:, :], empty_fill_value) + assert np.array_equiv(larger_slide[:, :, dims[1]:], empty_fill_value) + + +def test_get_luminance() -> None: + level = 2 # here we only need to test at a single resolution + reader = WSIReader('cuCIM') + slide_obj: CuImage = reader.read(TEST_IMAGE_PATH) + + slide = load_slide_at_level(reader, slide_obj, level) + slide_luminance = get_luminance(slide) + assert isinstance(slide_luminance, np.ndarray) + assert slide_luminance.shape == slide.shape[1:] + assert (slide_luminance <= 255).all() and (slide_luminance >= 0).all() + + tiles, _ = tile_array_2d(slide, tile_size=224, constant_values=255) + tiles_luminance = get_luminance(tiles) + assert isinstance(tiles_luminance, np.ndarray) + assert tiles_luminance.shape == (tiles.shape[0], *tiles.shape[2:]) + assert (tiles_luminance <= 255).all() and (tiles_luminance >= 0).all() + + slide_luminance_tiles, _ = tile_array_2d(np.expand_dims(slide_luminance, axis=0), + tile_size=224, constant_values=255) + assert np.array_equal(slide_luminance_tiles.squeeze(1), tiles_luminance) + + +def test_segment_foreground() -> None: + level = 2 # here we only need to test at a single resolution + reader = WSIReader('cuCIM') + slide_obj: CuImage = reader.read(TEST_IMAGE_PATH) + slide = load_slide_at_level(reader, slide_obj, level) + + auto_mask, auto_threshold = segment_foreground(slide, threshold=None) + assert isinstance(auto_mask, np.ndarray) + assert auto_mask.dtype == bool + assert auto_mask.shape == slide.shape[1:] + assert 0 < auto_mask.sum() < auto_mask.size # auto-seg should not produce trivial mask + luminance = get_luminance(slide) + assert luminance.min() < auto_threshold < luminance.max() + + mask, returned_threshold = segment_foreground(slide, threshold=auto_threshold) + assert isinstance(mask, np.ndarray) + assert mask.dtype == bool + assert mask.shape == slide.shape[1:] + assert np.array_equal(mask, auto_mask) + assert returned_threshold == auto_threshold + + tiles, _ = tile_array_2d(slide, tile_size=224, constant_values=255) + tiles_mask, _ = segment_foreground(tiles, threshold=auto_threshold) + assert isinstance(tiles_mask, np.ndarray) + assert tiles_mask.dtype == bool + assert tiles_mask.shape == (tiles.shape[0], *tiles.shape[2:]) + + slide_mask_tiles, _ = tile_array_2d(np.expand_dims(mask, axis=0), + tile_size=224, constant_values=False) + assert np.array_equal(slide_mask_tiles.squeeze(1), tiles_mask) + + +@pytest.mark.parametrize('level', [1, 2]) +@pytest.mark.parametrize('foreground_threshold', [None, 215]) +def test_get_bounding_box(level: int, foreground_threshold: Optional[float]) -> None: + margin = 0 + reader = WSIReader('cuCIM') + loader = LoadROId(reader, image_key=SlideKey.IMAGE, level=level, margin=margin, + foreground_threshold=foreground_threshold) + slide_obj: CuImage = reader.read(TEST_IMAGE_PATH) + level0_bbox, _ = loader._get_bounding_box(slide_obj) + + highest_level = slide_obj.resolutions['level_count'] - 1 + # level = highest_level + slide = load_slide_at_level(reader, slide_obj, level=level) + scale = slide_obj.resolutions['level_downsamples'][level] + bbox = level0_bbox / scale + assert bbox.x >= 0 and bbox.y >= 0 + assert bbox.x + bbox.w <= slide.shape[1] + assert bbox.y + bbox.h <= slide.shape[2] + + # Now with nonzero margin + margin = 42 + loader_margin = LoadROId(reader, image_key=SlideKey.IMAGE, level=level, margin=margin, + foreground_threshold=foreground_threshold) + level0_bbox_margin, _ = loader_margin._get_bounding_box(slide_obj) + # Here we test the box differences at the highest resolution, because margin is + # specified in low-res pixels. Otherwise could fail due to rounding error. + level0_scale: float = slide_obj.resolutions['level_downsamples'][highest_level] + level0_margin = int(level0_scale * margin) + assert level0_bbox_margin.x == level0_bbox.x - level0_margin + assert level0_bbox_margin.y == level0_bbox.y - level0_margin + assert level0_bbox_margin.w == level0_bbox.w + 2 * level0_margin + assert level0_bbox_margin.h == level0_bbox.h + 2 * level0_margin + + +@pytest.mark.parametrize('level', [1, 2]) +@pytest.mark.parametrize('margin', [0, 42]) +@pytest.mark.parametrize('foreground_threshold', [None, 215]) +def test_load_roi(level: int, margin: int, foreground_threshold: Optional[float]) -> None: + dataset = MockSlidesDataset() + sample = dataset[0] + reader = WSIReader('cuCIM') + loader = LoadROId(reader, image_key=SlideKey.IMAGE, level=level, margin=margin, + foreground_threshold=foreground_threshold) + loaded_sample = loader(sample) + assert isinstance(loaded_sample, dict) + # Check that none of the input keys were removed + assert all(key in loaded_sample for key in sample) + + # Check that the expected new keys were inserted + additional_keys = [SlideKey.ORIGIN, SlideKey.SCALE, SlideKey.FOREGROUND_THRESHOLD] + assert all(key in loaded_sample for key in additional_keys) + + assert isinstance(loaded_sample[SlideKey.IMAGE], np.ndarray) + image_shape = loaded_sample[SlideKey.IMAGE].shape + assert len(image_shape) + assert image_shape[0] == 3 + + origin = loaded_sample[SlideKey.ORIGIN] + assert isinstance(origin, tuple) + assert len(origin) == 2 + assert all(isinstance(coord, int) for coord in origin) + + assert isinstance(loaded_sample[SlideKey.SCALE], (int, float)) + assert loaded_sample[SlideKey.SCALE] >= 1.0 + + assert isinstance(loaded_sample[SlideKey.FOREGROUND_THRESHOLD], (int, float)) diff --git a/Tests/ML/histopathology/test_data/panda_wsi_example.tiff b/Tests/ML/histopathology/test_data/panda_wsi_example.tiff new file mode 100644 index 000000000..0334007f4 --- /dev/null +++ b/Tests/ML/histopathology/test_data/panda_wsi_example.tiff @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06eb0acaa2883181e9b6ab976863f71cc843a75ed9175fae8fe9b879635af1b0 +size 816563 diff --git a/Tests/ML/histopathology/test_data/test_slides_dataset.csv b/Tests/ML/histopathology/test_data/test_slides_dataset.csv new file mode 100644 index 000000000..e9f865454 --- /dev/null +++ b/Tests/ML/histopathology/test_data/test_slides_dataset.csv @@ -0,0 +1,2 @@ +slide_id,image,label,meta1,meta2 +foo,panda_wsi_example.tiff,0,bar,baz \ No newline at end of file diff --git a/environment.yml b/environment.yml index 13073808b..99028b0c8 100644 --- a/environment.yml +++ b/environment.yml @@ -20,6 +20,7 @@ dependencies: - azureml-tensorboard==1.36.0 - conda-merge==0.1.5 - cryptography==3.3.2 + - cucim==21.10.1; platform_system=="Linux" - dataclasses-json==0.5.2 - docker==4.3.1 - flake8==3.8.3 diff --git a/pytest.ini b/pytest.ini index 85174886c..05089f9b9 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,6 @@ [pytest] testpaths=Tests TestsOutsidePackage TestSubmodule -norecursedirs=azure-pipelines docs datasets sphinx-docs InnerEye logs outputs test_data +norecursedirs=azure-pipelines docs sphinx-docs InnerEye logs outputs test_data Tests/ML/datasets addopts=--strict-markers markers= gpu: Test needs a GPU to run