Skip to content

Commit

Permalink
Model training fix and code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-unyi-42 committed Dec 15, 2024
1 parent 6ab4e5f commit 0787167
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 103 deletions.
41 changes: 24 additions & 17 deletions src/segger/cli/train_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import click
import typing
import os
from segger.cli.utils import add_options, CustomFormatter
from pathlib import Path
import logging
Expand Down Expand Up @@ -53,8 +51,10 @@ def train_model(args: Namespace):

# Import packages
logging.info("Importing packages...")
import torch
from segger.training.train import LitSegger
from segger.training.segger_data_module import SeggerDataModule
from segger.prediction.predict_parquet import load_model
from lightning.pytorch.loggers import CSVLogger
from pytorch_lightning import Trainer

Expand All @@ -72,31 +72,38 @@ def train_model(args: Namespace):
logging.info("Done.")

# Initialize model
logging.info("Initializing Segger model and trainer...")
metadata = (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")])

if args.pretrained_model_dir is not None:
logging.info("Loading pretrained model...")
from segger.prediction.predict_parquet import load_model

ls = load_model(args.pretrained_model_dir / "lightning_logs" / f"version_{args.model_version}" / "checkpoints")
else:
logging.info("Creating new model...")
is_token_based = dm.train[0].x_dict["tx"].ndim == 1
if is_token_based:
# if the model is token-based, the input is a 1D tensor of token indices
assert dm.train[0].x_dict["tx"].ndim == 1
assert dm.train[0].x_dict["tx"].dtype == torch.long
num_tx_features = args.num_tx_tokens
print("Using token-based embeddings as node features, number of tokens: ", num_tx_features)
else:
# if the model is not token-based, the input is a 2D tensor of scRNAseq embeddings
assert dm.train[0].x_dict["tx"].ndim == 2
assert dm.train[0].x_dict["tx"].dtype == torch.float32
num_tx_features = dm.train[0].x_dict["tx"].shape[1]
print("Using scRNAseq embeddings as node features, number of features: ", num_tx_features)
num_bd_features = dm.train[0].x_dict["bd"].shape[1]
print("Number of boundary node features: ", num_bd_features)
ls = LitSegger(
num_tx_tokens=args.num_tx_tokens,
is_token_based=is_token_based,
num_node_features={"tx": num_tx_features, "bd": num_bd_features},
init_emb=args.init_emb,
hidden_channels=args.hidden_channels,
out_channels=args.out_channels, # Hard-coded value
heads=args.heads, # Hard-coded value
num_mid_layers=args.num_mid_layers, # Hard-coded value
out_channels=args.out_channels,
heads=args.heads,
num_mid_layers=args.num_mid_layers,
aggr="sum", # Hard-coded value
learning_rate=args.learning_rate,
metadata=metadata,
)

# Forward pass to initialize the model
if args.devices > 1:
batch = dm.train[0].to(ls.device)
ls.forward(batch)
logging.info("Done.")

# Initialize the Lightning trainer
trainer = Trainer(
Expand Down
4 changes: 0 additions & 4 deletions src/segger/data/parquet/pyg_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ def get(self, idx: int) -> Data:
"""
filepath = Path(self.processed_dir) / self.processed_file_names[idx]
data = torch.load(filepath)
data["tx"].x = data["tx"].x.to_dense()
if data["tx"].x.dim() == 1:
data["tx"].x = data["tx"].x.unsqueeze(1)
assert data["tx"].x.dim() == 2
# this is an issue in PyG's RandomLinkSplit, dimensions are not consistent if there is only one edge in the graph
if hasattr(data["tx", "belongs", "bd"], "edge_label_index"):
if data["tx", "belongs", "bd"].edge_label_index.dim() == 1:
Expand Down
139 changes: 86 additions & 53 deletions src/segger/models/segger_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,33 @@
import torch
from torch_geometric.nn import GATv2Conv, Linear
from torch.nn import Embedding
from torch_geometric.nn import HeteroConv, GATv2Conv, HeteroDictLinear
from torch import nn
from torch import Tensor
from typing import Union

# from torch_sparse import SparseTensor


class Segger(torch.nn.Module):
class SkipGAT(nn.Module):
def __init__(self, in_channels, out_channels, heads, apply_activation=True):
super().__init__()
self.apply_activation = apply_activation
self.conv = HeteroConv({
('tx', 'neighbors', 'tx'): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False),
('tx', 'belongs', 'bd'): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False),
}, aggr='sum')
self.lin = HeteroDictLinear(in_channels, out_channels * heads, types=('tx', 'bd'))

def forward(self, x_dict, edge_index_dict):
x_conv = self.conv(x_dict, edge_index_dict)
x_lin = self.lin(x_dict)
x_dict = {key: x_conv[key] + x_lin[key] for key in x_dict}
if self.apply_activation:
x_dict = {key: x_dict[key].relu() for key in x_dict}
return x_dict


class Segger(nn.Module):
def __init__(
self,
num_tx_tokens: int,
is_token_based: int,
num_node_features: dict[str, int],
init_emb: int = 16,
hidden_channels: int = 32,
num_mid_layers: int = 3,
Expand All @@ -21,80 +38,96 @@ def __init__(
Initializes the Segger model.
Args:
num_tx_tokens (int) : Number of unique 'tx' tokens for embedding.
init_emb (int) : Initial embedding size for both 'tx' and boundary (non-token) nodes.
hidden_channels (int): Number of hidden channels.
num_mid_layers (int) : Number of hidden layers (excluding first and last layers).
out_channels (int) : Number of output channels.
heads (int) : Number of attention heads.
is_token_based (int) : Whether the model is using token-based embeddings or scRNAseq embeddings.
num_node_features (dict[str, int]): Number of node features for each node type.
init_emb (int) : Initial embedding size for both 'tx' and boundary (non-token) nodes.
hidden_channels (int) : Number of hidden channels.
num_mid_layers (int) : Number of hidden layers (excluding first and last layers).
out_channels (int) : Number of output channels.
heads (int) : Number of attention heads.
"""
super().__init__()

# Embedding for 'tx' (transcript) nodes
self.tx_embedding = Embedding(num_tx_tokens, init_emb)

# Linear layer for boundary (non-token) nodes
self.lin0 = Linear(-1, init_emb, bias=False)
# Initialize node embeddings
if is_token_based:
# Using token-based embeddings for transcript ('tx') nodes
self.node_init = nn.ModuleDict({
'tx': nn.Embedding(num_node_features['tx'], init_emb),
'bd': nn.Linear(num_node_features['bd'], init_emb),
})
else:
# Using scRNAseq embeddings (i.e. prior biological knowledge) for transcript ('tx') nodes
self.node_init = nn.ModuleDict({
'tx': nn.Linear(num_node_features['tx'], init_emb),
'bd': nn.Linear(num_node_features['bd'], init_emb),
})

# First GATv2Conv layer
self.conv_first = GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)
self.lin_first = Linear(-1, hidden_channels * heads)
self.conv1 = SkipGAT(init_emb, hidden_channels, heads)

# Middle GATv2Conv layers
self.num_mid_layers = num_mid_layers
if num_mid_layers > 0:
self.conv_mid_layers = torch.nn.ModuleList()
self.lin_mid_layers = torch.nn.ModuleList()
self.conv_mid_layers = nn.ModuleList()
for _ in range(num_mid_layers):
self.conv_mid_layers.append(GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False))
self.lin_mid_layers.append(Linear(-1, hidden_channels * heads))

self.conv_mid_layers.append(SkipGAT(heads * hidden_channels, hidden_channels, heads))

# Last GATv2Conv layer
self.conv_last = GATv2Conv((-1, -1), out_channels, heads=heads, add_self_loops=False)
self.lin_last = Linear(-1, out_channels * heads)
self.conv_last = SkipGAT(heads * hidden_channels, out_channels, heads)

# Finalize node embeddings
self.node_final = HeteroDictLinear(heads * out_channels, out_channels, types=('tx', 'bd'))

# # Edge probability predictor
# self.edge_predictor = nn.Sequential(
# nn.Linear(2 * out_channels, out_channels),
# nn.ReLU(),
# nn.Linear(out_channels, 1),
# )

def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
def forward(
self,
x_dict: dict[str, Tensor],
edge_index_dict: dict[str, Tensor],
) -> dict[str, Tensor]:
"""
Forward pass for the Segger model.
Args:
x (Tensor): Node features.
edge_index (Tensor): Edge indices.
Returns:
Tensor: Output node embeddings.
x_dict (dict[str, Tensor]): Node features for each node type.
edge_index_dict (dict[str, Tensor]): Edge indices for each edge type.
"""
x = torch.nan_to_num(x, nan=0)
is_one_dim = (x.ndim == 1) * 1
# x = x[:, None]
x = self.tx_embedding(((x.sum(1) * is_one_dim).int())) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim)
# First layer
x = x.relu()
x = self.conv_first(x, edge_index) + self.lin_first(x)
x = x.relu()

# Middle layers

x_dict = {key: self.node_init[key](x) for key, x in x_dict.items()}

x_dict = self.conv1(x_dict, edge_index_dict)

if self.num_mid_layers > 0:
for i in range(self.num_mid_layers):
conv_mid = self.conv_mid_layers[i]
lin_mid = self.lin_mid_layers[i]
x = conv_mid(x, edge_index) + lin_mid(x)
x = x.relu()
x_dict = self.conv_mid_layers[i](x_dict, edge_index_dict)

x_dict = self.conv_last(x_dict, edge_index_dict)

# Last layer
x = self.conv_last(x, edge_index) + self.lin_last(x)
x_dict = self.node_final(x_dict)

return x
return x_dict

def decode(self, z: Tensor, edge_index: Union[Tensor]) -> Tensor:
def decode(
self,
z_dict: dict[str, Tensor],
edge_index: Union[Tensor],
) -> Tensor:
"""
Decode the node embeddings to predict edge values.
Args:
z (Tensor): Node embeddings.
z (dict[str, Tensor]): Node embeddings for each node type.
edge_index (EdgeIndex): Edge label indices.
Returns:
Tensor: Predicted edge values.
"""
return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
z_left = z_dict['tx'][edge_index[0]]
z_right = z_dict['bd'][edge_index[1]]
return (z_left * z_right).sum(dim=-1)
# return self.edge_predictor(torch.cat([z_left, z_right], dim=-1)).squeeze()
53 changes: 24 additions & 29 deletions src/segger/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
import torchmetrics
from torchmetrics import F1Score
import lightning as L
from torch_geometric.loader import DataLoader
from torch_geometric.typing import Metadata
from torch_geometric.nn import to_hetero
from torch_geometric.data import HeteroData
from segger.models.segger_model import *
from segger.data.utils import SpatialTranscriptomicsDataset
from typing import Any, List, Tuple, Union
Expand Down Expand Up @@ -65,22 +61,24 @@ def __init__(self, learning_rate: float = 1e-3, **kwargs):

def from_new(
self,
num_tx_tokens: int,
is_token_based: int,
num_node_features: dict[str, int],
init_emb: int,
hidden_channels: int,
out_channels: int,
heads: int,
num_mid_layers: int,
aggr: str,
metadata: Union[Tuple, Metadata],
):
"""
Initializes the LitSegger module with new parameters.
Parameters
----------
num_tx_tokens : int
Number of unique 'tx' tokens for embedding (this must be passed here).
is_token_based : int
Whether the model is using token-based embeddings or scRNAseq embeddings.
num_node_features : dict[str, int]
Number of node features for each node type.
init_emb : int
Initial embedding size.
hidden_channels : int
Expand All @@ -97,17 +95,15 @@ def from_new(
Metadata for heterogeneous graph structure.
"""
# Create the Segger model (ensure num_tx_tokens is passed here)
model = Segger(
num_tx_tokens=num_tx_tokens, # This is required and must be passed here
self.model = Segger(
is_token_based=is_token_based,
num_node_features=num_node_features,
init_emb=init_emb,
hidden_channels=hidden_channels,
out_channels=out_channels,
heads=heads,
num_mid_layers=num_mid_layers,
)
# Convert model to handle heterogeneous graphs
model = to_hetero(model, metadata=metadata, aggr=aggr)
self.model = model
# Save hyperparameters
self.save_hyperparameters()

Expand Down Expand Up @@ -137,7 +133,8 @@ def forward(self, batch: SpatialTranscriptomicsDataset) -> torch.Tensor:
The output of the model.
"""
z = self.model(batch.x_dict, batch.edge_index_dict)
output = torch.matmul(z["tx"], z["bd"].t()) # Example for bipartite graph
edge_label_index = batch["tx", "belongs", "bd"].edge_label_index
output = self.model.decode(z, edge_label_index)
return output

def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
Expand All @@ -156,17 +153,16 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
torch.Tensor
The loss value for the current training step.
"""
# Forward pass to get the logits
z = self.model(batch.x_dict, batch.edge_index_dict)
output = torch.matmul(z["tx"], z["bd"].t())

# Get edge labels and logits
# Get edge labels
edge_label_index = batch["tx", "belongs", "bd"].edge_label_index
out_values = output[edge_label_index[0], edge_label_index[1]]
edge_label = batch["tx", "belongs", "bd"].edge_label

# Forward pass to get the logits
z = self.model(batch.x_dict, batch.edge_index_dict)
output = self.model.decode(z, edge_label_index)

# Compute binary cross-entropy loss with logits (no sigmoid here)
loss = self.criterion(out_values, edge_label)
loss = self.criterion(output, edge_label)

# Log the training loss
self.log("train_loss", loss, prog_bar=True, batch_size=batch.num_graphs)
Expand All @@ -188,20 +184,19 @@ def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
torch.Tensor
The loss value for the current validation step.
"""
# Forward pass to get the logits
z = self.model(batch.x_dict, batch.edge_index_dict)
output = torch.matmul(z["tx"], z["bd"].t())

# Get edge labels and logits
# Get edge labels
edge_label_index = batch["tx", "belongs", "bd"].edge_label_index
out_values = output[edge_label_index[0], edge_label_index[1]]
edge_label = batch["tx", "belongs", "bd"].edge_label

# Forward pass to get the logits
z = self.model(batch.x_dict, batch.edge_index_dict)
output = self.model.decode(z, edge_label_index)

# Compute binary cross-entropy loss with logits (no sigmoid here)
loss = self.criterion(out_values, edge_label)
loss = self.criterion(output, edge_label)

# Apply sigmoid to logits for AUROC and F1 metrics
out_values_prob = torch.sigmoid(out_values)
out_values_prob = torch.sigmoid(output)

# Compute metrics
auroc = torchmetrics.AUROC(task="binary")
Expand Down

0 comments on commit 0787167

Please sign in to comment.