-
Notifications
You must be signed in to change notification settings - Fork 26
/
odegcn.py
74 lines (59 loc) · 2.3 KB
/
odegcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
# Whether use adjoint method or not.
adjoint = False
if adjoint:
from torchdiffeq import odeint_adjoint as odeint
else:
from torchdiffeq import odeint
# Define the ODE function.
# Input:
# --- t: A tensor with shape [], meaning the current time.
# --- x: A tensor with shape [#batches, dims], meaning the value of x at t.
# Output:
# --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t.
class ODEFunc(nn.Module):
def __init__(self, feature_dim, temporal_dim, adj):
super(ODEFunc, self).__init__()
self.adj = adj
self.x0 = None
self.alpha = nn.Parameter(0.8 * torch.ones(adj.shape[1]))
self.beta = 0.6
self.w = nn.Parameter(torch.eye(feature_dim))
self.d = nn.Parameter(torch.zeros(feature_dim) + 1)
self.w2 = nn.Parameter(torch.eye(temporal_dim))
self.d2 = nn.Parameter(torch.zeros(temporal_dim) + 1)
def forward(self, t, x):
alpha = torch.sigmoid(self.alpha).unsqueeze(-1).unsqueeze(-1).unsqueeze(0)
xa = torch.einsum('ij, kjlm->kilm', self.adj, x)
# ensure the eigenvalues to be less than 1
d = torch.clamp(self.d, min=0, max=1)
w = torch.mm(self.w * d, torch.t(self.w))
xw = torch.einsum('ijkl, lm->ijkm', x, w)
d2 = torch.clamp(self.d2, min=0, max=1)
w2 = torch.mm(self.w2 * d2, torch.t(self.w2))
xw2 = torch.einsum('ijkl, km->ijml', x, w2)
f = alpha / 2 * xa - x + xw - x + xw2 - x + self.x0
return f
class ODEblock(nn.Module):
def __init__(self, odefunc, t=torch.tensor([0,1])):
super(ODEblock, self).__init__()
self.t = t
self.odefunc = odefunc
def set_x0(self, x0):
self.odefunc.x0 = x0.clone().detach()
def forward(self, x):
t = self.t.type_as(x)
z = odeint(self.odefunc, x, t, method='euler')[1]
return z
# Define the ODEGCN model.
class ODEG(nn.Module):
def __init__(self, feature_dim, temporal_dim, adj, time):
super(ODEG, self).__init__()
self.odeblock = ODEblock(ODEFunc(feature_dim, temporal_dim, adj), t=torch.tensor([0, time]))
def forward(self, x):
self.odeblock.set_x0(x)
z = self.odeblock(x)
return F.relu(z)