-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsliced_wasserstein.py
135 lines (112 loc) · 5.61 KB
/
sliced_wasserstein.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
import numpy as np
import scipy.ndimage
#----------------------------------------------------------------------------
def get_descriptors_for_minibatch(minibatch, nhood_size, nhoods_per_image):
S = minibatch.shape # (minibatch, channel, height, width)
assert len(S) == 4 and S[1] == 3
N = nhoods_per_image * S[0]
H = nhood_size // 2
nhood, chan, x, y = np.ogrid[0:N, 0:3, -H:H+1, -H:H+1]
img = nhood // nhoods_per_image
x = x + np.random.randint(H, S[3] - H, size=(N, 1, 1, 1))
y = y + np.random.randint(H, S[2] - H, size=(N, 1, 1, 1))
idx = ((img * S[1] + chan) * S[2] + y) * S[3] + x
return minibatch.flat[idx]
#----------------------------------------------------------------------------
def finalize_descriptors(desc):
if isinstance(desc, list):
desc = np.concatenate(desc, axis=0)
assert desc.ndim == 4 # (neighborhood, channel, height, width)
desc -= np.mean(desc, axis=(0, 2, 3), keepdims=True)
desc /= np.std(desc, axis=(0, 2, 3), keepdims=True)
desc = desc.reshape(desc.shape[0], -1)
return desc
#----------------------------------------------------------------------------
def sliced_wasserstein(A, B, dir_repeats, dirs_per_repeat):
assert A.ndim == 2 and A.shape == B.shape # (neighborhood, descriptor_component)
results = []
for repeat in range(dir_repeats):
dirs = np.random.randn(A.shape[1], dirs_per_repeat) # (descriptor_component, direction)
dirs /= np.sqrt(np.sum(np.square(dirs), axis=0, keepdims=True)) # normalize descriptor components for each direction
dirs = dirs.astype(np.float32)
projA = np.matmul(A, dirs) # (neighborhood, direction)
projB = np.matmul(B, dirs)
projA = np.sort(projA, axis=0) # sort neighborhood projections for each direction
projB = np.sort(projB, axis=0)
dists = np.abs(projA - projB) # pointwise wasserstein distances
results.append(np.mean(dists)) # average over neighborhoods and directions
return np.mean(results) # average over repeats
#----------------------------------------------------------------------------
def downscale_minibatch(minibatch, lod):
if lod == 0:
return minibatch
t = minibatch.astype(np.float32)
for i in range(lod):
t = (t[:, :, 0::2, 0::2] + t[:, :, 0::2, 1::2] + t[:, :, 1::2, 0::2] + t[:, :, 1::2, 1::2]) * 0.25
return np.round(t).clip(0, 255).astype(np.uint8)
#----------------------------------------------------------------------------
gaussian_filter = np.float32([
[1, 4, 6, 4, 1],
[4, 16, 24, 16, 4],
[6, 24, 36, 24, 6],
[4, 16, 24, 16, 4],
[1, 4, 6, 4, 1]]) / 256.0
def pyr_down(minibatch): # matches cv2.pyrDown()
assert minibatch.ndim == 4
return scipy.ndimage.convolve(minibatch, gaussian_filter[np.newaxis, np.newaxis, :, :], mode='mirror')[:, :, ::2, ::2]
def pyr_up(minibatch): # matches cv2.pyrUp()
assert minibatch.ndim == 4
S = minibatch.shape
res = np.zeros((S[0], S[1], S[2] * 2, S[3] * 2), minibatch.dtype)
res[:, :, ::2, ::2] = minibatch
return scipy.ndimage.convolve(res, gaussian_filter[np.newaxis, np.newaxis, :, :] * 4.0, mode='mirror')
def generate_laplacian_pyramid(minibatch, num_levels):
pyramid = [np.float32(minibatch)]
for i in range(1, num_levels):
pyramid.append(pyr_down(pyramid[-1]))
pyramid[-2] -= pyr_up(pyramid[-1])
return pyramid
def reconstruct_laplacian_pyramid(pyramid):
minibatch = pyramid[-1]
for level in pyramid[-2::-1]:
minibatch = pyr_up(minibatch) + level
return minibatch
#----------------------------------------------------------------------------
class API:
def __init__(self, image_shape):
self.nhood_size = 7
self.nhoods_per_image = 128
self.dir_repeats = 4
self.dirs_per_repeat = 128
self.resolutions = []
res = image_shape[1]
while res >= 16:
self.resolutions.append(res)
res //= 2
def get_metric_names(self):
return ['SWDx1e3_%d' % res for res in self.resolutions] + ['SWDx1e3_avg']
def get_metric_formatting(self):
return ['%-13.4f'] * len(self.get_metric_names())
def begin(self, mode):
assert mode in ['warmup', 'reals', 'fakes']
self.descriptors = [[] for res in self.resolutions]
def feed(self, mode, minibatch):
for lod, level in enumerate(generate_laplacian_pyramid(minibatch, len(self.resolutions))):
desc = get_descriptors_for_minibatch(level, self.nhood_size, self.nhoods_per_image)
self.descriptors[lod].append(desc)
def end(self, mode):
desc = [finalize_descriptors(d) for d in self.descriptors]
del self.descriptors
if mode in ['warmup', 'reals']:
self.desc_real = desc
dist = [sliced_wasserstein(dreal, dfake, self.dir_repeats, self.dirs_per_repeat) for dreal, dfake in zip(self.desc_real, desc)]
del desc
dist = [d * 1e3 for d in dist] # multiply by 10^3
return dist + [np.mean(dist)]
#----------------------------------------------------------------------------