From 0787167054949969317bcb3d1a3a1ccab3b7613b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=81NIEL=20UNYI?= Date: Sun, 15 Dec 2024 18:09:20 +0100 Subject: [PATCH] Model training fix and code refactoring --- src/segger/cli/train_model.py | 41 +++++--- src/segger/data/parquet/pyg_dataset.py | 4 - src/segger/models/segger_model.py | 139 +++++++++++++++---------- src/segger/training/train.py | 53 +++++----- 4 files changed, 134 insertions(+), 103 deletions(-) diff --git a/src/segger/cli/train_model.py b/src/segger/cli/train_model.py index 731fcb3..5a02090 100644 --- a/src/segger/cli/train_model.py +++ b/src/segger/cli/train_model.py @@ -1,6 +1,4 @@ import click -import typing -import os from segger.cli.utils import add_options, CustomFormatter from pathlib import Path import logging @@ -53,8 +51,10 @@ def train_model(args: Namespace): # Import packages logging.info("Importing packages...") + import torch from segger.training.train import LitSegger from segger.training.segger_data_module import SeggerDataModule + from segger.prediction.predict_parquet import load_model from lightning.pytorch.loggers import CSVLogger from pytorch_lightning import Trainer @@ -72,31 +72,38 @@ def train_model(args: Namespace): logging.info("Done.") # Initialize model - logging.info("Initializing Segger model and trainer...") - metadata = (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]) - if args.pretrained_model_dir is not None: logging.info("Loading pretrained model...") - from segger.prediction.predict_parquet import load_model - ls = load_model(args.pretrained_model_dir / "lightning_logs" / f"version_{args.model_version}" / "checkpoints") else: + logging.info("Creating new model...") + is_token_based = dm.train[0].x_dict["tx"].ndim == 1 + if is_token_based: + # if the model is token-based, the input is a 1D tensor of token indices + assert dm.train[0].x_dict["tx"].ndim == 1 + assert dm.train[0].x_dict["tx"].dtype == torch.long + num_tx_features = args.num_tx_tokens + print("Using token-based embeddings as node features, number of tokens: ", num_tx_features) + else: + # if the model is not token-based, the input is a 2D tensor of scRNAseq embeddings + assert dm.train[0].x_dict["tx"].ndim == 2 + assert dm.train[0].x_dict["tx"].dtype == torch.float32 + num_tx_features = dm.train[0].x_dict["tx"].shape[1] + print("Using scRNAseq embeddings as node features, number of features: ", num_tx_features) + num_bd_features = dm.train[0].x_dict["bd"].shape[1] + print("Number of boundary node features: ", num_bd_features) ls = LitSegger( - num_tx_tokens=args.num_tx_tokens, + is_token_based=is_token_based, + num_node_features={"tx": num_tx_features, "bd": num_bd_features}, init_emb=args.init_emb, hidden_channels=args.hidden_channels, - out_channels=args.out_channels, # Hard-coded value - heads=args.heads, # Hard-coded value - num_mid_layers=args.num_mid_layers, # Hard-coded value + out_channels=args.out_channels, + heads=args.heads, + num_mid_layers=args.num_mid_layers, aggr="sum", # Hard-coded value learning_rate=args.learning_rate, - metadata=metadata, ) - - # Forward pass to initialize the model - if args.devices > 1: - batch = dm.train[0].to(ls.device) - ls.forward(batch) + logging.info("Done.") # Initialize the Lightning trainer trainer = Trainer( diff --git a/src/segger/data/parquet/pyg_dataset.py b/src/segger/data/parquet/pyg_dataset.py index a54803a..a694018 100644 --- a/src/segger/data/parquet/pyg_dataset.py +++ b/src/segger/data/parquet/pyg_dataset.py @@ -65,10 +65,6 @@ def get(self, idx: int) -> Data: """ filepath = Path(self.processed_dir) / self.processed_file_names[idx] data = torch.load(filepath) - data["tx"].x = data["tx"].x.to_dense() - if data["tx"].x.dim() == 1: - data["tx"].x = data["tx"].x.unsqueeze(1) - assert data["tx"].x.dim() == 2 # this is an issue in PyG's RandomLinkSplit, dimensions are not consistent if there is only one edge in the graph if hasattr(data["tx", "belongs", "bd"], "edge_label_index"): if data["tx", "belongs", "bd"].edge_label_index.dim() == 1: diff --git a/src/segger/models/segger_model.py b/src/segger/models/segger_model.py index f470574..f535221 100644 --- a/src/segger/models/segger_model.py +++ b/src/segger/models/segger_model.py @@ -1,16 +1,33 @@ import torch -from torch_geometric.nn import GATv2Conv, Linear -from torch.nn import Embedding +from torch_geometric.nn import HeteroConv, GATv2Conv, HeteroDictLinear +from torch import nn from torch import Tensor from typing import Union -# from torch_sparse import SparseTensor - - -class Segger(torch.nn.Module): +class SkipGAT(nn.Module): + def __init__(self, in_channels, out_channels, heads, apply_activation=True): + super().__init__() + self.apply_activation = apply_activation + self.conv = HeteroConv({ + ('tx', 'neighbors', 'tx'): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False), + ('tx', 'belongs', 'bd'): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False), + }, aggr='sum') + self.lin = HeteroDictLinear(in_channels, out_channels * heads, types=('tx', 'bd')) + + def forward(self, x_dict, edge_index_dict): + x_conv = self.conv(x_dict, edge_index_dict) + x_lin = self.lin(x_dict) + x_dict = {key: x_conv[key] + x_lin[key] for key in x_dict} + if self.apply_activation: + x_dict = {key: x_dict[key].relu() for key in x_dict} + return x_dict + + +class Segger(nn.Module): def __init__( self, - num_tx_tokens: int, + is_token_based: int, + num_node_features: dict[str, int], init_emb: int = 16, hidden_channels: int = 32, num_mid_layers: int = 3, @@ -21,80 +38,96 @@ def __init__( Initializes the Segger model. Args: - num_tx_tokens (int) : Number of unique 'tx' tokens for embedding. - init_emb (int) : Initial embedding size for both 'tx' and boundary (non-token) nodes. - hidden_channels (int): Number of hidden channels. - num_mid_layers (int) : Number of hidden layers (excluding first and last layers). - out_channels (int) : Number of output channels. - heads (int) : Number of attention heads. + is_token_based (int) : Whether the model is using token-based embeddings or scRNAseq embeddings. + num_node_features (dict[str, int]): Number of node features for each node type. + init_emb (int) : Initial embedding size for both 'tx' and boundary (non-token) nodes. + hidden_channels (int) : Number of hidden channels. + num_mid_layers (int) : Number of hidden layers (excluding first and last layers). + out_channels (int) : Number of output channels. + heads (int) : Number of attention heads. """ super().__init__() - # Embedding for 'tx' (transcript) nodes - self.tx_embedding = Embedding(num_tx_tokens, init_emb) - - # Linear layer for boundary (non-token) nodes - self.lin0 = Linear(-1, init_emb, bias=False) + # Initialize node embeddings + if is_token_based: + # Using token-based embeddings for transcript ('tx') nodes + self.node_init = nn.ModuleDict({ + 'tx': nn.Embedding(num_node_features['tx'], init_emb), + 'bd': nn.Linear(num_node_features['bd'], init_emb), + }) + else: + # Using scRNAseq embeddings (i.e. prior biological knowledge) for transcript ('tx') nodes + self.node_init = nn.ModuleDict({ + 'tx': nn.Linear(num_node_features['tx'], init_emb), + 'bd': nn.Linear(num_node_features['bd'], init_emb), + }) # First GATv2Conv layer - self.conv_first = GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False) - self.lin_first = Linear(-1, hidden_channels * heads) + self.conv1 = SkipGAT(init_emb, hidden_channels, heads) # Middle GATv2Conv layers self.num_mid_layers = num_mid_layers if num_mid_layers > 0: - self.conv_mid_layers = torch.nn.ModuleList() - self.lin_mid_layers = torch.nn.ModuleList() + self.conv_mid_layers = nn.ModuleList() for _ in range(num_mid_layers): - self.conv_mid_layers.append(GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)) - self.lin_mid_layers.append(Linear(-1, hidden_channels * heads)) - + self.conv_mid_layers.append(SkipGAT(heads * hidden_channels, hidden_channels, heads)) + # Last GATv2Conv layer - self.conv_last = GATv2Conv((-1, -1), out_channels, heads=heads, add_self_loops=False) - self.lin_last = Linear(-1, out_channels * heads) + self.conv_last = SkipGAT(heads * hidden_channels, out_channels, heads) + + # Finalize node embeddings + self.node_final = HeteroDictLinear(heads * out_channels, out_channels, types=('tx', 'bd')) + + # # Edge probability predictor + # self.edge_predictor = nn.Sequential( + # nn.Linear(2 * out_channels, out_channels), + # nn.ReLU(), + # nn.Linear(out_channels, 1), + # ) - def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: + def forward( + self, + x_dict: dict[str, Tensor], + edge_index_dict: dict[str, Tensor], + ) -> dict[str, Tensor]: """ Forward pass for the Segger model. Args: - x (Tensor): Node features. - edge_index (Tensor): Edge indices. - - Returns: - Tensor: Output node embeddings. + x_dict (dict[str, Tensor]): Node features for each node type. + edge_index_dict (dict[str, Tensor]): Edge indices for each edge type. """ - x = torch.nan_to_num(x, nan=0) - is_one_dim = (x.ndim == 1) * 1 - # x = x[:, None] - x = self.tx_embedding(((x.sum(1) * is_one_dim).int())) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim) - # First layer - x = x.relu() - x = self.conv_first(x, edge_index) + self.lin_first(x) - x = x.relu() - - # Middle layers + + x_dict = {key: self.node_init[key](x) for key, x in x_dict.items()} + + x_dict = self.conv1(x_dict, edge_index_dict) + if self.num_mid_layers > 0: for i in range(self.num_mid_layers): - conv_mid = self.conv_mid_layers[i] - lin_mid = self.lin_mid_layers[i] - x = conv_mid(x, edge_index) + lin_mid(x) - x = x.relu() + x_dict = self.conv_mid_layers[i](x_dict, edge_index_dict) + + x_dict = self.conv_last(x_dict, edge_index_dict) - # Last layer - x = self.conv_last(x, edge_index) + self.lin_last(x) + x_dict = self.node_final(x_dict) - return x + return x_dict - def decode(self, z: Tensor, edge_index: Union[Tensor]) -> Tensor: + def decode( + self, + z_dict: dict[str, Tensor], + edge_index: Union[Tensor], + ) -> Tensor: """ Decode the node embeddings to predict edge values. Args: - z (Tensor): Node embeddings. + z (dict[str, Tensor]): Node embeddings for each node type. edge_index (EdgeIndex): Edge label indices. Returns: Tensor: Predicted edge values. """ - return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) + z_left = z_dict['tx'][edge_index[0]] + z_right = z_dict['bd'][edge_index[1]] + return (z_left * z_right).sum(dim=-1) + # return self.edge_predictor(torch.cat([z_left, z_right], dim=-1)).squeeze() diff --git a/src/segger/training/train.py b/src/segger/training/train.py index 872f249..5bd6f66 100644 --- a/src/segger/training/train.py +++ b/src/segger/training/train.py @@ -4,10 +4,6 @@ import torchmetrics from torchmetrics import F1Score import lightning as L -from torch_geometric.loader import DataLoader -from torch_geometric.typing import Metadata -from torch_geometric.nn import to_hetero -from torch_geometric.data import HeteroData from segger.models.segger_model import * from segger.data.utils import SpatialTranscriptomicsDataset from typing import Any, List, Tuple, Union @@ -65,22 +61,24 @@ def __init__(self, learning_rate: float = 1e-3, **kwargs): def from_new( self, - num_tx_tokens: int, + is_token_based: int, + num_node_features: dict[str, int], init_emb: int, hidden_channels: int, out_channels: int, heads: int, num_mid_layers: int, aggr: str, - metadata: Union[Tuple, Metadata], ): """ Initializes the LitSegger module with new parameters. Parameters ---------- - num_tx_tokens : int - Number of unique 'tx' tokens for embedding (this must be passed here). + is_token_based : int + Whether the model is using token-based embeddings or scRNAseq embeddings. + num_node_features : dict[str, int] + Number of node features for each node type. init_emb : int Initial embedding size. hidden_channels : int @@ -97,17 +95,15 @@ def from_new( Metadata for heterogeneous graph structure. """ # Create the Segger model (ensure num_tx_tokens is passed here) - model = Segger( - num_tx_tokens=num_tx_tokens, # This is required and must be passed here + self.model = Segger( + is_token_based=is_token_based, + num_node_features=num_node_features, init_emb=init_emb, hidden_channels=hidden_channels, out_channels=out_channels, heads=heads, num_mid_layers=num_mid_layers, ) - # Convert model to handle heterogeneous graphs - model = to_hetero(model, metadata=metadata, aggr=aggr) - self.model = model # Save hyperparameters self.save_hyperparameters() @@ -137,7 +133,8 @@ def forward(self, batch: SpatialTranscriptomicsDataset) -> torch.Tensor: The output of the model. """ z = self.model(batch.x_dict, batch.edge_index_dict) - output = torch.matmul(z["tx"], z["bd"].t()) # Example for bipartite graph + edge_label_index = batch["tx", "belongs", "bd"].edge_label_index + output = self.model.decode(z, edge_label_index) return output def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: @@ -156,17 +153,16 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: torch.Tensor The loss value for the current training step. """ - # Forward pass to get the logits - z = self.model(batch.x_dict, batch.edge_index_dict) - output = torch.matmul(z["tx"], z["bd"].t()) - - # Get edge labels and logits + # Get edge labels edge_label_index = batch["tx", "belongs", "bd"].edge_label_index - out_values = output[edge_label_index[0], edge_label_index[1]] edge_label = batch["tx", "belongs", "bd"].edge_label + # Forward pass to get the logits + z = self.model(batch.x_dict, batch.edge_index_dict) + output = self.model.decode(z, edge_label_index) + # Compute binary cross-entropy loss with logits (no sigmoid here) - loss = self.criterion(out_values, edge_label) + loss = self.criterion(output, edge_label) # Log the training loss self.log("train_loss", loss, prog_bar=True, batch_size=batch.num_graphs) @@ -188,20 +184,19 @@ def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: torch.Tensor The loss value for the current validation step. """ - # Forward pass to get the logits - z = self.model(batch.x_dict, batch.edge_index_dict) - output = torch.matmul(z["tx"], z["bd"].t()) - - # Get edge labels and logits + # Get edge labels edge_label_index = batch["tx", "belongs", "bd"].edge_label_index - out_values = output[edge_label_index[0], edge_label_index[1]] edge_label = batch["tx", "belongs", "bd"].edge_label + # Forward pass to get the logits + z = self.model(batch.x_dict, batch.edge_index_dict) + output = self.model.decode(z, edge_label_index) + # Compute binary cross-entropy loss with logits (no sigmoid here) - loss = self.criterion(out_values, edge_label) + loss = self.criterion(output, edge_label) # Apply sigmoid to logits for AUROC and F1 metrics - out_values_prob = torch.sigmoid(out_values) + out_values_prob = torch.sigmoid(output) # Compute metrics auroc = torchmetrics.AUROC(task="binary")