Skip to content

Commit

Permalink
Geml (#210)
Browse files Browse the repository at this point in the history
* Rebase GEML to master

* Denormalize

* Repair 3 problems:
  1. The parameter of MutiLearning Layer is not with grad.
  2. The SLSTM Layer' hidden layer , cell layer, and output is wrong.
  3. The geo weight matrix wrong according to the paper.
And speed up the calculation of sem weight matrix

* checkout unrelated file

* checkout unrelated file

* checkout unrelated file

* Deal with the batch_size error.
  • Loading branch information
Apolsus authored Nov 21, 2021
1 parent fc91e66 commit 9670c6a
Showing 1 changed file with 42 additions and 29 deletions.
71 changes: 42 additions & 29 deletions libcity/model/traffic_od_prediction/GEML.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from libcity.model import loss
from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel

# TODO Adjust parameters to make the model converge, or denormalization the loss

class SLSTM(nn.Module):
def __init__(self, feature_dim, hidden_dim, device, p_interval):
super(SLSTM, self).__init__()
Expand All @@ -29,18 +29,22 @@ def __init__(self, feature_dim, hidden_dim, device, p_interval):
nn.Tanh()
)

self.tanh = nn.Tanh()

self.device = device

def forward(self, x):
# (T, B * N, 2E)
h_cur = 0
h = torch.zeros((x.shape[1], self.hidden_dim)).repeat(self.p_interval, 1).unsqueeze(dim=0).to(
self.device) # (B * N, 2E)
c = torch.zeros((x.shape[1], self.cell_dim)).to(self.device) # (B * N, 2E)
h = torch.zeros((x.shape[1], self.hidden_dim)).unsqueeze(dim=0).repeat(self.p_interval, 1, 1).to(self.device)
# (P, B * N, 2E)
c = torch.zeros((x.shape[1], self.hidden_dim)).unsqueeze(dim=0).repeat(self.p_interval, 1, 1).to(self.device)
# (P, B * N, 2E)

for t in range(x.shape[0]):
T = x.shape[0]

for t in range(T):
x_ = x[t, :, :] # (B * N, 2E)
x_ = torch.cat((x_, h[h_cur % self.p_interval]), 1) # (B * N, 2E + 2E)
x_ = torch.cat((x_, h[t % self.p_interval]), 1) # (B * N, 2E + 2E)

f = self.f_gate(x_) # (B * N, 2E)

Expand All @@ -50,36 +54,33 @@ def forward(self, x):

g = self.g_gate(x_) # (B * N, 2E)

c = f * c + i * g
c = f * c[t % self.p_interval] + i * g # (B * N, 2E)

h[(h_cur - 1) % self.p_interval] = o * c
c = self.tanh(c) # (B * N, 2E)

h_cur = h_cur + 1
h[t % self.p_interval] = o * c # (B * N, 2E)

return h[0]
return h[(T - 1) % self.p_interval] # (B * N, 2E)


class MutiLearning(nn.Module):
def __init__(self, fea_dim, device):
super(MutiLearning, self).__init__()
self.fea_dim = fea_dim
transition = torch.randn(self.fea_dim, self.fea_dim)
self.transition = nn.Parameter(data=transition.to(device), requires_grad=True)
project_in = torch.randn(self.fea_dim, 1).to(device)
self.project_in = project_in
project_out = torch.randn(self.fea_dim, 1).to(device)
self.project_out = project_out
self.transition = nn.Parameter(data=torch.randn(self.fea_dim, self.fea_dim).to(device), requires_grad=True)
self.project_in = nn.Parameter(data=torch.randn(self.fea_dim, 1).to(device), requires_grad=True)
self.project_out = nn.Parameter(data=torch.randn(self.fea_dim, 1).to(device), requires_grad=True)

def forward(self, x: torch.Tensor):
# (B, H * W, E)
x_t = x.permute(0, 2, 1) # (B, E, N)
# (B, N, 2E)
x_t = x.permute(0, 2, 1) # (B, 2E, N)

x_in = torch.matmul(x, self.project_in) # (B, N, 1)

x_out = torch.matmul(x, self.project_out) # (B, N, 1)

x = torch.matmul(x, self.transition)
# (B, N, E)
# (B, N, 2E)
x = torch.bmm(x, x_t)
# (B, N, N)

Expand Down Expand Up @@ -141,9 +142,10 @@ def forward(self, input_seq, adj_seq):
return torch.stack(embed, dim=1)


def generate_geo_adj(cost_matrix: np.matrix):
# TODO cost_matrix[cost_matrix > ?] = inf
cost_matrix = torch.Tensor(cost_matrix) # (N, N)
def generate_geo_adj(adj_matrix: np.matrix):
adj_matrix = torch.Tensor(adj_matrix) # (N, N)
cost_matrix = torch.Tensor([[abs(i - j) for j in range(adj_matrix.shape[0])] for i in range(adj_matrix.shape[1])])
cost_matrix = cost_matrix * adj_matrix
sum_cost_vector = torch.sum(cost_matrix, dim=1, keepdim=True) # (N, 1)
weight_matrix = cost_matrix / sum_cost_vector
weight_matrix[range(weight_matrix.shape[0]), range(weight_matrix.shape[1])] = 1
Expand All @@ -159,13 +161,13 @@ def generate_semantic_adj(demand_matrix, device):

adj_matrix[in_matrix > 0] = 1

# (B, T, N, 1)
degree_vector = torch.sum(adj_matrix, dim=3, keepdim=True)
# (B, T, N, 1)

sum_degree_vector = torch.matmul(adj_matrix, degree_vector)
# (B, T, N, 1)

weight_matrix = torch.matmul(1 / (sum_degree_vector + torch.full(sum_degree_vector.shape, 1e-3).to(device)),
degree_vector.permute((0, 1, 3, 2))) # (B, T, N, N)
weight_matrix = torch.matmul(1 / (sum_degree_vector + 1e-3), degree_vector.permute((0, 1, 3, 2))) # (B, T, N, N)

weight_matrix[:, :, range(weight_matrix.shape[2]), range(weight_matrix.shape[3])] = 1

Expand All @@ -176,11 +178,13 @@ class GEML(AbstractTrafficStateModel):
def __init__(self, config, data_feature):
super().__init__(config, data_feature)
self.num_nodes = self.data_feature.get('num_nodes')
self._scaler = self.data_feature.get('scaler')
self.output_dim = config.get('output_dim')
self.device = config.get('device', torch.device('cpu'))
self.input_window = config.get('input_window', 1)
self.output_window = config.get('output_window', 1)
self.p_interval = config.get('p_interval', 1)

self.p_interval = config.get('p_interval', 1)
self.embed_dim = config.get('embed_dim')
self.batch_size = config.get('batch_size')
self.loss_p0 = config.get('loss_p0', 0.5)
Expand All @@ -203,7 +207,7 @@ def __init__(self, config, data_feature):
def forward(self, batch):
x = batch['X'].squeeze(dim=-1)
# (B, T, N, N)
x_ge_embed = self.GCN(x, self.geo_adj)
x_ge_embed = self.GCN(x, self.geo_adj[:x.shape[0], ...])
# (B, T, N, E)

x_se_embed = self.GCN(x, self.semantic_adj)
Expand All @@ -218,7 +222,7 @@ def forward(self, batch):

# _, (h, _) = self.LSTM(x_embed)
# x_embed_pred = h[0].reshape((self.batch_size, -1, 2 * self.embed_dim))
x_embed_pred = self.LSTM(x_embed).reshape((self.batch_size, -1, 2 * self.embed_dim))
x_embed_pred = self.LSTM(x_embed).reshape((x.shape[0], -1, 2 * self.embed_dim))
# (B, N, 2E)

out = self.mutiLearning(x_embed_pred)
Expand All @@ -230,6 +234,15 @@ def calculate_loss(self, batch):
y_in_true = torch.sum(y_true, dim=-2, keepdim=True) # (B, TO, N, 1)
y_out_true = torch.sum(y_true.permute(0, 1, 3, 2, 4), dim=-2, keepdim=True) # (B, TO, N, 1)
y_pred, y_in, y_out = self.predict(batch)

y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
y_in_true = self._scaler.inverse_transform(y_in_true[..., :self.output_dim])
y_out_true = self._scaler.inverse_transform(y_out_true[..., :self.output_dim])

y_pred = self._scaler.inverse_transform(y_pred[..., :self.output_dim])
y_in = self._scaler.inverse_transform(y_in[..., :self.output_dim])
y_out = self._scaler.inverse_transform(y_out[..., :self.output_dim])

loss_pred = loss.masked_mse_torch(y_pred, y_true)
loss_in = loss.masked_mse_torch(y_in, y_in_true)
loss_out = loss.masked_mse_torch(y_out, y_out_true)
Expand Down

0 comments on commit 9670c6a

Please sign in to comment.