-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathggnn_utils.py
46 lines (37 loc) · 1.39 KB
/
ggnn_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch_geometric.data import Batch, Data
class SimpleBatch(list):
def to(self, device):
for ele in self:
ele.to(device)
return self
def collate_wrapper(x, y, edge_index, edge_attr, device, return_y=True):
x = torch.tensor(x, dtype=torch.float, device=device)
y = torch.tensor(y, dtype=torch.float, device=device)
x = x.transpose(dim0=1, dim1=0) # (T, N, num_nodes, num_features)
y_T_first = y.transpose(dim0=1, dim1=0) # (T, N, num_nodes, num_features)
# do not tranpose y_truth
T = x.size()[0]
N = x.size()[1]
# generate batched sequence.
sequences = []
for t in range(T):
cur_batch_x = x[t]
cur_batch_y = y_T_first[t]
batch = Batch.from_data_list([
Data(x=cur_batch_x[i],
edge_index=edge_index,
edge_attr=edge_attr,
y=cur_batch_y[i]) for i in range(N)
])
sequences.append(batch)
if return_y:
return SimpleBatch(sequences), y
else:
return SimpleBatch(sequences)
def collate_wrapper_multi_branches(x_numpy, y_numpy, edge_index_list, device):
sequences_multi_branches = []
for edge_index in edge_index_list:
sequences, y = collate_wrapper(x_numpy, y_numpy, edge_index, device, return_y=True)
sequences_multi_branches.append(sequences)
return sequences_multi_branches, y