diff --git a/src/segger/cli/configs/train/default.yaml b/src/segger/cli/configs/train/default.yaml new file mode 100644 index 0000000..b685eac --- /dev/null +++ b/src/segger/cli/configs/train/default.yaml @@ -0,0 +1,64 @@ +dataset_dir: + type: Path + required: true + help: Directory containing the processed Segger dataset. +models_dir: + type: Path + required: true + help: Directory to save the trained model and the training logs. +sample_tag: + type: str + required: true + help: Sample tag for the dataset. +init_emb: + type: int + default: 8 + help: Size of the embedding layer. +hidden_channels: + type: int + default: 32 + help: Size of hidden channels in the model. +num_tx_tokens: + type: int + default: 500 + help: Number of transcript tokens. +out_channels: + type: int + default: 8 + help: Number of output channels. +heads: + type: int + default: 2 + help: Number of attention heads. +num_mid_layers: + type: int + default: 2 + help: Number of mid layers in the model. +batch_size: + type: int + default: 4 + help: Batch size for training. +num_workers: + type: int + default: 2 + help: Number of workers for data loading. +accelerator: + type: str + default: 'cuda' + help: Device type to use for training (e.g., "cuda", "cpu"). +max_epochs: + type: int + default: 200 + help: Number of epochs for training. +devices: + type: int + default: 4 + help: Number of devices (GPUs) to use. +strategy: + type: str + default: 'auto' + help: Training strategy for the trainer. +precision: + type: str + default: '16-mixed' + help: Precision for training. diff --git a/src/segger/cli/train_model.py b/src/segger/cli/train_model.py index d2edde5..78bd5d7 100644 --- a/src/segger/cli/train_model.py +++ b/src/segger/cli/train_model.py @@ -4,9 +4,12 @@ from segger.cli.utils import add_options, CustomFormatter from pathlib import Path import logging +from argparse import Namespace -help_msg = "Train the Segger segmentation model." +# Path to default YAML configuration file +train_yml = Path(__file__).parent / 'configs' / 'train' / 'default.yaml' +help_msg = "Train the Segger segmentation model." @click.command(name="train_model", help=help_msg) @add_options(config_path=train_yml) @click.option('--dataset_dir', type=Path, required=True, help='Directory containing the processed Segger dataset.') @@ -25,12 +28,7 @@ @click.option('--devices', type=int, default=4, help='Number of devices (GPUs) to use.') @click.option('--strategy', type=str, default='auto', help='Training strategy for the trainer.') @click.option('--precision', type=str, default='16-mixed', help='Precision for training.') -def train_model(dataset_dir: Path, models_dir: Path, sample_tag: str, - init_emb: int = 8, hidden_channels: int = 32, num_tx_tokens: int = 500, - out_channels: int = 8, heads: int = 2, num_mid_layers: int = 2, - batch_size: int = 4, num_workers: int = 2, - accelerator: str = 'cuda', max_epochs: int = 200, - devices: int = 4, strategy: str = 'auto', precision: str = '16-mixed'): +def train_model(args: Namespace): # Setup logging ch = logging.StreamHandler() @@ -50,9 +48,9 @@ def train_model(dataset_dir: Path, models_dir: Path, sample_tag: str, # Load datasets logging.info("Loading Xenium datasets...") dm = SeggerDataModule( - data_dir=dataset_dir, - batch_size=batch_size, # Hard-coded batch size - num_workers=num_workers, # Hard-coded number of workers + data_dir=args.dataset_dir, + batch_size=args.batch_size, # Hard-coded batch size + num_workers=args.num_workers, # Hard-coded number of workers ) dm.setup() @@ -62,25 +60,25 @@ def train_model(dataset_dir: Path, models_dir: Path, sample_tag: str, logging.info("Initializing Segger model and trainer...") metadata = (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]) ls = LitSegger( - num_tx_tokens=num_tx_tokens, - init_emb=init_emb, - hidden_channels=hidden_channels, - out_channels=out_channels, # Hard-coded value - heads=heads, # Hard-coded value - num_mid_layers=num_mid_layers, # Hard-coded value + num_tx_tokens=args.num_tx_tokens, + init_emb=args.init_emb, + hidden_channels=args.hidden_channels, + out_channels=args.out_channels, # Hard-coded value + heads=args.heads, # Hard-coded value + num_mid_layers=args.num_mid_layers, # Hard-coded value aggr='sum', # Hard-coded value metadata=metadata, ) # Initialize the Lightning trainer trainer = Trainer( - accelerator=accelerator, # Directly use the specified accelerator - strategy=strategy, # Hard-coded value - precision=precision, # Hard-coded value - devices=devices, # Hard-coded value - max_epochs=max_epochs, # Hard-coded value - default_root_dir=models_dir, - logger=CSVLogger(models_dir), + accelerator=args.accelerator, # Directly use the specified accelerator + strategy=args.strategy, # Hard-coded value + precision=args.precision, # Hard-coded value + devices=args.devices, # Hard-coded value + max_epochs=args.max_epochs, # Hard-coded value + default_root_dir=args.models_dir, + logger=CSVLogger(args.models_dir), ) logging.info("Done.") @@ -93,7 +91,6 @@ def train_model(dataset_dir: Path, models_dir: Path, sample_tag: str, ) logging.info("Done.") -train_yml = Path(__file__).parent / 'configs' / 'train' / 'default.yaml' @click.command(name="slurm", help="Train on Slurm cluster") @add_options(config_path=train_yml) @@ -105,3 +102,6 @@ def train(): pass train.add_command(train_slurm) + +if __name__ == '__main__': + train_model() \ No newline at end of file diff --git a/src/segger/data/parquet/pyg_dataset.py b/src/segger/data/parquet/pyg_dataset.py index 8642ae3..5599cb3 100644 --- a/src/segger/data/parquet/pyg_dataset.py +++ b/src/segger/data/parquet/pyg_dataset.py @@ -64,4 +64,12 @@ def get(self, idx: int) -> Data: filepath = Path(self.processed_dir) / self.processed_file_names[idx] data = torch.load(filepath) data['tx'].x = data['tx'].x.to_dense() - return data \ No newline at end of file + if data['tx'].x.dim() == 1: + data['tx'].x = data['tx'].x.unsqueeze(1) + assert data['tx'].x.dim() == 2 + # this is an issue in PyG's RandomLinkSplit, dimensions are not consistent if there is only one edge in the graph + if data['tx', 'belongs', 'bd'].edge_label_index.dim() == 1: + data['tx', 'belongs', 'bd'].edge_label_index = data['tx', 'belongs', 'bd'].edge_label_index.unsqueeze(1) + data['tx', 'belongs', 'bd'].edge_label = data['tx', 'belongs', 'bd'].edge_label.unsqueeze(0) + assert data['tx', 'belongs', 'bd'].edge_label_index.dim() == 2 + return data