-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
125 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,60 +1,73 @@ | ||
import click | ||
import typing | ||
import os | ||
from segger.cli.utils import add_options, CustomFormatter | ||
from segger.training.segger_data_module import SeggerDataModule | ||
from segger.prediction.predict import segment, load_model | ||
from pathlib import Path | ||
import logging | ||
import os | ||
|
||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | ||
|
||
predict_yml = Path(__file__).parent / 'configs' / 'predict' / 'default.yaml' | ||
|
||
|
||
@click.command(name="predict", help="Predict using the Segger model") | ||
#@click.option('--foo', default="bar") # add more options above, not below | ||
@add_options(config_path=predict_yml) | ||
def predict(args): | ||
|
||
# Setup | ||
ch = logging.StreamHandler() | ||
ch.setLevel(logging.INFO) | ||
ch.setFormatter(CustomFormatter()) | ||
logging.basicConfig(level=logging.INFO, handlers=[ch]) | ||
|
||
# Import packages | ||
logging.info("Importing packages...") | ||
from segger.data.utils import XeniumDataset | ||
from torch_geometric.loader import DataLoader | ||
from segger.prediction.predict import load_model, predict | ||
logging.info("Done.") | ||
@click.command(name="run_segmentation", help="Run the Segger segmentation model.") | ||
@click.option('--segger_data_dir', type=Path, required=True, help='Directory containing the processed Segger dataset.') | ||
@click.option('--models_dir', type=Path, required=True, help='Directory containing the trained models.') | ||
@click.option('--benchmarks_dir', type=Path, required=True, help='Directory to save the segmentation results.') | ||
@click.option('--transcripts_file', type=str, required=True, help='Path to the transcripts file.') | ||
@click.option('--batch_size', type=int, default=1, help='Batch size for processing.') | ||
@click.option('--num_workers', type=int, default=1, help='Number of workers for data loading.') | ||
@click.option('--model_version', type=int, default=0, help='Model version to load.') | ||
@click.option('--save_tag', type=str, default='segger_embedding_1001_0.5', help='Tag for saving segmentation results.') | ||
@click.option('--min_transcripts', type=int, default=5, help='Minimum number of transcripts for segmentation.') | ||
@click.option('--cell_id_col', type=str, default='segger_cell_id', help='Column name for cell IDs.') | ||
@click.option('--use_cc', is_flag=True, default=False, help='Use connected components if specified.') | ||
@click.option('--knn_method', type=str, default='cuda', help='Method for KNN computation.') | ||
@click.option('--file_format', type=str, default='anndata', help='File format for output data.') | ||
@click.option('--k_bd', type=int, default=4, help='K value for boundary computation.') | ||
@click.option('--dist_bd', type=int, default=12, help='Distance for boundary computation.') | ||
@click.option('--k_tx', type=int, default=5, help='K value for transcript computation.') | ||
@click.option('--dist_tx', type=int, default=5, help='Distance for transcript computation.') | ||
def run_segmentation(segger_data_dir: Path, models_dir: Path, benchmarks_dir: Path, | ||
transcripts_file: str, batch_size: int = 1, num_workers: int = 1, | ||
model_version: int = 0, save_tag: str = 'segger_embedding_1001_0.5', | ||
min_transcripts: int = 5, cell_id_col: str = 'segger_cell_id', | ||
use_cc: bool = False, knn_method: str = 'cuda', | ||
file_format: str = 'anndata', k_bd: int = 4, dist_bd: int = 12, | ||
k_tx: int = 5, dist_tx: int = 5): | ||
|
||
# Setup logging | ||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
# Load datasets and model | ||
logging.info("Loading Xenium datasets and Segger model...") | ||
dataset = XeniumDataset(args.dataset_path) | ||
data_loader = DataLoader( | ||
dataset, | ||
batch_size=args.batch_size, | ||
num_workers=args.workers, | ||
pin_memory=True, | ||
shuffle=False, | ||
logger.info("Initializing Segger data module...") | ||
# Initialize the Lightning data module | ||
dm = SeggerDataModule( | ||
data_dir=segger_data_dir, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
) | ||
if len(data_loader) == 0: | ||
msg = f"Nothing to predict: No data found at '{args.dataset_path}'." | ||
logging.warning(msg) | ||
return | ||
lit_segger = load_model(args.checkpoint_path) | ||
logging.info("Done.") | ||
|
||
dm.setup() | ||
|
||
logger.info("Loading the model...") | ||
# Load in the latest checkpoint | ||
model_path = models_dir / 'lightning_logs' / f'version_{model_version}' | ||
model = load_model(model_path / 'checkpoints') | ||
|
||
# Make prediction on dataset | ||
logging.info("Making predictions on data") | ||
predictions = predict( | ||
lit_segger=lit_segger, | ||
data_loader=data_loader, | ||
score_cut=args.score_cut, | ||
use_cc=args.use_cc, | ||
logger.info("Running segmentation...") | ||
segment( | ||
model, | ||
dm, | ||
save_dir=benchmarks_dir, | ||
seg_tag=save_tag, | ||
transcript_file=transcripts_file, | ||
file_format=file_format, | ||
receptive_field={'k_bd': k_bd, 'dist_bd': dist_bd, 'k_tx': k_tx, 'dist_tx': dist_tx}, | ||
min_transcripts=min_transcripts, | ||
cell_id_col=cell_id_col, | ||
use_cc=use_cc, | ||
knn_method=knn_method, | ||
) | ||
logging.info("Done.") | ||
|
||
logger.info("Segmentation completed.") | ||
|
||
# Write predictions to file | ||
logging.info("Saving predictions to file") | ||
predictions.to_csv(args.output_path, index=False) | ||
logging.info("Done.") | ||
if __name__ == '__main__': | ||
run_segmentation() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters