Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 15, 2024
1 parent 0787167 commit 1f51104
Showing 1 changed file with 29 additions and 21 deletions.
50 changes: 29 additions & 21 deletions src/segger/models/segger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
from torch import Tensor
from typing import Union


class SkipGAT(nn.Module):
def __init__(self, in_channels, out_channels, heads, apply_activation=True):
super().__init__()
self.apply_activation = apply_activation
self.conv = HeteroConv({
('tx', 'neighbors', 'tx'): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False),
('tx', 'belongs', 'bd'): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False),
}, aggr='sum')
self.lin = HeteroDictLinear(in_channels, out_channels * heads, types=('tx', 'bd'))
self.conv = HeteroConv(
{
("tx", "neighbors", "tx"): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False),
("tx", "belongs", "bd"): GATv2Conv(in_channels, out_channels, heads=heads, add_self_loops=False),
},
aggr="sum",
)
self.lin = HeteroDictLinear(in_channels, out_channels * heads, types=("tx", "bd"))

def forward(self, x_dict, edge_index_dict):
x_conv = self.conv(x_dict, edge_index_dict)
Expand Down Expand Up @@ -51,16 +55,20 @@ def __init__(
# Initialize node embeddings
if is_token_based:
# Using token-based embeddings for transcript ('tx') nodes
self.node_init = nn.ModuleDict({
'tx': nn.Embedding(num_node_features['tx'], init_emb),
'bd': nn.Linear(num_node_features['bd'], init_emb),
})
self.node_init = nn.ModuleDict(
{
"tx": nn.Embedding(num_node_features["tx"], init_emb),
"bd": nn.Linear(num_node_features["bd"], init_emb),
}
)
else:
# Using scRNAseq embeddings (i.e. prior biological knowledge) for transcript ('tx') nodes
self.node_init = nn.ModuleDict({
'tx': nn.Linear(num_node_features['tx'], init_emb),
'bd': nn.Linear(num_node_features['bd'], init_emb),
})
self.node_init = nn.ModuleDict(
{
"tx": nn.Linear(num_node_features["tx"], init_emb),
"bd": nn.Linear(num_node_features["bd"], init_emb),
}
)

# First GATv2Conv layer
self.conv1 = SkipGAT(init_emb, hidden_channels, heads)
Expand All @@ -71,12 +79,12 @@ def __init__(
self.conv_mid_layers = nn.ModuleList()
for _ in range(num_mid_layers):
self.conv_mid_layers.append(SkipGAT(heads * hidden_channels, hidden_channels, heads))

# Last GATv2Conv layer
self.conv_last = SkipGAT(heads * hidden_channels, out_channels, heads)

# Finalize node embeddings
self.node_final = HeteroDictLinear(heads * out_channels, out_channels, types=('tx', 'bd'))
self.node_final = HeteroDictLinear(heads * out_channels, out_channels, types=("tx", "bd"))

# # Edge probability predictor
# self.edge_predictor = nn.Sequential(
Expand Down Expand Up @@ -113,10 +121,10 @@ def forward(
return x_dict

def decode(
self,
z_dict: dict[str, Tensor],
edge_index: Union[Tensor],
) -> Tensor:
self,
z_dict: dict[str, Tensor],
edge_index: Union[Tensor],
) -> Tensor:
"""
Decode the node embeddings to predict edge values.
Expand All @@ -127,7 +135,7 @@ def decode(
Returns:
Tensor: Predicted edge values.
"""
z_left = z_dict['tx'][edge_index[0]]
z_right = z_dict['bd'][edge_index[1]]
z_left = z_dict["tx"][edge_index[0]]
z_right = z_dict["bd"][edge_index[1]]
return (z_left * z_right).sum(dim=-1)
# return self.edge_predictor(torch.cat([z_left, z_right], dim=-1)).squeeze()

0 comments on commit 1f51104

Please sign in to comment.