Skip to content

Commit

Permalink
fixes #16: added automated merging and saving of segmentations
Browse files Browse the repository at this point in the history
  • Loading branch information
EliHei2 committed Sep 25, 2024
1 parent e3ffffc commit 441434a
Showing 1 changed file with 101 additions and 1 deletion.
102 changes: 101 additions & 1 deletion src/segger/prediction/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from segger.data.io import XeniumSample
from segger.models.segger_model import Segger
from segger.training.train import LitSegger
from segger.training.segger_data_module import SeggerDataModule
from lightning import LightningModule
from torch_geometric.nn import to_hetero
import random
Expand All @@ -28,6 +29,13 @@
import typing
import re
from tqdm import tqdm
from segger.data.utils import create_anndata
import dask.dataframe as dd
import dask
import pandas as pd
from dask import delayed
from typing import Union, Optional
import anndata as ad


# CONFIG
Expand Down Expand Up @@ -268,4 +276,96 @@ def predict(
idx = assignments.groupby('transcript_id')['score'].idxmax()
assignments = assignments.loc[idx].reset_index(drop=True)

return assignments
return assignments


def segment(
model: LitSegger,
dm: SeggerDataModule,
save_dir: Union[str, Path],
seg_tag: str,
transcript_file: Union[str, Path],
file_format: str = 'anndata',
receptive_field: dict = {'k_bd': 4, 'dist_bd': 10, 'k_tx': 5, 'dist_tx': 3},
**anndata_kwargs
) -> None:
"""
Perform segmentation using the model, merge segmentation results with transcripts_df, and save in the specified format.
Parameters:
----------
model : LitSegger
The trained segmentation model.
dm : SeggerDataModule
The SeggerDataModule instance for data loading.
save_dir : Union[str, Path]
Directory to save the final segmentation results.
seg_tag : str
Tag to include in the saved filename.
transcript_file : Union[str, Path]
Path to the transcripts parquet file.
file_format : str, optional
File format to save the results ('csv', 'parquet', or 'anndata'). Defaults to 'anndata'.
**anndata_kwargs : dict, optional
Additional keyword arguments passed to the `create_anndata` function, such as:
- panel_df: pd.DataFrame
- min_transcripts: int
- cell_id_col: str
- qv_threshold: float
- min_cell_area: float
- max_cell_area: float
Returns:
-------
None
"""
# Ensure the save directory exists
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)

# Define delayed prediction steps for parallel execution
delayed_train = delayed(predict)(model, dm.train_dataloader(), score_cut=0.5, receptive_field=receptive_field, use_cc=True)
delayed_val = delayed(predict)(model, dm.val_dataloader(), score_cut=0.5, receptive_field=receptive_field, use_cc=True)
delayed_test = delayed(predict)(model, dm.test_dataloader(), score_cut=0.5, receptive_field=receptive_field, use_cc=True)

# Trigger parallel execution and get results
segmentation_train, segmentation_val, segmentation_test = dask.compute(delayed_train, delayed_val, delayed_test)

# Combine the segmentation results
seg_combined = pd.concat([segmentation_train, segmentation_val, segmentation_test]).reset_index()

# Group by transcript_id and keep the row with the highest score
seg_final = seg_combined.loc[seg_combined.groupby('transcript_id')['score'].idxmax()]

# Drop rows where segger_cell_id is NaN
seg_final = seg_final.dropna(subset=['segger_cell_id'])

# Reset the index
seg_final.reset_index(drop=True, inplace=True)

# Load the transcript data
transcripts_df = dd.read_parquet(transcript_file)

# Convert seg_final to a Dask DataFrame and merge with transcripts
seg_final_dd = dd.from_pandas(seg_final, npartitions=transcripts_df.npartitions)
transcripts_df_filtered = transcripts_df.merge(seg_final_dd, on='transcript_id', how='inner')

# Compute the final result
transcripts_df_filtered = transcripts_df_filtered.compute()

# Save the merged result based on the file format
if file_format == 'csv':
save_path = save_dir / f'{seg_tag}_segmentation.csv'
transcripts_df_filtered.to_csv(save_path, index=False)
elif file_format == 'parquet':
save_path = save_dir / f'{seg_tag}_segmentation.parquet'
transcripts_df_filtered.to_parquet(save_path, index=False)
elif file_format == 'anndata':
# Create an AnnData object and save as h5ad, passing additional arguments from kwargs
save_path = save_dir / f'{seg_tag}_segmentation.h5ad'
segger_adata = create_anndata(transcripts_df_filtered, **anndata_kwargs)
segger_adata.write(save_path)
else:
raise ValueError(f"Unsupported file format: {file_format}")

print(f"Segmentation results saved at {save_path}")

0 comments on commit 441434a

Please sign in to comment.