-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpit_wrapper.py
319 lines (279 loc) · 14.3 KB
/
pit_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
from itertools import permutations
import torch
from torch import nn
from scipy.optimize import linear_sum_assignment
class PITLossWrapper(nn.Module):
r""" Permutation invariant loss wrapper.
Args:
loss_func: function with signature (targets, est_targets, **kwargs).
pit_from (str): Determines how PIT is applied.
* ``'pw_mtx'`` (pairwise matrix): `loss_func` computes pairwise
losses and returns a torch.Tensor of shape
:math:`(batch, n\_src, n\_src)`. Each element
:math:`[batch, i, j]` corresponds to the loss between
:math:`targets[:, i]` and :math:`est\_targets[:, j]`
* ``'pw_pt'`` (pairwise point): `loss_func` computes the loss for
a batch of single source and single estimates (tensors won't
have the source axis). Output shape : :math:`(batch)`.
See :meth:`~PITLossWrapper.get_pw_losses`.
* ``'perm_avg'``(permutation average): `loss_func` computes the
average loss for a given permutations of the sources and
estimates. Output shape : :math:`(batch)`.
See :meth:`~PITLossWrapper.best_perm_from_perm_avg_loss`.
In terms of efficiency, ``'perm_avg'`` is the least efficicient.
perm_reduce (Callable): torch function to reduce permutation losses.
Defaults to None (equivalent to mean). Signature of the func
(pwl_set, **kwargs) : (B, n_src!, n_src) --> (B, n_src!).
`perm_reduce` can receive **kwargs during forward using the
`reduce_kwargs` argument (dict). If those argument are static,
consider defining a small function or using `functools.partial`.
Only used in `'pw_mtx'` and `'pw_pt'` `pit_from` modes.
For each of these modes, the best permutation and reordering will be
automatically computed.
Examples:
>>> import torch
>>> from asteroid.losses import pairwise_neg_sisdr
>>> sources = torch.randn(10, 3, 16000)
>>> est_sources = torch.randn(10, 3, 16000)
>>> # Compute PIT loss based on pairwise losses
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
>>> loss_val = loss_func(est_sources, sources)
>>>
>>> # Using reduce
>>> def reduce(perm_loss, src):
>>> weighted = perm_loss * src.norm(dim=-1, keepdim=True)
>>> return torch.mean(weighted, dim=-1)
>>>
>>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx',
>>> perm_reduce=reduce)
>>> reduce_kwargs = {'src': sources}
>>> loss_val = loss_func(est_sources, sources,
>>> reduce_kwargs=reduce_kwargs)
"""
def __init__(self, loss_func, pit_from="pw_mtx", perm_reduce=None):
super().__init__()
self.loss_func = loss_func
self.pit_from = pit_from
self.perm_reduce = perm_reduce
if self.pit_from not in ["pw_mtx", "pw_pt", "perm_avg"]:
raise ValueError(
"Unsupported loss function type for now. Expected"
"one of [`pw_mtx`, `pw_pt`, `perm_avg`]"
)
def forward(self, est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs):
""" Find the best permutation and return the loss.
Args:
est_targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of target estimates.
targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of training targets
return_est: Boolean. Whether to return the reordered targets
estimates (To compute metrics or to save example).
reduce_kwargs (dict or None): kwargs that will be passed to the
pairwise losses reduce function (`perm_reduce`).
**kwargs: additional keyword argument that will be passed to the
loss function.
Returns:
- Best permutation loss for each batch sample, average over
the batch. torch.Tensor(loss_value)
- The reordered targets estimates if return_est is True.
torch.Tensor of shape [batch, nsrc, *].
"""
n_src = targets.shape[1]
assert n_src < 10, f"Expected source axis along dim 1, found {n_src}"
if self.pit_from == "pw_mtx":
# Loss function already returns pairwise losses
pw_losses = self.loss_func(est_targets, targets, **kwargs)
elif self.pit_from == "pw_pt":
# Compute pairwise losses with a for loop.
pw_losses = self.get_pw_losses(self.loss_func, est_targets, targets, **kwargs)
elif self.pit_from == "perm_avg":
# Cannot get pairwise losses from this type of loss.
# Find best permutation directly.
min_loss, batch_indices = self.best_perm_from_perm_avg_loss(
self.loss_func, est_targets, targets, **kwargs
)
# Take the mean over the batch
mean_loss = torch.mean(min_loss)
if not return_est:
return mean_loss
reordered = self.reorder_source(est_targets, batch_indices)
return mean_loss, reordered
else:
return
assert pw_losses.ndim == 3, (
"Something went wrong with the loss " "function, please read the docs."
)
assert pw_losses.shape[0] == targets.shape[0], "PIT loss needs same batch dim as input"
reduce_kwargs = reduce_kwargs if reduce_kwargs is not None else dict()
min_loss, batch_indices = self.find_best_perm(
pw_losses, perm_reduce=self.perm_reduce, **reduce_kwargs
)
mean_loss = torch.mean(min_loss)
if not return_est:
return mean_loss
reordered = self.reorder_source(est_targets, batch_indices)
return mean_loss, reordered
@staticmethod
def get_pw_losses(loss_func, est_targets, targets, **kwargs):
""" Get pair-wise losses between the training targets and its estimate
for a given loss function.
Args:
loss_func: function with signature (targets, est_targets, **kwargs)
The loss function to get pair-wise losses from.
est_targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of target estimates.
targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of training targets.
**kwargs: additional keyword argument that will be passed to the
loss function.
Returns:
torch.Tensor or size [batch, nsrc, nsrc], losses computed for
all permutations of the targets and est_targets.
This function can be called on a loss function which returns a tensor
of size [batch]. There are more efficient ways to compute pair-wise
losses using broadcasting.
"""
batch_size, n_src, *_ = targets.shape
pair_wise_losses = targets.new_empty(batch_size, n_src, n_src)
for est_idx, est_src in enumerate(est_targets.transpose(0, 1)):
for target_idx, target_src in enumerate(targets.transpose(0, 1)):
pair_wise_losses[:, est_idx, target_idx] = loss_func(est_src, target_src, **kwargs)
return pair_wise_losses
@staticmethod
def best_perm_from_perm_avg_loss(loss_func, est_targets, targets, **kwargs):
""" Find best permutation from loss function with source axis.
Args:
loss_func: function with signature (targets, est_targets, **kwargs)
The loss function batch losses from.
est_targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of target estimates.
targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of training targets.
**kwargs: additional keyword argument that will be passed to the
loss function.
Returns:
tuple:
:class:`torch.Tensor`: The loss corresponding to the best
permutation of size (batch,).
:class:`torch.Tensor`: The indices of the best permutations.
"""
n_src = targets.shape[1]
perms = torch.tensor(list(permutations(range(n_src))), dtype=torch.long)
loss_set = torch.stack(
[loss_func(est_targets[:, perm], targets, **kwargs) for perm in perms], dim=1
)
# Indexes and values of min losses for each batch element
min_loss, min_loss_idx = torch.min(loss_set, dim=1)
# Permutation indices for each batch.
batch_indices = torch.stack([perms[m] for m in min_loss_idx], dim=0)
return min_loss, batch_indices
@staticmethod
def find_best_perm(pair_wise_losses, perm_reduce=None, **kwargs):
"""Find the best permutation, given the pair-wise losses.
Dispatch between factorial method if number of sources is small (<3)
and hungarian method for more sources. If `perm_reduce` is not None,
the factorial method is always used.
Args:
pair_wise_losses (:class:`torch.Tensor`):
Tensor of shape [batch, n_src, n_src]. Pairwise losses.
perm_reduce (Callable): torch function to reduce permutation losses.
Defaults to None (equivalent to mean). Signature of the func
(pwl_set, **kwargs) : (B, n_src!, n_src) --> (B, n_src!)
**kwargs: additional keyword argument that will be passed to the
permutation reduce function.
Returns:
tuple:
:class:`torch.Tensor`: The loss corresponding to the best
permutation of size (batch,).
:class:`torch.Tensor`: The indices of the best permutations.
"""
n_src = pair_wise_losses.shape[-1]
if perm_reduce is not None or n_src <= 3:
min_loss, batch_indices = PITLossWrapper.find_best_perm_factorial(
pair_wise_losses, perm_reduce=perm_reduce, **kwargs
)
else:
min_loss, batch_indices = PITLossWrapper.find_best_perm_hungarian(pair_wise_losses)
return min_loss, batch_indices
@staticmethod
def reorder_source(source, batch_indices):
""" Reorder sources according to the best permutation.
Args:
source (torch.Tensor): Tensor of shape [batch, n_src, time]
batch_indices (torch.Tensor): Tensor of shape [batch, n_src].
Contains optimal permutation indices for each batch.
Returns:
:class:`torch.Tensor`:
Reordered sources of shape [batch, n_src, time].
"""
reordered_sources = torch.stack(
[torch.index_select(s, 0, b) for s, b in zip(source, batch_indices)]
)
return reordered_sources
@staticmethod
def find_best_perm_factorial(pair_wise_losses, perm_reduce=None, **kwargs):
"""Find the best permutation given the pair-wise losses by looping
through all the permutations.
Args:
pair_wise_losses (:class:`torch.Tensor`):
Tensor of shape [batch, n_src, n_src]. Pairwise losses.
perm_reduce (Callable): torch function to reduce permutation losses.
Defaults to None (equivalent to mean). Signature of the func
(pwl_set, **kwargs) : (B, n_src!, n_src) --> (B, n_src!)
**kwargs: additional keyword argument that will be passed to the
permutation reduce function.
Returns:
tuple:
:class:`torch.Tensor`: The loss corresponding to the best
permutation of size (batch,).
:class:`torch.Tensor`: The indices of the best permutations.
MIT Copyright (c) 2018 Kaituo XU.
See `Original code
<https://github.com/kaituoxu/Conv-TasNet/blob/master>`__ and `License
<https://github.com/kaituoxu/Conv-TasNet/blob/master/LICENSE>`__.
"""
n_src = pair_wise_losses.shape[-1]
# After transposition, dim 1 corresp. to sources and dim 2 to estimates
pwl = pair_wise_losses.transpose(-1, -2)
perms = pwl.new_tensor(list(permutations(range(n_src))), dtype=torch.long)
# Column permutation indices
idx = torch.unsqueeze(perms, 2)
# Loss mean of each permutation
if perm_reduce is None:
# one-hot, [n_src!, n_src, n_src]
perms_one_hot = pwl.new_zeros((*perms.size(), n_src)).scatter_(2, idx, 1)
loss_set = torch.einsum("bij,pij->bp", [pwl, perms_one_hot])
loss_set /= n_src
else:
# batch = pwl.shape[0]; n_perm = idx.shape[0]
# [batch, n_src!, n_src] : Pairwise losses for each permutation.
pwl_set = pwl[:, torch.arange(n_src), idx.squeeze(-1)]
# Apply reduce [batch, n_src!, n_src] --> [batch, n_src!]
loss_set = perm_reduce(pwl_set, **kwargs)
# Indexes and values of min losses for each batch element
min_loss, min_loss_idx = torch.min(loss_set, dim=1)
# Permutation indices for each batch.
batch_indices = torch.stack([perms[m] for m in min_loss_idx], dim=0)
return min_loss, batch_indices
@staticmethod
def find_best_perm_hungarian(pair_wise_losses: torch.Tensor):
"""Find the best permutation given the pair-wise losses, using the
Hungarian algorithm.
Args:
pair_wise_losses (:class:`torch.Tensor`):
Tensor of shape [batch, n_src, n_src]. Pairwise losses.
Returns:
tuple:
:class:`torch.Tensor`: The loss corresponding to the best
permutation of size (batch,).
:class:`torch.Tensor`: The indices of the best permutations.
"""
# After transposition, dim 1 corresp. to sources and dim 2 to estimates
pwl = pair_wise_losses.transpose(-1, -2)
# Just bring the numbers to cpu(), not the graph
pwl_copy = pwl.detach().cpu()
# Loop over batch + row indices are always ordered for square matrices.
batch_indices = torch.tensor([linear_sum_assignment(pwl)[1] for pwl in pwl_copy])
min_loss = torch.gather(pwl, 2, batch_indices[..., None]).mean([-1, -2])
return min_loss, batch_indices