-
Notifications
You must be signed in to change notification settings - Fork 66
/
Copy pathobjectives.py
85 lines (69 loc) · 3.3 KB
/
objectives.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
class cca_loss():
def __init__(self, outdim_size, use_all_singular_values, device):
self.outdim_size = outdim_size
self.use_all_singular_values = use_all_singular_values
self.device = device
# print(device)
def loss(self, H1, H2):
"""
It is the loss function of CCA as introduced in the original paper. There can be other formulations.
"""
r1 = 1e-3
r2 = 1e-3
eps = 1e-9
H1, H2 = H1.t(), H2.t()
# assert torch.isnan(H1).sum().item() == 0
# assert torch.isnan(H2).sum().item() == 0
o1 = o2 = H1.size(0)
m = H1.size(1)
# print(H1.size())
H1bar = H1 - H1.mean(dim=1).unsqueeze(dim=1)
H2bar = H2 - H2.mean(dim=1).unsqueeze(dim=1)
# assert torch.isnan(H1bar).sum().item() == 0
# assert torch.isnan(H2bar).sum().item() == 0
SigmaHat12 = (1.0 / (m - 1)) * torch.matmul(H1bar, H2bar.t())
SigmaHat11 = (1.0 / (m - 1)) * torch.matmul(H1bar,
H1bar.t()) + r1 * torch.eye(o1, device=self.device)
SigmaHat22 = (1.0 / (m - 1)) * torch.matmul(H2bar,
H2bar.t()) + r2 * torch.eye(o2, device=self.device)
# assert torch.isnan(SigmaHat11).sum().item() == 0
# assert torch.isnan(SigmaHat12).sum().item() == 0
# assert torch.isnan(SigmaHat22).sum().item() == 0
# Calculating the root inverse of covariance matrices by using eigen decomposition
[D1, V1] = torch.symeig(SigmaHat11, eigenvectors=True)
[D2, V2] = torch.symeig(SigmaHat22, eigenvectors=True)
# assert torch.isnan(D1).sum().item() == 0
# assert torch.isnan(D2).sum().item() == 0
# assert torch.isnan(V1).sum().item() == 0
# assert torch.isnan(V2).sum().item() == 0
# Added to increase stability
posInd1 = torch.gt(D1, eps).nonzero()[:, 0]
D1 = D1[posInd1]
V1 = V1[:, posInd1]
posInd2 = torch.gt(D2, eps).nonzero()[:, 0]
D2 = D2[posInd2]
V2 = V2[:, posInd2]
# print(posInd1.size())
# print(posInd2.size())
SigmaHat11RootInv = torch.matmul(
torch.matmul(V1, torch.diag(D1 ** -0.5)), V1.t())
SigmaHat22RootInv = torch.matmul(
torch.matmul(V2, torch.diag(D2 ** -0.5)), V2.t())
Tval = torch.matmul(torch.matmul(SigmaHat11RootInv,
SigmaHat12), SigmaHat22RootInv)
# print(Tval.size())
if self.use_all_singular_values:
# all singular values are used to calculate the correlation
tmp = torch.matmul(Tval.t(), Tval)
corr = torch.trace(torch.sqrt(tmp))
# assert torch.isnan(corr).item() == 0
else:
# just the top self.outdim_size singular values are used
trace_TT = torch.matmul(Tval.t(), Tval)
trace_TT = torch.add(trace_TT, (torch.eye(trace_TT.shape[0])*r1).to(self.device)) # regularization for more stability
U, V = torch.symeig(trace_TT, eigenvectors=True)
U = torch.where(U>eps, U, (torch.ones(U.shape).double()*eps).to(self.device))
U = U.topk(self.outdim_size)[0]
corr = torch.sum(torch.sqrt(U))
return -corr