Existing methods of survival analysis on multi-scale WSIs still face two major problems: high computational cost and the unnoticed semantical gap in multi-resolution feature fusion. Inspired by modern CNNs, this work proposes to efficiently exploit WSI pyramids from a new perspective, the dual-stream network with cross-attention (DSCA). Our key idea is to utilize two sub-streams to process the WSI patches with two resolutions, where a square pooling is devised in a high-resolution stream to significantly reduce computational costs, and a cross-attention based method is proposed to properly handle the fusion of dual-stream features. Our scheme could be easily extended to multi-branch for multi-resolution WSIs. The experiments on three publicly-available datasets confirm the effectiveness of DSCA in predictive performance and computational cost.
Our experiments are run on a machine with
- linux, Ubuntu 18.04;
- 2 NVIDIA V100s (32G) GPUs, CUDA version 11.6 and cudnn version 10.2;
- python packages: torch (1.9.0+cu111), numpy (1.19.5), pyyaml (6.0), pandas (1.1.5), h5py (3.1.0), scikit-learn (0.24.2), nystrom-attention (0.0.11).
Full requirements are provided at requirements.txt.
As an alternative, you could use our Docker image:
docker pull yuukilp/deepath
Here we show how to run DSCA for cancer prognosis using WSI pyramids. The dataset TCGA-BRCA
will be taken as example.
First of all, you should download the slides from the project TCGA-BRCA
in official TCGA website. The details regarding slide donwload can be found at the first tutorial - Downloading-Slides-from-TCGA and the second tutorial - Reorganizing-Slides-at-Patient-Level. Note that the dataset label of patient survival is available at ./data_split/tcga_brca/tcga_brca_path_full.csv
so you can skip some steps of label data acquisition in the second tutorial.
Then, just like most procedures of WSI analysis, we process each WSI into small patches. A detailed tutorial is given in the third tutorial - Segmenting-and-Patching-Slides. In DSCA, we prepare high- and low-resolution patches. Their details are given as follows.
Specifically, to obtain the region-aligned patches at a high resolution and a low one, we first process each WSI into patches at a low resolution (level = 2
, downsampled 16x
) using the command as follows
# Sample patches of SIZE x SIZE at LEVEL
# NOTE: the following paths, for illustruting, should be replaced by your owns
# Path where CLAM is installed (refer to the third tutorial)
# Root path to pathology images
# Root path to result files
cd ${DIR_REPO}
echo "run seg & patching for all slides"
CUDA_VISIBLE_DEVICES=0 python3 create_patches_fp.py \
--source ${DIR_READ} \
--save_dir ${DIR_SAVE}/tiles-l${LEVEL}-s${SIZE} \
--patch_size ${SIZE} \
--step_size ${SIZE} \
--preset tcga.csv \
--patch_level ${LEVEL} \
--seg --patch --stitch \
--no_auto_skip --in_child_dir
And then we run the script ./tools/big_to_small_patching.py
to obtain aligned high-resolution patches (level = 1
, downsampled 4x
) by directly transforming the coordinates of low-resolution patches, using the following command
python3 big_to_small_patching.py \
/NAS02/ExpData/tcga_brca/tiles-l2-s256 \
Finally, we extract features from the high- and low-resolution image patches. A detailed tutorial is given in the fourth tutorial - Extracting-Patch-Features. Specifically, you could use the following command for the high- and low-resolution patches.
# Sample patches of SIZE x SIZE at LEVEL
LEVEL=2 # for extracting features from the low-resolution patches
#LEVEL=1 # uncomment this line for extracting features from the high-resolution patches
# Path where CLAM is installed
# Root path to pathology images
# Sub-directory to the patch coordinates
# Sub-directory to the patch features
cd ${DIR_REPO}
echo "running for extracting features from all tiles"
CUDA_VISIBLE_DEVICES=0 python3 extract_features_fp.py \
--data_h5_dir ${DIR_EXP_DATA}/${SUBDIR_READ} \
--data_slide_dir ${DIR_RAW_DATA} \
--csv_path ${DIR_EXP_DATA}/${SUBDIR_READ}/process_list_autogen.csv \
--feat_dir ${DIR_EXP_DATA}/${SUBDIR_SAVE} \
--batch_size 128 \
--slide_ext .svs \
--color_norm \
--slide_in_child_dir --no_auto_skip
Your file structure is expected to be appeared as follows:
/NAS02/ExpData/tcga_brca # The results directory of tcga_brca.
├─ feats-l1-s256-RN50-color_norm # The directory of all patch features (level = 1).
│ └─ pt_files
│ ├─ TCGA-E2-A1B4-01Z-00-DX1.E585C4FB-0D3E-4160-8192-53A329648F5C.pt # The patch features of a single slide.
│ ├─ TCGA-B6-A0RQ-01Z-00-DX1.68B4F49D-9F81-4501-BC66-0349012077C8.pt
│ └─ ...
├─ tiles-l1-s256 # The directory of all segmented patch coordinates (level = 1).
│ ├─ patches
│ │ ├─ TCGA-E2-A1B4-01Z-00-DX1.E585C4FB-0D3E-4160-8192-53A329648F5C.h5 # The patch coordinates of slide 10015.
│ │ ├─ TCGA-B6-A0RQ-01Z-00-DX1.68B4F49D-9F81-4501-BC66-0349012077C8.h5
│ │ └─ ...
│ └─ process_list_autogen.csv # csv file recording all processing details (autogeneraed by CLAM).
├─ feats-l2-s256-RN50-color_norm # The directory of all patch features (level = 2).
│ └─ pt_files
│ ├─ TCGA-E2-A1B4-01Z-00-DX1.E585C4FB-0D3E-4160-8192-53A329648F5C.pt # The patch features of a single slide.
│ ├─ TCGA-B6-A0RQ-01Z-00-DX1.68B4F49D-9F81-4501-BC66-0349012077C8.pt
│ └─ ...
└─ tiles-l2-s256 # The directory of all segmented patch coordinates (level = 2).
├─ patches
│ ├─ TCGA-E2-A1B4-01Z-00-DX1.E585C4FB-0D3E-4160-8192-53A329648F5C.h5 # The patch coordinates of slide 10015.
│ ├─ TCGA-B6-A0RQ-01Z-00-DX1.68B4F49D-9F81-4501-BC66-0349012077C8.h5
│ └─ ...
└─ process_list_autogen.csv # csv file recording all processing details (autogeneraed by CLAM).
We use a YAML
file to configure the related paths, the networks, and the training hyper-parameters. An example configuration for training on TCGA_BRCA
is available in ./config/config_hier.yaml
We show some important configurations as follows:
. It means loading a DSCA network.magnification
. It means loading dual-stream patches.dims
. It means the input dimensionality of different layers in DSCA. These layers are input layer, patch feature embedding layer, hidden layer, and output layer, from left and right.
The pseudo-codes of DSCA implementation (simplified for better understanding) are given as follows:
class DSCA(nn.Module):
"""A hierarchical network for WSI with multiple magnitudes.
A typical case of WSI:
level = 0, 1, 2, 3
downsample = 1, 4, 16, 32
Current version utilizes the levels of 1 and 2.
def __init__(self, dims:List, args_x20_emb, args_x5_emb, args_tra_layer, dropout:float=0.25, pool:str='gap', join='post', fusion='cat'):
super(DSCA, self).__init__()
assert len(dims) == 4 # [1024, 384, 384, 1]
assert args_x20_emb.backbone in ['avgpool', 'gapool', 'capool'] # arch of embedding layer in high-resolution stream
assert args_x5_emb.backbone in ['conv1d'] # arch of embedding layer in low-resolution stream
assert args_tra_layer.backbone in ['Nystromformer', 'Transformer'] # Transformer layer
assert pool in ['max', 'mean', 'max_mean', 'gap'] # pooling layer
assert join in ['pre', 'post'] # concat dual-stream embeddings before or after Transformer
assert fusion in ['cat', 'fusion'] # the way of fusing dual-stream features
self.x20_emb_backbone = args_x20_emb.backbone
# dims[0] -> dims[1]: embedding layers
self.patchx20_embedding_layer = make_embedding_layer(args_x20_emb.backbone, args_x20_emb)
self.patchx5_embedding_layer = make_embedding_layer(args_x5_emb.backbone, args_x5_emb)
self.dim_hidden = dims[1]
# dims[1] -> dims[2]: Transformer layers (default using 'post'; implementation of 'pre' is deleted)
self.join, self.fusion = join, fusion
args_tra_layer.d_model = dims[1]
self.patch_encoder_layer = make_transformer_layer(args_tra_layer.backbone, args_tra_layer)
self.patch_encoder_layer_parallel = make_transformer_layer(args_tra_layer.backbone, args_tra_layer)
enc_dim = 2 * dims[2] if fusion == 'cat' else dims[2]
# global attention pooling layer (default using 'gap'; other implementations are deleted)
self.pool = GAPool(enc_dim, enc_dim)
# dims[2] -> dims[3]: output layer
self.out_layer = nn.Sequential(nn.Linear(enc_dim, dims[3]), nn.Sigmoid())
def forward(self, x20, x5, x5_coord=None, mode=None):
"""x5 and x20 must be aligned.
x20: [B, 16N, d], level = 1, downsample = 4
x5: [B, N, d], level = 2, downsample = 16
x5_coord: [B, N, 2], the coordinates after discretization for position encoding, used for the stream x20.
# Patch Embedding of high-resolution patches
if self.x20_emb_backbone == 'capool':
patchx20_emb, x20_x5_cross_attn, _ = self.patchx20_embedding_layer(x20, x5) # [B, 16N, d]->[B, N, d']
patchx20_emb = self.patchx20_embedding_layer(x20) # [B, 16N, d]->[B, N, d']
# Patch Embedding of low-resolution patches
patchx5_emb = self.patchx5_embedding_layer(x5) # [B, N, d]->[B, N, d']
# Adding Position Embedding
if x5_coord is not None:
PEx5 = compute_pe(x5_coord, ndim=self.dim_hidden, device=x20.device, dtype=x20.dtype)
patchx20_emb = patchx20_emb + PEx5
patchx5_emb = patchx5_emb + PEx5.clone()
# Patch Transformer for low- and high-resolution patches
patchx20_feat = self.patch_encoder_layer_parallel(patchx20_emb)
patchx5_feat = self.patch_encoder_layer(patchx5_emb)
if self.fusion == 'cat':
patch_feat = torch.cat([patchx20_feat, patchx5_feat], dim=2) # [B, N, 2d']
patch_feat = patchx20_feat + patchx5_feat # [B, N, d']
# Global Attention Pooling
rep, patch_attn = self.pool(patch_feat) # [B, L*L, d] -> [B, d]
# Output Layer
out = self.out_layer(rep)
return out
Then you can train DSCA using the following command:
python3 main.py --config config/config_hier.yaml --multi_run
Any issues can be sent via E-mail ([email protected]) or be posted on the issue page of this repo.
