-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathutils.py
47 lines (39 loc) · 1.42 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env python3
import sys
import os
import argparse
import json
import random
import shutil
import torch
from torch import cuda
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def log_bernoulli_loss(pred, y, min_eps=1e-5, max_eps=1.-1e-5, average=True):
prob = torch.clamp(pred.view(pred.size(0), -1), min=min_eps, max = max_eps)
y_vec = y.view(y.size(0), -1)
log_bernoulli = y_vec * torch.log(prob) + (1. - y_vec)*torch.log(1. - prob)
if average:
return -torch.mean(torch.sum(log_bernoulli, 1))
else:
return -torch.sum(log_bernoulli, 1)
def logsumexp(x):
max_x = torch.max(x, 1)[0]
new_x = x - max_x.unsqueeze(1).expand_as(x)
return max_x + (new_x.exp().sum(1)).log()
def kl_loss(mean1, logvar1, mean2, logvar2):
result = -0.5 * torch.sum(logvar1 - logvar2 - torch.pow(mean1-mean2, 2) / logvar2.exp() -
torch.exp(logvar1 - logvar2) + 1, 1)
return result.mean()
def kl_loss_diag(mean, logvar, logvar_prior=None, average=True):
if logvar_prior is None:
result = -0.5 * torch.sum(logvar - torch.pow(mean, 2) - torch.exp(logvar) + 1, 1)
else:
logvar_prior = logvar_prior.unsqueeze(0).expand_as(logvar)
result = -0.5 * torch.sum(logvar - logvar_prior - torch.pow(mean, 2) / logvar_prior.exp() -
torch.exp(logvar - logvar_prior) + 1, 1)
if average:
return result.mean()
else:
return result