Skip to content

Commit

Permalink
Merge pull request #71 from daniel-unyi-42/main
Browse files Browse the repository at this point in the history
Add new parameters to training
  • Loading branch information
daniel-unyi-42 authored Dec 15, 2024
2 parents d091fb4 + 1f51104 commit 17d378a
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 123 deletions.
6 changes: 5 additions & 1 deletion scripts/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,18 @@ training:
num_workers: 12
accelerator: "cuda"
max_epochs: 200
save_best_model: true
learning_rate: 1e-3
pretrained_model_dir: null
pretrained_model_version: 0
devices: 8
strategy: "auto"
precision: "16-mixed"
memory: "16G" # this is ignored if use_lsf is false
gpu_memory: "8G" # this is ignored if use_lsf is false

prediction:
output_log: "predict_parquet_final_output.log"
output_log: "predict_parquet_output.log"
segger_data_dir: "data_segger"
models_dir: "model_dir"
benchmarks_dir: "benchmark_dir"
Expand Down
9 changes: 9 additions & 0 deletions scripts/submit_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ def run_training():
config["training"]["accelerator"],
"--max_epochs",
str(config["training"]["max_epochs"]),
"--save_best_model",
str(config["training"]["save_best_model"]),
"--learning_rate",
str(config["training"]["learning_rate"]),
"--devices",
str(config["training"]["devices"]),
"--strategy",
Expand All @@ -160,6 +164,11 @@ def run_training():
]
)

if config["training"].get("pretrained_model_dir") is not None:
command.extend(["--pretrained_model_dir", config["training"]["pretrained_model_dir"]])
if config["training"].get("pretrained_model_version") is not None:
command.extend(["--pretrained_model_version", str(config["training"]["pretrained_model_version"])])

if config.get("use_lsf", False):
command = [
"bsub",
Expand Down
16 changes: 16 additions & 0 deletions src/segger/cli/configs/train/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ max_epochs:
type: int
default: 200
help: Number of epochs for training.
save_best_model:
type: bool
default: true
help: Whether to save the best model.
learning_rate:
type: float
default: 1e-3
help: Learning rate for training.
pretrained_model_dir:
type: Path
default: null
help: Directory containing the pretrained model to use (if any).
pretrained_model_version:
type: int
default: null
help: Version of pretrained model.
devices:
type: int
default: 4
Expand Down
62 changes: 43 additions & 19 deletions src/segger/cli/train_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import click
import typing
import os
from segger.cli.utils import add_options, CustomFormatter
from pathlib import Path
import logging
Expand Down Expand Up @@ -31,6 +29,15 @@
"--accelerator", type=str, default="cuda", help='Device type to use for training (e.g., "cuda", "cpu").'
) # Ask for accelerator
@click.option("--max_epochs", type=int, default=200, help="Number of epochs for training.")
@click.option("--save_best_model", type=bool, default=True, help="Whether to save the best model.") # unused for now
@click.option("--learning_rate", type=float, default=1e-3, help="Learning rate for training.")
@click.option(
"--pretrained_model_dir",
type=Path,
default=None,
help="Directory containing the pretrained modelDirectory containing the pretrained model to use (if any).",
)
@click.option("--pretrained_model_version", type=int, default=None, help="Version of pretrained model.")
@click.option("--devices", type=int, default=4, help="Number of devices (GPUs) to use.")
@click.option("--strategy", type=str, default="auto", help="Training strategy for the trainer.")
@click.option("--precision", type=str, default="16-mixed", help="Precision for training.")
Expand All @@ -44,8 +51,10 @@ def train_model(args: Namespace):

# Import packages
logging.info("Importing packages...")
import torch
from segger.training.train import LitSegger
from segger.training.segger_data_module import SeggerDataModule
from segger.prediction.predict_parquet import load_model
from lightning.pytorch.loggers import CSVLogger
from pytorch_lightning import Trainer

Expand All @@ -63,23 +72,38 @@ def train_model(args: Namespace):
logging.info("Done.")

# Initialize model
logging.info("Initializing Segger model and trainer...")
metadata = (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")])
ls = LitSegger(
num_tx_tokens=args.num_tx_tokens,
init_emb=args.init_emb,
hidden_channels=args.hidden_channels,
out_channels=args.out_channels, # Hard-coded value
heads=args.heads, # Hard-coded value
num_mid_layers=args.num_mid_layers, # Hard-coded value
aggr="sum", # Hard-coded value
metadata=metadata,
)

# Forward pass to initialize the model
if args.devices > 1:
batch = dm.train[0]
ls.forward(batch)
if args.pretrained_model_dir is not None:
logging.info("Loading pretrained model...")
ls = load_model(args.pretrained_model_dir / "lightning_logs" / f"version_{args.model_version}" / "checkpoints")
else:
logging.info("Creating new model...")
is_token_based = dm.train[0].x_dict["tx"].ndim == 1
if is_token_based:
# if the model is token-based, the input is a 1D tensor of token indices
assert dm.train[0].x_dict["tx"].ndim == 1
assert dm.train[0].x_dict["tx"].dtype == torch.long
num_tx_features = args.num_tx_tokens
print("Using token-based embeddings as node features, number of tokens: ", num_tx_features)
else:
# if the model is not token-based, the input is a 2D tensor of scRNAseq embeddings
assert dm.train[0].x_dict["tx"].ndim == 2
assert dm.train[0].x_dict["tx"].dtype == torch.float32
num_tx_features = dm.train[0].x_dict["tx"].shape[1]
print("Using scRNAseq embeddings as node features, number of features: ", num_tx_features)
num_bd_features = dm.train[0].x_dict["bd"].shape[1]
print("Number of boundary node features: ", num_bd_features)
ls = LitSegger(
is_token_based=is_token_based,
num_node_features={"tx": num_tx_features, "bd": num_bd_features},
init_emb=args.init_emb,
hidden_channels=args.hidden_channels,
out_channels=args.out_channels,
heads=args.heads,
num_mid_layers=args.num_mid_layers,
aggr="sum", # Hard-coded value
learning_rate=args.learning_rate,
)
logging.info("Done.")

# Initialize the Lightning trainer
trainer = Trainer(
Expand Down
4 changes: 0 additions & 4 deletions src/segger/data/parquet/pyg_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ def get(self, idx: int) -> Data:
"""
filepath = Path(self.processed_dir) / self.processed_file_names[idx]
data = torch.load(filepath)
data["tx"].x = data["tx"].x.to_dense()
if data["tx"].x.dim() == 1:
data["tx"].x = data["tx"].x.unsqueeze(1)
assert data["tx"].x.dim() == 2
# this is an issue in PyG's RandomLinkSplit, dimensions are not consistent if there is only one edge in the graph
if hasattr(data["tx", "belongs", "bd"], "edge_label_index"):
if data["tx", "belongs", "bd"].edge_label_index.dim() == 1:
Expand Down
143 changes: 92 additions & 51 deletions src/segger/models/segger_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,37 @@
import torch
from torch_geometric.nn import GATv2Conv, Linear
from torch.nn import Embedding
from torch_geometric.nn import HeteroConv, GATv2Conv, HeteroDictLinear
from torch import nn
from torch import Tensor
from typing import Union

# from torch_sparse import SparseTensor


class Segger(torch.nn.Module):
class SkipGAT(nn.Module):
def __init__(self, in_channels, out_channels, heads, apply_activation=True):
super().__init__()
self.apply_activation = apply_activation
self.conv = HeteroConv(
{
("tx", "neighbors", "tx"): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False),
("tx", "belongs", "bd"): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False),
},
aggr="sum",
)
self.lin = HeteroDictLinear(in_channels, out_channels * heads, types=("tx", "bd"))

def forward(self, x_dict, edge_index_dict):
x_conv = self.conv(x_dict, edge_index_dict)
x_lin = self.lin(x_dict)
x_dict = {key: x_conv[key] + x_lin[key] for key in x_dict}
if self.apply_activation:
x_dict = {key: x_dict[key].relu() for key in x_dict}
return x_dict


class Segger(nn.Module):
def __init__(
self,
num_tx_tokens: int,
is_token_based: int,
num_node_features: dict[str, int],
init_emb: int = 16,
hidden_channels: int = 32,
num_mid_layers: int = 3,
Expand All @@ -21,80 +42,100 @@ def __init__(
Initializes the Segger model.
Args:
num_tx_tokens (int) : Number of unique 'tx' tokens for embedding.
init_emb (int) : Initial embedding size for both 'tx' and boundary (non-token) nodes.
hidden_channels (int): Number of hidden channels.
num_mid_layers (int) : Number of hidden layers (excluding first and last layers).
out_channels (int) : Number of output channels.
heads (int) : Number of attention heads.
is_token_based (int) : Whether the model is using token-based embeddings or scRNAseq embeddings.
num_node_features (dict[str, int]): Number of node features for each node type.
init_emb (int) : Initial embedding size for both 'tx' and boundary (non-token) nodes.
hidden_channels (int) : Number of hidden channels.
num_mid_layers (int) : Number of hidden layers (excluding first and last layers).
out_channels (int) : Number of output channels.
heads (int) : Number of attention heads.
"""
super().__init__()

# Embedding for 'tx' (transcript) nodes
self.tx_embedding = Embedding(num_tx_tokens, init_emb)

# Linear layer for boundary (non-token) nodes
self.lin0 = Linear(-1, init_emb, bias=False)
# Initialize node embeddings
if is_token_based:
# Using token-based embeddings for transcript ('tx') nodes
self.node_init = nn.ModuleDict(
{
"tx": nn.Embedding(num_node_features["tx"], init_emb),
"bd": nn.Linear(num_node_features["bd"], init_emb),
}
)
else:
# Using scRNAseq embeddings (i.e. prior biological knowledge) for transcript ('tx') nodes
self.node_init = nn.ModuleDict(
{
"tx": nn.Linear(num_node_features["tx"], init_emb),
"bd": nn.Linear(num_node_features["bd"], init_emb),
}
)

# First GATv2Conv layer
self.conv_first = GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)
self.lin_first = Linear(-1, hidden_channels * heads)
self.conv1 = SkipGAT(init_emb, hidden_channels, heads)

# Middle GATv2Conv layers
self.num_mid_layers = num_mid_layers
if num_mid_layers > 0:
self.conv_mid_layers = torch.nn.ModuleList()
self.lin_mid_layers = torch.nn.ModuleList()
self.conv_mid_layers = nn.ModuleList()
for _ in range(num_mid_layers):
self.conv_mid_layers.append(GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False))
self.lin_mid_layers.append(Linear(-1, hidden_channels * heads))
self.conv_mid_layers.append(SkipGAT(heads * hidden_channels, hidden_channels, heads))

# Last GATv2Conv layer
self.conv_last = GATv2Conv((-1, -1), out_channels, heads=heads, add_self_loops=False)
self.lin_last = Linear(-1, out_channels * heads)
self.conv_last = SkipGAT(heads * hidden_channels, out_channels, heads)

# Finalize node embeddings
self.node_final = HeteroDictLinear(heads * out_channels, out_channels, types=("tx", "bd"))

def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
# # Edge probability predictor
# self.edge_predictor = nn.Sequential(
# nn.Linear(2 * out_channels, out_channels),
# nn.ReLU(),
# nn.Linear(out_channels, 1),
# )

def forward(
self,
x_dict: dict[str, Tensor],
edge_index_dict: dict[str, Tensor],
) -> dict[str, Tensor]:
"""
Forward pass for the Segger model.
Args:
x (Tensor): Node features.
edge_index (Tensor): Edge indices.
Returns:
Tensor: Output node embeddings.
x_dict (dict[str, Tensor]): Node features for each node type.
edge_index_dict (dict[str, Tensor]): Edge indices for each edge type.
"""
x = torch.nan_to_num(x, nan=0)
is_one_dim = (x.ndim == 1) * 1
# x = x[:, None]
x = self.tx_embedding(((x.sum(1) * is_one_dim).int())) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim)
# First layer
x = x.relu()
x = self.conv_first(x, edge_index) + self.lin_first(x)
x = x.relu()

# Middle layers

x_dict = {key: self.node_init[key](x) for key, x in x_dict.items()}

x_dict = self.conv1(x_dict, edge_index_dict)

if self.num_mid_layers > 0:
for i in range(self.num_mid_layers):
conv_mid = self.conv_mid_layers[i]
lin_mid = self.lin_mid_layers[i]
x = conv_mid(x, edge_index) + lin_mid(x)
x = x.relu()
x_dict = self.conv_mid_layers[i](x_dict, edge_index_dict)

x_dict = self.conv_last(x_dict, edge_index_dict)

# Last layer
x = self.conv_last(x, edge_index) + self.lin_last(x)
x_dict = self.node_final(x_dict)

return x
return x_dict

def decode(self, z: Tensor, edge_index: Union[Tensor]) -> Tensor:
def decode(
self,
z_dict: dict[str, Tensor],
edge_index: Union[Tensor],
) -> Tensor:
"""
Decode the node embeddings to predict edge values.
Args:
z (Tensor): Node embeddings.
z (dict[str, Tensor]): Node embeddings for each node type.
edge_index (EdgeIndex): Edge label indices.
Returns:
Tensor: Predicted edge values.
"""
return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
z_left = z_dict["tx"][edge_index[0]]
z_right = z_dict["bd"][edge_index[1]]
return (z_left * z_right).sum(dim=-1)
# return self.edge_predictor(torch.cat([z_left, z_right], dim=-1)).squeeze()
2 changes: 2 additions & 0 deletions src/segger/prediction/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional as F
import torch._dynamo
import gc

# import rmm
import re
import glob
Expand All @@ -28,6 +29,7 @@
from dask.diagnostics import ProgressBar
import time
import dask

# from rmm.allocators.cupy import rmm_cupy_allocator
from cupyx.scipy.sparse import coo_matrix
from torch.utils.dlpack import to_dlpack, from_dlpack
Expand Down
Loading

0 comments on commit 17d378a

Please sign in to comment.