From e89013bd6a8f74b72c98cef679d803d33a873001 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=81NIEL=20UNYI?= Date: Mon, 9 Dec 2024 18:07:14 +0100 Subject: [PATCH 1/5] Add new parameters to training (learning rate, early stopping, and start training on pretrained model) --- scripts/config.yaml | 6 +++- scripts/submit_job.py | 9 ++++++ src/segger/cli/configs/train/default.yaml | 16 +++++++++++ src/segger/cli/train_model.py | 34 +++++++++++++++-------- src/segger/training/train.py | 7 +++-- 5 files changed, 58 insertions(+), 14 deletions(-) diff --git a/scripts/config.yaml b/scripts/config.yaml index 38b74e0..21b5259 100644 --- a/scripts/config.yaml +++ b/scripts/config.yaml @@ -43,6 +43,10 @@ training: num_workers: 12 accelerator: "cuda" max_epochs: 200 + save_best_model: true + learning_rate: 1e-3 + pretrained_model_dir: null + pretrained_model_version: 0 devices: 8 strategy: "auto" precision: "16-mixed" @@ -50,7 +54,7 @@ training: gpu_memory: "8G" # this is ignored if use_lsf is false prediction: - output_log: "predict_parquet_final_output.log" + output_log: "predict_parquet_output.log" segger_data_dir: "data_segger" models_dir: "model_dir" benchmarks_dir: "benchmark_dir" diff --git a/scripts/submit_job.py b/scripts/submit_job.py index 139ed65..b0aeb36 100644 --- a/scripts/submit_job.py +++ b/scripts/submit_job.py @@ -151,6 +151,10 @@ def run_training(): config["training"]["accelerator"], "--max_epochs", str(config["training"]["max_epochs"]), + "--save_best_model", + str(config["training"]["save_best_model"]), + "--learning_rate", + str(config["training"]["learning_rate"]), "--devices", str(config["training"]["devices"]), "--strategy", @@ -160,6 +164,11 @@ def run_training(): ] ) + if config["training"].get("pretrained_model_dir") is not None: + command.extend(["--pretrained_model_dir", config["training"]["pretrained_model_dir"]]) + if config["training"].get("pretrained_model_version") is not None: + command.extend(["--pretrained_model_version", str(config["training"]["pretrained_model_version"])]) + if config.get("use_lsf", False): command = [ "bsub", diff --git a/src/segger/cli/configs/train/default.yaml b/src/segger/cli/configs/train/default.yaml index cf27fc1..329f68a 100644 --- a/src/segger/cli/configs/train/default.yaml +++ b/src/segger/cli/configs/train/default.yaml @@ -50,6 +50,22 @@ max_epochs: type: int default: 200 help: Number of epochs for training. +save_best_model: + type: bool + default: true + help: Whether to save the best model. +learning_rate: + type: float + default: 1e-3 + help: Learning rate for training. +pretrained_model_dir: + type: Path + default: null + help: Directory containing the pretrained model to use (if any). +pretrained_model_version: + type: int + default: null + help: Version of pretrained model. devices: type: int default: 4 diff --git a/src/segger/cli/train_model.py b/src/segger/cli/train_model.py index 5747cfb..04c1cc0 100644 --- a/src/segger/cli/train_model.py +++ b/src/segger/cli/train_model.py @@ -31,6 +31,10 @@ "--accelerator", type=str, default="cuda", help='Device type to use for training (e.g., "cuda", "cpu").' ) # Ask for accelerator @click.option("--max_epochs", type=int, default=200, help="Number of epochs for training.") +@click.option("--save_best_model", type=bool, default=True, help="Whether to save the best model.") # unused for now +@click.option("--learning_rate", type=float, default=1e-3, help="Learning rate for training.") +@click.option("--pretrained_model_dir", type=Path, default=None, help="Directory containing the pretrained modelDirectory containing the pretrained model to use (if any).") +@click.option("--pretrained_model_version", type=int, default=None, help="Version of pretrained model.") @click.option("--devices", type=int, default=4, help="Number of devices (GPUs) to use.") @click.option("--strategy", type=str, default="auto", help="Training strategy for the trainer.") @click.option("--precision", type=str, default="16-mixed", help="Precision for training.") @@ -65,20 +69,28 @@ def train_model(args: Namespace): # Initialize model logging.info("Initializing Segger model and trainer...") metadata = (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]) - ls = LitSegger( - num_tx_tokens=args.num_tx_tokens, - init_emb=args.init_emb, - hidden_channels=args.hidden_channels, - out_channels=args.out_channels, # Hard-coded value - heads=args.heads, # Hard-coded value - num_mid_layers=args.num_mid_layers, # Hard-coded value - aggr="sum", # Hard-coded value - metadata=metadata, - ) + + if args.pretrained_model_dir is not None: + logging.info("Loading pretrained model...") + from segger.prediction.predict_parquet import load_model + + ls = load_model(args.pretrained_model_dir / "lightning_logs" / f"version_{args.model_version}" / "checkpoints") + else: + ls = LitSegger( + num_tx_tokens=args.num_tx_tokens, + init_emb=args.init_emb, + hidden_channels=args.hidden_channels, + out_channels=args.out_channels, # Hard-coded value + heads=args.heads, # Hard-coded value + num_mid_layers=args.num_mid_layers, # Hard-coded value + aggr="sum", # Hard-coded value + learning_rate=args.learning_rate, + metadata=metadata, + ) # Forward pass to initialize the model if args.devices > 1: - batch = dm.train[0] + batch = dm.train[0].to(ls.device) ls.forward(batch) # Initialize the Lightning trainer diff --git a/src/segger/training/train.py b/src/segger/training/train.py index 68adbb3..872f249 100644 --- a/src/segger/training/train.py +++ b/src/segger/training/train.py @@ -29,12 +29,14 @@ class LitSegger(LightningModule): The loss function used for training, specifically BCEWithLogitsLoss. """ - def __init__(self, **kwargs): + def __init__(self, learning_rate: float = 1e-3, **kwargs): """ Initializes the LitSegger module with the given parameters. Parameters ---------- + learning_rate : float + The learning rate for the optimizer. **kwargs : dict Keyword arguments for initializing the module. Specific parameters depend on whether the module is initialized with new parameters or components. @@ -59,6 +61,7 @@ def __init__(self, **kwargs): self.validation_step_outputs = [] self.criterion = torch.nn.BCEWithLogitsLoss() + self.learning_rate = learning_rate def from_new( self, @@ -223,5 +226,5 @@ def configure_optimizers(self) -> torch.optim.Optimizer: torch.optim.Optimizer The optimizer for training. """ - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer From 954d1cb3e7c66c89ad25fda5e7bd9614366b42df Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:08:55 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/segger/cli/train_model.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/segger/cli/train_model.py b/src/segger/cli/train_model.py index 04c1cc0..731fcb3 100644 --- a/src/segger/cli/train_model.py +++ b/src/segger/cli/train_model.py @@ -31,9 +31,14 @@ "--accelerator", type=str, default="cuda", help='Device type to use for training (e.g., "cuda", "cpu").' ) # Ask for accelerator @click.option("--max_epochs", type=int, default=200, help="Number of epochs for training.") -@click.option("--save_best_model", type=bool, default=True, help="Whether to save the best model.") # unused for now +@click.option("--save_best_model", type=bool, default=True, help="Whether to save the best model.") # unused for now @click.option("--learning_rate", type=float, default=1e-3, help="Learning rate for training.") -@click.option("--pretrained_model_dir", type=Path, default=None, help="Directory containing the pretrained modelDirectory containing the pretrained model to use (if any).") +@click.option( + "--pretrained_model_dir", + type=Path, + default=None, + help="Directory containing the pretrained modelDirectory containing the pretrained model to use (if any).", +) @click.option("--pretrained_model_version", type=int, default=None, help="Version of pretrained model.") @click.option("--devices", type=int, default=4, help="Number of devices (GPUs) to use.") @click.option("--strategy", type=str, default="auto", help="Training strategy for the trainer.") @@ -73,7 +78,7 @@ def train_model(args: Namespace): if args.pretrained_model_dir is not None: logging.info("Loading pretrained model...") from segger.prediction.predict_parquet import load_model - + ls = load_model(args.pretrained_model_dir / "lightning_logs" / f"version_{args.model_version}" / "checkpoints") else: ls = LitSegger( From 6ab4e5fc5258a5c98c9761bdcf4f5143a96666cb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Dec 2024 10:31:39 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/segger/prediction/predict.py | 2 ++ src/segger/prediction/predict_parquet.py | 36 +++++++++++++----------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/segger/prediction/predict.py b/src/segger/prediction/predict.py index 2c2f481..2717af8 100644 --- a/src/segger/prediction/predict.py +++ b/src/segger/prediction/predict.py @@ -6,6 +6,7 @@ import torch.nn.functional as F import torch._dynamo import gc + # import rmm import re import glob @@ -28,6 +29,7 @@ from dask.diagnostics import ProgressBar import time import dask + # from rmm.allocators.cupy import rmm_cupy_allocator from cupyx.scipy.sparse import coo_matrix from torch.utils.dlpack import to_dlpack, from_dlpack diff --git a/src/segger/prediction/predict_parquet.py b/src/segger/prediction/predict_parquet.py index 9af4090..c6d39ef 100644 --- a/src/segger/prediction/predict_parquet.py +++ b/src/segger/prediction/predict_parquet.py @@ -6,6 +6,7 @@ import torch.nn.functional as F import torch._dynamo import gc + # import rmm import re import glob @@ -30,6 +31,7 @@ from dask.diagnostics import ProgressBar import time import dask + # from rmm.allocators.cupy import rmm_cupy_allocator from cupyx.scipy.sparse import coo_matrix from torch.utils.dlpack import to_dlpack, from_dlpack @@ -564,19 +566,19 @@ def segment( print(f"Merged segmentation results with transcripts in {elapsed_time:.2f} seconds.") if use_cc: - + step_start_time = time() if verbose: print(f"Computing connected components for unassigned transcripts...") # Load edge indices from saved Parquet edge_index_dd = pd.read_parquet(edge_index_save_path) - + # Step 2: Get unique transcript_ids from edge_index_dd and their positional indices transcript_ids_in_edges = pd.concat([edge_index_dd["source"], edge_index_dd["target"]]).unique() - + # Create a lookup table with unique indices lookup_table = pd.Series(data=range(len(transcript_ids_in_edges)), index=transcript_ids_in_edges).to_dict() - + # Map source and target to positional indices edge_index_dd["index_source"] = edge_index_dd["source"].map(lookup_table) edge_index_dd["index_target"] = edge_index_dd["target"].map(lookup_table) @@ -584,44 +586,44 @@ def segment( source_indices = np.asarray(edge_index_dd["index_source"]) target_indices = np.asarray(edge_index_dd["index_target"]) data_cp = np.ones(len(source_indices), dtype=np.float32) - + # Create the sparse COO matrix coo_cp_matrix = scipy_coo_matrix( (data_cp, (source_indices, target_indices)), shape=(len(transcript_ids_in_edges), len(transcript_ids_in_edges)), ) - + # Use CuPy's connected components algorithm to compute components n, comps = cc(coo_cp_matrix, directed=True, connection="strong") if verbose: elapsed_time = time() - step_start_time print(f"Computed connected components for unassigned transcripts in {elapsed_time:.2f} seconds.") - + step_start_time = time() if verbose: print(f"The rest...") # # Step 4: Map back the component labels to the original transcript_ids - + def _get_id(): """Generate a random Xenium-style ID.""" return "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), 8)) + "-nx" - + new_ids = np.array([_get_id() for _ in range(n)]) comp_labels = new_ids[comps] comp_labels = pd.Series(comp_labels, index=transcript_ids_in_edges) # Step 5: Handle only unassigned transcripts in transcripts_df_filtered unassigned_mask = transcripts_df_filtered["segger_cell_id"].isna() - + unassigned_transcripts_df = transcripts_df_filtered.loc[unassigned_mask, ["transcript_id"]] - + # Step 6: Map component labels only to unassigned transcript_ids new_segger_cell_ids = unassigned_transcripts_df["transcript_id"].map(comp_labels) - + # Step 7: Create a DataFrame with updated 'segger_cell_id' for unassigned transcripts unassigned_transcripts_df = unassigned_transcripts_df.assign(segger_cell_id=new_segger_cell_ids) - + # Step 8: Merge this DataFrame back into the original to update only the unassigned segger_cell_id - + # Merging the updates back to the original DataFrame transcripts_df_filtered = transcripts_df_filtered.merge( unassigned_transcripts_df[["transcript_id", "segger_cell_id"]], @@ -629,14 +631,14 @@ def _get_id(): how="left", # Perform a left join to only update the unassigned rows suffixes=("", "_new"), # Suffix for new column to avoid overwriting ) - + # Step 9: Fill missing segger_cell_id values with the updated values from the merge transcripts_df_filtered["segger_cell_id"] = transcripts_df_filtered["segger_cell_id"].fillna( transcripts_df_filtered["segger_cell_id_new"] ) - + transcripts_df_filtered = transcripts_df_filtered.drop(columns=["segger_cell_id_new"]) - + if verbose: elapsed_time = time() - step_start_time print(f"The rest computed in {elapsed_time:.2f} seconds.") From 0787167054949969317bcb3d1a3a1ccab3b7613b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=81NIEL=20UNYI?= Date: Sun, 15 Dec 2024 18:09:20 +0100 Subject: [PATCH 4/5] Model training fix and code refactoring --- src/segger/cli/train_model.py | 41 +++++--- src/segger/data/parquet/pyg_dataset.py | 4 - src/segger/models/segger_model.py | 139 +++++++++++++++---------- src/segger/training/train.py | 53 +++++----- 4 files changed, 134 insertions(+), 103 deletions(-) diff --git a/src/segger/cli/train_model.py b/src/segger/cli/train_model.py index 731fcb3..5a02090 100644 --- a/src/segger/cli/train_model.py +++ b/src/segger/cli/train_model.py @@ -1,6 +1,4 @@ import click -import typing -import os from segger.cli.utils import add_options, CustomFormatter from pathlib import Path import logging @@ -53,8 +51,10 @@ def train_model(args: Namespace): # Import packages logging.info("Importing packages...") + import torch from segger.training.train import LitSegger from segger.training.segger_data_module import SeggerDataModule + from segger.prediction.predict_parquet import load_model from lightning.pytorch.loggers import CSVLogger from pytorch_lightning import Trainer @@ -72,31 +72,38 @@ def train_model(args: Namespace): logging.info("Done.") # Initialize model - logging.info("Initializing Segger model and trainer...") - metadata = (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]) - if args.pretrained_model_dir is not None: logging.info("Loading pretrained model...") - from segger.prediction.predict_parquet import load_model - ls = load_model(args.pretrained_model_dir / "lightning_logs" / f"version_{args.model_version}" / "checkpoints") else: + logging.info("Creating new model...") + is_token_based = dm.train[0].x_dict["tx"].ndim == 1 + if is_token_based: + # if the model is token-based, the input is a 1D tensor of token indices + assert dm.train[0].x_dict["tx"].ndim == 1 + assert dm.train[0].x_dict["tx"].dtype == torch.long + num_tx_features = args.num_tx_tokens + print("Using token-based embeddings as node features, number of tokens: ", num_tx_features) + else: + # if the model is not token-based, the input is a 2D tensor of scRNAseq embeddings + assert dm.train[0].x_dict["tx"].ndim == 2 + assert dm.train[0].x_dict["tx"].dtype == torch.float32 + num_tx_features = dm.train[0].x_dict["tx"].shape[1] + print("Using scRNAseq embeddings as node features, number of features: ", num_tx_features) + num_bd_features = dm.train[0].x_dict["bd"].shape[1] + print("Number of boundary node features: ", num_bd_features) ls = LitSegger( - num_tx_tokens=args.num_tx_tokens, + is_token_based=is_token_based, + num_node_features={"tx": num_tx_features, "bd": num_bd_features}, init_emb=args.init_emb, hidden_channels=args.hidden_channels, - out_channels=args.out_channels, # Hard-coded value - heads=args.heads, # Hard-coded value - num_mid_layers=args.num_mid_layers, # Hard-coded value + out_channels=args.out_channels, + heads=args.heads, + num_mid_layers=args.num_mid_layers, aggr="sum", # Hard-coded value learning_rate=args.learning_rate, - metadata=metadata, ) - - # Forward pass to initialize the model - if args.devices > 1: - batch = dm.train[0].to(ls.device) - ls.forward(batch) + logging.info("Done.") # Initialize the Lightning trainer trainer = Trainer( diff --git a/src/segger/data/parquet/pyg_dataset.py b/src/segger/data/parquet/pyg_dataset.py index a54803a..a694018 100644 --- a/src/segger/data/parquet/pyg_dataset.py +++ b/src/segger/data/parquet/pyg_dataset.py @@ -65,10 +65,6 @@ def get(self, idx: int) -> Data: """ filepath = Path(self.processed_dir) / self.processed_file_names[idx] data = torch.load(filepath) - data["tx"].x = data["tx"].x.to_dense() - if data["tx"].x.dim() == 1: - data["tx"].x = data["tx"].x.unsqueeze(1) - assert data["tx"].x.dim() == 2 # this is an issue in PyG's RandomLinkSplit, dimensions are not consistent if there is only one edge in the graph if hasattr(data["tx", "belongs", "bd"], "edge_label_index"): if data["tx", "belongs", "bd"].edge_label_index.dim() == 1: diff --git a/src/segger/models/segger_model.py b/src/segger/models/segger_model.py index f470574..f535221 100644 --- a/src/segger/models/segger_model.py +++ b/src/segger/models/segger_model.py @@ -1,16 +1,33 @@ import torch -from torch_geometric.nn import GATv2Conv, Linear -from torch.nn import Embedding +from torch_geometric.nn import HeteroConv, GATv2Conv, HeteroDictLinear +from torch import nn from torch import Tensor from typing import Union -# from torch_sparse import SparseTensor - - -class Segger(torch.nn.Module): +class SkipGAT(nn.Module): + def __init__(self, in_channels, out_channels, heads, apply_activation=True): + super().__init__() + self.apply_activation = apply_activation + self.conv = HeteroConv({ + ('tx', 'neighbors', 'tx'): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False), + ('tx', 'belongs', 'bd'): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False), + }, aggr='sum') + self.lin = HeteroDictLinear(in_channels, out_channels * heads, types=('tx', 'bd')) + + def forward(self, x_dict, edge_index_dict): + x_conv = self.conv(x_dict, edge_index_dict) + x_lin = self.lin(x_dict) + x_dict = {key: x_conv[key] + x_lin[key] for key in x_dict} + if self.apply_activation: + x_dict = {key: x_dict[key].relu() for key in x_dict} + return x_dict + + +class Segger(nn.Module): def __init__( self, - num_tx_tokens: int, + is_token_based: int, + num_node_features: dict[str, int], init_emb: int = 16, hidden_channels: int = 32, num_mid_layers: int = 3, @@ -21,80 +38,96 @@ def __init__( Initializes the Segger model. Args: - num_tx_tokens (int) : Number of unique 'tx' tokens for embedding. - init_emb (int) : Initial embedding size for both 'tx' and boundary (non-token) nodes. - hidden_channels (int): Number of hidden channels. - num_mid_layers (int) : Number of hidden layers (excluding first and last layers). - out_channels (int) : Number of output channels. - heads (int) : Number of attention heads. + is_token_based (int) : Whether the model is using token-based embeddings or scRNAseq embeddings. + num_node_features (dict[str, int]): Number of node features for each node type. + init_emb (int) : Initial embedding size for both 'tx' and boundary (non-token) nodes. + hidden_channels (int) : Number of hidden channels. + num_mid_layers (int) : Number of hidden layers (excluding first and last layers). + out_channels (int) : Number of output channels. + heads (int) : Number of attention heads. """ super().__init__() - # Embedding for 'tx' (transcript) nodes - self.tx_embedding = Embedding(num_tx_tokens, init_emb) - - # Linear layer for boundary (non-token) nodes - self.lin0 = Linear(-1, init_emb, bias=False) + # Initialize node embeddings + if is_token_based: + # Using token-based embeddings for transcript ('tx') nodes + self.node_init = nn.ModuleDict({ + 'tx': nn.Embedding(num_node_features['tx'], init_emb), + 'bd': nn.Linear(num_node_features['bd'], init_emb), + }) + else: + # Using scRNAseq embeddings (i.e. prior biological knowledge) for transcript ('tx') nodes + self.node_init = nn.ModuleDict({ + 'tx': nn.Linear(num_node_features['tx'], init_emb), + 'bd': nn.Linear(num_node_features['bd'], init_emb), + }) # First GATv2Conv layer - self.conv_first = GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False) - self.lin_first = Linear(-1, hidden_channels * heads) + self.conv1 = SkipGAT(init_emb, hidden_channels, heads) # Middle GATv2Conv layers self.num_mid_layers = num_mid_layers if num_mid_layers > 0: - self.conv_mid_layers = torch.nn.ModuleList() - self.lin_mid_layers = torch.nn.ModuleList() + self.conv_mid_layers = nn.ModuleList() for _ in range(num_mid_layers): - self.conv_mid_layers.append(GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)) - self.lin_mid_layers.append(Linear(-1, hidden_channels * heads)) - + self.conv_mid_layers.append(SkipGAT(heads * hidden_channels, hidden_channels, heads)) + # Last GATv2Conv layer - self.conv_last = GATv2Conv((-1, -1), out_channels, heads=heads, add_self_loops=False) - self.lin_last = Linear(-1, out_channels * heads) + self.conv_last = SkipGAT(heads * hidden_channels, out_channels, heads) + + # Finalize node embeddings + self.node_final = HeteroDictLinear(heads * out_channels, out_channels, types=('tx', 'bd')) + + # # Edge probability predictor + # self.edge_predictor = nn.Sequential( + # nn.Linear(2 * out_channels, out_channels), + # nn.ReLU(), + # nn.Linear(out_channels, 1), + # ) - def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: + def forward( + self, + x_dict: dict[str, Tensor], + edge_index_dict: dict[str, Tensor], + ) -> dict[str, Tensor]: """ Forward pass for the Segger model. Args: - x (Tensor): Node features. - edge_index (Tensor): Edge indices. - - Returns: - Tensor: Output node embeddings. + x_dict (dict[str, Tensor]): Node features for each node type. + edge_index_dict (dict[str, Tensor]): Edge indices for each edge type. """ - x = torch.nan_to_num(x, nan=0) - is_one_dim = (x.ndim == 1) * 1 - # x = x[:, None] - x = self.tx_embedding(((x.sum(1) * is_one_dim).int())) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim) - # First layer - x = x.relu() - x = self.conv_first(x, edge_index) + self.lin_first(x) - x = x.relu() - - # Middle layers + + x_dict = {key: self.node_init[key](x) for key, x in x_dict.items()} + + x_dict = self.conv1(x_dict, edge_index_dict) + if self.num_mid_layers > 0: for i in range(self.num_mid_layers): - conv_mid = self.conv_mid_layers[i] - lin_mid = self.lin_mid_layers[i] - x = conv_mid(x, edge_index) + lin_mid(x) - x = x.relu() + x_dict = self.conv_mid_layers[i](x_dict, edge_index_dict) + + x_dict = self.conv_last(x_dict, edge_index_dict) - # Last layer - x = self.conv_last(x, edge_index) + self.lin_last(x) + x_dict = self.node_final(x_dict) - return x + return x_dict - def decode(self, z: Tensor, edge_index: Union[Tensor]) -> Tensor: + def decode( + self, + z_dict: dict[str, Tensor], + edge_index: Union[Tensor], + ) -> Tensor: """ Decode the node embeddings to predict edge values. Args: - z (Tensor): Node embeddings. + z (dict[str, Tensor]): Node embeddings for each node type. edge_index (EdgeIndex): Edge label indices. Returns: Tensor: Predicted edge values. """ - return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) + z_left = z_dict['tx'][edge_index[0]] + z_right = z_dict['bd'][edge_index[1]] + return (z_left * z_right).sum(dim=-1) + # return self.edge_predictor(torch.cat([z_left, z_right], dim=-1)).squeeze() diff --git a/src/segger/training/train.py b/src/segger/training/train.py index 872f249..5bd6f66 100644 --- a/src/segger/training/train.py +++ b/src/segger/training/train.py @@ -4,10 +4,6 @@ import torchmetrics from torchmetrics import F1Score import lightning as L -from torch_geometric.loader import DataLoader -from torch_geometric.typing import Metadata -from torch_geometric.nn import to_hetero -from torch_geometric.data import HeteroData from segger.models.segger_model import * from segger.data.utils import SpatialTranscriptomicsDataset from typing import Any, List, Tuple, Union @@ -65,22 +61,24 @@ def __init__(self, learning_rate: float = 1e-3, **kwargs): def from_new( self, - num_tx_tokens: int, + is_token_based: int, + num_node_features: dict[str, int], init_emb: int, hidden_channels: int, out_channels: int, heads: int, num_mid_layers: int, aggr: str, - metadata: Union[Tuple, Metadata], ): """ Initializes the LitSegger module with new parameters. Parameters ---------- - num_tx_tokens : int - Number of unique 'tx' tokens for embedding (this must be passed here). + is_token_based : int + Whether the model is using token-based embeddings or scRNAseq embeddings. + num_node_features : dict[str, int] + Number of node features for each node type. init_emb : int Initial embedding size. hidden_channels : int @@ -97,17 +95,15 @@ def from_new( Metadata for heterogeneous graph structure. """ # Create the Segger model (ensure num_tx_tokens is passed here) - model = Segger( - num_tx_tokens=num_tx_tokens, # This is required and must be passed here + self.model = Segger( + is_token_based=is_token_based, + num_node_features=num_node_features, init_emb=init_emb, hidden_channels=hidden_channels, out_channels=out_channels, heads=heads, num_mid_layers=num_mid_layers, ) - # Convert model to handle heterogeneous graphs - model = to_hetero(model, metadata=metadata, aggr=aggr) - self.model = model # Save hyperparameters self.save_hyperparameters() @@ -137,7 +133,8 @@ def forward(self, batch: SpatialTranscriptomicsDataset) -> torch.Tensor: The output of the model. """ z = self.model(batch.x_dict, batch.edge_index_dict) - output = torch.matmul(z["tx"], z["bd"].t()) # Example for bipartite graph + edge_label_index = batch["tx", "belongs", "bd"].edge_label_index + output = self.model.decode(z, edge_label_index) return output def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: @@ -156,17 +153,16 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: torch.Tensor The loss value for the current training step. """ - # Forward pass to get the logits - z = self.model(batch.x_dict, batch.edge_index_dict) - output = torch.matmul(z["tx"], z["bd"].t()) - - # Get edge labels and logits + # Get edge labels edge_label_index = batch["tx", "belongs", "bd"].edge_label_index - out_values = output[edge_label_index[0], edge_label_index[1]] edge_label = batch["tx", "belongs", "bd"].edge_label + # Forward pass to get the logits + z = self.model(batch.x_dict, batch.edge_index_dict) + output = self.model.decode(z, edge_label_index) + # Compute binary cross-entropy loss with logits (no sigmoid here) - loss = self.criterion(out_values, edge_label) + loss = self.criterion(output, edge_label) # Log the training loss self.log("train_loss", loss, prog_bar=True, batch_size=batch.num_graphs) @@ -188,20 +184,19 @@ def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: torch.Tensor The loss value for the current validation step. """ - # Forward pass to get the logits - z = self.model(batch.x_dict, batch.edge_index_dict) - output = torch.matmul(z["tx"], z["bd"].t()) - - # Get edge labels and logits + # Get edge labels edge_label_index = batch["tx", "belongs", "bd"].edge_label_index - out_values = output[edge_label_index[0], edge_label_index[1]] edge_label = batch["tx", "belongs", "bd"].edge_label + # Forward pass to get the logits + z = self.model(batch.x_dict, batch.edge_index_dict) + output = self.model.decode(z, edge_label_index) + # Compute binary cross-entropy loss with logits (no sigmoid here) - loss = self.criterion(out_values, edge_label) + loss = self.criterion(output, edge_label) # Apply sigmoid to logits for AUROC and F1 metrics - out_values_prob = torch.sigmoid(out_values) + out_values_prob = torch.sigmoid(output) # Compute metrics auroc = torchmetrics.AUROC(task="binary") From 1f5110475c6d5213a9b1dd968b9891987a17540e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Dec 2024 17:09:43 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/segger/models/segger_model.py | 50 ++++++++++++++++++------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/src/segger/models/segger_model.py b/src/segger/models/segger_model.py index f535221..b53b3dd 100644 --- a/src/segger/models/segger_model.py +++ b/src/segger/models/segger_model.py @@ -4,15 +4,19 @@ from torch import Tensor from typing import Union + class SkipGAT(nn.Module): def __init__(self, in_channels, out_channels, heads, apply_activation=True): super().__init__() self.apply_activation = apply_activation - self.conv = HeteroConv({ - ('tx', 'neighbors', 'tx'): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False), - ('tx', 'belongs', 'bd'): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False), - }, aggr='sum') - self.lin = HeteroDictLinear(in_channels, out_channels * heads, types=('tx', 'bd')) + self.conv = HeteroConv( + { + ("tx", "neighbors", "tx"): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False), + ("tx", "belongs", "bd"): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False), + }, + aggr="sum", + ) + self.lin = HeteroDictLinear(in_channels, out_channels * heads, types=("tx", "bd")) def forward(self, x_dict, edge_index_dict): x_conv = self.conv(x_dict, edge_index_dict) @@ -51,16 +55,20 @@ def __init__( # Initialize node embeddings if is_token_based: # Using token-based embeddings for transcript ('tx') nodes - self.node_init = nn.ModuleDict({ - 'tx': nn.Embedding(num_node_features['tx'], init_emb), - 'bd': nn.Linear(num_node_features['bd'], init_emb), - }) + self.node_init = nn.ModuleDict( + { + "tx": nn.Embedding(num_node_features["tx"], init_emb), + "bd": nn.Linear(num_node_features["bd"], init_emb), + } + ) else: # Using scRNAseq embeddings (i.e. prior biological knowledge) for transcript ('tx') nodes - self.node_init = nn.ModuleDict({ - 'tx': nn.Linear(num_node_features['tx'], init_emb), - 'bd': nn.Linear(num_node_features['bd'], init_emb), - }) + self.node_init = nn.ModuleDict( + { + "tx": nn.Linear(num_node_features["tx"], init_emb), + "bd": nn.Linear(num_node_features["bd"], init_emb), + } + ) # First GATv2Conv layer self.conv1 = SkipGAT(init_emb, hidden_channels, heads) @@ -71,12 +79,12 @@ def __init__( self.conv_mid_layers = nn.ModuleList() for _ in range(num_mid_layers): self.conv_mid_layers.append(SkipGAT(heads * hidden_channels, hidden_channels, heads)) - + # Last GATv2Conv layer self.conv_last = SkipGAT(heads * hidden_channels, out_channels, heads) # Finalize node embeddings - self.node_final = HeteroDictLinear(heads * out_channels, out_channels, types=('tx', 'bd')) + self.node_final = HeteroDictLinear(heads * out_channels, out_channels, types=("tx", "bd")) # # Edge probability predictor # self.edge_predictor = nn.Sequential( @@ -113,10 +121,10 @@ def forward( return x_dict def decode( - self, - z_dict: dict[str, Tensor], - edge_index: Union[Tensor], - ) -> Tensor: + self, + z_dict: dict[str, Tensor], + edge_index: Union[Tensor], + ) -> Tensor: """ Decode the node embeddings to predict edge values. @@ -127,7 +135,7 @@ def decode( Returns: Tensor: Predicted edge values. """ - z_left = z_dict['tx'][edge_index[0]] - z_right = z_dict['bd'][edge_index[1]] + z_left = z_dict["tx"][edge_index[0]] + z_right = z_dict["bd"][edge_index[1]] return (z_left * z_right).sum(dim=-1) # return self.edge_predictor(torch.cat([z_left, z_right], dim=-1)).squeeze()