Skip to content

Commit

Permalink
Merge pull request #31 from daniel-unyi-42/main
Browse files Browse the repository at this point in the history
Adjustments in model training
  • Loading branch information
EliHei2 authored Oct 10, 2024
2 parents db561c4 + f2f245d commit 9ae16c7
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 25 deletions.
64 changes: 64 additions & 0 deletions src/segger/cli/configs/train/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
dataset_dir:
type: Path
required: true
help: Directory containing the processed Segger dataset.
models_dir:
type: Path
required: true
help: Directory to save the trained model and the training logs.
sample_tag:
type: str
required: true
help: Sample tag for the dataset.
init_emb:
type: int
default: 8
help: Size of the embedding layer.
hidden_channels:
type: int
default: 32
help: Size of hidden channels in the model.
num_tx_tokens:
type: int
default: 500
help: Number of transcript tokens.
out_channels:
type: int
default: 8
help: Number of output channels.
heads:
type: int
default: 2
help: Number of attention heads.
num_mid_layers:
type: int
default: 2
help: Number of mid layers in the model.
batch_size:
type: int
default: 4
help: Batch size for training.
num_workers:
type: int
default: 2
help: Number of workers for data loading.
accelerator:
type: str
default: 'cuda'
help: Device type to use for training (e.g., "cuda", "cpu").
max_epochs:
type: int
default: 200
help: Number of epochs for training.
devices:
type: int
default: 4
help: Number of devices (GPUs) to use.
strategy:
type: str
default: 'auto'
help: Training strategy for the trainer.
precision:
type: str
default: '16-mixed'
help: Precision for training.
48 changes: 24 additions & 24 deletions src/segger/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
from segger.cli.utils import add_options, CustomFormatter
from pathlib import Path
import logging
from argparse import Namespace

help_msg = "Train the Segger segmentation model."
# Path to default YAML configuration file
train_yml = Path(__file__).parent / 'configs' / 'train' / 'default.yaml'

help_msg = "Train the Segger segmentation model."
@click.command(name="train_model", help=help_msg)
@add_options(config_path=train_yml)
@click.option('--dataset_dir', type=Path, required=True, help='Directory containing the processed Segger dataset.')
Expand All @@ -25,12 +28,7 @@
@click.option('--devices', type=int, default=4, help='Number of devices (GPUs) to use.')
@click.option('--strategy', type=str, default='auto', help='Training strategy for the trainer.')
@click.option('--precision', type=str, default='16-mixed', help='Precision for training.')
def train_model(dataset_dir: Path, models_dir: Path, sample_tag: str,
init_emb: int = 8, hidden_channels: int = 32, num_tx_tokens: int = 500,
out_channels: int = 8, heads: int = 2, num_mid_layers: int = 2,
batch_size: int = 4, num_workers: int = 2,
accelerator: str = 'cuda', max_epochs: int = 200,
devices: int = 4, strategy: str = 'auto', precision: str = '16-mixed'):
def train_model(args: Namespace):

# Setup logging
ch = logging.StreamHandler()
Expand All @@ -50,9 +48,9 @@ def train_model(dataset_dir: Path, models_dir: Path, sample_tag: str,
# Load datasets
logging.info("Loading Xenium datasets...")
dm = SeggerDataModule(
data_dir=dataset_dir,
batch_size=batch_size, # Hard-coded batch size
num_workers=num_workers, # Hard-coded number of workers
data_dir=args.dataset_dir,
batch_size=args.batch_size, # Hard-coded batch size
num_workers=args.num_workers, # Hard-coded number of workers
)

dm.setup()
Expand All @@ -62,25 +60,25 @@ def train_model(dataset_dir: Path, models_dir: Path, sample_tag: str,
logging.info("Initializing Segger model and trainer...")
metadata = (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")])
ls = LitSegger(
num_tx_tokens=num_tx_tokens,
init_emb=init_emb,
hidden_channels=hidden_channels,
out_channels=out_channels, # Hard-coded value
heads=heads, # Hard-coded value
num_mid_layers=num_mid_layers, # Hard-coded value
num_tx_tokens=args.num_tx_tokens,
init_emb=args.init_emb,
hidden_channels=args.hidden_channels,
out_channels=args.out_channels, # Hard-coded value
heads=args.heads, # Hard-coded value
num_mid_layers=args.num_mid_layers, # Hard-coded value
aggr='sum', # Hard-coded value
metadata=metadata,
)

# Initialize the Lightning trainer
trainer = Trainer(
accelerator=accelerator, # Directly use the specified accelerator
strategy=strategy, # Hard-coded value
precision=precision, # Hard-coded value
devices=devices, # Hard-coded value
max_epochs=max_epochs, # Hard-coded value
default_root_dir=models_dir,
logger=CSVLogger(models_dir),
accelerator=args.accelerator, # Directly use the specified accelerator
strategy=args.strategy, # Hard-coded value
precision=args.precision, # Hard-coded value
devices=args.devices, # Hard-coded value
max_epochs=args.max_epochs, # Hard-coded value
default_root_dir=args.models_dir,
logger=CSVLogger(args.models_dir),
)

logging.info("Done.")
Expand All @@ -93,7 +91,6 @@ def train_model(dataset_dir: Path, models_dir: Path, sample_tag: str,
)
logging.info("Done.")

train_yml = Path(__file__).parent / 'configs' / 'train' / 'default.yaml'

@click.command(name="slurm", help="Train on Slurm cluster")
@add_options(config_path=train_yml)
Expand All @@ -105,3 +102,6 @@ def train():
pass

train.add_command(train_slurm)

if __name__ == '__main__':
train_model()
10 changes: 9 additions & 1 deletion src/segger/data/parquet/pyg_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,12 @@ def get(self, idx: int) -> Data:
filepath = Path(self.processed_dir) / self.processed_file_names[idx]
data = torch.load(filepath)
data['tx'].x = data['tx'].x.to_dense()
return data
if data['tx'].x.dim() == 1:
data['tx'].x = data['tx'].x.unsqueeze(1)
assert data['tx'].x.dim() == 2
# this is an issue in PyG's RandomLinkSplit, dimensions are not consistent if there is only one edge in the graph
if data['tx', 'belongs', 'bd'].edge_label_index.dim() == 1:
data['tx', 'belongs', 'bd'].edge_label_index = data['tx', 'belongs', 'bd'].edge_label_index.unsqueeze(1)
data['tx', 'belongs', 'bd'].edge_label = data['tx', 'belongs', 'bd'].edge_label.unsqueeze(0)
assert data['tx', 'belongs', 'bd'].edge_label_index.dim() == 2
return data

0 comments on commit 9ae16c7

Please sign in to comment.