diff --git a/src/segger/models/segger_model.py b/src/segger/models/segger_model.py index f535221..b53b3dd 100644 --- a/src/segger/models/segger_model.py +++ b/src/segger/models/segger_model.py @@ -4,15 +4,19 @@ from torch import Tensor from typing import Union + class SkipGAT(nn.Module): def __init__(self, in_channels, out_channels, heads, apply_activation=True): super().__init__() self.apply_activation = apply_activation - self.conv = HeteroConv({ - ('tx', 'neighbors', 'tx'): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False), - ('tx', 'belongs', 'bd'): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False), - }, aggr='sum') - self.lin = HeteroDictLinear(in_channels, out_channels * heads, types=('tx', 'bd')) + self.conv = HeteroConv( + { + ("tx", "neighbors", "tx"): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False), + ("tx", "belongs", "bd"): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False), + }, + aggr="sum", + ) + self.lin = HeteroDictLinear(in_channels, out_channels * heads, types=("tx", "bd")) def forward(self, x_dict, edge_index_dict): x_conv = self.conv(x_dict, edge_index_dict) @@ -51,16 +55,20 @@ def __init__( # Initialize node embeddings if is_token_based: # Using token-based embeddings for transcript ('tx') nodes - self.node_init = nn.ModuleDict({ - 'tx': nn.Embedding(num_node_features['tx'], init_emb), - 'bd': nn.Linear(num_node_features['bd'], init_emb), - }) + self.node_init = nn.ModuleDict( + { + "tx": nn.Embedding(num_node_features["tx"], init_emb), + "bd": nn.Linear(num_node_features["bd"], init_emb), + } + ) else: # Using scRNAseq embeddings (i.e. prior biological knowledge) for transcript ('tx') nodes - self.node_init = nn.ModuleDict({ - 'tx': nn.Linear(num_node_features['tx'], init_emb), - 'bd': nn.Linear(num_node_features['bd'], init_emb), - }) + self.node_init = nn.ModuleDict( + { + "tx": nn.Linear(num_node_features["tx"], init_emb), + "bd": nn.Linear(num_node_features["bd"], init_emb), + } + ) # First GATv2Conv layer self.conv1 = SkipGAT(init_emb, hidden_channels, heads) @@ -71,12 +79,12 @@ def __init__( self.conv_mid_layers = nn.ModuleList() for _ in range(num_mid_layers): self.conv_mid_layers.append(SkipGAT(heads * hidden_channels, hidden_channels, heads)) - + # Last GATv2Conv layer self.conv_last = SkipGAT(heads * hidden_channels, out_channels, heads) # Finalize node embeddings - self.node_final = HeteroDictLinear(heads * out_channels, out_channels, types=('tx', 'bd')) + self.node_final = HeteroDictLinear(heads * out_channels, out_channels, types=("tx", "bd")) # # Edge probability predictor # self.edge_predictor = nn.Sequential( @@ -113,10 +121,10 @@ def forward( return x_dict def decode( - self, - z_dict: dict[str, Tensor], - edge_index: Union[Tensor], - ) -> Tensor: + self, + z_dict: dict[str, Tensor], + edge_index: Union[Tensor], + ) -> Tensor: """ Decode the node embeddings to predict edge values. @@ -127,7 +135,7 @@ def decode( Returns: Tensor: Predicted edge values. """ - z_left = z_dict['tx'][edge_index[0]] - z_right = z_dict['bd'][edge_index[1]] + z_left = z_dict["tx"][edge_index[0]] + z_right = z_dict["bd"][edge_index[1]] return (z_left * z_right).sum(dim=-1) # return self.edge_predictor(torch.cat([z_left, z_right], dim=-1)).squeeze()