Skip to content

Commit

Permalink
updated CLI and fixed #19
Browse files Browse the repository at this point in the history
  • Loading branch information
EliHei2 committed Oct 7, 2024
1 parent 060efd2 commit e1a7b05
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 97 deletions.
4 changes: 2 additions & 2 deletions scripts/predict_model_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@
model,
dm,
save_dir=benchmarks_dir,
seg_tag='segger_embedding_1001_cc_true',
seg_tag='segger_embedding_1001_0.5',
transcript_file=transcripts_file,
file_format='anndata',
receptive_field = receptive_field,
min_transcripts=5,
# max_transcripts=1500,
cell_id_col='segger_cell_id',
use_cc=True,
use_cc=False,
knn_method='cuda'
)
113 changes: 63 additions & 50 deletions src/segger/cli/predict.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,73 @@
import click
import typing
import os
from segger.cli.utils import add_options, CustomFormatter
from segger.training.segger_data_module import SeggerDataModule
from segger.prediction.predict import segment, load_model
from pathlib import Path
import logging
import os

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

predict_yml = Path(__file__).parent / 'configs' / 'predict' / 'default.yaml'


@click.command(name="predict", help="Predict using the Segger model")
#@click.option('--foo', default="bar") # add more options above, not below
@add_options(config_path=predict_yml)
def predict(args):

# Setup
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(CustomFormatter())
logging.basicConfig(level=logging.INFO, handlers=[ch])

# Import packages
logging.info("Importing packages...")
from segger.data.utils import XeniumDataset
from torch_geometric.loader import DataLoader
from segger.prediction.predict import load_model, predict
logging.info("Done.")
@click.command(name="run_segmentation", help="Run the Segger segmentation model.")
@click.option('--segger_data_dir', type=Path, required=True, help='Directory containing the processed Segger dataset.')
@click.option('--models_dir', type=Path, required=True, help='Directory containing the trained models.')
@click.option('--benchmarks_dir', type=Path, required=True, help='Directory to save the segmentation results.')
@click.option('--transcripts_file', type=str, required=True, help='Path to the transcripts file.')
@click.option('--batch_size', type=int, default=1, help='Batch size for processing.')
@click.option('--num_workers', type=int, default=1, help='Number of workers for data loading.')
@click.option('--model_version', type=int, default=0, help='Model version to load.')
@click.option('--save_tag', type=str, default='segger_embedding_1001_0.5', help='Tag for saving segmentation results.')
@click.option('--min_transcripts', type=int, default=5, help='Minimum number of transcripts for segmentation.')
@click.option('--cell_id_col', type=str, default='segger_cell_id', help='Column name for cell IDs.')
@click.option('--use_cc', is_flag=True, default=False, help='Use connected components if specified.')
@click.option('--knn_method', type=str, default='cuda', help='Method for KNN computation.')
@click.option('--file_format', type=str, default='anndata', help='File format for output data.')
@click.option('--k_bd', type=int, default=4, help='K value for boundary computation.')
@click.option('--dist_bd', type=int, default=12, help='Distance for boundary computation.')
@click.option('--k_tx', type=int, default=5, help='K value for transcript computation.')
@click.option('--dist_tx', type=int, default=5, help='Distance for transcript computation.')
def run_segmentation(segger_data_dir: Path, models_dir: Path, benchmarks_dir: Path,
transcripts_file: str, batch_size: int = 1, num_workers: int = 1,
model_version: int = 0, save_tag: str = 'segger_embedding_1001_0.5',
min_transcripts: int = 5, cell_id_col: str = 'segger_cell_id',
use_cc: bool = False, knn_method: str = 'cuda',
file_format: str = 'anndata', k_bd: int = 4, dist_bd: int = 12,
k_tx: int = 5, dist_tx: int = 5):

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load datasets and model
logging.info("Loading Xenium datasets and Segger model...")
dataset = XeniumDataset(args.dataset_path)
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.workers,
pin_memory=True,
shuffle=False,
logger.info("Initializing Segger data module...")
# Initialize the Lightning data module
dm = SeggerDataModule(
data_dir=segger_data_dir,
batch_size=batch_size,
num_workers=num_workers,
)
if len(data_loader) == 0:
msg = f"Nothing to predict: No data found at '{args.dataset_path}'."
logging.warning(msg)
return
lit_segger = load_model(args.checkpoint_path)
logging.info("Done.")

dm.setup()

logger.info("Loading the model...")
# Load in the latest checkpoint
model_path = models_dir / 'lightning_logs' / f'version_{model_version}'
model = load_model(model_path / 'checkpoints')

# Make prediction on dataset
logging.info("Making predictions on data")
predictions = predict(
lit_segger=lit_segger,
data_loader=data_loader,
score_cut=args.score_cut,
use_cc=args.use_cc,
logger.info("Running segmentation...")
segment(
model,
dm,
save_dir=benchmarks_dir,
seg_tag=save_tag,
transcript_file=transcripts_file,
file_format=file_format,
receptive_field={'k_bd': k_bd, 'dist_bd': dist_bd, 'k_tx': k_tx, 'dist_tx': dist_tx},
min_transcripts=min_transcripts,
cell_id_col=cell_id_col,
use_cc=use_cc,
knn_method=knn_method,
)
logging.info("Done.")

logger.info("Segmentation completed.")

# Write predictions to file
logging.info("Saving predictions to file")
predictions.to_csv(args.output_path, index=False)
logging.info("Done.")
if __name__ == '__main__':
run_segmentation()
103 changes: 59 additions & 44 deletions src/segger/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,88 +5,103 @@
from pathlib import Path
import logging

help_msg = "Train the Segger segmentation model."

def train_model(args):

# Setup
@click.command(name="train_model", help=help_msg)
@add_options(config_path=train_yml)
@click.option('--dataset_dir', type=Path, required=True, help='Directory containing the processed Segger dataset.')
@click.option('--models_dir', type=Path, required=True, help='Directory to save the trained model and the training logs.')
@click.option('--sample_tag', type=str, required=True, help='Sample tag for the dataset.')
@click.option('--init_emb', type=int, default=8, help='Size of the embedding layer.')
@click.option('--hidden_channels', type=int, default=32, help='Size of hidden channels in the model.')
@click.option('--num_tx_tokens', type=int, default=500, help='Number of transcript tokens.')
@click.option('--out_channels', type=int, default=8, help='Number of output channels.')
@click.option('--heads', type=int, default=2, help='Number of attention heads.')
@click.option('--num_mid_layers', type=int, default=2, help='Number of mid layers in the model.')
@click.option('--batch_size', type=int, default=4, help='Batch size for training.')
@click.option('--num_workers', type=int, default=2, help='Number of workers for data loading.')
@click.option('--accelerator', type=str, default='cuda', help='Device type to use for training (e.g., "cuda", "cpu").') # Ask for accelerator
@click.option('--max_epochs', type=int, default=200, help='Number of epochs for training.')
@click.option('--devices', type=int, default=4, help='Number of devices (GPUs) to use.')
@click.option('--strategy', type=str, default='auto', help='Training strategy for the trainer.')
@click.option('--precision', type=str, default='16-mixed', help='Precision for training.')
def train_model(dataset_dir: Path, models_dir: Path, sample_tag: str,
init_emb: int = 8, hidden_channels: int = 32, num_tx_tokens: int = 500,
out_channels: int = 8, heads: int = 2, num_mid_layers: int = 2,
batch_size: int = 4, num_workers: int = 2,
accelerator: str = 'cuda', max_epochs: int = 200,
devices: int = 4, strategy: str = 'auto', precision: str = '16-mixed'):

# Setup logging
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(CustomFormatter())
logging.basicConfig(level=logging.INFO, handlers=[ch])

# Import packages
logging.info("Importing packages...")
from segger.data.utils import XeniumDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import to_hetero
from segger.data.io import XeniumSample
from segger.training.train import LitSegger
from lightning import Trainer
from segger.training.segger_data_module import SeggerDataModule
from lightning.pytorch.loggers import CSVLogger
from pytorch_lightning import Trainer
logging.info("Done.")

# Load datasets
logging.info("Loading Xenium datasets...")
trn_ds = XeniumDataset(root=Path(args.data_dir) / 'train_tiles')
val_ds = XeniumDataset(root=Path(args.data_dir) / 'val_tiles')
kwargs = dict(
num_workers=0,
pin_memory=True,
)
trn_loader = DataLoader(
trn_ds, batch_size=args.batch_size_train, shuffle=True, **kwargs
)
val_loader = DataLoader(
val_ds, batch_size=args.batch_size_val, shuffle=False, **kwargs
dm = SeggerDataModule(
data_dir=dataset_dir,
batch_size=batch_size, # Hard-coded batch size
num_workers=num_workers, # Hard-coded number of workers
)

dm.setup()
logging.info("Done.")

# Initialize model
logging.info("Initializing Segger model and trainer...")
metadata = (
["tx", "nc"], [("tx", "belongs", "nc"), ("tx", "neighbors", "tx")]
)
lit_segger = LitSegger(
init_emb=args.init_emb,
hidden_channels=args.hidden_channels,
out_channels=args.out_channels,
heads=args.heads,
aggr=args.aggr,
metadata = (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")])
ls = LitSegger(
num_tx_tokens=num_tx_tokens,
init_emb=init_emb,
hidden_channels=hidden_channels,
out_channels=out_channels, # Hard-coded value
heads=heads, # Hard-coded value
num_mid_layers=num_mid_layers, # Hard-coded value
aggr='sum', # Hard-coded value
metadata=metadata,
)

# Initialize lightning trainer
# Initialize the Lightning trainer
trainer = Trainer(
accelerator=args.accelerator,
strategy=args.strategy,
precision=args.precision,
devices=args.devices,
max_epochs=args.epochs,
default_root_dir=args.model_dir,
accelerator=accelerator, # Directly use the specified accelerator
strategy=strategy, # Hard-coded value
precision=precision, # Hard-coded value
devices=devices, # Hard-coded value
max_epochs=max_epochs, # Hard-coded value
default_root_dir=models_dir,
logger=CSVLogger(models_dir),
)

logging.info("Done.")

# Train model
logging.info("Training model...")
trainer.fit(
model=lit_segger,
train_dataloaders=trn_loader,
val_dataloaders=val_loader,
model=ls,
datamodule=dm
)
logging.info("Done...")

logging.info("Done.")

train_yml = Path(__file__).parent / 'configs' / 'train' / 'default.yaml'


@click.command(name="slurm", help="Train on Slurm cluster")
#@click.option('--foo', default="bar") # add more options above, not below
@add_options(config_path=train_yml)
def train_slurm(args):
train_model(args)


@click.group(help="Train the Segger model")
def train():
pass


train.add_command(train_slurm)
train.add_command(train_slurm)
2 changes: 1 addition & 1 deletion src/segger/prediction/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def segment(
save_dir: Union[str, Path],
seg_tag: str,
transcript_file: Union[str, Path],
score_cut: float = .25,
score_cut: float = .5,
use_cc: bool = True,
file_format: str = 'anndata',
receptive_field: dict = {'k_bd': 4, 'dist_bd': 10, 'k_tx': 5, 'dist_tx': 3},
Expand Down

0 comments on commit e1a7b05

Please sign in to comment.