Skip to content

Commit

Permalink
Merge pull request #7 from dpeerlab/umain
Browse files Browse the repository at this point in the history
Checkup for segger 🩹 and tutorial notebook - andrew's savior magic
  • Loading branch information
EliHei2 authored Sep 12, 2024
2 parents 8f3f54b + 399d421 commit ec5fb2a
Show file tree
Hide file tree
Showing 12 changed files with 1,065 additions and 111 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,6 @@ data_*
model_*

*.egg_info
figure*
figure*

dev*
782 changes: 782 additions & 0 deletions docs/notebooks/segger_tutorial.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- defaults
dependencies:
- python=3.10
- torch>=2.0.0
- pytorch>=2.0.0
- numpy>=1.21.0
- pandas>=1.3.0
- scipy>=1.7.0
Expand All @@ -15,7 +15,7 @@ dependencies:
- lightning>=1.9.0
- torchmetrics>=0.5.0
- scanpy>=1.9.3
- squidpy=1.2.0
- squidpy>=1.2.0
- adjustText>=0.8
- scikit-learn>=0.24.0
- geopandas>=0.9.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"shapely>=1.7.0",
"path>=17.0.0",
"pyarrow>=17.0.0"
"torch-geometric>=2.2.0"
]

[project.optional-dependencies]
Expand All @@ -48,7 +49,6 @@ torch-geometric = [
"torch-scatter>=2.1.2",
"torch-sparse>=0.6.18",
"torch-cluster>=1.6.3",
"torch-geometric>=2.2.0"
]
multiprocessing = ["multiprocessing"]
dev = [
Expand Down
2 changes: 1 addition & 1 deletion src/segger/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"compute_transcript_metrics",
"SpatialTranscriptomicsSample",
"calculate_gene_celltype_abundance_embedding",
"get_edge_index"
"get_edge_index",
]

from .utils import (
Expand Down
154 changes: 94 additions & 60 deletions src/segger/data/io.py

Large diffs are not rendered by default.

71 changes: 70 additions & 1 deletion src/segger/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def try_import(module_name):
from shapely.geometry import Polygon
from shapely.affinity import scale
import dask.dataframe as dd
from pyarrow import parquet as pq
import sys

# Attempt to import specific modules with try_import function
try_import('multiprocessing')
Expand Down Expand Up @@ -521,7 +523,6 @@ def __init__(self, root: str, transform: Callable = None, pre_transform: Callabl
pre_filter (callable, optional): A function that takes in a Data object and returns a boolean indicating whether to keep it. Defaults to None.
"""
super().__init__(root, transform, pre_transform, pre_filter)
os.makedirs(os.path.join(self.processed_dir, 'raw'), exist_ok=True)

@property
def raw_file_names(self) -> List[str]:
Expand Down Expand Up @@ -573,3 +574,71 @@ def get(self, idx: int) -> Data:
return data


def get_xy_extents(
filepath,
x: str,
y: str,
) -> Tuple[int]:
"""
Get the bounding box of the x and y coordinates from a Parquet file.
Parameters
----------
filepath : str
The path to the Parquet file.
x : str
The name of the column representing the x-coordinate.
y : str
The name of the column representing the y-coordinate.
Returns
-------
shapely.Polygon
A polygon representing the bounding box of the x and y coordinates.
"""
# Get index of columns of parquet file
metadata = pq.read_metadata(filepath)
schema_idx = dict(map(reversed, enumerate(metadata.schema.names)))

# Find min and max values across all row groups
x_max = -1
x_min = sys.maxsize
y_max = -1
y_min = sys.maxsize
for i in range(metadata.num_row_groups):
group = metadata.row_group(i)
x_min = min(x_min, group.column(schema_idx[x]).statistics.min)
x_max = max(x_max, group.column(schema_idx[x]).statistics.max)
y_min = min(y_min, group.column(schema_idx[y]).statistics.min)
y_max = max(y_max, group.column(schema_idx[y]).statistics.max)
return x_min, y_min, x_max, y_max


def coo_to_dense_adj(
edge_index: torch.Tensor,
num_nodes: Optional[int] = None,
num_nbrs: Optional[int] = None,
) -> torch.Tensor:

# Check COO format
if not edge_index.shape[0] == 2:
msg = (
"Edge index is not in COO format. First dimension should have "
f"size 2, but found {edge_index.shape[0]}."
)
raise ValueError(msg)

# Get split points
uniques, counts = torch.unique(edge_index[0], return_counts=True)
if num_nodes is None:
num_nodes = uniques.max() + 1
if num_nbrs is None:
num_nbrs = counts.max()
counts = tuple(counts.cpu().tolist())

# Fill matrix with neighbors
nbr_idx = torch.full((num_nodes, num_nbrs), -1)
for i, nbrs in zip(uniques, torch.split(edge_index[1], counts)):
nbr_idx[i, :len(nbrs)] = nbrs

return nbr_idx
4 changes: 2 additions & 2 deletions src/segger/models/segger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.nn import Embedding
from torch import Tensor
from typing import Union
from torch_sparse import SparseTensor
#from torch_sparse import SparseTensor

class Segger(torch.nn.Module):
def __init__(self, num_tx_tokens: int, init_emb: int = 16, hidden_channels: int = 32, num_mid_layers: int = 3, out_channels: int = 32, heads: int = 3):
Expand Down Expand Up @@ -78,7 +78,7 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
return x


def decode(self, z: Tensor, edge_index: Union[Tensor, SparseTensor]) -> Tensor:
def decode(self, z: Tensor, edge_index: Union[Tensor]) -> Tensor:
"""
Decode the node embeddings to predict edge values.
Expand Down
84 changes: 55 additions & 29 deletions src/segger/prediction/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from torchmetrics import F1Score
from scipy.sparse.csgraph import connected_components as cc

from segger.data.utils import SpatialTranscriptomicsDataset
from segger.data.utils import (
SpatialTranscriptomicsDataset,
get_edge_index,
coo_to_dense_adj,
)
from segger.data.io import XeniumSample
from segger.models.segger_model import Segger
from segger.training.train import LitSegger
Expand All @@ -31,7 +35,7 @@
os.environ["PYTORCH_USE_CUDA_DSA"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

def load_model(checkpoint_path: str, init_emb: int, hidden_channels: int, out_channels: int, heads: int, aggr: str) -> LitSegger:
def load_model(checkpoint_path: str) -> LitSegger:
"""
Load a LitSegger model from a checkpoint.
Expand Down Expand Up @@ -122,9 +126,9 @@ def get_similarity_scores(

# Neighbor-filtered similarity scores
shape = batch[from_type].x.shape[0], batch[to_type].x.shape[0]
indices = torch.argwhere(nbr_idx != shape[1]).T
indices[1] = nbr_idx[nbr_idx != shape[1]]
values = similarity.to_sparse().values()
indices = torch.argwhere(nbr_idx != -1).T
indices[1] = nbr_idx[nbr_idx != -1]
values = similarity[nbr_idx != -1].flatten()
sparse_sim = torch.sparse_coo_tensor(indices, values, shape)

# Return in dense format for backwards compatibility
Expand All @@ -137,6 +141,7 @@ def predict_batch(
lit_segger: LitSegger,
batch: Batch,
score_cut: float,
receptive_field: dict,
use_cc: bool = True,
) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -170,29 +175,47 @@ def _get_id():

# Assignments of cells to nuclei
assignments = pd.DataFrame()
assignments['transcript_id'] = batch['tx'].id[0].flatten()

# Transcript-cell similarity scores, filtered by neighbors
scores = get_similarity_scores(lit_segger.model, batch, "tx", "nc")

# 1. Get direct assignments from similarity matrix
belongs = scores.max(1)
assignments['score'] = belongs.values.cpu()
mask = assignments['score'] > score_cut
all_ids = batch['nc'].id[0].flatten()[belongs.indices]
assignments.loc[mask, 'segger_cell_id'] = all_ids[mask]

if use_cc:
# Transcript-transcript similarity scores, filtered by neighbors
scores = get_similarity_scores(lit_segger.model, batch, "tx", "tx")
scores = scores.fill_diagonal_(0) # ignore self-similarity

# 2. Assign remainder using connected components
no_id = assignments['segger_cell_id'].isna().values
no_id_scores = scores[no_id][:, no_id]
n, comps = cc(no_id_scores, connection="weak", directed=False)
new_ids = np.array([_get_id() for _ in range(n)])
assignments.loc[no_id, 'segger_cell_id'] = new_ids[comps]
assignments['transcript_id'] = batch['tx'].id.cpu().numpy()

if len(batch['bd'].id[0]) > 0:
# Transcript-cell similarity scores, filtered by neighbors
edge_index = get_edge_index(
batch['bd'].pos[:, :2].cpu(),
batch['tx'].pos[:, :2].cpu(),
k=receptive_field['k_bd'],
dist=receptive_field['dist_bd'],
method='kd_tree',
).T
batch['tx']['bd_field'] = coo_to_dense_adj(
edge_index,
num_nodes=batch['tx'].id.shape[0],
num_nbrs=receptive_field['k_bd'],
)
scores = get_similarity_scores(lit_segger.model, batch, "tx", "bd")
# 1. Get direct assignments from similarity matrix
belongs = scores.max(1)
assignments['score'] = belongs.values.cpu()
mask = assignments['score'] > score_cut
all_ids = np.concatenate(batch['bd'].id)[belongs.indices.cpu()]
assignments.loc[mask, 'segger_cell_id'] = all_ids[mask]

if use_cc:
# Transcript-transcript similarity scores, filtered by neighbors
edge_index = batch['tx', 'neighbors', 'tx'].edge_index
batch['tx']['tx_field'] = coo_to_dense_adj(
edge_index,
num_nodes=batch['tx'].id.shape[0],
)
scores = get_similarity_scores(lit_segger.model, batch, "tx", "tx")
scores = scores.fill_diagonal_(0) # ignore self-similarity

# 2. Assign remainder using connected components
no_id = assignments['segger_cell_id'].isna().values
no_id_scores = scores[no_id][:, no_id]
print('here')
n, comps = cc(no_id_scores, connection="weak", directed=False)
new_ids = np.array([_get_id() for _ in range(n)])
assignments.loc[no_id, 'segger_cell_id'] = new_ids[comps]

return assignments

Expand All @@ -201,6 +224,7 @@ def predict(
lit_segger: LitSegger,
data_loader: DataLoader,
score_cut: float,
receptive_field: dict,
use_cc: bool = True,
) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -232,7 +256,9 @@ def predict(
# Assign transcripts from each batch to nuclei
# TODO: parallelize this step
for batch in tqdm(data_loader):
batch_assignments = predict_batch(lit_segger, batch, score_cut, use_cc)
batch_assignments = predict_batch(
lit_segger, batch, score_cut, receptive_field, use_cc
)
assignments.append(batch_assignments)

# Join across batches and handle duplicates between batches
Expand Down
49 changes: 49 additions & 0 deletions src/segger/training/segger_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from pytorch_lightning import LightningDataModule
from torch_geometric.loader import DataLoader
import os
from pathlib import Path
from segger.data.io import SpatialTranscriptomicsDataset


# TODO: Add documentation
class SeggerDataModule(LightningDataModule):

def __init__(
self,
data_dir: os.PathLike,
batch_size: int = 4,
num_workers: int = 1,
):
super().__init__()
self.data_dir = Path(data_dir)
self.batch_size = batch_size
self.num_workers = num_workers

# TODO: Add documentation
def setup(self, stage=None):
self.train = SpatialTranscriptomicsDataset(
root=self.data_dir / 'train_tiles'
)
self.test = SpatialTranscriptomicsDataset(
root=self.data_dir / 'test_tiles'
)
self.val = SpatialTranscriptomicsDataset(
root=self.data_dir / 'val_tiles'
)
self.loader_kwargs = dict(
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
)

# TODO: Add documentation
def train_dataloader(self):
return DataLoader(self.train, shuffle=True, **self.loader_kwargs)

# TODO: Add documentation
def test_dataloader(self):
return DataLoader(self.test, shuffle=False, **self.loader_kwargs)

# TODO: Add documentation
def val_dataloader(self):
return DataLoader(self.val, shuffle=False, **self.loader_kwargs)
9 changes: 6 additions & 3 deletions src/segger/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import inspect


class LitSegger(L.LightningModule):
class LitSegger(LightningModule):
"""
LitSegger is a PyTorch Lightning module for training and validating the Segger model.
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, **kwargs):
self.validation_step_outputs = []
self.criterion = torch.nn.BCEWithLogitsLoss()

def from_new(self, num_tx_tokens: int, init_emb: int, hidden_channels: int, out_channels: int, heads: int, aggr: str, metadata: Union[Tuple, Metadata]):
def from_new(self, num_tx_tokens: int, init_emb: int, hidden_channels: int, out_channels: int, heads: int, num_mid_layers: int, aggr: str, metadata: Union[Tuple, Metadata]):
"""
Initializes the LitSegger module with new parameters.
Expand All @@ -78,6 +78,8 @@ def from_new(self, num_tx_tokens: int, init_emb: int, hidden_channels: int, out_
Number of attention heads.
aggr : str
Aggregation method for heterogeneous graph conversion.
num_mid_layers: int
Number of hidden layers (excluding first and last layers).
metadata : Union[Tuple, Metadata]
Metadata for heterogeneous graph structure.
"""
Expand All @@ -87,7 +89,8 @@ def from_new(self, num_tx_tokens: int, init_emb: int, hidden_channels: int, out_
init_emb=init_emb,
hidden_channels=hidden_channels,
out_channels=out_channels,
heads=heads
heads=heads,
num_mid_layers=num_mid_layers,
)
# Convert model to handle heterogeneous graphs
model = to_hetero(model, metadata=metadata, aggr=aggr)
Expand Down
Loading

0 comments on commit ec5fb2a

Please sign in to comment.