Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move core methods to hestcore #30

Merged
merged 25 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Hest tests

on:
#push:
# branches: [ "main", "develop"]
pull_request:
branches: [ "main" ]

permissions:
contents: read

jobs:
build:

runs-on: ubuntu-latest
env:
HF_READ_TOKEN_PAUL: ${{ secrets.HF_READ_TOKEN_PAUL }}

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v3
with:
python-version: "3.9"

- name: Install python dependencies
run: |
python -m pip install -e .
- name: Install apt dependencies
run: |
sudo apt-get update
sudo apt-get install libvips libvips-dev openslide-tools

- name: Run tests
run: |
python tests/hest_tests.py
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,5 @@ hest_vis
vis
vis2
models/deeplabv3*
.github
htmlcov
models/CellViT-SAM-H-x40.pth
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ dependencies = [
"dask >= 2024.2.1",
"spatial_image >= 0.3.0",
"datasets",
"mygene"
"mygene",
"hestcore == 1.0.0"
]

requires-python = ">=3.9"
Expand Down
156 changes: 39 additions & 117 deletions src/hest/HESTData.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,28 @@

import cv2
import geopandas as gpd
import matplotlib
import numpy as np
from hestcore.wsi import (WSI, CucimWarningSingleton, NumpyWSI,
contours_to_img, wsi_factory)

from hest.io.seg_readers import (TissueContourReader,
write_geojson)
from hest.LazyShapes import LazyShapes, convert_old_to_gpd
from hest.io.seg_readers import TissueContourReader
from hest.LazyShapes import LazyShapes, convert_old_to_gpd, old_geojson_to_new
from hest.segmentation.TissueMask import TissueMask, load_tissue_mask
from hest.wsi import WSI, CucimWarningSingleton, NumpyWSI, wsi_factory

try:
import openslide
except Exception:
print("Couldn't import openslide, verify that openslide is installed on your system, https://openslide.org/download/")
import pandas as pd
from matplotlib.collections import PatchCollection
from hestcore.segmentation import (apply_otsu_thresholding, mask_to_gdf,
save_pkl, segment_tissue_deep)
from PIL import Image
from shapely import Point
from tqdm import tqdm

from .segmentation.segmentation import (apply_otsu_thresholding,
contours_to_img, get_tissue_vis,
mask_to_contours, save_pkl,
segment_tissue_deep)
from .utils import (ALIGNED_HE_FILENAME, check_arg, deprecated,
find_first_file_endswith, get_path_from_meta_row,
plot_verify_pixel_size, tiff_save, verify_paths)
from .vst_save_utils import initsave_hdf5


class HESTData:
Expand Down Expand Up @@ -137,7 +132,7 @@ def save_spatial_plot(self, save_path: str, name: str='', key='total_counts', pl
filename = f"{name}spatial_plots.png"

# Save the figure
fig.savefig(os.path.join(save_path, filename))
fig.savefig(os.path.join(save_path, filename), dpi=400)
print(f"H&E overlay spatial plots saved in {save_path}")


Expand Down Expand Up @@ -261,21 +256,14 @@ def segment_tissue(
tissue_mask = np.round(cv2.resize(mask, (width, height))).astype(np.uint8)

#TODO directly convert to gpd
gdf_contours = mask_to_contours(tissue_mask, pixel_size=self.pixel_size)
gdf_contours = mask_to_gdf(tissue_mask, pixel_size=self.pixel_size)
self._tissue_contours = gdf_contours

return self.tissue_contours


def save_tissue_contours(self, save_dir: str, name: str) -> None:
write_geojson(
self.tissue_contours,
os.path.join(save_dir, name + '_contours.geojson'),
'tissue_id',
extra_prop=True,
index_key='hole'
)

self.tissue_contours.to_file(os.path.join(save_dir, name + '_contours.geojson'), driver="GeoJSON")

@deprecated
def get_tissue_mask(self) -> np.ndarray:
Expand Down Expand Up @@ -316,6 +304,7 @@ def dump_patches(
"""

import matplotlib.pyplot as plt
dst_pixel_size = target_pixel_size

adata = self.adata.copy()

Expand All @@ -326,106 +315,33 @@ def dump_patches(

src_pixel_size = self.pixel_size

# minimum intersection percecentage with the tissue mask to keep a patch
TISSUE_INTER_THRESH = 0.7
TARGET_VIS_SIZE = 1000

scale_factor = target_pixel_size / src_pixel_size
patch_size_pxl = round(target_patch_size * scale_factor)
patch_count = 0
output_datafile = os.path.join(patch_save_dir, name + '.h5')
h5_path = os.path.join(patch_save_dir, name + '.h5')

assert len(adata.obs) == len(adata.obsm['spatial'])

_, ax = plt.subplots()

mode_HE = 'w'
i = 0
img_width, img_height = self.wsi.get_dimensions()
patch_rectangles = [] # lower corner (x, y) + (widht, height)
downscale_vis = TARGET_VIS_SIZE / img_width

if use_mask:
tissue_mask = np.zeros((img_height, img_width, 3), dtype=np.uint8)
tissue_mask = contours_to_img(
self.tissue_contours,
tissue_mask,
draw_contours=False,
line_color=(1, 1, 1)
)[:, :, 0]
else:
tissue_mask = np.ones((img_height, img_width)).astype(np.uint8)

mask_plot = self.get_tissue_vis()

ax.imshow(mask_plot)
for _, row in tqdm(adata.obs.iterrows(), total=len(adata.obs)):

barcode_spot = row.name

xImage = int(adata.obsm['spatial'][i][0])
yImage = int(adata.obsm['spatial'][i][1])

i += 1

if not(0 <= xImage and xImage < img_width and 0 <= yImage and yImage < img_height):
if verbose:
print('Warning, spot is out of the image, skipping')
continue

if not(0 <= yImage - patch_size_pxl // 2 and yImage + patch_size_pxl // 2 < img_height and \
0 <= xImage - patch_size_pxl // 2 and xImage + patch_size_pxl // 2 < img_width):
if verbose:
print('Warning, patch is out of the image, skipping')
continue

## TODO reimplement now that we use the pyramidal level
image_patch = self.wsi.read_region((xImage - patch_size_pxl // 2, yImage - patch_size_pxl // 2), 0, (patch_size_pxl, patch_size_pxl))
rect_x = (xImage - patch_size_pxl // 2) * downscale_vis
rect_y = (yImage - patch_size_pxl // 2) * downscale_vis
rect_width = patch_size_pxl * downscale_vis
rect_height = patch_size_pxl * downscale_vis

image_patch = np.array(image_patch)
if image_patch.shape[2] == 4:
image_patch = image_patch[:, :, :3]


if use_mask:
patch_mask = tissue_mask[yImage - patch_size_pxl // 2: yImage + patch_size_pxl // 2,
xImage - patch_size_pxl // 2: xImage + patch_size_pxl // 2]
patch_area = patch_mask.shape[0] ** 2
pixel_count = patch_mask.sum()

if pixel_count / patch_area < TISSUE_INTER_THRESH:
continue

patch_rectangles.append(matplotlib.patches.Rectangle((rect_x, rect_y), rect_width, rect_height))

patch_count += 1
image_patch = cv2.resize(image_patch, (target_patch_size, target_patch_size), interpolation=cv2.INTER_CUBIC)


# Save ref patches
assert image_patch.shape == (target_patch_size, target_patch_size, 3)
asset_dict = { 'img': np.expand_dims(image_patch, axis=0), # (1 x w x h x 3)
'coords': np.expand_dims([yImage, xImage], axis=0), # (1 x 2)
'barcode': np.expand_dims([barcode_spot], axis=0)
}
patch_size_src = target_patch_size * (dst_pixel_size / src_pixel_size)
coords_center = adata.obsm['spatial']
coords_topleft = coords_center - patch_size_src // 2
len_tmp = len(coords_topleft)
in_slide_mask = (0 <= coords_topleft[:, 0] + patch_size_src) & (coords_topleft[:, 0] < self.wsi.width) & (0 <= coords_topleft[:, 1] + patch_size_src) & (coords_topleft[:, 1] < self.wsi.height)
coords_topleft = coords_topleft[in_slide_mask]
if len(coords_topleft) < len_tmp:
warnings.warn(f"Filtered {len_tmp - len(coords_topleft)} spots outside the WSI")

barcodes = np.array(adata.obs.index)
barcodes = barcodes[in_slide_mask]
mask = self.tissue_contours if use_mask else None
coords_topleft = np.array(coords_topleft).astype(int)
patcher = self.wsi.create_patcher(target_patch_size, src_pixel_size, dst_pixel_size, mask=mask, custom_coords=coords_topleft)

attr_dict = {}
attr_dict['img'] = {'patch_size': patch_size_pxl,
'factor': scale_factor}
if mask is not None:
valid_barcodes = barcodes[patcher.valid_mask]

initsave_hdf5(output_datafile, asset_dict, attr_dict, mode=mode_HE)
mode_HE = 'a'
patcher.to_h5(h5_path, extra_assets={'barcodes': valid_barcodes})


if dump_visualization:
ax.add_collection(PatchCollection(patch_rectangles, facecolor='none', edgecolor='black', linewidth=0.3))
ax.set_axis_off()
plt.tight_layout()
plt.savefig(os.path.join(patch_save_dir, name + '_patch_vis.png'), dpi=400, bbox_inches = 'tight')
patcher.save_visualization(os.path.join(patch_save_dir, name + '_patch_vis.png'), dpi=400)

if verbose:
print(f'found {patch_count} valid patches')
Expand Down Expand Up @@ -516,8 +432,7 @@ def save_tissue_seg_pkl(self, save_dir: str, name: str) -> None:


def get_tissue_vis(self):
return get_tissue_vis(
self.wsi.img,
return self.wsi.get_tissue_vis(
self.tissue_contours,
line_color=(0, 255, 0),
line_thickness=5,
Expand Down Expand Up @@ -773,8 +688,15 @@ def read_HESTData(
tissue_contours = None
tissue_seg = None
if tissue_contours_path is not None:
tissue_contours = TissueContourReader().read_gdf(tissue_contours_path)
tissue_contours['tissue_id'] = tissue_contours['tissue_id'].astype(int)
with open(tissue_contours_path) as f:
lines = f.read()
if 'hole' in lines:
warnings.warn("this type of .geojson tissue contour file is deprecated, please download the new `tissue_seg` folder on huggingface: https://huggingface.co/datasets/MahmoodLab/hest/tree/main")
gdf = TissueContourReader().read_gdf(tissue_contours_path)
tissue_contours = old_geojson_to_new(gdf)
else:
tissue_contours = gpd.read_file(tissue_contours_path)

elif mask_path_pkl is not None and mask_path_jpg is not None:
tissue_seg = load_tissue_mask(mask_path_pkl, mask_path_jpg, width, height)

Expand Down
33 changes: 21 additions & 12 deletions src/hest/LazyShapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,28 @@ def convert_old_to_gpd(contours_holes, contours_tissue) -> gpd.GeoDataFrame:
types = []
for i in range(len(contours_holes)):
tissue = contours_tissue[i]
shapes.append(Polygon(tissue[:, 0, :]))
tissue_ids.append(i)
types.append('tissue')
holes = contours_holes[i]
if len(holes) > 0:
for hole in holes:
shapes.append(Polygon(hole[:, 0, :]))
tissue_ids.append(i)
types.append('hole')

holes = contours_holes[i] if len(contours_holes[i]) > 0 else None
shapes.append(Polygon(tissue[:, 0, :]), holes=holes)

df = pd.DataFrame(tissue_ids, columns=['tissue_id'])
df['hole'] = types
df['hole'] = df['hole'] == 'hole'

return gpd.GeoDataFrame(df, geometry=shapes)



def old_geojson_to_new(gdf):
polygons = []
keys = []
for key, group in gdf.groupby('tissue_id'):
holes = []
for row in group.values:
if row[2]:
holes.append([coord for coord in row[0].exterior.coords])
else:
exterior = [coord for coord in row[0].exterior.coords]
polygons.append(Polygon(exterior, holes))
keys.append(key)

gdf = gpd.GeoDataFrame(geometry=polygons)
gdf['tissue_id'] = keys
return gdf
6 changes: 2 additions & 4 deletions src/hest/io/seg_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,8 @@ def read_gdf(self, path) -> gpd.GeoDataFrame:

class TissueContourReader(GDFReader):

def read_gdf(self, path) -> gpd.GeoDataFrame:

gdf = _read_geojson(path, class_name='tissue_id', index_key='hole')

def read_gdf(self, path) -> gpd.GeoDataFrame:
gdf = _read_geojson(path, 'tissue_id', extra_props=False, index_key='hole')
return gdf


Expand Down
Loading
Loading