-
-
Notifications
You must be signed in to change notification settings - Fork 108
/
base.py
719 lines (633 loc) · 24.4 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
import torch
import torch.nn as nn
import numpy as np
from .. import flows
class BaseDistribution(nn.Module):
"""
Base distribution of a flow-based model
Parameters do not depend of target variable (as is the case for a VAE encoder)
"""
def __init__(self):
super().__init__()
def forward(self, num_samples=1):
"""Samples from base distribution and calculates log probability
Args:
num_samples: Number of samples to draw from the distriubtion
Returns:
Samples drawn from the distribution, log probability
"""
raise NotImplementedError
def log_prob(self, z):
"""Calculate log probability of batch of samples
Args:
z: Batch of random variables to determine log probability for
Returns:
log probability for each batch element
"""
raise NotImplementedError
def sample(self, num_samples=1, **kwargs):
"""Samples from base distribution
Args:
num_samples: Number of samples to draw from the distriubtion
Returns:
Samples drawn from the distribution
"""
z, _ = self.forward(num_samples, **kwargs)
return z
class DiagGaussian(BaseDistribution):
"""
Multivariate Gaussian distribution with diagonal covariance matrix
"""
def __init__(self, shape, trainable=True):
"""Constructor
Args:
shape: Tuple with shape of data, if int shape has one dimension
trainable: Flag whether to use trainable or fixed parameters
"""
super().__init__()
if isinstance(shape, int):
shape = (shape,)
if isinstance(shape, list):
shape = tuple(shape)
self.shape = shape
self.n_dim = len(shape)
self.d = np.prod(shape)
if trainable:
self.loc = nn.Parameter(torch.zeros(1, *self.shape))
self.log_scale = nn.Parameter(torch.zeros(1, *self.shape))
else:
self.register_buffer("loc", torch.zeros(1, *self.shape))
self.register_buffer("log_scale", torch.zeros(1, *self.shape))
self.temperature = None # Temperature parameter for annealed sampling
def forward(self, num_samples=1, context=None):
eps = torch.randn(
(num_samples,) + self.shape, dtype=self.loc.dtype, device=self.loc.device
)
if self.temperature is None:
log_scale = self.log_scale
else:
log_scale = self.log_scale + np.log(self.temperature)
z = self.loc + torch.exp(log_scale) * eps
log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
log_scale + 0.5 * torch.pow(eps, 2), list(range(1, self.n_dim + 1))
)
return z, log_p
def log_prob(self, z, context=None):
if self.temperature is None:
log_scale = self.log_scale
else:
log_scale = self.log_scale + np.log(self.temperature)
log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
log_scale + 0.5 * torch.pow((z - self.loc) / torch.exp(log_scale), 2),
list(range(1, self.n_dim + 1)),
)
return log_p
class ConditionalDiagGaussian(BaseDistribution):
"""
Conditional multivariate Gaussian distribution with diagonal
covariance matrix, parameters are obtained by a context encoder,
context meaning the variable to condition on
"""
def __init__(self, shape, context_encoder):
"""Constructor
Args:
shape: Tuple with shape of data, if int shape has one dimension
context_encoder: Computes mean and log of the standard deviation
of the Gaussian, mean is the first half of the last dimension
of the encoder output, log of the standard deviation the second
half
"""
super().__init__()
if isinstance(shape, int):
shape = (shape,)
if isinstance(shape, list):
shape = tuple(shape)
self.shape = shape
self.n_dim = len(shape)
self.d = np.prod(shape)
self.context_encoder = context_encoder
def forward(self, num_samples=1, context=None):
encoder_output = self.context_encoder(context)
split_ind = encoder_output.shape[-1] // 2
mean = encoder_output[..., :split_ind]
log_scale = encoder_output[..., split_ind:]
eps = torch.randn(
(num_samples,) + self.shape, dtype=mean.dtype, device=mean.device
)
z = mean + torch.exp(log_scale) * eps
log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
log_scale + 0.5 * torch.pow(eps, 2), list(range(1, self.n_dim + 1))
)
return z, log_p
def log_prob(self, z, context=None):
encoder_output = self.context_encoder(context)
split_ind = encoder_output.shape[-1] // 2
mean = encoder_output[..., :split_ind]
log_scale = encoder_output[..., split_ind:]
log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
log_scale + 0.5 * torch.pow((z - mean) / torch.exp(log_scale), 2),
list(range(1, self.n_dim + 1)),
)
return log_p
class Uniform(BaseDistribution):
"""
Multivariate uniform distribution
"""
def __init__(self, shape, low=-1.0, high=1.0):
"""Constructor
Args:
shape: Tuple with shape of data, if int shape has one dimension
low: Lower bound of uniform distribution
high: Upper bound of uniform distribution
"""
super().__init__()
if isinstance(shape, int):
shape = (shape,)
if isinstance(shape, list):
shape = tuple(shape)
self.shape = shape
self.d = np.prod(shape)
self.low = torch.tensor(low)
self.high = torch.tensor(high)
self.log_prob_val = -self.d * np.log(self.high - self.low)
def forward(self, num_samples=1, context=None):
eps = torch.rand(
(num_samples,) + self.shape, dtype=self.low.dtype, device=self.low.device
)
z = self.low + (self.high - self.low) * eps
log_p = self.log_prob_val * torch.ones(num_samples, device=self.low.device)
return z, log_p
def log_prob(self, z, context=None):
log_p = self.log_prob_val * torch.ones(z.shape[0], device=z.device)
out_range = torch.logical_or(z < self.low, z > self.high)
ind_inf = torch.any(torch.reshape(out_range, (z.shape[0], -1)), dim=-1)
log_p[ind_inf] = -np.inf
return log_p
class UniformGaussian(BaseDistribution):
"""
Distribution of a 1D random variable with some entries having a uniform and
others a Gaussian distribution
"""
def __init__(self, ndim, ind, scale=None):
"""Constructor
Args:
ndim: Int, number of dimensions
ind: Iterable, indices of uniformly distributed entries
scale: Iterable, standard deviation of Gaussian or width of uniform distribution
"""
super().__init__()
self.ndim = ndim
if isinstance(ind, int):
ind = [ind]
# Set up indices and permutations
self.ndim = ndim
if torch.is_tensor(ind):
self.register_buffer("ind", torch._cast_Long(ind))
else:
self.register_buffer("ind", torch.tensor(ind, dtype=torch.long))
ind_ = []
for i in range(self.ndim):
if not i in self.ind:
ind_ += [i]
self.register_buffer("ind_", torch.tensor(ind_, dtype=torch.long))
perm_ = torch.cat((self.ind, self.ind_))
inv_perm_ = torch.zeros_like(perm_)
for i in range(self.ndim):
inv_perm_[perm_[i]] = i
self.register_buffer("inv_perm", inv_perm_)
if scale is None:
self.register_buffer("scale", torch.ones(self.ndim))
else:
self.register_buffer("scale", scale)
def forward(self, num_samples=1, context=None):
z = self.sample(num_samples)
return z, self.log_prob(z)
def sample(self, num_samples=1, context=None):
eps_u = (
torch.rand(
(num_samples, len(self.ind)),
dtype=self.scale.dtype,
device=self.scale.device,
)
- 0.5
)
eps_g = torch.randn(
(num_samples, len(self.ind_)),
dtype=self.scale.dtype,
device=self.scale.device,
)
z = torch.cat((eps_u, eps_g), -1)
z = z[..., self.inv_perm]
return self.scale * z
def log_prob(self, z, context=None):
log_p_u = torch.broadcast_to(-torch.log(self.scale[self.ind]), (len(z), -1))
log_p_g = (
-0.5 * np.log(2 * np.pi)
- torch.log(self.scale[self.ind_])
- 0.5 * torch.pow(z[..., self.ind_] / self.scale[self.ind_], 2)
)
return torch.sum(log_p_u, -1) + torch.sum(log_p_g, -1)
class ClassCondDiagGaussian(BaseDistribution):
"""
Class conditional multivariate Gaussian distribution with diagonal covariance matrix
"""
def __init__(self, shape, num_classes):
"""Constructor
Args:
shape: Tuple with shape of data, if int shape has one dimension
num_classes: Number of classes to condition on
"""
super().__init__()
if isinstance(shape, int):
shape = (shape,)
if isinstance(shape, list):
shape = tuple(shape)
self.shape = shape
self.n_dim = len(shape)
self.perm = [self.n_dim] + list(range(self.n_dim))
self.d = np.prod(shape)
self.num_classes = num_classes
self.loc = nn.Parameter(torch.zeros(*self.shape, num_classes))
self.log_scale = nn.Parameter(torch.zeros(*self.shape, num_classes))
self.temperature = None # Temperature parameter for annealed sampling
def forward(self, num_samples=1, y=None):
if y is not None:
num_samples = len(y)
else:
y = torch.randint(self.num_classes, (num_samples,), device=self.loc.device)
if y.dim() == 1:
y_onehot = torch.zeros(
(self.num_classes, num_samples),
dtype=self.loc.dtype,
device=self.loc.device,
)
y_onehot.scatter_(0, y[None], 1)
y = y_onehot
else:
y = y.t()
eps = torch.randn(
(num_samples,) + self.shape, dtype=self.loc.dtype, device=self.loc.device
)
loc = (self.loc @ y).permute(*self.perm)
log_scale = (self.log_scale @ y).permute(*self.perm)
if self.temperature is not None:
log_scale = np.log(self.temperature) + log_scale
z = loc + torch.exp(log_scale) * eps
log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
log_scale + 0.5 * torch.pow(eps, 2), list(range(1, self.n_dim + 1))
)
return z, log_p
def log_prob(self, z, y):
if y.dim() == 1:
y_onehot = torch.zeros(
(self.num_classes, len(y)), dtype=self.loc.dtype, device=self.loc.device
)
y_onehot.scatter_(0, y[None], 1)
y = y_onehot
else:
y = y.t()
loc = (self.loc @ y).permute(*self.perm)
log_scale = (self.log_scale @ y).permute(*self.perm)
if self.temperature is not None:
log_scale = np.log(self.temperature) + log_scale
log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
log_scale + 0.5 * torch.pow((z - loc) / torch.exp(log_scale), 2),
list(range(1, self.n_dim + 1)),
)
return log_p
class GlowBase(BaseDistribution):
"""
Base distribution of the Glow model, i.e. Diagonal Gaussian with one mean and
log scale for each channel
"""
def __init__(self, shape, num_classes=None, logscale_factor=3.0):
"""Constructor
Args:
shape: Shape of the variables
num_classes: Number of classes if the base is class conditional, None otherwise
logscale_factor: Scaling factor for mean and log variance
"""
super().__init__()
# Save shape and related statistics
if isinstance(shape, int):
shape = (shape,)
if isinstance(shape, list):
shape = tuple(shape)
self.shape = shape
self.n_dim = len(shape)
self.num_pix = np.prod(shape[1:])
self.d = np.prod(shape)
self.sum_dim = list(range(1, self.n_dim + 1))
self.num_classes = num_classes
self.class_cond = num_classes is not None
self.logscale_factor = logscale_factor
# Set up parameters
self.loc = nn.Parameter(
torch.zeros(1, self.shape[0], *((self.n_dim - 1) * [1]))
)
self.loc_logs = nn.Parameter(
torch.zeros(1, self.shape[0], *((self.n_dim - 1) * [1]))
)
self.log_scale = nn.Parameter(
torch.zeros(1, self.shape[0], *((self.n_dim - 1) * [1]))
)
self.log_scale_logs = nn.Parameter(
torch.zeros(1, self.shape[0], *((self.n_dim - 1) * [1]))
)
# Class conditional parameter if needed
if self.class_cond:
self.loc_cc = nn.Parameter(torch.zeros(self.num_classes, self.shape[0]))
self.log_scale_cc = nn.Parameter(
torch.zeros(self.num_classes, self.shape[0])
)
# Temperature parameter for annealed sampling
self.temperature = None
def forward(self, num_samples=1, y=None):
# Prepare parameter
loc = self.loc * torch.exp(self.loc_logs * self.logscale_factor)
log_scale = self.log_scale * torch.exp(
self.log_scale_logs * self.logscale_factor
)
if self.class_cond:
if y is not None:
num_samples = len(y)
else:
y = torch.randint(
self.num_classes, (num_samples,), device=self.loc.device
)
if y.dim() == 1:
y_onehot = torch.zeros(
(len(y), self.num_classes),
dtype=self.loc.dtype,
device=self.loc.device,
)
y_onehot.scatter_(1, y[:, None], 1)
y = y_onehot
loc = loc + (y @ self.loc_cc).view(
y.size(0), self.shape[0], *((self.n_dim - 1) * [1])
)
log_scale = log_scale + (y @ self.log_scale_cc).view(
y.size(0), self.shape[0], *((self.n_dim - 1) * [1])
)
if self.temperature is not None:
log_scale = log_scale + np.log(self.temperature)
# Sample
eps = torch.randn(
(num_samples,) + self.shape, dtype=self.loc.dtype, device=self.loc.device
)
z = loc + torch.exp(log_scale) * eps
# Get log prob
log_p = (
-0.5 * self.d * np.log(2 * np.pi)
- self.num_pix * torch.sum(log_scale, dim=self.sum_dim)
- 0.5 * torch.sum(torch.pow(eps, 2), dim=self.sum_dim)
)
return z, log_p
def log_prob(self, z, y=None):
# Perpare parameter
loc = self.loc * torch.exp(self.loc_logs * self.logscale_factor)
log_scale = self.log_scale * torch.exp(
self.log_scale_logs * self.logscale_factor
)
if self.class_cond:
if y.dim() == 1:
y_onehot = torch.zeros(
(len(y), self.num_classes),
dtype=self.loc.dtype,
device=self.loc.device,
)
y_onehot.scatter_(1, y[:, None], 1)
y = y_onehot
loc = loc + (y @ self.loc_cc).view(
y.size(0), self.shape[0], *((self.n_dim - 1) * [1])
)
log_scale = log_scale + (y @ self.log_scale_cc).view(
y.size(0), self.shape[0], *((self.n_dim - 1) * [1])
)
if self.temperature is not None:
log_scale = log_scale + np.log(self.temperature)
# Get log prob
log_p = (
-0.5 * self.d * np.log(2 * np.pi)
- self.num_pix * torch.sum(log_scale, dim=self.sum_dim)
- 0.5
* torch.sum(
torch.pow((z - loc) / torch.exp(log_scale), 2), dim=self.sum_dim
)
)
return log_p
class AffineGaussian(BaseDistribution):
"""
Diagonal Gaussian an affine constant transformation applied to it,
can be class conditional or not
"""
def __init__(self, shape, affine_shape, num_classes=None):
"""Constructor
Args:
shape: Shape of the variables
affine_shape: Shape of the parameters in the affine transformation
num_classes: Number of classes if the base is class conditional, None otherwise
"""
super().__init__()
if isinstance(shape, int):
shape = (shape,)
if isinstance(shape, list):
shape = tuple(shape)
self.shape = shape
self.n_dim = len(shape)
self.d = np.prod(shape)
self.sum_dim = list(range(1, self.n_dim + 1))
self.affine_shape = affine_shape
self.num_classes = num_classes
self.class_cond = num_classes is not None
# Affine transformation
if self.class_cond:
self.transform = flows.CCAffineConst(self.affine_shape, self.num_classes)
else:
self.transform = flows.AffineConstFlow(self.affine_shape)
# Temperature parameter for annealed sampling
self.temperature = None
def forward(self, num_samples=1, y=None):
dtype = self.transform.s.dtype
device = self.transform.s.device
if self.class_cond:
if y is not None:
num_samples = len(y)
else:
y = torch.randint(self.num_classes, (num_samples,), device=device)
if y.dim() == 1:
y_onehot = torch.zeros(
(len(y), self.num_classes), dtype=dtype, device=device
)
y_onehot.scatter_(1, y[:, None], 1)
y = y_onehot
if self.temperature is not None:
log_scale = np.log(self.temperature)
else:
log_scale = 0.0
# Sample
eps = torch.randn((num_samples,) + self.shape, dtype=dtype, device=device)
z = np.exp(log_scale) * eps
# Get log prob
log_p = (
-0.5 * self.d * np.log(2 * np.pi)
- self.d * log_scale
- 0.5 * torch.sum(torch.pow(eps, 2), dim=self.sum_dim)
)
# Apply transform
if self.class_cond:
z, log_det = self.transform(z, y)
else:
z, log_det = self.transform(z)
log_p -= log_det
return z, log_p
def log_prob(self, z, y=None):
# Perpare onehot encoding of class if needed
if self.class_cond:
if y.dim() == 1:
y_onehot = torch.zeros(
(len(y), self.num_classes),
dtype=self.transform.s.dtype,
device=self.transform.s.device,
)
y_onehot.scatter_(1, y[:, None], 1)
y = y_onehot
if self.temperature is not None:
log_scale = np.log(self.temperature)
else:
log_scale = 0.0
# Get log prob
if self.class_cond:
z, log_p = self.transform.inverse(z, y)
else:
z, log_p = self.transform.inverse(z)
z = z / np.exp(log_scale)
log_p = (
log_p
- self.d * log_scale
- 0.5 * self.d * np.log(2 * np.pi)
- 0.5 * torch.sum(torch.pow(z, 2), dim=self.sum_dim)
)
return log_p
class GaussianMixture(BaseDistribution):
"""
Mixture of Gaussians with diagonal covariance matrix
"""
def __init__(
self, n_modes, dim, loc=None, scale=None, weights=None, trainable=True
):
"""Constructor
Args:
n_modes: Number of modes of the mixture model
dim: Number of dimensions of each Gaussian
loc: List of mean values
scale: List of diagonals of the covariance matrices
weights: List of mode probabilities
trainable: Flag, if true parameters will be optimized during training
"""
super().__init__()
self.n_modes = n_modes
self.dim = dim
if loc is None:
loc = np.random.randn(self.n_modes, self.dim)
loc = np.array(loc)[None, ...]
if scale is None:
scale = np.ones((self.n_modes, self.dim))
scale = np.array(scale)[None, ...]
if weights is None:
weights = np.ones(self.n_modes)
weights = np.array(weights)[None, ...]
weights /= weights.sum(1)
if trainable:
self.loc = nn.Parameter(torch.tensor(1.0 * loc))
self.log_scale = nn.Parameter(torch.tensor(np.log(1.0 * scale)))
self.weight_scores = nn.Parameter(torch.tensor(np.log(1.0 * weights)))
else:
self.register_buffer("loc", torch.tensor(1.0 * loc))
self.register_buffer("log_scale", torch.tensor(np.log(1.0 * scale)))
self.register_buffer("weight_scores", torch.tensor(np.log(1.0 * weights)))
def forward(self, num_samples=1):
# Get weights
weights = torch.softmax(self.weight_scores, 1)
# Sample mode indices
mode = torch.multinomial(weights[0, :], num_samples, replacement=True)
mode_1h = nn.functional.one_hot(mode, self.n_modes)
mode_1h = mode_1h[..., None]
# Get samples
eps_ = torch.randn(
num_samples, self.dim, dtype=self.loc.dtype, device=self.loc.device
)
scale_sample = torch.sum(torch.exp(self.log_scale) * mode_1h, 1)
loc_sample = torch.sum(self.loc * mode_1h, 1)
z = eps_ * scale_sample + loc_sample
# Compute log probability
eps = (z[:, None, :] - self.loc) / torch.exp(self.log_scale)
log_p = (
-0.5 * self.dim * np.log(2 * np.pi)
+ torch.log(weights)
- 0.5 * torch.sum(torch.pow(eps, 2), 2)
- torch.sum(self.log_scale, 2)
)
log_p = torch.logsumexp(log_p, 1)
return z, log_p
def log_prob(self, z):
# Get weights
weights = torch.softmax(self.weight_scores, 1)
# Compute log probability
eps = (z[:, None, :] - self.loc) / torch.exp(self.log_scale)
log_p = (
-0.5 * self.dim * np.log(2 * np.pi)
+ torch.log(weights)
- 0.5 * torch.sum(torch.pow(eps, 2), 2)
- torch.sum(self.log_scale, 2)
)
log_p = torch.logsumexp(log_p, 1)
return log_p
class GaussianPCA(BaseDistribution):
"""
Gaussian distribution resulting from linearly mapping a normal distributed latent
variable describing the "content of the target"
"""
def __init__(self, dim, latent_dim=None, sigma=0.1):
"""Constructor
Args:
dim: Number of dimensions of the flow variables
latent_dim: Number of dimensions of the latent "content" variable;
if None it is set equal to dim
sigma: Noise level
"""
super().__init__()
self.dim = dim
if latent_dim is None:
self.latent_dim = dim
else:
self.latent_dim = latent_dim
self.loc = nn.Parameter(torch.zeros(1, dim))
self.W = nn.Parameter(torch.randn(latent_dim, dim))
self.log_sigma = nn.Parameter(torch.tensor(np.log(sigma)))
def forward(self, num_samples=1):
eps = torch.randn(
num_samples, self.latent_dim, dtype=self.loc.dtype, device=self.loc.device
)
z_ = torch.matmul(eps, self.W)
z = z_ + self.loc
Sig = torch.matmul(self.W.T, self.W) + torch.exp(
self.log_sigma * 2
) * torch.eye(self.dim, dtype=self.loc.dtype, device=self.loc.device)
log_p = (
self.dim / 2 * np.log(2 * np.pi)
- 0.5 * torch.det(Sig)
- 0.5 * torch.sum(z_ * torch.matmul(z_, torch.inverse(Sig)), 1)
)
return z, log_p
def log_prob(self, z):
z_ = z - self.loc
Sig = torch.matmul(self.W.T, self.W) + torch.exp(
self.log_sigma * 2
) * torch.eye(self.dim, dtype=self.loc.dtype, device=self.loc.device)
log_p = (
self.dim / 2 * np.log(2 * np.pi)
- 0.5 * torch.det(Sig)
- 0.5 * torch.sum(z_ * torch.matmul(z_, torch.inverse(Sig)), 1)
)
return log_p