From 2bf271c4cf1b614bc0a176a52a9cb082e4ae2d3b Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Wed, 7 Aug 2024 11:26:45 -0400 Subject: [PATCH 01/24] refactor WSIPatcher --- src/hest/segmentation/SegDataset.py | 3 +- src/hest/segmentation/segmentation.py | 9 +- src/hest/wsi.py | 212 +++++++++++++++++++------- 3 files changed, 160 insertions(+), 64 deletions(-) diff --git a/src/hest/segmentation/SegDataset.py b/src/hest/segmentation/SegDataset.py index b9de230..7cbbc87 100644 --- a/src/hest/segmentation/SegDataset.py +++ b/src/hest/segmentation/SegDataset.py @@ -55,13 +55,12 @@ def __init__(self, patcher: WSIPatcher, transform): self.patcher = patcher self.cols, self.rows = self.patcher.get_cols_rows() - self.size = self.cols * self.rows self.transform = transform def __len__(self): - return self.size + return len(self.patcher) def __getitem__(self, index): col = index % self.cols diff --git a/src/hest/segmentation/segmentation.py b/src/hest/segmentation/segmentation.py index 0a99d54..263a1ed 100644 --- a/src/hest/segmentation/segmentation.py +++ b/src/hest/segmentation/segmentation.py @@ -63,16 +63,13 @@ def segment_tissue_deep( patch_size_deeplab = 512 - # TODO fix overlap - overlap=0 - scale = pixel_size_src / target_pxl_size patch_size_src = round(patch_size_um / scale) wsi = wsi_factory(wsi) weights_path = get_path_relative(__file__, f'../../../models/{model_name}') - patcher = WSIPatcher(wsi, patch_size_src, patch_size_deeplab) + patcher = wsi.create_patcher(patch_size_src, patch_size_deeplab) eval_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) dataset = SegWSIDataset(patcher, eval_transforms) @@ -134,8 +131,8 @@ def segment_tissue_deep( coord = coords[i] x, y = round(coord[0] * src_to_deeplab_scale), round(coord[1] * src_to_deeplab_scale) - y_end = min(y+patch_size_deeplab + overlap, height) - x_end = min(x+patch_size_deeplab + overlap, width) + y_end = min(y+patch_size_deeplab, height) + x_end = min(x+patch_size_deeplab, width) stitched_img[y:y_end, x:x_end] += pred[:y_end-y, :x_end-x] diff --git a/src/hest/wsi.py b/src/hest/wsi.py index e922573..a40b83f 100644 --- a/src/hest/wsi.py +++ b/src/hest/wsi.py @@ -3,9 +3,9 @@ from typing import Tuple import cv2 +import geopandas as gpd import numpy as np import openslide -from openslide.deepzoom import DeepZoomGenerator class CucimWarningSingleton: @@ -29,6 +29,7 @@ def is_cuimage(img): class WSI: + def __init__(self, img): self.img = img @@ -58,6 +59,10 @@ def __repr__(self) -> str: return f"" + @abstractmethod + def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None): + pass + def wsi_factory(img) -> WSI: try: @@ -102,6 +107,9 @@ def read_region(self, location, level, size) -> np.ndarray: def get_thumbnail(self, width, height) -> np.ndarray: return cv2.resize(self.img, (width, height)) + def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None): + return NumpyWSIPatcher(self, patch_size_src, patch_size_target, overlap, mask) + class OpenSlideWSI(WSI): def __init__(self, img: openslide.OpenSlide): @@ -128,6 +136,9 @@ def level_dimensions(self): def level_downsamples(self): return self.img.level_downsamples + def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None): + return OpenSlideWSIPatcher(self, patch_size_src, patch_size_target, overlap, mask) + class CuImageWSI(WSI): def __init__(self, img: 'CuImage'): super().__init__(img) @@ -171,44 +182,110 @@ def level_dimensions(self): def level_downsamples(self): return self.img.resolutions['level_downsamples'] + + def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None): + return CuImageWSIPatcher(self, patch_size_src, patch_size_target, overlap, mask) class WSIPatcher: - def __init__(self, wsi: WSI, patch_size_src: int, patch_size_target: int = None): + """ Iterator class to handle patching, patch scaling and tissue mask intersection """ + + def __init__( + self, + wsi: WSI, + patch_size: int, + patch_size_target: int = None, + overlap: int = 0, + mask: gpd.GeoDataFrame = None + ): + """ Initialize patcher, compute number of (masked) rows, columns. + + Args: + wsi (WSI): wsi to patch + patch_size (int): patch width/height in pixel on the slide before rescaling + patch_size_target (int, optional): largest patch size in pixel after rescaling. Defaults to None. + overlap (int, optional): overlap size in pixel before rescaling. Defaults to 0. + mask (gpd.GeoDataFrame, optional): geopandas dataframe of Polygons. Defaults to None. + """ self.wsi = wsi - self.patch_size_src = patch_size_src - self.overlap = 0 + self.patch_size = patch_size + self.overlap = overlap self.width, self.height = self.wsi.get_dimensions() self.patch_size_target = patch_size_target - self.downsample = patch_size_src / patch_size_target - - self._compute_cols_rows() - - def _compute_cols_rows(self) -> None: - img = self.wsi.img - if isinstance(img, openslide.OpenSlide): - self.level = self.wsi.get_best_level_for_downsample(self.downsample) - self.level_dimensions = self.wsi.level_dimensions()[self.level] - self.level_downsample = self.wsi.level_downsamples()[self.level] - self.patch_size_level = round(self.patch_size_src / self.level_downsample) - self.dz = DeepZoomGenerator(img, self.patch_size_level, self.overlap) - self.nb_levels = len(self.dz.level_tiles) - self.cols, self.rows = self.dz.level_tiles[self.nb_levels - self.level - 1] - elif isinstance(img, np.ndarray): - self.cols, self.rows = round(np.ceil((self.width - self.overlap / 2) / (self.patch_size_src - self.overlap / 2))), round(np.ceil((self.height - self.overlap / 2) / (self.patch_size_src - self.overlap / 2))) - self.level = -1 - self.level_dimensions = (self.width, self.height) - elif is_cuimage(img): - self.level = self.wsi.get_best_level_for_downsample(self.downsample) - self.level_downsample = self.wsi.level_downsamples()[self.level] - self.level_dimensions = self.wsi.level_dimensions()[self.level] - self.patch_size_level = round(self.patch_size_src / self.level_downsample) - level_width, level_height = self.level_dimensions - self.cols, self.rows = round(np.ceil((level_width - self.overlap / 2) / (self.patch_size_level - self.overlap / 2))), round(np.ceil((level_height - self.overlap / 2) / (self.patch_size_level - self.overlap / 2))) + self.mask = mask + self.i = 0 + + if patch_size_target is None: + self.downsample = 1. + else: + self.downsample = patch_size / patch_size_target + + self.level, self.patch_size_level, self.overlap_level = self._prepare() + self.cols, self.rows = self._compute_cols_rows() + + col_rows = np.array([ + [col, row] + for col in range(self.cols) + for row in range(self.rows) + ]) + + if self.mask is not None: + self.valid_patches_nb, self.valid_col_rows = self._compute_masked(col_rows) + else: + self.valid_patches_nb, self.valid_col_rows = len(col_rows), col_rows + + def _colrow_to_xy(self, col, row): + """ Convert col row of a tile to its top-left coordinates before rescaling (x, y) """ + x = col * (self.patch_size) if col == 0 else col * (self.patch_size) - self.overlap + y = row * (self.patch_size) if row == 0 else row * (self.patch_size) - self.overlap + return (x, y) + + def _compute_masked(self, col_rows) -> None: + """ Compute tiles which center falls under the tissue """ + + xy_topleft = np.array([self._colrow_to_xy(xy[0], xy[1]) for xy in col_rows]) + + # Note: we don't take into account the overlap size we calculating centers + xy_centers = xy_topleft + self.patch_size_level // 2 + + union_mask = self.mask.unary_union + + points = gpd.points_from_xy(xy_centers) + valid_mask = gpd.GeoSeries(points).within(union_mask).values + valid_patches_nb = valid_mask.sum() + valid_col_rows = col_rows[valid_mask] + return valid_patches_nb, valid_col_rows + + def __len__(self): + return self.valid_patches_nb + + def __iter__(self): + self.i = 0 + return self + def __next__(self): + if self.i >= self.valid_patches_nb: + raise StopIteration + tile, x, y = self.__getitem__(self.i) + self.i += 1 + return tile, x, y + + def __getitem__(self, index): + if 0 <= index < len(self.valid_col_rows): + col_row = self.valid_col_rows[self.i] + col, row = col_row[0], col_row[1] + tile, x, y = self.get_tile(col, row) + return tile, x, y + else: + raise IndexError("Index out of range") + + @abstractmethod + def _prepare(self) -> None: + pass + def get_cols_rows(self) -> Tuple[int, int]: - """ Get the number of columns and rows the associated WSI + """ Get the number of columns and rows in the associated WSI Returns: Tuple[int, int]: (nb_columns, nb_rows) @@ -223,33 +300,56 @@ def get_tile(self, col: int, row: int) -> Tuple[np.ndarray, int, int]: row (int): row Returns: - Tuple[np.ndarray, int, int]: (tile, pixel x of top-left corner, pixel_y of top-left corner) + Tuple[np.ndarray, int, int]: (tile, pixel x of top-left corner (before rescaling), pixel_y of top-left corner (before rescaling)) """ - img = self.wsi.img - if isinstance(img, openslide.OpenSlide): - raw_tile = self.dz.get_tile(self.nb_levels - self.level - 1, (col, row)) - addr = self.dz.get_tile_coordinates(self.nb_levels - self.level - 1, (col, row)) - pxl_x, pxl_y = addr[0] - if pxl_x == 556 and pxl_y == 556: - a = 1 - elif isinstance(img, np.ndarray): - x_begin = round(col * (self.patch_size_src - self.overlap)) - x_end = min(x_begin + self.patch_size_src + self.overlap, self.width) - y_begin = round(row * (self.patch_size_src - self.overlap)) - y_end = min(y_begin + self.patch_size_src + self.overlap, self.height) - tmp_tile = np.zeros((self.patch_size_src, self.patch_size_src, 3), dtype=np.uint8) - tmp_tile[:y_end-y_begin, :x_end-x_begin] += img[y_begin:y_end, x_begin:x_end] - pxl_x, pxl_y = x_begin, y_begin - raw_tile = tmp_tile - elif is_cuimage(img): - x_begin = round(col * (self.patch_size_src - self.overlap)) - y_begin = round(row * (self.patch_size_src - self.overlap)) - raw_tile = self.wsi.read_region(location=(x_begin, y_begin), level=self.level, size=(self.patch_size_level + self.overlap, self.patch_size_level + self.overlap)) - pxl_x = x_begin - pxl_y = y_begin - + x, y = self._colrow_to_xy(col, row) + raw_tile = self.wsi.read_region(location=(x, y), level=self.level, size=(self.patch_size_level, self.patch_size_level)) tile = np.array(raw_tile) if self.patch_size_target is not None: tile = cv2.resize(tile, (self.patch_size_target, self.patch_size_target)) - assert pxl_x < self.width and pxl_y < self.height - return tile, pxl_x, pxl_y \ No newline at end of file + assert x < self.width and y < self.height + return tile, x, y + + def _compute_cols_rows(self) -> Tuple[int, int]: + col = 0 + row = 0 + x, y = self._colrow_to_xy(col, row) + while x < self.width: + col += 1 + x, _ = self._colrow_to_xy(col, row) + cols = col - 1 + while y < self.height: + row += 1 + _, y = self._colrow_to_xy(col, row) + rows = row - 1 + return cols, rows + + +class OpenSlideWSIPatcher(WSIPatcher): + wsi: OpenSlideWSI + + def _prepare(self) -> None: + level = self.wsi.get_best_level_for_downsample(self.downsample) + level_downsample = self.wsi.level_downsamples()[level] + patch_size_level = round(self.patch_size / level_downsample) + overlap_level = round(self.overlap / level_downsample) + return level, patch_size_level, overlap_level + +class CuImageWSIPatcher(WSIPatcher): + wsi: CuImageWSI + + def _prepare(self) -> None: + level = self.wsi.get_best_level_for_downsample(self.downsample) + level_downsample = self.wsi.level_downsamples()[level] + patch_size_level = round(self.patch_size / level_downsample) + overlap_level = round(self.overlap / level_downsample) + return level, patch_size_level, overlap_level + +class NumpyWSIPatcher(WSIPatcher): + WSI: NumpyWSI + + def _prepare(self) -> None: + patch_size_level = self.patch_size + overlap_level = self.overlap + level = -1 + return level, patch_size_level, overlap_level \ No newline at end of file From aadfb153639617da04b51b2cd9a921a022bb3cef Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Wed, 7 Aug 2024 11:52:54 -0400 Subject: [PATCH 02/24] put holes directly inside polygons --- src/hest/HESTData.py | 13 +++---------- src/hest/LazyShapes.py | 14 +++----------- src/hest/io/seg_readers.py | 6 ++---- src/hest/segmentation/segmentation.py | 26 +++++++++++++------------- 4 files changed, 21 insertions(+), 38 deletions(-) diff --git a/src/hest/HESTData.py b/src/hest/HESTData.py index e6ce886..8089d75 100644 --- a/src/hest/HESTData.py +++ b/src/hest/HESTData.py @@ -29,7 +29,7 @@ from .segmentation.segmentation import (apply_otsu_thresholding, contours_to_img, get_tissue_vis, - mask_to_contours, save_pkl, + mask_to_gdf, save_pkl, segment_tissue_deep) from .utils import (ALIGNED_HE_FILENAME, check_arg, deprecated, find_first_file_endswith, get_path_from_meta_row, @@ -261,21 +261,14 @@ def segment_tissue( tissue_mask = np.round(cv2.resize(mask, (width, height))).astype(np.uint8) #TODO directly convert to gpd - gdf_contours = mask_to_contours(tissue_mask, pixel_size=self.pixel_size) + gdf_contours = mask_to_gdf(tissue_mask, pixel_size=self.pixel_size) self._tissue_contours = gdf_contours return self.tissue_contours def save_tissue_contours(self, save_dir: str, name: str) -> None: - write_geojson( - self.tissue_contours, - os.path.join(save_dir, name + '_contours.geojson'), - 'tissue_id', - extra_prop=True, - index_key='hole' - ) - + self.tissue_contours.to_file(os.path.join(save_dir, name + '_contours.geojson'), driver="GeoJSON") @deprecated def get_tissue_mask(self) -> np.ndarray: diff --git a/src/hest/LazyShapes.py b/src/hest/LazyShapes.py index 0927c7b..8905382 100644 --- a/src/hest/LazyShapes.py +++ b/src/hest/LazyShapes.py @@ -45,19 +45,11 @@ def convert_old_to_gpd(contours_holes, contours_tissue) -> gpd.GeoDataFrame: types = [] for i in range(len(contours_holes)): tissue = contours_tissue[i] - shapes.append(Polygon(tissue[:, 0, :])) tissue_ids.append(i) - types.append('tissue') - holes = contours_holes[i] - if len(holes) > 0: - for hole in holes: - shapes.append(Polygon(hole[:, 0, :])) - tissue_ids.append(i) - types.append('hole') - + holes = contours_holes[i] if len(contours_holes[i]) > 0 else None + shapes.append(Polygon(tissue[:, 0, :]), holes=holes) + df = pd.DataFrame(tissue_ids, columns=['tissue_id']) - df['hole'] = types - df['hole'] = df['hole'] == 'hole' return gpd.GeoDataFrame(df, geometry=shapes) \ No newline at end of file diff --git a/src/hest/io/seg_readers.py b/src/hest/io/seg_readers.py index 94e7ec1..c77cf8d 100644 --- a/src/hest/io/seg_readers.py +++ b/src/hest/io/seg_readers.py @@ -97,10 +97,8 @@ def read_gdf(self, path) -> gpd.GeoDataFrame: class TissueContourReader(GDFReader): - def read_gdf(self, path) -> gpd.GeoDataFrame: - - gdf = _read_geojson(path, class_name='tissue_id', index_key='hole') - + def read_gdf(self, path) -> gpd.GeoDataFrame: + gdf = gpd.read_file(path) return gdf diff --git a/src/hest/segmentation/segmentation.py b/src/hest/segmentation/segmentation.py index 263a1ed..07aadd0 100644 --- a/src/hest/segmentation/segmentation.py +++ b/src/hest/segmentation/segmentation.py @@ -138,7 +138,7 @@ def segment_tissue_deep( mask = (stitched_img > 0).astype(np.uint8) - gdf_contours = mask_to_contours(mask, max_nb_holes=5, pixel_size=pixel_size_src, contour_scale=1 / src_to_deeplab_scale) + gdf_contours = mask_to_gdf(mask, max_nb_holes=5, pixel_size=pixel_size_src, contour_scale=1 / src_to_deeplab_scale) return gdf_contours @@ -204,13 +204,14 @@ def contours_to_img( for _, row in group.iterrows(): cont = np.array([[round(x * downsample), round(y * downsample)] for x, y in row.geometry.exterior.coords]) + holes = np.array([[round(x * downsample), round(y * downsample)] for hole in row.geometry.interiors for x, y in hole.coords]) - if row['hole']: - draw_cont_fill(image=img, contours=[cont], color=(0, 0, 0)) - else: - draw_cont_fill(image=img, contours=[cont], color=line_color) + draw_cont_fill(image=img, contours=[cont], color=line_color) if draw_contours: draw_cont(image=img, contours=[cont], color=line_color) + + if len(holes) > 0: + draw_cont_fill(image=img, contours=[holes], color=(0, 0, 0)) return img @@ -432,7 +433,7 @@ def filter_contours(contours, hierarchy, filter_params, scale, pixel_size): return foreground_contours, hole_contours -def mask_to_contours(mask: np.ndarray, keep_ids = [], exclude_ids=[], max_nb_holes=0, min_contour_area=1000, pixel_size=1, contour_scale=1.): +def mask_to_gdf(mask: np.ndarray, keep_ids = [], exclude_ids=[], max_nb_holes=0, min_contour_area=1000, pixel_size=1, contour_scale=1.): TARGET_EDGE_SIZE = 2000 scale = TARGET_EDGE_SIZE / mask.shape[0] @@ -471,15 +472,14 @@ def mask_to_contours(mask: np.ndarray, keep_ids = [], exclude_ids=[], max_nb_hol else: contour_ids = set(np.arange(len(contours_tissue))) - set(exclude_ids) - tissue_poly = [Polygon(contours_tissue[i].squeeze(1)) for i in contour_ids] - hole_poly = [Polygon(contours_holes[i][0].squeeze(1)) for i in contour_ids if len(contours_holes[i]) > 0] - geometry = tissue_poly + hole_poly tissue_ids = [i for i in contour_ids] + [i for i in contour_ids if len(contours_holes[i]) > 0] - tissue_types = ['tissue' for _ in contour_ids] + ['hole' for i in contour_ids if len(contours_holes[i]) > 0] + polygons = [] + for i in contour_ids: + holes = contours_holes[i][0].squeeze(1) if len(contours_holes[i]) > 0 else None + polygon = Polygon(contours_tissue[i].squeeze(1), holes=holes) + polygons.append(polygon) - gdf_contours = gpd.GeoDataFrame(pd.DataFrame(tissue_ids, columns=['tissue_id']), geometry=geometry) - gdf_contours['hole'] = tissue_types - gdf_contours['hole'] = gdf_contours['hole'] == 'hole' + gdf_contours = gpd.GeoDataFrame(pd.DataFrame(tissue_ids, columns=['tissue_id']), geometry=polygons) return gdf_contours From 653cb73c104ebd964cb81c451aa14aa0cef72942 Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Wed, 7 Aug 2024 12:09:54 -0400 Subject: [PATCH 03/24] correct typo in getitem --- src/hest/wsi.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/hest/wsi.py b/src/hest/wsi.py index a40b83f..ad7cc98 100644 --- a/src/hest/wsi.py +++ b/src/hest/wsi.py @@ -60,7 +60,7 @@ def __repr__(self) -> str: return f"" @abstractmethod - def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None): + def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): pass @@ -107,8 +107,8 @@ def read_region(self, location, level, size) -> np.ndarray: def get_thumbnail(self, width, height) -> np.ndarray: return cv2.resize(self.img, (width, height)) - def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None): - return NumpyWSIPatcher(self, patch_size_src, patch_size_target, overlap, mask) + def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): + return NumpyWSIPatcher(self, patch_size_src, patch_size_target, overlap, mask, coords_only) class OpenSlideWSI(WSI): @@ -136,8 +136,8 @@ def level_dimensions(self): def level_downsamples(self): return self.img.level_downsamples - def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None): - return OpenSlideWSIPatcher(self, patch_size_src, patch_size_target, overlap, mask) + def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): + return OpenSlideWSIPatcher(self, patch_size_src, patch_size_target, overlap, mask, coords_only) class CuImageWSI(WSI): def __init__(self, img: 'CuImage'): @@ -183,8 +183,8 @@ def level_dimensions(self): def level_downsamples(self): return self.img.resolutions['level_downsamples'] - def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None): - return CuImageWSIPatcher(self, patch_size_src, patch_size_target, overlap, mask) + def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): + return CuImageWSIPatcher(self, patch_size_src, patch_size_target, overlap, mask, coords_only) class WSIPatcher: @@ -196,7 +196,8 @@ def __init__( patch_size: int, patch_size_target: int = None, overlap: int = 0, - mask: gpd.GeoDataFrame = None + mask: gpd.GeoDataFrame = None, + coords_only = False ): """ Initialize patcher, compute number of (masked) rows, columns. @@ -206,6 +207,7 @@ def __init__( patch_size_target (int, optional): largest patch size in pixel after rescaling. Defaults to None. overlap (int, optional): overlap size in pixel before rescaling. Defaults to 0. mask (gpd.GeoDataFrame, optional): geopandas dataframe of Polygons. Defaults to None. + coords_only (bool, optional): whenever to extract only the coordinates insteaf of coordinates + tile. Default to False. """ self.wsi = wsi self.patch_size = patch_size @@ -214,6 +216,7 @@ def __init__( self.patch_size_target = patch_size_target self.mask = mask self.i = 0 + self.coords_only = coords_only if patch_size_target is None: self.downsample = 1. @@ -266,14 +269,16 @@ def __iter__(self): def __next__(self): if self.i >= self.valid_patches_nb: raise StopIteration - tile, x, y = self.__getitem__(self.i) + x = self.__getitem__(self.i) self.i += 1 - return tile, x, y + return x def __getitem__(self, index): if 0 <= index < len(self.valid_col_rows): - col_row = self.valid_col_rows[self.i] + col_row = self.valid_col_rows[index] col, row = col_row[0], col_row[1] + if self.coords_only: + return self._colrow_to_xy(col, row) tile, x, y = self.get_tile(col, row) return tile, x, y else: From 801655a60ef5a9189e50dac3089cb3c83e4e643e Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Wed, 7 Aug 2024 14:17:23 -0400 Subject: [PATCH 04/24] refactor patcher arguments --- src/hest/segmentation/segmentation.py | 16 +++++------ src/hest/wsi.py | 41 ++++++++++++++------------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/src/hest/segmentation/segmentation.py b/src/hest/segmentation/segmentation.py index 07aadd0..4a6f856 100644 --- a/src/hest/segmentation/segmentation.py +++ b/src/hest/segmentation/segmentation.py @@ -26,7 +26,7 @@ def segment_tissue_deep( wsi: Union[np.ndarray, openslide.OpenSlide, CuImage, WSI], # type: ignore pixel_size: float, fast_mode=False, - target_pxl_size=1, + dst_pixel_size=1, patch_size_um=512, model_name='deeplabv3_seg_v4.ckpt', batch_size=8, @@ -40,7 +40,7 @@ def segment_tissue_deep( pixel_size (float): pixel size in um/px for the wsi fast_mode (bool, optional): in fast mode the inference is done at 2 um/px instead of 1 um/px, note that the inference pixel size is overwritten by the `target_pxl_size` argument if != 1. Defaults to False. - target_pxl_size (int, optional): patches are scaled to this pixel size in um/px for inference. Defaults to 1. + dst_pixel_size (int, optional): patches are scaled to this pixel size in um/px for inference. Defaults to 1. patch_size_um (int, optional): patch size in um. Defaults to 512. model_name (str, optional): model name in `HEST/models` dir. Defaults to 'deeplabv3_seg_v4.ckpt'. batch_size (int, optional): batch size for inference. Defaults to 8. @@ -56,20 +56,20 @@ def segment_tissue_deep( from torchvision import transforms from hest.segmentation.SegDataset import SegWSIDataset - pixel_size_src = pixel_size + src_pixel_size = pixel_size - if fast_mode and target_pxl_size == 1: - target_pxl_size = 2 + if fast_mode and dst_pixel_size == 1: + dst_pixel_size = 2 patch_size_deeplab = 512 - scale = pixel_size_src / target_pxl_size + scale = src_pixel_size / dst_pixel_size patch_size_src = round(patch_size_um / scale) wsi = wsi_factory(wsi) weights_path = get_path_relative(__file__, f'../../../models/{model_name}') - patcher = wsi.create_patcher(patch_size_src, patch_size_deeplab) + patcher = wsi.create_patcher(patch_size_deeplab, src_pixel_size, dst_pixel_size) eval_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) dataset = SegWSIDataset(patcher, eval_transforms) @@ -138,7 +138,7 @@ def segment_tissue_deep( mask = (stitched_img > 0).astype(np.uint8) - gdf_contours = mask_to_gdf(mask, max_nb_holes=5, pixel_size=pixel_size_src, contour_scale=1 / src_to_deeplab_scale) + gdf_contours = mask_to_gdf(mask, max_nb_holes=5, pixel_size=src_pixel_size, contour_scale=1 / src_to_deeplab_scale) return gdf_contours diff --git a/src/hest/wsi.py b/src/hest/wsi.py index ad7cc98..762aead 100644 --- a/src/hest/wsi.py +++ b/src/hest/wsi.py @@ -60,7 +60,7 @@ def __repr__(self) -> str: return f"" @abstractmethod - def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): + def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): pass @@ -107,8 +107,8 @@ def read_region(self, location, level, size) -> np.ndarray: def get_thumbnail(self, width, height) -> np.ndarray: return cv2.resize(self.img, (width, height)) - def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): - return NumpyWSIPatcher(self, patch_size_src, patch_size_target, overlap, mask, coords_only) + def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): + return NumpyWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only) class OpenSlideWSI(WSI): @@ -136,8 +136,8 @@ def level_dimensions(self): def level_downsamples(self): return self.img.level_downsamples - def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): - return OpenSlideWSIPatcher(self, patch_size_src, patch_size_target, overlap, mask, coords_only) + def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): + return OpenSlideWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only) class CuImageWSI(WSI): def __init__(self, img: 'CuImage'): @@ -183,8 +183,8 @@ def level_dimensions(self): def level_downsamples(self): return self.img.resolutions['level_downsamples'] - def create_patcher(self, patch_size_src: int, patch_size_target: int = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): - return CuImageWSIPatcher(self, patch_size_src, patch_size_target, overlap, mask, coords_only) + def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): + return CuImageWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only) class WSIPatcher: @@ -194,7 +194,8 @@ def __init__( self, wsi: WSI, patch_size: int, - patch_size_target: int = None, + src_pixel_size: float, + dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False @@ -203,25 +204,27 @@ def __init__( Args: wsi (WSI): wsi to patch - patch_size (int): patch width/height in pixel on the slide before rescaling - patch_size_target (int, optional): largest patch size in pixel after rescaling. Defaults to None. + patch_size (int): patch width/height in pixel on the slide after rescaling + src_pixel_size (float, optional): pixel size in um/px of the slide before rescaling. Defaults to None. + dst_pixel_size (float, optional): pixel size in um/px of the slide after rescaling. Defaults to None. overlap (int, optional): overlap size in pixel before rescaling. Defaults to 0. mask (gpd.GeoDataFrame, optional): geopandas dataframe of Polygons. Defaults to None. coords_only (bool, optional): whenever to extract only the coordinates insteaf of coordinates + tile. Default to False. """ self.wsi = wsi - self.patch_size = patch_size self.overlap = overlap self.width, self.height = self.wsi.get_dimensions() - self.patch_size_target = patch_size_target + self.patch_size_target = patch_size self.mask = mask self.i = 0 self.coords_only = coords_only - if patch_size_target is None: + if dst_pixel_size is None: self.downsample = 1. else: - self.downsample = patch_size / patch_size_target + self.downsample = dst_pixel_size / src_pixel_size + + self.patch_size_src = round(patch_size * self.downsample) self.level, self.patch_size_level, self.overlap_level = self._prepare() self.cols, self.rows = self._compute_cols_rows() @@ -239,8 +242,8 @@ def __init__( def _colrow_to_xy(self, col, row): """ Convert col row of a tile to its top-left coordinates before rescaling (x, y) """ - x = col * (self.patch_size) if col == 0 else col * (self.patch_size) - self.overlap - y = row * (self.patch_size) if row == 0 else row * (self.patch_size) - self.overlap + x = col * (self.patch_size_src) if col == 0 else col * (self.patch_size_src) - self.overlap + y = row * (self.patch_size_src) if row == 0 else row * (self.patch_size_src) - self.overlap return (x, y) def _compute_masked(self, col_rows) -> None: @@ -336,7 +339,7 @@ class OpenSlideWSIPatcher(WSIPatcher): def _prepare(self) -> None: level = self.wsi.get_best_level_for_downsample(self.downsample) level_downsample = self.wsi.level_downsamples()[level] - patch_size_level = round(self.patch_size / level_downsample) + patch_size_level = round(self.patch_size_src / level_downsample) overlap_level = round(self.overlap / level_downsample) return level, patch_size_level, overlap_level @@ -346,7 +349,7 @@ class CuImageWSIPatcher(WSIPatcher): def _prepare(self) -> None: level = self.wsi.get_best_level_for_downsample(self.downsample) level_downsample = self.wsi.level_downsamples()[level] - patch_size_level = round(self.patch_size / level_downsample) + patch_size_level = round(self.patch_size_src / level_downsample) overlap_level = round(self.overlap / level_downsample) return level, patch_size_level, overlap_level @@ -354,7 +357,7 @@ class NumpyWSIPatcher(WSIPatcher): WSI: NumpyWSI def _prepare(self) -> None: - patch_size_level = self.patch_size + patch_size_level = self.patch_size_src overlap_level = self.overlap level = -1 return level, patch_size_level, overlap_level \ No newline at end of file From 528ca8f7288b3cf654fbce4f007291787944bc80 Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Wed, 7 Aug 2024 21:50:20 -0400 Subject: [PATCH 05/24] commit before erasing history --- src/hest/HESTData.py | 16 +-- src/hest/io/seg_readers.py | 2 +- src/hest/segmentation/SegDataset.py | 9 +- src/hest/segmentation/segmentation.py | 71 ---------- src/hest/wsi.py | 190 +++++++++++++++++++++----- 5 files changed, 163 insertions(+), 125 deletions(-) diff --git a/src/hest/HESTData.py b/src/hest/HESTData.py index 8089d75..e79836a 100644 --- a/src/hest/HESTData.py +++ b/src/hest/HESTData.py @@ -8,29 +8,25 @@ import cv2 import geopandas as gpd -import matplotlib import numpy as np -from hest.io.seg_readers import (TissueContourReader, - write_geojson) +from hest.io.seg_readers import TissueContourReader, write_geojson from hest.LazyShapes import LazyShapes, convert_old_to_gpd from hest.segmentation.TissueMask import TissueMask, load_tissue_mask -from hest.wsi import WSI, CucimWarningSingleton, NumpyWSI, wsi_factory +from hest.wsi import (WSI, CucimWarningSingleton, NumpyWSI, contours_to_img, + get_tissue_vis, wsi_factory) try: import openslide except Exception: print("Couldn't import openslide, verify that openslide is installed on your system, https://openslide.org/download/") import pandas as pd -from matplotlib.collections import PatchCollection from PIL import Image from shapely import Point from tqdm import tqdm -from .segmentation.segmentation import (apply_otsu_thresholding, - contours_to_img, get_tissue_vis, - mask_to_gdf, save_pkl, - segment_tissue_deep) +from .segmentation.segmentation import (apply_otsu_thresholding, mask_to_gdf, + save_pkl, segment_tissue_deep) from .utils import (ALIGNED_HE_FILENAME, check_arg, deprecated, find_first_file_endswith, get_path_from_meta_row, plot_verify_pixel_size, tiff_save, verify_paths) @@ -308,7 +304,9 @@ def dump_patches( use_mask (bool, optional): whenever to take into account the tissue mask. Defaults to True. """ + import matplotlib import matplotlib.pyplot as plt + from matplotlib.collections import PatchCollection adata = self.adata.copy() diff --git a/src/hest/io/seg_readers.py b/src/hest/io/seg_readers.py index c77cf8d..c863c52 100644 --- a/src/hest/io/seg_readers.py +++ b/src/hest/io/seg_readers.py @@ -98,7 +98,7 @@ def read_gdf(self, path) -> gpd.GeoDataFrame: class TissueContourReader(GDFReader): def read_gdf(self, path) -> gpd.GeoDataFrame: - gdf = gpd.read_file(path) + gdf = _read_geojson(path, 'tissue_id') return gdf diff --git a/src/hest/segmentation/SegDataset.py b/src/hest/segmentation/SegDataset.py index 7cbbc87..f61abd0 100644 --- a/src/hest/segmentation/SegDataset.py +++ b/src/hest/segmentation/SegDataset.py @@ -47,15 +47,10 @@ def __getitem__(self, index): class SegWSIDataset(Dataset): - masks = [] - patches = [] - coords = [] def __init__(self, patcher: WSIPatcher, transform): self.patcher = patcher - self.cols, self.rows = self.patcher.get_cols_rows() - self.transform = transform @@ -63,9 +58,7 @@ def __len__(self): return len(self.patcher) def __getitem__(self, index): - col = index % self.cols - row = index // self.cols - tile, x, y = self.patcher.get_tile(col, row) + tile, x, y = self.patcher[index] if self.transform: tile = self.transform(tile) diff --git a/src/hest/segmentation/segmentation.py b/src/hest/segmentation/segmentation.py index 4a6f856..d94cda8 100644 --- a/src/hest/segmentation/segmentation.py +++ b/src/hest/segmentation/segmentation.py @@ -1,7 +1,6 @@ from __future__ import annotations import pickle -from functools import partial from typing import Union import cv2 @@ -186,76 +185,6 @@ def keep_largest_area(mask: np.ndarray) -> np.ndarray: largest_mask[label_image == largest_label] = True mask[~largest_mask] = 0 return mask - - -def contours_to_img( - contours: gpd.GeoDataFrame, - img: np.ndarray, - draw_contours=False, - thickness=1, - downsample=1., - line_color=(0, 255, 0) -) -> np.ndarray: - draw_cont = partial(cv2.drawContours, contourIdx=-1, thickness=thickness, lineType=cv2.LINE_8) - draw_cont_fill = partial(cv2.drawContours, contourIdx=-1, thickness=cv2.FILLED) - - groups = contours.groupby('tissue_id') - for _, group in groups: - - for _, row in group.iterrows(): - cont = np.array([[round(x * downsample), round(y * downsample)] for x, y in row.geometry.exterior.coords]) - holes = np.array([[round(x * downsample), round(y * downsample)] for hole in row.geometry.interiors for x, y in hole.coords]) - - draw_cont_fill(image=img, contours=[cont], color=line_color) - if draw_contours: - draw_cont(image=img, contours=[cont], color=line_color) - - if len(holes) > 0: - draw_cont_fill(image=img, contours=[holes], color=(0, 0, 0)) - return img - - -def get_tissue_vis( - img: Union[np.ndarray, openslide.OpenSlide, CuImage, WSI], - tissue_contours: gpd.GeoDataFrame, - line_color=(0, 255, 0), - line_thickness=5, - target_width=1000, - seg_display=True, - ) -> Image: - tissue_contours = tissue_contours.copy() - - wsi = wsi_factory(img) - - width, height = wsi.get_dimensions() - downsample = target_width / width - - top_left = (0,0) - - img = wsi.get_thumbnail(round(width * downsample), round(height * downsample)) - - if tissue_contours is None: - return Image.fromarray(img) - - downscaled_mask = np.zeros(img.shape[:2], dtype=np.uint8) - downscaled_mask = np.expand_dims(downscaled_mask, axis=-1) - downscaled_mask = downscaled_mask * np.array([0, 0, 0]).astype(np.uint8) - - if tissue_contours is not None and seg_display: - downscaled_mask = contours_to_img( - tissue_contours, - downscaled_mask, - draw_contours=True, - thickness=line_thickness, - downsample=downsample, - line_color=line_color - ) - - alpha = 0.4 - img = cv2.addWeighted(img, 1 - alpha, downscaled_mask, alpha, 0) - img = img.astype(np.uint8) - - return Image.fromarray(img) @deprecated diff --git a/src/hest/wsi.py b/src/hest/wsi.py index 762aead..01baa0c 100644 --- a/src/hest/wsi.py +++ b/src/hest/wsi.py @@ -1,11 +1,15 @@ +from __future__ import annotations + import warnings from abc import abstractmethod -from typing import Tuple +from functools import partial +from typing import Tuple, Union import cv2 import geopandas as gpd import numpy as np import openslide +from PIL import Image class CucimWarningSingleton: @@ -60,7 +64,7 @@ def __repr__(self) -> str: return f"" @abstractmethod - def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): + def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False) -> WSIPatcher: pass @@ -107,7 +111,7 @@ def read_region(self, location, level, size) -> np.ndarray: def get_thumbnail(self, width, height) -> np.ndarray: return cv2.resize(self.img, (width, height)) - def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): + def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False) -> WSIPatcher: return NumpyWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only) @@ -136,7 +140,7 @@ def level_dimensions(self): def level_downsamples(self): return self.img.level_downsamples - def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): + def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False) -> WSIPatcher: return OpenSlideWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only) class CuImageWSI(WSI): @@ -183,7 +187,7 @@ def level_dimensions(self): def level_downsamples(self): return self.img.resolutions['level_downsamples'] - def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False): + def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False) -> WSIPatcher: return CuImageWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only) @@ -198,7 +202,8 @@ def __init__( dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, - coords_only = False + coords_only = False, + custom_coords = None ): """ Initialize patcher, compute number of (masked) rows, columns. @@ -218,6 +223,7 @@ def __init__( self.mask = mask self.i = 0 self.coords_only = coords_only + self.custom_coords = custom_coords if dst_pixel_size is None: self.downsample = 1. @@ -226,41 +232,45 @@ def __init__( self.patch_size_src = round(patch_size * self.downsample) - self.level, self.patch_size_level, self.overlap_level = self._prepare() - self.cols, self.rows = self._compute_cols_rows() + self.level, self.patch_size_level, self.overlap_level = self._prepare() - col_rows = np.array([ - [col, row] - for col in range(self.cols) - for row in range(self.rows) - ]) + if custom_coords is None: + self.cols, self.rows = self._compute_cols_rows() + + col_rows = np.array([ + [col, row] + for col in range(self.cols) + for row in range(self.rows) + ]) + coords = np.array([self._colrow_to_xy(xy[0], xy[1]) for xy in col_rows]) + else: + coords = custom_coords if self.mask is not None: - self.valid_patches_nb, self.valid_col_rows = self._compute_masked(col_rows) + self.valid_patches_nb, self.valid_coords = self._compute_masked(coords) else: - self.valid_patches_nb, self.valid_col_rows = len(col_rows), col_rows + self.valid_patches_nb, self.valid_coords = len(coords), coords def _colrow_to_xy(self, col, row): """ Convert col row of a tile to its top-left coordinates before rescaling (x, y) """ - x = col * (self.patch_size_src) if col == 0 else col * (self.patch_size_src) - self.overlap - y = row * (self.patch_size_src) if row == 0 else row * (self.patch_size_src) - self.overlap + x = col * (self.patch_size_src) - self.overlap * np.clip(col - 1, 0, None) + y = row * (self.patch_size_src) - self.overlap * np.clip(row - 1, 0, None) return (x, y) - def _compute_masked(self, col_rows) -> None: + def _compute_masked(self, coords) -> None: """ Compute tiles which center falls under the tissue """ - xy_topleft = np.array([self._colrow_to_xy(xy[0], xy[1]) for xy in col_rows]) - + # TODO spots are already at the center # Note: we don't take into account the overlap size we calculating centers - xy_centers = xy_topleft + self.patch_size_level // 2 + xy_centers = coords + self.patch_size_level // 2 union_mask = self.mask.unary_union - points = gpd.points_from_xy(xy_centers) + points = gpd.points_from_xy(xy_centers[:, 0], xy_centers[:, 1]) valid_mask = gpd.GeoSeries(points).within(union_mask).values valid_patches_nb = valid_mask.sum() - valid_col_rows = col_rows[valid_mask] - return valid_patches_nb, valid_col_rows + valid_coords = coords[valid_mask] + return valid_patches_nb, valid_coords def __len__(self): return self.valid_patches_nb @@ -277,12 +287,12 @@ def __next__(self): return x def __getitem__(self, index): - if 0 <= index < len(self.valid_col_rows): - col_row = self.valid_col_rows[index] - col, row = col_row[0], col_row[1] + if 0 <= index < len(self): + xy = self.valid_coords[index] + x, y = xy[0], xy[1] if self.coords_only: - return self._colrow_to_xy(col, row) - tile, x, y = self.get_tile(col, row) + return x, y + tile, x, y = self.get_tile_xy(x, y) return tile, x, y else: raise IndexError("Index out of range") @@ -299,6 +309,14 @@ def get_cols_rows(self) -> Tuple[int, int]: Tuple[int, int]: (nb_columns, nb_rows) """ return self.cols, self.rows + + def get_tile_xy(self, x: int, y: int) -> Tuple[np.ndarray, int, int]: + raw_tile = self.wsi.read_region(location=(x, y), level=self.level, size=(self.patch_size_level, self.patch_size_level)) + tile = np.array(raw_tile) + if self.patch_size_target is not None: + tile = cv2.resize(tile, (self.patch_size_target, self.patch_size_target)) + assert x < self.width and y < self.height + return tile, x, y def get_tile(self, col: int, row: int) -> Tuple[np.ndarray, int, int]: """ get tile at position (column, row) @@ -310,13 +328,11 @@ def get_tile(self, col: int, row: int) -> Tuple[np.ndarray, int, int]: Returns: Tuple[np.ndarray, int, int]: (tile, pixel x of top-left corner (before rescaling), pixel_y of top-left corner (before rescaling)) """ + if self.custom_coords is not None: + raise ValueError("Can't use get_tile as 'custom_coords' was passed to the constructor") + x, y = self._colrow_to_xy(col, row) - raw_tile = self.wsi.read_region(location=(x, y), level=self.level, size=(self.patch_size_level, self.patch_size_level)) - tile = np.array(raw_tile) - if self.patch_size_target is not None: - tile = cv2.resize(tile, (self.patch_size_target, self.patch_size_target)) - assert x < self.width and y < self.height - return tile, x, y + return self.get_tile_xy(x, y) def _compute_cols_rows(self) -> Tuple[int, int]: col = 0 @@ -332,6 +348,36 @@ def _compute_cols_rows(self) -> Tuple[int, int]: rows = row - 1 return cols, rows + def save_visualization(self, path, vis_width=1000, dpi=150): + mask_plot = get_tissue_vis( + self.wsi, + self.mask, + line_color=(0, 255, 0), + line_thickness=5, + target_width=vis_width, + seg_display=True, + ) + import matplotlib.pyplot as plt + from matplotlib.collections import PatchCollection + from matplotlib.patches import Rectangle + + downscale_vis = vis_width / self.width + + _, ax = plt.subplots() + ax.imshow(mask_plot) + + patch_rectangles = [] + for xy in self.valid_coords: + x, y = xy[0], xy[1] + x, y = x * downscale_vis, y * downscale_vis + + patch_rectangles.append(Rectangle((x, y), self.patch_size_src * downscale_vis, self.patch_size_src * downscale_vis)) + + ax.add_collection(PatchCollection(patch_rectangles, facecolor='none', edgecolor='black', linewidth=0.3)) + ax.set_axis_off() + plt.tight_layout() + plt.savefig(path, dpi=dpi) + class OpenSlideWSIPatcher(WSIPatcher): wsi: OpenSlideWSI @@ -360,4 +406,76 @@ def _prepare(self) -> None: patch_size_level = self.patch_size_src overlap_level = self.overlap level = -1 - return level, patch_size_level, overlap_level \ No newline at end of file + return level, patch_size_level, overlap_level + + + +def contours_to_img( + contours: gpd.GeoDataFrame, + img: np.ndarray, + draw_contours=False, + thickness=1, + downsample=1., + line_color=(0, 255, 0) +) -> np.ndarray: + draw_cont = partial(cv2.drawContours, contourIdx=-1, thickness=thickness, lineType=cv2.LINE_8) + draw_cont_fill = partial(cv2.drawContours, contourIdx=-1, thickness=cv2.FILLED) + + groups = contours.groupby('tissue_id') + for _, group in groups: + + for _, row in group.iterrows(): + cont = np.array([[round(x * downsample), round(y * downsample)] for x, y in row.geometry.exterior.coords]) + holes = np.array([[round(x * downsample), round(y * downsample)] for hole in row.geometry.interiors for x, y in hole.coords]) + + draw_cont_fill(image=img, contours=[cont], color=line_color) + if draw_contours: + draw_cont(image=img, contours=[cont], color=line_color) + + if len(holes) > 0: + draw_cont_fill(image=img, contours=[holes], color=(0, 0, 0)) + return img + + +def get_tissue_vis( + img: Union[np.ndarray, openslide.OpenSlide, CuImage, WSI], + tissue_contours: gpd.GeoDataFrame, + line_color=(0, 255, 0), + line_thickness=5, + target_width=1000, + seg_display=True, + ) -> Image: + + wsi = wsi_factory(img) + + width, height = wsi.get_dimensions() + downsample = target_width / width + + top_left = (0,0) + + img = wsi.get_thumbnail(round(width * downsample), round(height * downsample)) + + if tissue_contours is None: + return Image.fromarray(img) + + tissue_contours = tissue_contours.copy() + + downscaled_mask = np.zeros(img.shape[:2], dtype=np.uint8) + downscaled_mask = np.expand_dims(downscaled_mask, axis=-1) + downscaled_mask = downscaled_mask * np.array([0, 0, 0]).astype(np.uint8) + + if tissue_contours is not None and seg_display: + downscaled_mask = contours_to_img( + tissue_contours, + downscaled_mask, + draw_contours=True, + thickness=line_thickness, + downsample=downsample, + line_color=line_color + ) + + alpha = 0.4 + img = cv2.addWeighted(img, 1 - alpha, downscaled_mask, alpha, 0) + img = img.astype(np.uint8) + + return Image.fromarray(img) \ No newline at end of file From a5065d106c42204024d2d858188a63158d1e2c5f Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Thu, 8 Aug 2024 10:25:31 -0400 Subject: [PATCH 06/24] reimplement dunp_patches with the new wsipatcher --- src/hest/HESTData.py | 88 ++++++--------------------- src/hest/segmentation/segmentation.py | 4 +- src/hest/wsi.py | 63 +++++++++++++++---- 3 files changed, 69 insertions(+), 86 deletions(-) diff --git a/src/hest/HESTData.py b/src/hest/HESTData.py index e79836a..cdc00f1 100644 --- a/src/hest/HESTData.py +++ b/src/hest/HESTData.py @@ -133,7 +133,7 @@ def save_spatial_plot(self, save_path: str, name: str='', key='total_counts', pl filename = f"{name}spatial_plots.png" # Save the figure - fig.savefig(os.path.join(save_path, filename)) + fig.savefig(os.path.join(save_path, filename), dpi=400) print(f"H&E overlay spatial plots saved in {save_path}") @@ -304,9 +304,8 @@ def dump_patches( use_mask (bool, optional): whenever to take into account the tissue mask. Defaults to True. """ - import matplotlib import matplotlib.pyplot as plt - from matplotlib.collections import PatchCollection + dst_pixel_size = target_pixel_size adata = self.adata.copy() @@ -333,75 +332,25 @@ def dump_patches( mode_HE = 'w' i = 0 img_width, img_height = self.wsi.get_dimensions() - patch_rectangles = [] # lower corner (x, y) + (widht, height) - downscale_vis = TARGET_VIS_SIZE / img_width - if use_mask: - tissue_mask = np.zeros((img_height, img_width, 3), dtype=np.uint8) - tissue_mask = contours_to_img( - self.tissue_contours, - tissue_mask, - draw_contours=False, - line_color=(1, 1, 1) - )[:, :, 0] - else: - tissue_mask = np.ones((img_height, img_width)).astype(np.uint8) - - mask_plot = self.get_tissue_vis() - - ax.imshow(mask_plot) - for _, row in tqdm(adata.obs.iterrows(), total=len(adata.obs)): - - barcode_spot = row.name - - xImage = int(adata.obsm['spatial'][i][0]) - yImage = int(adata.obsm['spatial'][i][1]) - - i += 1 - - if not(0 <= xImage and xImage < img_width and 0 <= yImage and yImage < img_height): - if verbose: - print('Warning, spot is out of the image, skipping') - continue - - if not(0 <= yImage - patch_size_pxl // 2 and yImage + patch_size_pxl // 2 < img_height and \ - 0 <= xImage - patch_size_pxl // 2 and xImage + patch_size_pxl // 2 < img_width): - if verbose: - print('Warning, patch is out of the image, skipping') - continue - - ## TODO reimplement now that we use the pyramidal level - image_patch = self.wsi.read_region((xImage - patch_size_pxl // 2, yImage - patch_size_pxl // 2), 0, (patch_size_pxl, patch_size_pxl)) - rect_x = (xImage - patch_size_pxl // 2) * downscale_vis - rect_y = (yImage - patch_size_pxl // 2) * downscale_vis - rect_width = patch_size_pxl * downscale_vis - rect_height = patch_size_pxl * downscale_vis - - image_patch = np.array(image_patch) - if image_patch.shape[2] == 4: - image_patch = image_patch[:, :, :3] - - - if use_mask: - patch_mask = tissue_mask[yImage - patch_size_pxl // 2: yImage + patch_size_pxl // 2, - xImage - patch_size_pxl // 2: xImage + patch_size_pxl // 2] - patch_area = patch_mask.shape[0] ** 2 - pixel_count = patch_mask.sum() - if pixel_count / patch_area < TISSUE_INTER_THRESH: - continue + coords_center = adata.obsm['spatial'] + coords_topleft = coords_center - target_patch_size // 2 + barcodes = np.array(adata.obs.index) + mask = self.tissue_contours if use_mask else None + patcher = self.wsi.create_patcher(target_patch_size, src_pixel_size, dst_pixel_size, mask=mask, custom_coords=coords_topleft) - patch_rectangles.append(matplotlib.patches.Rectangle((rect_x, rect_y), rect_width, rect_height)) - - patch_count += 1 - image_patch = cv2.resize(image_patch, (target_patch_size, target_patch_size), interpolation=cv2.INTER_CUBIC) + i = 0 + for tile, x, y in tqdm(patcher): + center_x = x + target_patch_size // 2 + center_y = y + target_patch_size // 2 # Save ref patches - assert image_patch.shape == (target_patch_size, target_patch_size, 3) - asset_dict = { 'img': np.expand_dims(image_patch, axis=0), # (1 x w x h x 3) - 'coords': np.expand_dims([yImage, xImage], axis=0), # (1 x 2) - 'barcode': np.expand_dims([barcode_spot], axis=0) + assert tile.shape == (target_patch_size, target_patch_size, 3) + asset_dict = { 'img': np.expand_dims(tile, axis=0), # (1 x w x h x 3) + 'coords': np.expand_dims([center_y, center_x], axis=0), # (1 x 2) + 'barcode': np.expand_dims([barcodes[i]], axis=0) } attr_dict = {} @@ -410,13 +359,10 @@ def dump_patches( initsave_hdf5(output_datafile, asset_dict, attr_dict, mode=mode_HE) mode_HE = 'a' + i += 1 - if dump_visualization: - ax.add_collection(PatchCollection(patch_rectangles, facecolor='none', edgecolor='black', linewidth=0.3)) - ax.set_axis_off() - plt.tight_layout() - plt.savefig(os.path.join(patch_save_dir, name + '_patch_vis.png'), dpi=400, bbox_inches = 'tight') + patcher.save_visualization(os.path.join(patch_save_dir, name + '_patch_vis.png'), dpi=400) if verbose: print(f'found {patch_count} valid patches') diff --git a/src/hest/segmentation/segmentation.py b/src/hest/segmentation/segmentation.py index d94cda8..adad80d 100644 --- a/src/hest/segmentation/segmentation.py +++ b/src/hest/segmentation/segmentation.py @@ -401,10 +401,10 @@ def mask_to_gdf(mask: np.ndarray, keep_ids = [], exclude_ids=[], max_nb_holes=0, else: contour_ids = set(np.arange(len(contours_tissue))) - set(exclude_ids) - tissue_ids = [i for i in contour_ids] + [i for i in contour_ids if len(contours_holes[i]) > 0] + tissue_ids = [i for i in contour_ids] polygons = [] for i in contour_ids: - holes = contours_holes[i][0].squeeze(1) if len(contours_holes[i]) > 0 else None + holes = [contours_holes[i][j].squeeze(1) for j in range(len(contours_holes[i]))] if len(contours_holes[i]) > 0 else None polygon = Polygon(contours_tissue[i].squeeze(1), holes=holes) polygons.append(polygon) diff --git a/src/hest/wsi.py b/src/hest/wsi.py index 01baa0c..5033657 100644 --- a/src/hest/wsi.py +++ b/src/hest/wsi.py @@ -64,7 +64,16 @@ def __repr__(self) -> str: return f"" @abstractmethod - def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False) -> WSIPatcher: + def create_patcher( + self, + patch_size: int, + src_pixel_size: float, + dst_pixel_size: float = None, + overlap: int = 0, + mask: gpd.GeoDataFrame = None, + coords_only = False, + custom_coords = None + ) -> WSIPatcher: pass @@ -111,8 +120,17 @@ def read_region(self, location, level, size) -> np.ndarray: def get_thumbnail(self, width, height) -> np.ndarray: return cv2.resize(self.img, (width, height)) - def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False) -> WSIPatcher: - return NumpyWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only) + def create_patcher( + self, + patch_size: int, + src_pixel_size: float, + dst_pixel_size: float = None, + overlap: int = 0, + mask: gpd.GeoDataFrame = None, + coords_only = False, + custom_coords = None + ) -> WSIPatcher: + return NumpyWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only, custom_coords) class OpenSlideWSI(WSI): @@ -140,8 +158,17 @@ def level_dimensions(self): def level_downsamples(self): return self.img.level_downsamples - def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False) -> WSIPatcher: - return OpenSlideWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only) + def create_patcher( + self, + patch_size: int, + src_pixel_size: float, + dst_pixel_size: float = None, + overlap: int = 0, + mask: gpd.GeoDataFrame = None, + coords_only = False, + custom_coords = None + ) -> WSIPatcher: + return OpenSlideWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only, custom_coords) class CuImageWSI(WSI): def __init__(self, img: 'CuImage'): @@ -187,8 +214,17 @@ def level_dimensions(self): def level_downsamples(self): return self.img.resolutions['level_downsamples'] - def create_patcher(self, patch_size: int, src_pixel_size: float, dst_pixel_size: float = None, overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False) -> WSIPatcher: - return CuImageWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only) + def create_patcher( + self, + patch_size: int, + src_pixel_size: float, + dst_pixel_size: float = None, + overlap: int = 0, + mask: gpd.GeoDataFrame = None, + coords_only = False, + custom_coords = None + ) -> WSIPatcher: + return CuImageWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only, custom_coords) class WSIPatcher: @@ -264,7 +300,7 @@ def _compute_masked(self, coords) -> None: # Note: we don't take into account the overlap size we calculating centers xy_centers = coords + self.patch_size_level // 2 - union_mask = self.mask.unary_union + union_mask = self.mask.union_all() points = gpd.points_from_xy(xy_centers[:, 0], xy_centers[:, 1]) valid_mask = gpd.GeoSeries(points).within(union_mask).values @@ -376,7 +412,7 @@ def save_visualization(self, path, vis_width=1000, dpi=150): ax.add_collection(PatchCollection(patch_rectangles, facecolor='none', edgecolor='black', linewidth=0.3)) ax.set_axis_off() plt.tight_layout() - plt.savefig(path, dpi=dpi) + plt.savefig(path, dpi=dpi, bbox_inches = 'tight') class OpenSlideWSIPatcher(WSIPatcher): @@ -426,14 +462,15 @@ def contours_to_img( for _, row in group.iterrows(): cont = np.array([[round(x * downsample), round(y * downsample)] for x, y in row.geometry.exterior.coords]) - holes = np.array([[round(x * downsample), round(y * downsample)] for hole in row.geometry.interiors for x, y in hole.coords]) + holes = [np.array([[round(x * downsample), round(y * downsample)] for x, y in hole.coords]) for hole in row.geometry.interiors] draw_cont_fill(image=img, contours=[cont], color=line_color) + + for hole in holes: + draw_cont_fill(image=img, contours=[hole], color=(0, 0, 0)) + if draw_contours: draw_cont(image=img, contours=[cont], color=line_color) - - if len(holes) > 0: - draw_cont_fill(image=img, contours=[holes], color=(0, 0, 0)) return img From bf50ae459a371f8d88a3f35a59b1eb9544937a06 Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Thu, 8 Aug 2024 10:31:26 -0400 Subject: [PATCH 07/24] delete hest bench From 837574ff6fc156819c174256a2ced071aaf719db Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Thu, 8 Aug 2024 16:00:10 -0400 Subject: [PATCH 08/24] clean imports --- .github/workflows/python-app.yml | 38 +++++++++++++++++++++++++++ src/hest/HESTData.py | 15 ++++++++--- src/hest/LazyShapes.py | 19 +++++++++++++- src/hest/io/seg_readers.py | 2 +- src/hest/segmentation/segmentation.py | 12 +++++++++ src/hest/wsi.py | 6 ++--- tests/hest_tests.py | 28 +++++++++++++++++++- 7 files changed, 110 insertions(+), 10 deletions(-) create mode 100644 .github/workflows/python-app.yml diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml new file mode 100644 index 0000000..e174238 --- /dev/null +++ b/.github/workflows/python-app.yml @@ -0,0 +1,38 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Hest tests + +on: + #push: + # branches: [ "main", "develop"] + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.9 + uses: actions/setup-python@v3 + with: + python-version: "3.9" + + - name: Install python dependencies + run: | + python -m pip install -e . + - name: Install apt dependencies + run: | + sudo apt install libvips libvips-dev openslide-tools + + - name: Run tests + run: | + python tests/hest_tests.py + env: + HF_READ_TOKEN_PAUL: ${{ secrets.HF_READ_TOKEN_PAUL }} diff --git a/src/hest/HESTData.py b/src/hest/HESTData.py index cdc00f1..35f3e72 100644 --- a/src/hest/HESTData.py +++ b/src/hest/HESTData.py @@ -10,8 +10,8 @@ import geopandas as gpd import numpy as np -from hest.io.seg_readers import TissueContourReader, write_geojson -from hest.LazyShapes import LazyShapes, convert_old_to_gpd +from hest.io.seg_readers import TissueContourReader +from hest.LazyShapes import LazyShapes, convert_old_to_gpd, old_geojson_to_new from hest.segmentation.TissueMask import TissueMask, load_tissue_mask from hest.wsi import (WSI, CucimWarningSingleton, NumpyWSI, contours_to_img, get_tissue_vis, wsi_factory) @@ -710,8 +710,15 @@ def read_HESTData( tissue_contours = None tissue_seg = None if tissue_contours_path is not None: - tissue_contours = TissueContourReader().read_gdf(tissue_contours_path) - tissue_contours['tissue_id'] = tissue_contours['tissue_id'].astype(int) + with open(tissue_contours_path) as f: + lines = f.read() + if 'hole' in lines: + warnings.warn("this type of .geojson tissue contour file is deprecated, please download the new `tissue_seg` folder on huggingface: https://huggingface.co/datasets/MahmoodLab/hest/tree/main") + gdf = TissueContourReader().read_gdf(tissue_contours_path) + tissue_contours = old_geojson_to_new(gdf) + else: + tissue_contours = gpd.read_file(tissue_contours_path) + elif mask_path_pkl is not None and mask_path_jpg is not None: tissue_seg = load_tissue_mask(mask_path_pkl, mask_path_jpg, width, height) diff --git a/src/hest/LazyShapes.py b/src/hest/LazyShapes.py index 8905382..8e95397 100644 --- a/src/hest/LazyShapes.py +++ b/src/hest/LazyShapes.py @@ -52,4 +52,21 @@ def convert_old_to_gpd(contours_holes, contours_tissue) -> gpd.GeoDataFrame: df = pd.DataFrame(tissue_ids, columns=['tissue_id']) return gpd.GeoDataFrame(df, geometry=shapes) - \ No newline at end of file + + +def old_geojson_to_new(gdf): + polygons = [] + keys = [] + for key, group in gdf.groupby('tissue_id'): + holes = [] + for row in group.values: + if row[2]: + holes.append([coord for coord in row[0].exterior.coords]) + else: + exterior = [coord for coord in row[0].exterior.coords] + polygons.append(Polygon(exterior, holes)) + keys.append(key) + + gdf = gpd.GeoDataFrame(geometry=polygons) + gdf['tissue_id'] = keys + return gdf \ No newline at end of file diff --git a/src/hest/io/seg_readers.py b/src/hest/io/seg_readers.py index c863c52..89746d6 100644 --- a/src/hest/io/seg_readers.py +++ b/src/hest/io/seg_readers.py @@ -98,7 +98,7 @@ def read_gdf(self, path) -> gpd.GeoDataFrame: class TissueContourReader(GDFReader): def read_gdf(self, path) -> gpd.GeoDataFrame: - gdf = _read_geojson(path, 'tissue_id') + gdf = _read_geojson(path, 'tissue_id', extra_props=False, index_key='hole') return gdf diff --git a/src/hest/segmentation/segmentation.py b/src/hest/segmentation/segmentation.py index adad80d..a52dec3 100644 --- a/src/hest/segmentation/segmentation.py +++ b/src/hest/segmentation/segmentation.py @@ -360,6 +360,14 @@ def filter_contours(contours, hierarchy, filter_params, scale, pixel_size): hole_contours.append(filtered_holes) return foreground_contours, hole_contours + + +def make_valid(polygon): + for i in [0, 0.1, -0.1, 0.2]: + new_polygon = polygon.buffer(i) + if isinstance(new_polygon, Polygon) and new_polygon.is_valid: + return new_polygon + raise Exception("Failed to make a valid polygon") def mask_to_gdf(mask: np.ndarray, keep_ids = [], exclude_ids=[], max_nb_holes=0, min_contour_area=1000, pixel_size=1, contour_scale=1.): @@ -406,6 +414,10 @@ def mask_to_gdf(mask: np.ndarray, keep_ids = [], exclude_ids=[], max_nb_holes=0, for i in contour_ids: holes = [contours_holes[i][j].squeeze(1) for j in range(len(contours_holes[i]))] if len(contours_holes[i]) > 0 else None polygon = Polygon(contours_tissue[i].squeeze(1), holes=holes) + if not polygon.is_valid: + # TODO replace by shapely make_valid after 2.1 + if not polygon.is_valid: + polygon = make_valid(polygon) polygons.append(polygon) gdf_contours = gpd.GeoDataFrame(pd.DataFrame(tissue_ids, columns=['tissue_id']), geometry=polygons) diff --git a/src/hest/wsi.py b/src/hest/wsi.py index 5033657..70be21b 100644 --- a/src/hest/wsi.py +++ b/src/hest/wsi.py @@ -298,7 +298,7 @@ def _compute_masked(self, coords) -> None: # TODO spots are already at the center # Note: we don't take into account the overlap size we calculating centers - xy_centers = coords + self.patch_size_level // 2 + xy_centers = coords + self.patch_size_src // 2 union_mask = self.mask.union_all() @@ -377,11 +377,11 @@ def _compute_cols_rows(self) -> Tuple[int, int]: while x < self.width: col += 1 x, _ = self._colrow_to_xy(col, row) - cols = col - 1 + cols = col while y < self.height: row += 1 _, y = self._colrow_to_xy(col, row) - rows = row - 1 + rows = row return cols, rows def save_visualization(self, path, vis_width=1000, dpi=150): diff --git a/tests/hest_tests.py b/tests/hest_tests.py index 685a9ee..f89081e 100644 --- a/tests/hest_tests.py +++ b/tests/hest_tests.py @@ -1,6 +1,7 @@ import os import unittest from os.path import join as _j +import warnings import hest from hest import HESTData, read_HESTData @@ -90,14 +91,39 @@ class TestHESTData(unittest.TestCase): @classmethod def setUpClass(self): + download = True self.cur_dir = get_path_relative(__file__, '') cur_dir = self.cur_dir self.output_dir = _j(cur_dir, 'output_tests/hestdata_tests') os.makedirs(self.output_dir, exist_ok=True) + from huggingface_hub import login + + token = os.getenv('HF_READ_TOKEN_PAUL') + if token is None: + warnings.warn("Please setup huggingface token 'HF_READ_TOKEN_PAUL'") + else: + login(token=token) + id_list = ['TENX24', 'SPA154', 'TENX96', 'TENX131'] - self.sts = hest.load_hest('hest_data', id_list) + if download: + import datasets + + local_dir = os.path.join(cur_dir, 'hest_data_test') + + ids_to_query = id_list + list_patterns = [f"*{id}[_.]**" for id in ids_to_query] + datasets.load_dataset( + 'MahmoodLab/hest', + cache_dir=local_dir, + patterns=list_patterns, + download_mode='force_redownload' + ) + + self.sts = hest.load_hest(local_dir, id_list) + else: + self.sts = hest.load_hest('hest_data', id_list) def test_tissue_seg(self): From 4c3b41b71ebafa901750a4206fb71c3674bd7587 Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Thu, 8 Aug 2024 16:02:33 -0400 Subject: [PATCH 09/24] change gitignore --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index c8a31e5..9a6229e 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,5 @@ hest_vis vis vis2 models/deeplabv3* -.github htmlcov models/CellViT-SAM-H-x40.pth From 589707956f77f6de93eaf98544a91910f0b8b35d Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Thu, 8 Aug 2024 16:08:45 -0400 Subject: [PATCH 10/24] apt-get update before install --- .github/workflows/python-app.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index e174238..70d581b 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -29,7 +29,8 @@ jobs: python -m pip install -e . - name: Install apt dependencies run: | - sudo apt install libvips libvips-dev openslide-tools + sudo apt-get update + sudo apt-get install libvips libvips-dev openslide-tools - name: Run tests run: | From 9500974bddb781985ba733c4b2260a5fd77c16e3 Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Thu, 8 Aug 2024 16:17:51 -0400 Subject: [PATCH 11/24] modify secret --- .github/workflows/python-app.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 70d581b..9fb5ecf 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -16,6 +16,8 @@ jobs: build: runs-on: ubuntu-latest + env: + HF_READ_TOKEN_PAUL: ${{ secrets.HF_READ_TOKEN_PAUL }} steps: - uses: actions/checkout@v4 @@ -35,5 +37,3 @@ jobs: - name: Run tests run: | python tests/hest_tests.py - env: - HF_READ_TOKEN_PAUL: ${{ secrets.HF_READ_TOKEN_PAUL }} From b4c43ea2e22e4613418961786309d8a09f49b4a0 Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Thu, 8 Aug 2024 16:24:47 -0400 Subject: [PATCH 12/24] allow code exec --- tests/hest_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/hest_tests.py b/tests/hest_tests.py index f89081e..8c931ee 100644 --- a/tests/hest_tests.py +++ b/tests/hest_tests.py @@ -121,7 +121,7 @@ def setUpClass(self): download_mode='force_redownload' ) - self.sts = hest.load_hest(local_dir, id_list) + self.sts = hest.load_hest(local_dir, id_list, trust_remote_code=True) else: self.sts = hest.load_hest('hest_data', id_list) From 8c938f4825bac5b25da9faf73e83d35f5341f54f Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Thu, 8 Aug 2024 16:29:43 -0400 Subject: [PATCH 13/24] trust_remote_code=True --- tests/hest_tests.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/hest_tests.py b/tests/hest_tests.py index 8c931ee..cd01bc9 100644 --- a/tests/hest_tests.py +++ b/tests/hest_tests.py @@ -118,10 +118,11 @@ def setUpClass(self): 'MahmoodLab/hest', cache_dir=local_dir, patterns=list_patterns, - download_mode='force_redownload' + download_mode='force_redownload', + trust_remote_code=True ) - self.sts = hest.load_hest(local_dir, id_list, trust_remote_code=True) + self.sts = hest.load_hest(local_dir, id_list) else: self.sts = hest.load_hest('hest_data', id_list) From 639ca00c81b940d3edbdd7c81a1f8433efa72abb Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Thu, 8 Aug 2024 16:37:50 -0400 Subject: [PATCH 14/24] reduce number of samples for test --- tests/hest_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/hest_tests.py b/tests/hest_tests.py index cd01bc9..5a8963d 100644 --- a/tests/hest_tests.py +++ b/tests/hest_tests.py @@ -105,7 +105,7 @@ def setUpClass(self): else: login(token=token) - id_list = ['TENX24', 'SPA154', 'TENX96', 'TENX131'] + id_list = ['TENX24', 'SPA154'] if download: import datasets From bccbb3ea381e86ae60779b6cb1c0c59d5587f8af Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Thu, 8 Aug 2024 17:07:50 -0400 Subject: [PATCH 15/24] force custom_coords to be int --- src/hest/HESTData.py | 1 + src/hest/wsi.py | 4 +++- tests/hest_tests.py | 2 ++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/hest/HESTData.py b/src/hest/HESTData.py index 35f3e72..ab9627d 100644 --- a/src/hest/HESTData.py +++ b/src/hest/HESTData.py @@ -338,6 +338,7 @@ def dump_patches( coords_topleft = coords_center - target_patch_size // 2 barcodes = np.array(adata.obs.index) mask = self.tissue_contours if use_mask else None + coords_topleft = np.array(coords_topleft).astype(int) patcher = self.wsi.create_patcher(target_patch_size, src_pixel_size, dst_pixel_size, mask=mask, custom_coords=coords_topleft) i = 0 diff --git a/src/hest/wsi.py b/src/hest/wsi.py index 70be21b..7a92e40 100644 --- a/src/hest/wsi.py +++ b/src/hest/wsi.py @@ -280,6 +280,8 @@ def __init__( ]) coords = np.array([self._colrow_to_xy(xy[0], xy[1]) for xy in col_rows]) else: + if round(custom_coords[0][0]) != custom_coords[0][0]: + raise ValueError("custom_coords must be a (N, 2) array of int") coords = custom_coords if self.mask is not None: @@ -352,7 +354,7 @@ def get_tile_xy(self, x: int, y: int) -> Tuple[np.ndarray, int, int]: if self.patch_size_target is not None: tile = cv2.resize(tile, (self.patch_size_target, self.patch_size_target)) assert x < self.width and y < self.height - return tile, x, y + return tile[:, :, :3], x, y def get_tile(self, col: int, row: int) -> Tuple[np.ndarray, int, int]: """ get tile at position (column, row) diff --git a/tests/hest_tests.py b/tests/hest_tests.py index 5a8963d..d6dd4a7 100644 --- a/tests/hest_tests.py +++ b/tests/hest_tests.py @@ -181,4 +181,6 @@ def test_saving(self): loader = unittest.TestLoader() suite = loader.loadTestsFromTestCase(TestHESTData) + #suite = unittest.TestSuite() + #suite.addTest(TestHESTData('test_patching')) unittest.TextTestRunner(verbosity=2).run(suite) From a95116868e0569cd2a9992ee1358aec86844b7d9 Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Thu, 8 Aug 2024 17:15:23 -0400 Subject: [PATCH 16/24] remove torch sync and throw exception on test fail --- src/hest/segmentation/segmentation.py | 2 -- tests/hest_tests.py | 20 +++++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/hest/segmentation/segmentation.py b/src/hest/segmentation/segmentation.py index a52dec3..8075cf0 100644 --- a/src/hest/segmentation/segmentation.py +++ b/src/hest/segmentation/segmentation.py @@ -118,8 +118,6 @@ def segment_tissue_deep( imgs = imgs.cuda() masks = model(imgs)['out'] preds = masks.argmax(1).to(torch.uint8).detach() - - torch.cuda.synchronize() preds = preds.cpu().numpy() coords = np.column_stack((coords[0], coords[1])) diff --git a/tests/hest_tests.py b/tests/hest_tests.py index d6dd4a7..e49ff86 100644 --- a/tests/hest_tests.py +++ b/tests/hest_tests.py @@ -153,11 +153,11 @@ def test_spatialdata(self): def test_patching(self): - for idx, st in enumerate(self.sts): - with self.subTest(st_object=idx): - name = '' - name += st.meta['id'] - st.dump_patches(self.output_dir, name=name) + for idx, st in enumerate(self.sts): + with self.subTest(st_object=idx): + name = '' + name += st.meta['id'] + st.dump_patches(self.output_dir, name=name) def test_saving(self): @@ -180,7 +180,9 @@ def test_saving(self): #TestHESTReader() loader = unittest.TestLoader() - suite = loader.loadTestsFromTestCase(TestHESTData) - #suite = unittest.TestSuite() - #suite.addTest(TestHESTData('test_patching')) - unittest.TextTestRunner(verbosity=2).run(suite) + #suite = loader.loadTestsFromTestCase(TestHESTData) + suite = unittest.TestSuite() + suite.addTest(TestHESTData('test_patching')) + result = unittest.TextTestRunner(verbosity=2).run(suite) + if not result.wasSuccessful(): + raise Exception('Test failed') \ No newline at end of file From 68303b1e860434a753abbb8c65353435c2cd03ad Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Thu, 8 Aug 2024 17:23:00 -0400 Subject: [PATCH 17/24] include all tests --- tests/hest_tests.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/hest_tests.py b/tests/hest_tests.py index e49ff86..214304f 100644 --- a/tests/hest_tests.py +++ b/tests/hest_tests.py @@ -153,11 +153,11 @@ def test_spatialdata(self): def test_patching(self): - for idx, st in enumerate(self.sts): - with self.subTest(st_object=idx): - name = '' - name += st.meta['id'] - st.dump_patches(self.output_dir, name=name) + for idx, st in enumerate(self.sts): + with self.subTest(st_object=idx): + name = '' + name += st.meta['id'] + st.dump_patches(self.output_dir, name=name) def test_saving(self): @@ -180,9 +180,9 @@ def test_saving(self): #TestHESTReader() loader = unittest.TestLoader() - #suite = loader.loadTestsFromTestCase(TestHESTData) - suite = unittest.TestSuite() - suite.addTest(TestHESTData('test_patching')) + suite = loader.loadTestsFromTestCase(TestHESTData) + #suite = unittest.TestSuite() + #suite.addTest(TestHESTData('test_patching')) result = unittest.TextTestRunner(verbosity=2).run(suite) if not result.wasSuccessful(): raise Exception('Test failed') \ No newline at end of file From f5c83cf153fca61ca84435c0a8f86df097453e5d Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Fri, 9 Aug 2024 10:41:43 -0400 Subject: [PATCH 18/24] fix dump_patches --- src/hest/HESTData.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/hest/HESTData.py b/src/hest/HESTData.py index ab9627d..615d3a3 100644 --- a/src/hest/HESTData.py +++ b/src/hest/HESTData.py @@ -316,26 +316,24 @@ def dump_patches( src_pixel_size = self.pixel_size - # minimum intersection percecentage with the tissue mask to keep a patch - TISSUE_INTER_THRESH = 0.7 - TARGET_VIS_SIZE = 1000 - scale_factor = target_pixel_size / src_pixel_size patch_size_pxl = round(target_patch_size * scale_factor) patch_count = 0 output_datafile = os.path.join(patch_save_dir, name + '.h5') assert len(adata.obs) == len(adata.obsm['spatial']) - - _, ax = plt.subplots() mode_HE = 'w' i = 0 - img_width, img_height = self.wsi.get_dimensions() - + patch_size_src = target_patch_size * (dst_pixel_size / src_pixel_size) coords_center = adata.obsm['spatial'] - coords_topleft = coords_center - target_patch_size // 2 + coords_topleft = coords_center - patch_size_src // 2 + len_tmp = len(coords_topleft) + coords_topleft = coords_topleft[(0 <= coords_topleft[:, 0] + patch_size_src) & (coords_topleft[:, 0] < self.wsi.width) & (0 <= coords_topleft[:, 1] + patch_size_src) & (coords_topleft[:, 1] < self.wsi.height)] + if len(coords_topleft) < len_tmp: + warnings.warn(f"Filtered {len_tmp - len(coords_topleft)} spots outside the WSI") + barcodes = np.array(adata.obs.index) mask = self.tissue_contours if use_mask else None coords_topleft = np.array(coords_topleft).astype(int) @@ -344,8 +342,8 @@ def dump_patches( i = 0 for tile, x, y in tqdm(patcher): - center_x = x + target_patch_size // 2 - center_y = y + target_patch_size // 2 + center_x = x + patch_size_src // 2 + center_y = y + patch_size_src // 2 # Save ref patches assert tile.shape == (target_patch_size, target_patch_size, 3) From a6dabc5c2cb3fe9979c5185e853575568b558b9b Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Fri, 9 Aug 2024 10:42:20 -0400 Subject: [PATCH 19/24] fix wsi --- src/hest/wsi.py | 61 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/src/hest/wsi.py b/src/hest/wsi.py index 7a92e40..0cc719e 100644 --- a/src/hest/wsi.py +++ b/src/hest/wsi.py @@ -10,6 +10,7 @@ import numpy as np import openslide from PIL import Image +from shapely import Polygon class CucimWarningSingleton: @@ -115,7 +116,17 @@ def read_region(self, location, level, size) -> np.ndarray: img = self.img x_start, y_start = location[0], location[1] x_size, y_size = size[0], size[1] - return img[y_start:y_start + y_size, x_start:x_start + x_size] + x_end, y_end = x_start + x_size, y_start + y_size + padding_left = max(0 - x_start, 0) + padding_top = max(0 - y_start, 0) + padding_right = max(x_start + x_size - self.width, 0) + padding_bottom = max(y_start + y_size - self.height, 0) + x_start, y_start = max(x_start, 0), max(y_start, 0) + x_end, y_end = min(x_end, self.width), min(y_end, self.height) + tile = img[y_start:y_end, x_start:x_end] + padded_tile = np.pad(tile, ((padding_top, padding_bottom), (padding_left, padding_right), (0, 0)), mode='constant', constant_values=0) + + return padded_tile def get_thumbnail(self, width, height) -> np.ndarray: return cv2.resize(self.img, (width, height)) @@ -239,7 +250,8 @@ def __init__( overlap: int = 0, mask: gpd.GeoDataFrame = None, coords_only = False, - custom_coords = None + custom_coords = None, + threshold = 0.15 ): """ Initialize patcher, compute number of (masked) rows, columns. @@ -251,6 +263,8 @@ def __init__( overlap (int, optional): overlap size in pixel before rescaling. Defaults to 0. mask (gpd.GeoDataFrame, optional): geopandas dataframe of Polygons. Defaults to None. coords_only (bool, optional): whenever to extract only the coordinates insteaf of coordinates + tile. Default to False. + threshold (float, optional): minimum proportion of the patch under tissue to be kept. + This argument is ignored if mask=None, passing threshold=0 will be faster """ self.wsi = wsi self.overlap = overlap @@ -283,9 +297,8 @@ def __init__( if round(custom_coords[0][0]) != custom_coords[0][0]: raise ValueError("custom_coords must be a (N, 2) array of int") coords = custom_coords - if self.mask is not None: - self.valid_patches_nb, self.valid_coords = self._compute_masked(coords) + self.valid_patches_nb, self.valid_coords = self._compute_masked(coords, threshold) else: self.valid_patches_nb, self.valid_coords = len(coords), coords @@ -295,17 +308,43 @@ def _colrow_to_xy(self, col, row): y = row * (self.patch_size_src) - self.overlap * np.clip(row - 1, 0, None) return (x, y) - def _compute_masked(self, coords) -> None: + def _compute_masked(self, coords, threshold) -> None: """ Compute tiles which center falls under the tissue """ - # TODO spots are already at the center - # Note: we don't take into account the overlap size we calculating centers - xy_centers = coords + self.patch_size_src // 2 + # Filter coordinates by bounding boxes of mask polygons + bounding_boxes = self.mask.geometry.bounds + valid_coords = [] + for _, bbox in bounding_boxes.iterrows(): + bbox_coords = coords[ + (coords[:, 0] >= bbox['minx'] - self.patch_size_src) & (coords[:, 0] <= bbox['maxx'] + self.patch_size_src) & + (coords[:, 1] >= bbox['miny'] - self.patch_size_src) & (coords[:, 1] <= bbox['maxy'] + self.patch_size_src) + ] + valid_coords.append(bbox_coords) + + if len(valid_coords) > 0: + coords = np.vstack(valid_coords) + coords = np.unique(coords, axis=0) + else: + coords = np.array([]) + union_mask = self.mask.union_all() - - points = gpd.points_from_xy(xy_centers[:, 0], xy_centers[:, 1]) - valid_mask = gpd.GeoSeries(points).within(union_mask).values + + squares = [ + Polygon([ + (xy[0], xy[1]), + (xy[0] + self.patch_size_src, xy[1]), + (xy[0] + self.patch_size_src, xy[1] + self.patch_size_src), + (xy[0], xy[1] + self.patch_size_src)]) + for xy in coords + ] + if threshold == 0: + valid_mask = gpd.GeoSeries(squares).intersects(union_mask).values + else: + gdf = gpd.GeoSeries(squares) + areas = gdf.area + valid_mask = gdf.intersection(union_mask).area >= threshold * areas + valid_patches_nb = valid_mask.sum() valid_coords = coords[valid_mask] return valid_patches_nb, valid_coords From dc451be9a369582abd4856e7ee0bce6ff02092d3 Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Fri, 9 Aug 2024 14:46:14 -0400 Subject: [PATCH 20/24] final changes wsi --- src/hest/HESTData.py | 35 +++------- src/hest/segmentation/SegDataset.py | 45 +----------- src/hest/segmentation/segmentation.py | 4 +- src/hest/wsi.py | 99 ++++++++++++++++++++++----- tests/hest_tests.py | 6 +- 5 files changed, 97 insertions(+), 92 deletions(-) diff --git a/src/hest/HESTData.py b/src/hest/HESTData.py index 615d3a3..4b3f3ad 100644 --- a/src/hest/HESTData.py +++ b/src/hest/HESTData.py @@ -316,49 +316,30 @@ def dump_patches( src_pixel_size = self.pixel_size - scale_factor = target_pixel_size / src_pixel_size - patch_size_pxl = round(target_patch_size * scale_factor) patch_count = 0 - output_datafile = os.path.join(patch_save_dir, name + '.h5') + h5_path = os.path.join(patch_save_dir, name + '.h5') assert len(adata.obs) == len(adata.obsm['spatial']) - mode_HE = 'w' - i = 0 - patch_size_src = target_patch_size * (dst_pixel_size / src_pixel_size) coords_center = adata.obsm['spatial'] coords_topleft = coords_center - patch_size_src // 2 len_tmp = len(coords_topleft) - coords_topleft = coords_topleft[(0 <= coords_topleft[:, 0] + patch_size_src) & (coords_topleft[:, 0] < self.wsi.width) & (0 <= coords_topleft[:, 1] + patch_size_src) & (coords_topleft[:, 1] < self.wsi.height)] + in_slide_mask = (0 <= coords_topleft[:, 0] + patch_size_src) & (coords_topleft[:, 0] < self.wsi.width) & (0 <= coords_topleft[:, 1] + patch_size_src) & (coords_topleft[:, 1] < self.wsi.height) + coords_topleft = coords_topleft[in_slide_mask] if len(coords_topleft) < len_tmp: warnings.warn(f"Filtered {len_tmp - len(coords_topleft)} spots outside the WSI") barcodes = np.array(adata.obs.index) + barcodes = barcodes[in_slide_mask] mask = self.tissue_contours if use_mask else None coords_topleft = np.array(coords_topleft).astype(int) patcher = self.wsi.create_patcher(target_patch_size, src_pixel_size, dst_pixel_size, mask=mask, custom_coords=coords_topleft) - i = 0 - for tile, x, y in tqdm(patcher): - - center_x = x + patch_size_src // 2 - center_y = y + patch_size_src // 2 - - # Save ref patches - assert tile.shape == (target_patch_size, target_patch_size, 3) - asset_dict = { 'img': np.expand_dims(tile, axis=0), # (1 x w x h x 3) - 'coords': np.expand_dims([center_y, center_x], axis=0), # (1 x 2) - 'barcode': np.expand_dims([barcodes[i]], axis=0) - } - - attr_dict = {} - attr_dict['img'] = {'patch_size': patch_size_pxl, - 'factor': scale_factor} - - initsave_hdf5(output_datafile, asset_dict, attr_dict, mode=mode_HE) - mode_HE = 'a' - i += 1 + if mask is not None: + valid_barcodes = barcodes[patcher.valid_mask] + + patcher.to_h5(h5_path, extra_assets={'barcodes': valid_barcodes}) if dump_visualization: patcher.save_visualization(os.path.join(patch_save_dir, name + '_patch_vis.png'), dpi=400) diff --git a/src/hest/segmentation/SegDataset.py b/src/hest/segmentation/SegDataset.py index f61abd0..f0631d8 100644 --- a/src/hest/segmentation/SegDataset.py +++ b/src/hest/segmentation/SegDataset.py @@ -1,52 +1,9 @@ -import os - -import numpy as np -from PIL import Image from torch.utils.data import Dataset -from tqdm import tqdm from hest.wsi import WSIPatcher -class SegFileDataset(Dataset): - masks = [] - patches = [] - coords = [] - - def __init__(self, root_path, transform): - self._load_paths(root_path) - - self.transform = transform - - def _load_paths(self, root_path): - self.mask_paths = [] - self.patch_paths = [] - self.coords = [] - for mask_filename in tqdm(os.listdir(root_path)): - name = mask_filename.split('.')[0] - pxl_x, pxl_y = int(name.split('_')[0]), int(name.split('_')[1]) - self.patch_paths.append(os.path.join(root_path, mask_filename)) - self.coords.append([pxl_x, pxl_y]) - - - def __len__(self): - return len(self.patch_paths) - - def __getitem__(self, index): - - with Image.open(self.patch_paths[index]) as patch: - patch = np.array(patch) - coord = self.coords[index] - - sample = patch - - if self.transform: - sample = self.transform(sample) - - return sample, coord - - -class SegWSIDataset(Dataset): +class WSIPatcherDataset(Dataset): def __init__(self, patcher: WSIPatcher, transform): self.patcher = patcher diff --git a/src/hest/segmentation/segmentation.py b/src/hest/segmentation/segmentation.py index 8075cf0..6674a1d 100644 --- a/src/hest/segmentation/segmentation.py +++ b/src/hest/segmentation/segmentation.py @@ -53,7 +53,7 @@ def segment_tissue_deep( from torch import nn from torch.utils.data import DataLoader from torchvision import transforms - from hest.segmentation.SegDataset import SegWSIDataset + from hest.segmentation.SegDataset import WSIPatcherDataset src_pixel_size = pixel_size @@ -71,7 +71,7 @@ def segment_tissue_deep( patcher = wsi.create_patcher(patch_size_deeplab, src_pixel_size, dst_pixel_size) eval_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) - dataset = SegWSIDataset(patcher, eval_transforms) + dataset = WSIPatcherDataset(patcher, eval_transforms) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50') model.classifier[4] = nn.Conv2d( diff --git a/src/hest/wsi.py b/src/hest/wsi.py index 0cc719e..d976596 100644 --- a/src/hest/wsi.py +++ b/src/hest/wsi.py @@ -11,6 +11,9 @@ import openslide from PIL import Image from shapely import Polygon +from tqdm import tqdm + +from hest.vst_save_utils import initsave_hdf5 class CucimWarningSingleton: @@ -37,6 +40,7 @@ class WSI: def __init__(self, img): self.img = img + self.name = None if not (isinstance(img, openslide.OpenSlide) or isinstance(img, np.ndarray) or is_cuimage(img)) : raise ValueError(f"Invalid type for img {type(img)}") @@ -55,6 +59,10 @@ def get_dimensions(self): def read_region(self, location, level, size) -> np.ndarray: pass + @abstractmethod + def read_region_pil(self, location, level, size) -> np.ndarray: + pass + @abstractmethod def get_thumbnail(self, width, height): pass @@ -128,6 +136,9 @@ def read_region(self, location, level, size) -> np.ndarray: return padded_tile + def read_region_pil(self, location, level, size) -> Image: + return Image.fromarray(self.read_region(location, level, size)) + def get_thumbnail(self, width, height) -> np.ndarray: return cv2.resize(self.img, (width, height)) @@ -155,7 +166,10 @@ def get_dimensions(self): return self.img.dimensions def read_region(self, location, level, size) -> np.ndarray: - return np.array(self.img.read_region(location, level, size)) + return np.array(self.read_region_pil(location, level, size)) + + def read_region_pil(self, location, level, size) -> Image: + return self.img.read_region(location, level, size) def get_thumbnail(self, width, height): return np.array(self.img.get_thumbnail((width, height))) @@ -192,7 +206,10 @@ def get_dimensions(self): return self.img.resolutions['level_dimensions'][0] def read_region(self, location, level, size) -> np.ndarray: - return np.array(self.img.read_region(location=location, level=level, size=size)) + return np.asarray(self.img.read_region(location=location, level=level, size=size)) + + def read_region_pil(self, location, level, size) -> Image: + return Image.fromarray(self.read_region(location, level, size)) def get_thumbnail(self, width, height): downsample = self.width / width @@ -251,7 +268,8 @@ def __init__( mask: gpd.GeoDataFrame = None, coords_only = False, custom_coords = None, - threshold = 0.15 + threshold = 0.15, + pil=False ): """ Initialize patcher, compute number of (masked) rows, columns. @@ -264,7 +282,8 @@ def __init__( mask (gpd.GeoDataFrame, optional): geopandas dataframe of Polygons. Defaults to None. coords_only (bool, optional): whenever to extract only the coordinates insteaf of coordinates + tile. Default to False. threshold (float, optional): minimum proportion of the patch under tissue to be kept. - This argument is ignored if mask=None, passing threshold=0 will be faster + This argument is ignored if mask=None, passing threshold=0 will be faster. Defaults to 0.15 + pil (bool, optional): whenever to get patches as `PIL.Image` (numpy array by default). Defaults to False """ self.wsi = wsi self.overlap = overlap @@ -274,6 +293,8 @@ def __init__( self.i = 0 self.coords_only = coords_only self.custom_coords = custom_coords + self.pil = pil + self.src_pixel_size = src_pixel_size if dst_pixel_size is None: self.downsample = 1. @@ -313,19 +334,18 @@ def _compute_masked(self, coords, threshold) -> None: # Filter coordinates by bounding boxes of mask polygons bounding_boxes = self.mask.geometry.bounds - valid_coords = [] + bbox_masks = [] for _, bbox in bounding_boxes.iterrows(): - bbox_coords = coords[ + bbox_mask = ( (coords[:, 0] >= bbox['minx'] - self.patch_size_src) & (coords[:, 0] <= bbox['maxx'] + self.patch_size_src) & (coords[:, 1] >= bbox['miny'] - self.patch_size_src) & (coords[:, 1] <= bbox['maxy'] + self.patch_size_src) - ] - valid_coords.append(bbox_coords) + ) + bbox_masks.append(bbox_mask) - if len(valid_coords) > 0: - coords = np.vstack(valid_coords) - coords = np.unique(coords, axis=0) + if len(bbox_masks) > 0: + bbox_mask = np.vstack(bbox_masks)[0] else: - coords = np.array([]) + bbox_mask = np.zeros(len(coords), dtype=bool) union_mask = self.mask.union_all() @@ -336,7 +356,7 @@ def _compute_masked(self, coords, threshold) -> None: (xy[0] + self.patch_size_src, xy[1]), (xy[0] + self.patch_size_src, xy[1] + self.patch_size_src), (xy[0], xy[1] + self.patch_size_src)]) - for xy in coords + for xy in coords[bbox_mask] ] if threshold == 0: valid_mask = gpd.GeoSeries(squares).intersects(union_mask).values @@ -344,9 +364,13 @@ def _compute_masked(self, coords, threshold) -> None: gdf = gpd.GeoSeries(squares) areas = gdf.area valid_mask = gdf.intersection(union_mask).area >= threshold * areas + + full_mask = bbox_mask + full_mask[bbox_mask] &= valid_mask - valid_patches_nb = valid_mask.sum() - valid_coords = coords[valid_mask] + valid_patches_nb = full_mask.sum() + self.valid_mask = full_mask + valid_coords = coords[full_mask] return valid_patches_nb, valid_coords def __len__(self): @@ -388,7 +412,10 @@ def get_cols_rows(self) -> Tuple[int, int]: return self.cols, self.rows def get_tile_xy(self, x: int, y: int) -> Tuple[np.ndarray, int, int]: - raw_tile = self.wsi.read_region(location=(x, y), level=self.level, size=(self.patch_size_level, self.patch_size_level)) + if self.pil: + raw_tile = self.wsi.read_region_pil(location=(x, y), level=self.level, size=(self.patch_size_level, self.patch_size_level)) + else: + raw_tile = self.wsi.read_region(location=(x, y), level=self.level, size=(self.patch_size_level, self.patch_size_level)) tile = np.array(raw_tile) if self.patch_size_target is not None: tile = cv2.resize(tile, (self.patch_size_target, self.patch_size_target)) @@ -454,6 +481,46 @@ def save_visualization(self, path, vis_width=1000, dpi=150): ax.set_axis_off() plt.tight_layout() plt.savefig(path, dpi=dpi, bbox_inches = 'tight') + + + def to_h5(self, path, extra_assets={}): + mode_HE = 'w' + i = 0 + + if extra_assets != {}: + for _, value in extra_assets.items(): + if len(value) != len(self): + raise ValueError("Each value in extra_assets must have the same length as the patcher object") + + if not (path.endswith('.h5') or path.endswith('.h5ad')): + path = path + '.h5' + + for tile, x, y in tqdm(self): + + # Save ref patches + assert tile.shape == (self.patch_size_target, self.patch_size_target, 3) + + asset_dict = {} + if not self.coords_only: + asset_dict['img'] = np.expand_dims(tile, axis=0) # (1 x w x h x 3) + asset_dict['coords'] = np.expand_dims([x, y], axis=0) # (1 x 2) + + extra_asset_dict = {key: np.expand_dims([value[i]], axis=0) for key, value in extra_assets.items()} + asset_dict = {**asset_dict, **extra_asset_dict} + + attr_dict = {} + attr_dict['img'] = { + 'patch_size': self.patch_size_target, + 'factor': self.downsample, + 'pixel_size': self.src_pixel_size, + } + + if self.wsi.name is not None: + attr_dict['img']['name'] = self.wsi.name + + initsave_hdf5(path, asset_dict, attr_dict, mode=mode_HE) + mode_HE = 'a' + i += 1 class OpenSlideWSIPatcher(WSIPatcher): diff --git a/tests/hest_tests.py b/tests/hest_tests.py index 214304f..2897b48 100644 --- a/tests/hest_tests.py +++ b/tests/hest_tests.py @@ -91,7 +91,7 @@ class TestHESTData(unittest.TestCase): @classmethod def setUpClass(self): - download = True + download = False self.cur_dir = get_path_relative(__file__, '') cur_dir = self.cur_dir self.output_dir = _j(cur_dir, 'output_tests/hestdata_tests') @@ -181,8 +181,8 @@ def test_saving(self): loader = unittest.TestLoader() suite = loader.loadTestsFromTestCase(TestHESTData) - #suite = unittest.TestSuite() - #suite.addTest(TestHESTData('test_patching')) + # suite = unittest.TestSuite() + # suite.addTest(TestHESTData('test_patching')) result = unittest.TextTestRunner(verbosity=2).run(suite) if not result.wasSuccessful(): raise Exception('Test failed') \ No newline at end of file From 1c48fa0a743b7fa317758661ee1307ec4fe3947e Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Sat, 10 Aug 2024 19:57:41 -0400 Subject: [PATCH 21/24] download true --- tests/hest_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/hest_tests.py b/tests/hest_tests.py index 2897b48..b065722 100644 --- a/tests/hest_tests.py +++ b/tests/hest_tests.py @@ -91,7 +91,7 @@ class TestHESTData(unittest.TestCase): @classmethod def setUpClass(self): - download = False + download = True self.cur_dir = get_path_relative(__file__, '') cur_dir = self.cur_dir self.output_dir = _j(cur_dir, 'output_tests/hestdata_tests') From 7449e1a6abee7f1132fd36d4c17f54d55dbb701a Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Mon, 12 Aug 2024 10:09:11 -0400 Subject: [PATCH 22/24] clean dependencies --- src/hest/HESTData.py | 1 - src/hest/segmentation/segmentation.py | 23 +++++++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/hest/HESTData.py b/src/hest/HESTData.py index 4b3f3ad..2fe8b71 100644 --- a/src/hest/HESTData.py +++ b/src/hest/HESTData.py @@ -30,7 +30,6 @@ from .utils import (ALIGNED_HE_FILENAME, check_arg, deprecated, find_first_file_endswith, get_path_from_meta_row, plot_verify_pixel_size, tiff_save, verify_paths) -from .vst_save_utils import initsave_hdf5 class HESTData: diff --git a/src/hest/segmentation/segmentation.py b/src/hest/segmentation/segmentation.py index 6674a1d..37b87e1 100644 --- a/src/hest/segmentation/segmentation.py +++ b/src/hest/segmentation/segmentation.py @@ -11,8 +11,7 @@ from PIL import Image from shapely import Polygon -from hest.utils import deprecated, get_path_relative -from hest.wsi import WSI, WSIPatcher, wsi_factory +from hest.wsi import WSI, wsi_factory try: import openslide @@ -183,6 +182,21 @@ def keep_largest_area(mask: np.ndarray) -> np.ndarray: largest_mask[label_image == largest_label] = True mask[~largest_mask] = 0 return mask + + +def deprecated(func): + """This is a decorator which can be used to mark functions + as deprecated. It will result in a warning being emitted + when the function is used.""" + @functools.wraps(func) + def new_func(*args, **kwargs): + warnings.simplefilter('always', DeprecationWarning) # turn off filter + warnings.warn("Call to deprecated function {}.".format(func.__name__), + category=DeprecationWarning, + stacklevel=2) + warnings.simplefilter('default', DeprecationWarning) # reset filter + return func(*args, **kwargs) + return new_func @deprecated @@ -292,6 +306,11 @@ def apply_otsu_thresholding(tile: np.ndarray) -> np.ndarray: return otsu_thr +def get_path_relative(file, path) -> str: + curr_dir = os.path.dirname(os.path.abspath(file)) + return os.path.join(curr_dir, path) + + def filter_contours(contours, hierarchy, filter_params, scale, pixel_size): """ Filter contours by: area From 169d3a8cea79dbcc24c32181b7ab207409f86842 Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Mon, 12 Aug 2024 15:27:48 -0400 Subject: [PATCH 23/24] add hestcore as dependency --- pyproject.toml | 3 +- src/hest/HESTData.py | 11 +- src/hest/segmentation/SegDataset.py | 23 - src/hest/segmentation/cell_segmenters.py | 2 +- src/hest/segmentation/segmentation.py | 454 ---------------- src/hest/utils.py | 2 +- src/hest/vst_save_utils.py | 57 --- src/hest/wsi.py | 626 ----------------------- tests/hest_tests.py | 4 +- 9 files changed, 11 insertions(+), 1171 deletions(-) delete mode 100644 src/hest/segmentation/SegDataset.py delete mode 100644 src/hest/segmentation/segmentation.py delete mode 100644 src/hest/vst_save_utils.py delete mode 100644 src/hest/wsi.py diff --git a/pyproject.toml b/pyproject.toml index 91ef57e..a1f05cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,8 @@ dependencies = [ "dask >= 2024.2.1", "spatial_image >= 0.3.0", "datasets", - "mygene" + "mygene", + "hestcore == 1.0.0" ] requires-python = ">=3.9" diff --git a/src/hest/HESTData.py b/src/hest/HESTData.py index 2fe8b71..396db73 100644 --- a/src/hest/HESTData.py +++ b/src/hest/HESTData.py @@ -9,24 +9,24 @@ import cv2 import geopandas as gpd import numpy as np +from hestcore.wsi import (WSI, CucimWarningSingleton, NumpyWSI, + contours_to_img, wsi_factory) from hest.io.seg_readers import TissueContourReader from hest.LazyShapes import LazyShapes, convert_old_to_gpd, old_geojson_to_new from hest.segmentation.TissueMask import TissueMask, load_tissue_mask -from hest.wsi import (WSI, CucimWarningSingleton, NumpyWSI, contours_to_img, - get_tissue_vis, wsi_factory) try: import openslide except Exception: print("Couldn't import openslide, verify that openslide is installed on your system, https://openslide.org/download/") import pandas as pd +from hestcore.segmentation import (apply_otsu_thresholding, mask_to_gdf, + save_pkl, segment_tissue_deep) from PIL import Image from shapely import Point from tqdm import tqdm -from .segmentation.segmentation import (apply_otsu_thresholding, mask_to_gdf, - save_pkl, segment_tissue_deep) from .utils import (ALIGNED_HE_FILENAME, check_arg, deprecated, find_first_file_endswith, get_path_from_meta_row, plot_verify_pixel_size, tiff_save, verify_paths) @@ -432,8 +432,7 @@ def save_tissue_seg_pkl(self, save_dir: str, name: str) -> None: def get_tissue_vis(self): - return get_tissue_vis( - self.wsi.img, + return self.wsi.get_tissue_vis( self.tissue_contours, line_color=(0, 255, 0), line_thickness=5, diff --git a/src/hest/segmentation/SegDataset.py b/src/hest/segmentation/SegDataset.py deleted file mode 100644 index f0631d8..0000000 --- a/src/hest/segmentation/SegDataset.py +++ /dev/null @@ -1,23 +0,0 @@ -from torch.utils.data import Dataset - -from hest.wsi import WSIPatcher - - -class WSIPatcherDataset(Dataset): - - def __init__(self, patcher: WSIPatcher, transform): - self.patcher = patcher - - self.transform = transform - - - def __len__(self): - return len(self.patcher) - - def __getitem__(self, index): - tile, x, y = self.patcher[index] - - if self.transform: - tile = self.transform(tile) - - return tile, (x, y) \ No newline at end of file diff --git a/src/hest/segmentation/cell_segmenters.py b/src/hest/segmentation/cell_segmenters.py index 31ce02f..b6af574 100644 --- a/src/hest/segmentation/cell_segmenters.py +++ b/src/hest/segmentation/cell_segmenters.py @@ -20,7 +20,7 @@ from hest.io.seg_readers import GeojsonCellReader from hest.utils import get_path_relative, verify_paths -from hest.wsi import wsi_factory +from hestcore.wsi import wsi_factory def cellvit_light_error(): diff --git a/src/hest/segmentation/segmentation.py b/src/hest/segmentation/segmentation.py deleted file mode 100644 index 37b87e1..0000000 --- a/src/hest/segmentation/segmentation.py +++ /dev/null @@ -1,454 +0,0 @@ -from __future__ import annotations - -import pickle -from typing import Union - -import cv2 -import numpy as np -import pandas as pd -from geopandas import gpd -from huggingface_hub import snapshot_download -from PIL import Image -from shapely import Polygon - -from hest.wsi import WSI, wsi_factory - -try: - import openslide -except Exception: - print("Couldn't import openslide, verify that openslide is installed on your system, https://openslide.org/download/") -from tqdm import tqdm - - -def segment_tissue_deep( - wsi: Union[np.ndarray, openslide.OpenSlide, CuImage, WSI], # type: ignore - pixel_size: float, - fast_mode=False, - dst_pixel_size=1, - patch_size_um=512, - model_name='deeplabv3_seg_v4.ckpt', - batch_size=8, - auto_download=True, - num_workers=8 -) -> gpd.GeoDataFrame: - """ Segment the tissue using a DeepLabV3 model - - Args: - wsi (Union[np.ndarray, openslide.OpenSlide, CuImage, WSI]): wsi - pixel_size (float): pixel size in um/px for the wsi - fast_mode (bool, optional): in fast mode the inference is done at 2 um/px instead of 1 um/px, - note that the inference pixel size is overwritten by the `target_pxl_size` argument if != 1. Defaults to False. - dst_pixel_size (int, optional): patches are scaled to this pixel size in um/px for inference. Defaults to 1. - patch_size_um (int, optional): patch size in um. Defaults to 512. - model_name (str, optional): model name in `HEST/models` dir. Defaults to 'deeplabv3_seg_v4.ckpt'. - batch_size (int, optional): batch size for inference. Defaults to 8. - auto_download (bool, optional): whenever to download the model weights automatically if not found. Defaults to True. - num_workers (int, optional): number of workers for the dataloader during inference. Defaults to 8. - - Returns: - gpd.GeoDataFrame: a geodataframe of the tissue contours, contains a column `tissue_id` indicating to which tissue the contour belongs to - """ - import torch - from torch import nn - from torch.utils.data import DataLoader - from torchvision import transforms - from hest.segmentation.SegDataset import WSIPatcherDataset - - src_pixel_size = pixel_size - - if fast_mode and dst_pixel_size == 1: - dst_pixel_size = 2 - - patch_size_deeplab = 512 - - scale = src_pixel_size / dst_pixel_size - patch_size_src = round(patch_size_um / scale) - wsi = wsi_factory(wsi) - - weights_path = get_path_relative(__file__, f'../../../models/{model_name}') - - patcher = wsi.create_patcher(patch_size_deeplab, src_pixel_size, dst_pixel_size) - - eval_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) - dataset = WSIPatcherDataset(patcher, eval_transforms) - dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) - model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50') - model.classifier[4] = nn.Conv2d( - in_channels=256, - out_channels=2, - kernel_size=1, - stride=1 - ) - - if auto_download: - model_dir = get_path_relative(__file__, f'../../../models') - snapshot_download(repo_id="MahmoodLab/hest-tissue-seg", repo_type='model', local_dir=model_dir, allow_patterns=model_name) - - if torch.cuda.is_available(): - checkpoint = torch.load(weights_path) - else: - checkpoint = torch.load(weights_path, map_location=torch.device('cpu')) - - new_state_dict = {} - for key in checkpoint['state_dict']: - if 'aux' in key: - continue - new_key = key.replace('model.', '') - new_state_dict[new_key] = checkpoint['state_dict'][key] - model.load_state_dict(new_state_dict) - - if torch.cuda.is_available(): - model.cuda() - - model.eval() - - cols, rows = patcher.get_cols_rows() - width, height = patch_size_deeplab * cols, patch_size_deeplab * rows - stitched_img = np.zeros((height, width), dtype=np.uint8) - src_to_deeplab_scale = patch_size_deeplab / patch_size_src - - with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16): - - for batch in tqdm(dataloader, total=len(dataloader)): - - # coords are top left coords of patch - imgs, coords = batch - if torch.cuda.is_available(): - imgs = imgs.cuda() - masks = model(imgs)['out'] - preds = masks.argmax(1).to(torch.uint8).detach() - - preds = preds.cpu().numpy() - coords = np.column_stack((coords[0], coords[1])) - - # stitch the patches - for i in range(preds.shape[0]): - pred = preds[i] - coord = coords[i] - x, y = round(coord[0] * src_to_deeplab_scale), round(coord[1] * src_to_deeplab_scale) - - y_end = min(y+patch_size_deeplab, height) - x_end = min(x+patch_size_deeplab, width) - stitched_img[y:y_end, x:x_end] += pred[:y_end-y, :x_end-x] - - - mask = (stitched_img > 0).astype(np.uint8) - - gdf_contours = mask_to_gdf(mask, max_nb_holes=5, pixel_size=src_pixel_size, contour_scale=1 / src_to_deeplab_scale) - - return gdf_contours - - -def save_pkl(filename, save_object): - writer = open(filename,'wb') - pickle.dump(save_object, writer) - writer.close() - - -def mask_rgb(rgb: np.ndarray, mask: np.ndarray) -> np.ndarray: - """Mask an RGB image - - Args: - rgb (np.ndarray): RGB image to mask with shape (height, width, 3) - mask (np.ndarray): Binary mask with shape (height, width) - - Returns: - np.ndarray: Masked image - """ - assert ( - rgb.shape[:-1] == mask.shape - ), "Mask and RGB shape are different. Cannot mask when source and mask have different dimension." - mask_positive = np.dstack([mask, mask, mask]) - mask_negative = np.dstack([~mask, ~mask, ~mask]) - positive = rgb * mask_positive - negative = rgb * mask_negative - negative = 255 * (negative > 0.0001).astype(int) - - masked_image = positive + negative - - return np.clip(masked_image, a_min=0, a_max=255) - - -def keep_largest_area(mask: np.ndarray) -> np.ndarray: - label_image, num_labels = sk_measure.label(mask, background=0, return_num=True) - largest_label = 0 - largest_area = 0 - for label in range(1, num_labels + 1): - area = np.sum(label_image == label) - if area > largest_area: - largest_label = label - largest_area = area - largest_mask = np.zeros_like(mask, dtype=bool) - largest_mask[label_image == largest_label] = True - mask[~largest_mask] = 0 - return mask - - -def deprecated(func): - """This is a decorator which can be used to mark functions - as deprecated. It will result in a warning being emitted - when the function is used.""" - @functools.wraps(func) - def new_func(*args, **kwargs): - warnings.simplefilter('always', DeprecationWarning) # turn off filter - warnings.warn("Call to deprecated function {}.".format(func.__name__), - category=DeprecationWarning, - stacklevel=2) - warnings.simplefilter('default', DeprecationWarning) # reset filter - return func(*args, **kwargs) - return new_func - - -@deprecated -def visualize_tissue_seg( - img, - tissue_mask, - contours_tissue, - contour_holes, - line_color=(0, 255, 0), - hole_color=(0, 0, 255), - line_thickness=5, - target_width=1000, - seg_display=True, - ): - hole_fill_color = (0, 0, 0) - - - wsi = wsi_factory(img) - - width, height = wsi.get_dimensions() - downsample = target_width / width - - top_left = (0,0) - scale = [downsample, downsample] - - img = wsi.get_thumbnail(round(width * downsample), round(height * downsample)) - #img = cv2.resize(img, (round(width * downsample), round(height * downsample))) - if tissue_mask is None and contours_tissue is None and contour_holes is None: - return Image.fromarray(img) - - downscaled_mask = cv2.resize(tissue_mask, (img.shape[1], img.shape[0])) - downscaled_mask = np.expand_dims(downscaled_mask, axis=-1) - downscaled_mask = downscaled_mask * np.array([0, 0, 0]).astype(np.uint8) - - - draw_cont = partial(cv2.drawContours, contourIdx=-1, thickness=line_thickness, lineType=cv2.LINE_8) - draw_cont_fill = partial(cv2.drawContours, contourIdx=-1, thickness=cv2.FILLED) - - if contours_tissue is not None and seg_display: - for _, cont in enumerate(contours_tissue): - cont = np.array(scale_contour_dim(cont, scale)) - draw_cont(image=img, contours=[cont], color=line_color) - draw_cont_fill(image=downscaled_mask, contours=[cont], color=line_color) - - ### Draw hole contours - for cont in contour_holes: - cont = scale_contour_dim(cont, scale) - draw_cont(image=img, contours=cont, color=hole_color) - draw_cont_fill(image=downscaled_mask, contours=cont, color=hole_fill_color) - - alpha = 0.4 - downscaled_mask = downscaled_mask - tissue_mask = cv2.resize(downscaled_mask, (width, height)).round().astype(np.uint8) - img = cv2.addWeighted(img, 1 - alpha, downscaled_mask, alpha, 0) - img = img.astype(np.uint8) - - return Image.fromarray(img) - - -def apply_otsu_thresholding(tile: np.ndarray) -> np.ndarray: - """Generate a binary tissue mask by using Otsu thresholding - - Args: - tile (np.ndarray): Tile with tissue with shape (height, width, 3) - - Returns: - np.ndarray: Binary mask with shape (height, width) - """ - import skimage.color as sk_color - import skimage.filters as sk_filters - import skimage.measure as sk_measure - import skimage.morphology as sk_morphology - - # this is to remove the black border padding in some images - black_pixels = np.all(tile == [0, 0, 0], axis=-1) - tile[black_pixels] = [255, 255, 255] - - - hsv_img = cv2.cvtColor(tile.astype(np.uint8), cv2.COLOR_RGB2HSV) - gray_mask = cv2.inRange(hsv_img, (0, 0, 70), (180, 10, 255)) - black_mask = cv2.inRange(hsv_img, (0, 0, 0), (180, 255, 85)) - # Set all grey/black pixels to white - full_tile_bg = np.copy(tile) - full_tile_bg[np.where(gray_mask | black_mask)] = 255 - - # apply otsu mask first time for removing larger artifacts - masked_image_gray = 255 * sk_color.rgb2gray(full_tile_bg) - thresh = sk_filters.threshold_otsu(masked_image_gray) - otsu_masking = masked_image_gray < thresh - # improving mask - otsu_masking = sk_morphology.remove_small_objects(otsu_masking, 60) - #otsu_masking = sk_morphology.dilation(otsu_masking, sk_morphology.square(12)) - #otsu_masking = sk_morphology.closing(otsu_masking, sk_morphology.square(5)) - #otsu_masking = sk_morphology.remove_small_holes(otsu_masking, 250) - tile = mask_rgb(tile, otsu_masking).astype(np.uint8) - - # apply otsu mask second time for removing small artifacts - masked_image_gray = 255 * sk_color.rgb2gray(tile) - thresh = sk_filters.threshold_otsu(masked_image_gray) - otsu_masking = masked_image_gray < thresh - otsu_masking = sk_morphology.remove_small_holes(otsu_masking, 5000) - otsu_thr = ~otsu_masking - otsu_thr = otsu_thr.astype(np.uint8) - - #Image.fromarray(np.expand_dims(otsu_thr, axis=-1) * np.array([255, 255, 255]).astype(np.uint8)).save('otsu_thr.png') - - return otsu_thr - - -def get_path_relative(file, path) -> str: - curr_dir = os.path.dirname(os.path.abspath(file)) - return os.path.join(curr_dir, path) - - -def filter_contours(contours, hierarchy, filter_params, scale, pixel_size): - """ - Filter contours by: area - """ - filtered = [] - - # find indices of foreground contours (parent == -1) - if len(hierarchy) == 0: - hierarchy_1 = [] - else: - hierarchy_1 = np.flatnonzero(hierarchy[:,1] == -1) - all_holes = [] - - # loop through foreground contour indices - for cont_idx in hierarchy_1: - # actual contour - cont = contours[cont_idx] - # indices of holes contained in this contour (children of parent contour) - holes = np.flatnonzero(hierarchy[:, 1] == cont_idx) - # take contour area (includes holes) - a = cv2.contourArea(cont) - # calculate the contour area of each hole - hole_areas = [cv2.contourArea(contours[hole_idx]) for hole_idx in holes] - # actual area of foreground contour region - a = a - np.array(hole_areas).sum() - a *= pixel_size ** 2 - - if a == 0: continue - - - - if tuple((filter_params['a_t'],)) < tuple((a,)): - - if (filter_params['filter_color_mode'] == 'none') or (filter_params['filter_color_mode'] is None): - filtered.append(cont_idx) - holes = [hole_idx for hole_idx in holes if cv2.contourArea(contours[hole_idx]) * pixel_size ** 2 > filter_params['min_hole_area']] - all_holes.append(holes) - else: - raise Exception() - - - # for parent in filtered: - # all_holes.append(np.flatnonzero(hierarchy[:, 1] == parent)) - - ##### TODO: re-implement this in a single for-loop that - ##### loops through both parent contours and holes - - foreground_contours = [contours[cont_idx] for cont_idx in filtered] - - hole_contours = [] - - for hole_ids in all_holes: - unfiltered_holes = [contours[idx] for idx in hole_ids ] - unfilered_holes = sorted(unfiltered_holes, key=cv2.contourArea, reverse=True) - # take max_n_holes largest holes by area - filtered_holes = unfilered_holes[:filter_params['max_n_holes']] - #filtered_holes = [] - - # filter these holes - #for hole in unfilered_holes: - # if cv2.contourArea(hole) > filter_params['a_h']: - # filtered_holes.append(hole) - - hole_contours.append(filtered_holes) - - return foreground_contours, hole_contours - - -def make_valid(polygon): - for i in [0, 0.1, -0.1, 0.2]: - new_polygon = polygon.buffer(i) - if isinstance(new_polygon, Polygon) and new_polygon.is_valid: - return new_polygon - raise Exception("Failed to make a valid polygon") - - -def mask_to_gdf(mask: np.ndarray, keep_ids = [], exclude_ids=[], max_nb_holes=0, min_contour_area=1000, pixel_size=1, contour_scale=1.): - TARGET_EDGE_SIZE = 2000 - scale = TARGET_EDGE_SIZE / mask.shape[0] - - downscaled_mask = cv2.resize(mask, (round(mask.shape[1] * scale), round(mask.shape[0] * scale))) - - # Find and filter contours - if max_nb_holes == 0: - contours, hierarchy = cv2.findContours(downscaled_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) - else: - contours, hierarchy = cv2.findContours(downscaled_mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) # Find contours - #print('Num Contours Before Filtering:', len(contours)) - if hierarchy is None: - hierarchy = [] - else: - hierarchy = np.squeeze(hierarchy, axis=(0,))[:, 2:] - - filter_params = { - 'filter_color_mode': 'none', - 'max_n_holes': max_nb_holes, - 'a_t': min_contour_area * pixel_size ** 2, - 'min_hole_area': 4000 * pixel_size ** 2 - } - - if filter_params: - foreground_contours, hole_contours = filter_contours(contours, hierarchy, filter_params, scale, pixel_size) # Necessary for filtering out artifacts - - - if len(foreground_contours) == 0: - raise Exception('no contour detected') - else: - contours_tissue = scale_contour_dim(foreground_contours, contour_scale / scale) - contours_holes = scale_holes_dim(hole_contours, contour_scale / scale) - - if len(keep_ids) > 0: - contour_ids = set(keep_ids) - set(exclude_ids) - else: - contour_ids = set(np.arange(len(contours_tissue))) - set(exclude_ids) - - tissue_ids = [i for i in contour_ids] - polygons = [] - for i in contour_ids: - holes = [contours_holes[i][j].squeeze(1) for j in range(len(contours_holes[i]))] if len(contours_holes[i]) > 0 else None - polygon = Polygon(contours_tissue[i].squeeze(1), holes=holes) - if not polygon.is_valid: - # TODO replace by shapely make_valid after 2.1 - if not polygon.is_valid: - polygon = make_valid(polygon) - polygons.append(polygon) - - gdf_contours = gpd.GeoDataFrame(pd.DataFrame(tissue_ids, columns=['tissue_id']), geometry=polygons) - - return gdf_contours - - -def scale_holes_dim(contours, scale): - r""" - """ - return [[np.array(hole * scale, dtype = 'int32') for hole in holes] for holes in contours] - - -def scale_contour_dim(contours, scale): - r""" - """ - return [np.array(cont * scale, dtype='int32') for cont in contours] \ No newline at end of file diff --git a/src/hest/utils.py b/src/hest/utils.py index e7ac36c..2ec426e 100644 --- a/src/hest/utils.py +++ b/src/hest/utils.py @@ -19,7 +19,7 @@ from scipy import sparse from tqdm import tqdm -from hest.wsi import WSI, NumpyWSI, WSIPatcher, wsi_factory +from hestcore.wsi import WSI, NumpyWSI, WSIPatcher, wsi_factory Image.MAX_IMAGE_PIXELS = 93312000000 ALIGNED_HE_FILENAME = 'aligned_fullres_HE.tif' diff --git a/src/hest/vst_save_utils.py b/src/hest/vst_save_utils.py deleted file mode 100644 index 39be25d..0000000 --- a/src/hest/vst_save_utils.py +++ /dev/null @@ -1,57 +0,0 @@ -##### -# Utils for h5 file saving -#### - -import h5py - - -def initsave_hdf5(output_fpath, asset_dict, attr_dict= None, mode='a', static_shape=None, chunk_as_max_shape=False, verbose=0): - with h5py.File(output_fpath, mode) as f: - for key, val in asset_dict.items(): - data_shape = val.shape - - # if len(data_shape) == 1: - # val = np.expand_dims(val, axis=1) - # data_shape = val.shape - - if key not in f: - data_type = val.dtype - - if data_type.kind == 'U': # This is for catching numpy array of unicode strings - chunk_shape = (1, 1) - max_shape = (None, 1) - data_type = h5py.string_dtype(encoding='utf-8') - else: - if chunk_as_max_shape: - chunk_shape = data_shape - max_shape = data_shape - elif static_shape is None: - chunk_shape = (1,) + data_shape[1:] - max_shape = (None,) + data_shape[1:] - # else: - # chunk_shape = static_shape - # max_shape = static_shape - - if verbose: - print(key, data_shape, chunk_shape, max_shape) - - dset = f.create_dataset(key, - shape=data_shape, - maxshape=max_shape, - chunks=chunk_shape, - dtype=data_type) - - dset[:] = val - - ### Save attribute dictionary - if attr_dict is not None: - if key in attr_dict.keys(): - for attr_key, attr_val in attr_dict[key].items(): - dset.attrs[attr_key] = attr_val - else: - dset = f[key] - dset.resize(len(dset) + data_shape[0], axis=0) - #assert dset.dtype == val.dtype - dset[-data_shape[0]:] = val - - return output_fpath \ No newline at end of file diff --git a/src/hest/wsi.py b/src/hest/wsi.py deleted file mode 100644 index d976596..0000000 --- a/src/hest/wsi.py +++ /dev/null @@ -1,626 +0,0 @@ -from __future__ import annotations - -import warnings -from abc import abstractmethod -from functools import partial -from typing import Tuple, Union - -import cv2 -import geopandas as gpd -import numpy as np -import openslide -from PIL import Image -from shapely import Polygon -from tqdm import tqdm - -from hest.vst_save_utils import initsave_hdf5 - - -class CucimWarningSingleton: - _warned_cucim = False - - @classmethod - def warn(cls): - if cls._warned_cucim is False: - warnings.warn("CuImage is not available. Ensure you have a GPU and cucim installed to use GPU acceleration.") - cls._warned_cucim = True - return cls._warned_cucim - - -def is_cuimage(img): - try: - from cucim import CuImage - except ImportError: - CuImage = None - CucimWarningSingleton.warn() - return CuImage is not None and isinstance(img, CuImage) # type: ignore - - -class WSI: - - def __init__(self, img): - self.img = img - self.name = None - - if not (isinstance(img, openslide.OpenSlide) or isinstance(img, np.ndarray) or is_cuimage(img)) : - raise ValueError(f"Invalid type for img {type(img)}") - - self.width, self.height = self.get_dimensions() - - @abstractmethod - def numpy(self) -> np.ndarray: - pass - - @abstractmethod - def get_dimensions(self): - pass - - @abstractmethod - def read_region(self, location, level, size) -> np.ndarray: - pass - - @abstractmethod - def read_region_pil(self, location, level, size) -> np.ndarray: - pass - - @abstractmethod - def get_thumbnail(self, width, height): - pass - - def __repr__(self) -> str: - width, height = self.get_dimensions() - - return f"" - - @abstractmethod - def create_patcher( - self, - patch_size: int, - src_pixel_size: float, - dst_pixel_size: float = None, - overlap: int = 0, - mask: gpd.GeoDataFrame = None, - coords_only = False, - custom_coords = None - ) -> WSIPatcher: - pass - - -def wsi_factory(img) -> WSI: - try: - from cucim import CuImage - except ImportError: - CuImage = None - CucimWarningSingleton.warn() - - if isinstance(img, WSI): - return img - elif isinstance(img, openslide.OpenSlide): - return OpenSlideWSI(img) - elif isinstance(img, np.ndarray): - return NumpyWSI(img) - elif is_cuimage(img): - return CuImageWSI(img) - elif isinstance(img, str): - if CuImage is not None: - return CuImageWSI(CuImage(img)) - else: - warnings.warn("Cucim isn't available, opening the image with OpenSlide (will be slower)") - return OpenSlideWSI(openslide.OpenSlide(img)) - else: - raise ValueError(f'type {type(img)} is not supported') - -class NumpyWSI(WSI): - def __init__(self, img: np.ndarray): - super().__init__(img) - - def numpy(self) -> np.ndarray: - return self.img - - def get_dimensions(self): - return self.img.shape[1], self.img.shape[0] - - def read_region(self, location, level, size) -> np.ndarray: - img = self.img - x_start, y_start = location[0], location[1] - x_size, y_size = size[0], size[1] - x_end, y_end = x_start + x_size, y_start + y_size - padding_left = max(0 - x_start, 0) - padding_top = max(0 - y_start, 0) - padding_right = max(x_start + x_size - self.width, 0) - padding_bottom = max(y_start + y_size - self.height, 0) - x_start, y_start = max(x_start, 0), max(y_start, 0) - x_end, y_end = min(x_end, self.width), min(y_end, self.height) - tile = img[y_start:y_end, x_start:x_end] - padded_tile = np.pad(tile, ((padding_top, padding_bottom), (padding_left, padding_right), (0, 0)), mode='constant', constant_values=0) - - return padded_tile - - def read_region_pil(self, location, level, size) -> Image: - return Image.fromarray(self.read_region(location, level, size)) - - def get_thumbnail(self, width, height) -> np.ndarray: - return cv2.resize(self.img, (width, height)) - - def create_patcher( - self, - patch_size: int, - src_pixel_size: float, - dst_pixel_size: float = None, - overlap: int = 0, - mask: gpd.GeoDataFrame = None, - coords_only = False, - custom_coords = None - ) -> WSIPatcher: - return NumpyWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only, custom_coords) - - -class OpenSlideWSI(WSI): - def __init__(self, img: openslide.OpenSlide): - super().__init__(img) - - def numpy(self) -> np.ndarray: - return self.get_thumbnail(self.width, self.height) - - def get_dimensions(self): - return self.img.dimensions - - def read_region(self, location, level, size) -> np.ndarray: - return np.array(self.read_region_pil(location, level, size)) - - def read_region_pil(self, location, level, size) -> Image: - return self.img.read_region(location, level, size) - - def get_thumbnail(self, width, height): - return np.array(self.img.get_thumbnail((width, height))) - - def get_best_level_for_downsample(self, downsample): - return self.img.get_best_level_for_downsample(downsample) - - def level_dimensions(self): - return self.img.level_dimensions - - def level_downsamples(self): - return self.img.level_downsamples - - def create_patcher( - self, - patch_size: int, - src_pixel_size: float, - dst_pixel_size: float = None, - overlap: int = 0, - mask: gpd.GeoDataFrame = None, - coords_only = False, - custom_coords = None - ) -> WSIPatcher: - return OpenSlideWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only, custom_coords) - -class CuImageWSI(WSI): - def __init__(self, img: 'CuImage'): - super().__init__(img) - - def numpy(self) -> np.ndarray: - return self.get_thumbnail(self.width, self.height) - - def get_dimensions(self): - return self.img.resolutions['level_dimensions'][0] - - def read_region(self, location, level, size) -> np.ndarray: - return np.asarray(self.img.read_region(location=location, level=level, size=size)) - - def read_region_pil(self, location, level, size) -> Image: - return Image.fromarray(self.read_region(location, level, size)) - - def get_thumbnail(self, width, height): - downsample = self.width / width - downsamples = self.img.resolutions['level_downsamples'] - closest = 0 - for i in range(len(downsamples)): - if downsamples[i] > downsample: - break - closest = i - - curr_width, curr_height = self.img.resolutions['level_dimensions'][closest] - thumbnail = np.array(self.img.read_region(location=(0, 0), level=closest, size=(curr_width, curr_height))) - thumbnail = cv2.resize(thumbnail, (width, height)) - - return thumbnail - - def get_best_level_for_downsample(self, downsample): - downsamples = self.img.resolutions['level_downsamples'] - last = 0 - for i in range(len(downsamples)): - down = downsamples[i] - if downsample < down: - return last - last = i - return last - - def level_dimensions(self): - return self.img.resolutions['level_dimensions'] - - def level_downsamples(self): - return self.img.resolutions['level_downsamples'] - - def create_patcher( - self, - patch_size: int, - src_pixel_size: float, - dst_pixel_size: float = None, - overlap: int = 0, - mask: gpd.GeoDataFrame = None, - coords_only = False, - custom_coords = None - ) -> WSIPatcher: - return CuImageWSIPatcher(self, patch_size, src_pixel_size, dst_pixel_size, overlap, mask, coords_only, custom_coords) - - -class WSIPatcher: - """ Iterator class to handle patching, patch scaling and tissue mask intersection """ - - def __init__( - self, - wsi: WSI, - patch_size: int, - src_pixel_size: float, - dst_pixel_size: float = None, - overlap: int = 0, - mask: gpd.GeoDataFrame = None, - coords_only = False, - custom_coords = None, - threshold = 0.15, - pil=False - ): - """ Initialize patcher, compute number of (masked) rows, columns. - - Args: - wsi (WSI): wsi to patch - patch_size (int): patch width/height in pixel on the slide after rescaling - src_pixel_size (float, optional): pixel size in um/px of the slide before rescaling. Defaults to None. - dst_pixel_size (float, optional): pixel size in um/px of the slide after rescaling. Defaults to None. - overlap (int, optional): overlap size in pixel before rescaling. Defaults to 0. - mask (gpd.GeoDataFrame, optional): geopandas dataframe of Polygons. Defaults to None. - coords_only (bool, optional): whenever to extract only the coordinates insteaf of coordinates + tile. Default to False. - threshold (float, optional): minimum proportion of the patch under tissue to be kept. - This argument is ignored if mask=None, passing threshold=0 will be faster. Defaults to 0.15 - pil (bool, optional): whenever to get patches as `PIL.Image` (numpy array by default). Defaults to False - """ - self.wsi = wsi - self.overlap = overlap - self.width, self.height = self.wsi.get_dimensions() - self.patch_size_target = patch_size - self.mask = mask - self.i = 0 - self.coords_only = coords_only - self.custom_coords = custom_coords - self.pil = pil - self.src_pixel_size = src_pixel_size - - if dst_pixel_size is None: - self.downsample = 1. - else: - self.downsample = dst_pixel_size / src_pixel_size - - self.patch_size_src = round(patch_size * self.downsample) - - self.level, self.patch_size_level, self.overlap_level = self._prepare() - - if custom_coords is None: - self.cols, self.rows = self._compute_cols_rows() - - col_rows = np.array([ - [col, row] - for col in range(self.cols) - for row in range(self.rows) - ]) - coords = np.array([self._colrow_to_xy(xy[0], xy[1]) for xy in col_rows]) - else: - if round(custom_coords[0][0]) != custom_coords[0][0]: - raise ValueError("custom_coords must be a (N, 2) array of int") - coords = custom_coords - if self.mask is not None: - self.valid_patches_nb, self.valid_coords = self._compute_masked(coords, threshold) - else: - self.valid_patches_nb, self.valid_coords = len(coords), coords - - def _colrow_to_xy(self, col, row): - """ Convert col row of a tile to its top-left coordinates before rescaling (x, y) """ - x = col * (self.patch_size_src) - self.overlap * np.clip(col - 1, 0, None) - y = row * (self.patch_size_src) - self.overlap * np.clip(row - 1, 0, None) - return (x, y) - - def _compute_masked(self, coords, threshold) -> None: - """ Compute tiles which center falls under the tissue """ - - # Filter coordinates by bounding boxes of mask polygons - bounding_boxes = self.mask.geometry.bounds - bbox_masks = [] - for _, bbox in bounding_boxes.iterrows(): - bbox_mask = ( - (coords[:, 0] >= bbox['minx'] - self.patch_size_src) & (coords[:, 0] <= bbox['maxx'] + self.patch_size_src) & - (coords[:, 1] >= bbox['miny'] - self.patch_size_src) & (coords[:, 1] <= bbox['maxy'] + self.patch_size_src) - ) - bbox_masks.append(bbox_mask) - - if len(bbox_masks) > 0: - bbox_mask = np.vstack(bbox_masks)[0] - else: - bbox_mask = np.zeros(len(coords), dtype=bool) - - - union_mask = self.mask.union_all() - - squares = [ - Polygon([ - (xy[0], xy[1]), - (xy[0] + self.patch_size_src, xy[1]), - (xy[0] + self.patch_size_src, xy[1] + self.patch_size_src), - (xy[0], xy[1] + self.patch_size_src)]) - for xy in coords[bbox_mask] - ] - if threshold == 0: - valid_mask = gpd.GeoSeries(squares).intersects(union_mask).values - else: - gdf = gpd.GeoSeries(squares) - areas = gdf.area - valid_mask = gdf.intersection(union_mask).area >= threshold * areas - - full_mask = bbox_mask - full_mask[bbox_mask] &= valid_mask - - valid_patches_nb = full_mask.sum() - self.valid_mask = full_mask - valid_coords = coords[full_mask] - return valid_patches_nb, valid_coords - - def __len__(self): - return self.valid_patches_nb - - def __iter__(self): - self.i = 0 - return self - - def __next__(self): - if self.i >= self.valid_patches_nb: - raise StopIteration - x = self.__getitem__(self.i) - self.i += 1 - return x - - def __getitem__(self, index): - if 0 <= index < len(self): - xy = self.valid_coords[index] - x, y = xy[0], xy[1] - if self.coords_only: - return x, y - tile, x, y = self.get_tile_xy(x, y) - return tile, x, y - else: - raise IndexError("Index out of range") - - - @abstractmethod - def _prepare(self) -> None: - pass - - def get_cols_rows(self) -> Tuple[int, int]: - """ Get the number of columns and rows in the associated WSI - - Returns: - Tuple[int, int]: (nb_columns, nb_rows) - """ - return self.cols, self.rows - - def get_tile_xy(self, x: int, y: int) -> Tuple[np.ndarray, int, int]: - if self.pil: - raw_tile = self.wsi.read_region_pil(location=(x, y), level=self.level, size=(self.patch_size_level, self.patch_size_level)) - else: - raw_tile = self.wsi.read_region(location=(x, y), level=self.level, size=(self.patch_size_level, self.patch_size_level)) - tile = np.array(raw_tile) - if self.patch_size_target is not None: - tile = cv2.resize(tile, (self.patch_size_target, self.patch_size_target)) - assert x < self.width and y < self.height - return tile[:, :, :3], x, y - - def get_tile(self, col: int, row: int) -> Tuple[np.ndarray, int, int]: - """ get tile at position (column, row) - - Args: - col (int): column - row (int): row - - Returns: - Tuple[np.ndarray, int, int]: (tile, pixel x of top-left corner (before rescaling), pixel_y of top-left corner (before rescaling)) - """ - if self.custom_coords is not None: - raise ValueError("Can't use get_tile as 'custom_coords' was passed to the constructor") - - x, y = self._colrow_to_xy(col, row) - return self.get_tile_xy(x, y) - - def _compute_cols_rows(self) -> Tuple[int, int]: - col = 0 - row = 0 - x, y = self._colrow_to_xy(col, row) - while x < self.width: - col += 1 - x, _ = self._colrow_to_xy(col, row) - cols = col - while y < self.height: - row += 1 - _, y = self._colrow_to_xy(col, row) - rows = row - return cols, rows - - def save_visualization(self, path, vis_width=1000, dpi=150): - mask_plot = get_tissue_vis( - self.wsi, - self.mask, - line_color=(0, 255, 0), - line_thickness=5, - target_width=vis_width, - seg_display=True, - ) - import matplotlib.pyplot as plt - from matplotlib.collections import PatchCollection - from matplotlib.patches import Rectangle - - downscale_vis = vis_width / self.width - - _, ax = plt.subplots() - ax.imshow(mask_plot) - - patch_rectangles = [] - for xy in self.valid_coords: - x, y = xy[0], xy[1] - x, y = x * downscale_vis, y * downscale_vis - - patch_rectangles.append(Rectangle((x, y), self.patch_size_src * downscale_vis, self.patch_size_src * downscale_vis)) - - ax.add_collection(PatchCollection(patch_rectangles, facecolor='none', edgecolor='black', linewidth=0.3)) - ax.set_axis_off() - plt.tight_layout() - plt.savefig(path, dpi=dpi, bbox_inches = 'tight') - - - def to_h5(self, path, extra_assets={}): - mode_HE = 'w' - i = 0 - - if extra_assets != {}: - for _, value in extra_assets.items(): - if len(value) != len(self): - raise ValueError("Each value in extra_assets must have the same length as the patcher object") - - if not (path.endswith('.h5') or path.endswith('.h5ad')): - path = path + '.h5' - - for tile, x, y in tqdm(self): - - # Save ref patches - assert tile.shape == (self.patch_size_target, self.patch_size_target, 3) - - asset_dict = {} - if not self.coords_only: - asset_dict['img'] = np.expand_dims(tile, axis=0) # (1 x w x h x 3) - asset_dict['coords'] = np.expand_dims([x, y], axis=0) # (1 x 2) - - extra_asset_dict = {key: np.expand_dims([value[i]], axis=0) for key, value in extra_assets.items()} - asset_dict = {**asset_dict, **extra_asset_dict} - - attr_dict = {} - attr_dict['img'] = { - 'patch_size': self.patch_size_target, - 'factor': self.downsample, - 'pixel_size': self.src_pixel_size, - } - - if self.wsi.name is not None: - attr_dict['img']['name'] = self.wsi.name - - initsave_hdf5(path, asset_dict, attr_dict, mode=mode_HE) - mode_HE = 'a' - i += 1 - - -class OpenSlideWSIPatcher(WSIPatcher): - wsi: OpenSlideWSI - - def _prepare(self) -> None: - level = self.wsi.get_best_level_for_downsample(self.downsample) - level_downsample = self.wsi.level_downsamples()[level] - patch_size_level = round(self.patch_size_src / level_downsample) - overlap_level = round(self.overlap / level_downsample) - return level, patch_size_level, overlap_level - -class CuImageWSIPatcher(WSIPatcher): - wsi: CuImageWSI - - def _prepare(self) -> None: - level = self.wsi.get_best_level_for_downsample(self.downsample) - level_downsample = self.wsi.level_downsamples()[level] - patch_size_level = round(self.patch_size_src / level_downsample) - overlap_level = round(self.overlap / level_downsample) - return level, patch_size_level, overlap_level - -class NumpyWSIPatcher(WSIPatcher): - WSI: NumpyWSI - - def _prepare(self) -> None: - patch_size_level = self.patch_size_src - overlap_level = self.overlap - level = -1 - return level, patch_size_level, overlap_level - - - -def contours_to_img( - contours: gpd.GeoDataFrame, - img: np.ndarray, - draw_contours=False, - thickness=1, - downsample=1., - line_color=(0, 255, 0) -) -> np.ndarray: - draw_cont = partial(cv2.drawContours, contourIdx=-1, thickness=thickness, lineType=cv2.LINE_8) - draw_cont_fill = partial(cv2.drawContours, contourIdx=-1, thickness=cv2.FILLED) - - groups = contours.groupby('tissue_id') - for _, group in groups: - - for _, row in group.iterrows(): - cont = np.array([[round(x * downsample), round(y * downsample)] for x, y in row.geometry.exterior.coords]) - holes = [np.array([[round(x * downsample), round(y * downsample)] for x, y in hole.coords]) for hole in row.geometry.interiors] - - draw_cont_fill(image=img, contours=[cont], color=line_color) - - for hole in holes: - draw_cont_fill(image=img, contours=[hole], color=(0, 0, 0)) - - if draw_contours: - draw_cont(image=img, contours=[cont], color=line_color) - return img - - -def get_tissue_vis( - img: Union[np.ndarray, openslide.OpenSlide, CuImage, WSI], - tissue_contours: gpd.GeoDataFrame, - line_color=(0, 255, 0), - line_thickness=5, - target_width=1000, - seg_display=True, - ) -> Image: - - wsi = wsi_factory(img) - - width, height = wsi.get_dimensions() - downsample = target_width / width - - top_left = (0,0) - - img = wsi.get_thumbnail(round(width * downsample), round(height * downsample)) - - if tissue_contours is None: - return Image.fromarray(img) - - tissue_contours = tissue_contours.copy() - - downscaled_mask = np.zeros(img.shape[:2], dtype=np.uint8) - downscaled_mask = np.expand_dims(downscaled_mask, axis=-1) - downscaled_mask = downscaled_mask * np.array([0, 0, 0]).astype(np.uint8) - - if tissue_contours is not None and seg_display: - downscaled_mask = contours_to_img( - tissue_contours, - downscaled_mask, - draw_contours=True, - thickness=line_thickness, - downsample=downsample, - line_color=line_color - ) - - alpha = 0.4 - img = cv2.addWeighted(img, 1 - alpha, downscaled_mask, alpha, 0) - img = img.astype(np.uint8) - - return Image.fromarray(img) \ No newline at end of file diff --git a/tests/hest_tests.py b/tests/hest_tests.py index b065722..a63a8ed 100644 --- a/tests/hest_tests.py +++ b/tests/hest_tests.py @@ -8,7 +8,7 @@ from hest.autoalign import autoalign_visium from hest.readers import VisiumReader from hest.utils import get_path_relative, load_image -from hest.wsi import WSI, CucimWarningSingleton, wsi_factory +from hestcore.wsi import WSI, CucimWarningSingleton, wsi_factory try: from cucim import CuImage @@ -91,7 +91,7 @@ class TestHESTData(unittest.TestCase): @classmethod def setUpClass(self): - download = True + download = False self.cur_dir = get_path_relative(__file__, '') cur_dir = self.cur_dir self.output_dir = _j(cur_dir, 'output_tests/hestdata_tests') From ad1ac5ba939f5b3d50f2de5a3556289c1e6cc6a3 Mon Sep 17 00:00:00 2001 From: Paul Doucet Date: Mon, 12 Aug 2024 15:32:50 -0400 Subject: [PATCH 24/24] download --- tests/hest_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/hest_tests.py b/tests/hest_tests.py index a63a8ed..09e6cf3 100644 --- a/tests/hest_tests.py +++ b/tests/hest_tests.py @@ -91,7 +91,7 @@ class TestHESTData(unittest.TestCase): @classmethod def setUpClass(self): - download = False + download = True self.cur_dir = get_path_relative(__file__, '') cur_dir = self.cur_dir self.output_dir = _j(cur_dir, 'output_tests/hestdata_tests')