-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathpit_criterion.py
140 lines (121 loc) · 4.82 KB
/
pit_criterion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 1 11:00:46 2020
@author: yoonsanghyu
"""
# Created on 2018/12
# Author: Kaituo XU
from itertools import permutations
import torch
EPS = 1e-8
def cal_loss(source, estimate_source, source_lengths):
"""
Args:
source: [B, C, T], B is batch size
estimate_source: [B, C, T]
source_lengths: [B]
"""
max_snr, perms, max_snr_idx = cal_si_snr_with_pit(source,
estimate_source,
source_lengths)
loss = 0 - torch.mean(max_snr)
reorder_estimate_source = reorder_source(estimate_source, perms, max_snr_idx)
return loss, max_snr, estimate_source, reorder_estimate_source
def cal_si_snr_with_pit(source, estimate_source, source_lengths):
"""Calculate SI-SNR with PIT training.
Args:
source: [B, C, T], B is batch size
estimate_source: [B, C, T]
source_lengths: [B], each item is between [0, T]
"""
assert source.size() == estimate_source.size()
B, C, T = source.size()
# mask padding position along T
mask = get_mask(source, source_lengths)
estimate_source *= mask
# Step 1. Zero-mean norm
num_samples = source_lengths.view(-1, 1, 1).float() # [B, 1, 1]
mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples
zero_mean_target = source - mean_target
zero_mean_estimate = estimate_source - mean_estimate
# mask padding position along T
zero_mean_target *= mask
zero_mean_estimate *= mask
# Step 2. SI-SNR with PIT
# reshape to use broadcast
s_target = torch.unsqueeze(zero_mean_target, dim=1) # [B, 1, C, T]
s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2) # [B, C, 1, T]
# s_target = <s', s>s / ||s||^2
pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True) # [B, C, C, 1]
s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + EPS # [B, 1, C, 1]
pair_wise_proj = pair_wise_dot * s_target / s_target_energy # [B, C, C, T]
# e_noise = s' - s_target
e_noise = s_estimate - pair_wise_proj # [B, C, C, T]
# SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS)
pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS) # [B, C, C]
# Get max_snr of each utterance
# permutations, [C!, C]
perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long)
# one-hot, [C!, C, C]
index = torch.unsqueeze(perms, 2)
perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1)
# [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot])
max_snr_idx = torch.argmax(snr_set, dim=1) # [B]
# max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1)) # [B, 1]
max_snr, _ = torch.max(snr_set, dim=1, keepdim=True)
max_snr /= C
return max_snr, perms, max_snr_idx
def reorder_source(source, perms, max_snr_idx):
"""
Args:
source: [B, C, T]
perms: [C!, C], permutations
max_snr_idx: [B], each item is between [0, C!)
Returns:
reorder_source: [B, C, T]
"""
B, C, *_ = source.size()
# [B, C], permutation whose SI-SNR is max of each utterance
# for each utterance, reorder estimate source according this permutation
max_snr_perm = torch.index_select(perms, dim=0, index=max_snr_idx)
# print('max_snr_perm', max_snr_perm)
# maybe use torch.gather()/index_select()/scatter() to impl this?
reorder_source = torch.zeros_like(source)
for b in range(B):
for c in range(C):
reorder_source[b, c] = source[b, max_snr_perm[b][c]]
return reorder_source
def get_mask(source, source_lengths):
"""
Args:
source: [B, C, T]
source_lengths: [B]
Returns:
mask: [B, 1, T]
"""
B, _, T = source.size()
mask = source.new_ones((B, 1, T))
for i in range(B):
mask[i, :, source_lengths[i]:] = 0
return mask
if __name__ == "__main__":
'''
torch.manual_seed(123)
B, C, T = 2, 3, 12
# fake data
source = torch.randint(4, (B, C, T))
estimate_source = torch.randint(4, (B, C, T))
source[1, :, -3:] = 0
estimate_source[1, :, -3:] = 0
source_lengths = torch.LongTensor([T, T-3])
print('source', source)
print('estimate_source', estimate_source)
print('source_lengths', source_lengths)
loss, max_snr, estimate_source, reorder_estimate_source = cal_loss(source, estimate_source, source_lengths)
print('loss', loss)
print('max_snr', max_snr)
print('reorder_estimate_source', reorder_estimate_source)
'''