-
Notifications
You must be signed in to change notification settings - Fork 4
/
deeppa_trainer.py
79 lines (64 loc) · 2.44 KB
/
deeppa_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import logging
import os
import time
from typing import Optional, List, Union
import numpy as np
import torch
from torch import nn, Tensor
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam
from src.utils.logging import get_logger
from src.base.trainer import BaseTrainer
from src.utils import graph_algo
class DeepPA_Trainer(BaseTrainer):
"""
Trainer class for DeepPA model.
Args:
**args: Additional keyword arguments.
Attributes:
_optimizer: The optimizer used for training.
_supports: List of support matrices calculated based on the adjacency matrix and filter type.
"""
def __init__(self, **args):
super(DeepPA_Trainer, self).__init__(**args)
self._optimizer = Adam(self.model.parameters(), self._base_lr)
self._supports = self._calculate_supports(args["adj_mat"], args["filter_type"])
def _calculate_supports(self, adj_mat, filter_type):
"""
Calculate the support matrices based on the adjacency matrix and filter type.
Args:
adj_mat: The adjacency matrix.
filter_type: The type of filter to be applied.
Returns:
List of support matrices.
Raises:
AssertionError: If the filter type is not defined.
"""
num_nodes = adj_mat.shape[0]
new_adj = adj_mat + np.eye(num_nodes)
if filter_type == "scalap":
supports = [graph_algo.calculate_scaled_laplacian(new_adj).todense()]
elif filter_type == "normlap":
supports = [
graph_algo.calculate_normalized_laplacian(new_adj)
.astype(np.float32)
.todense()
]
elif filter_type == "symnadj":
supports = [graph_algo.sym_adj(new_adj)]
elif filter_type == "transition":
supports = [graph_algo.asym_adj(new_adj)]
elif filter_type == "doubletransition":
supports = [
graph_algo.asym_adj(new_adj),
graph_algo.asym_adj(np.transpose(new_adj)),
]
elif filter_type == "identity":
supports = [np.diag(np.ones(new_adj.shape[0])).astype(np.float32)]
else:
error = 0
assert error, "adj type not defined"
supports = [torch.tensor(i).cuda() for i in supports]
return supports