Skip to content

Commit

Permalink
[Issue #39] Add Mean Absolute Error (MAE) of speaker count in metrics.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wq2012 committed Oct 25, 2024
1 parent 6e061a9 commit 22a8b49
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
24 changes: 24 additions & 0 deletions DiarizationLM/diarizationlm/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
recognition for unsegmented recordings."
arXiv preprint arXiv:2004.09249 (2020).
https://arxiv.org/pdf/2004.09249
- Speaker Count Mean Absolute Error (SpkCntMAE): This metric is used to evaluate
the accuracy of the predicted number of speakers. See
Quan Wang, Yiling Huang, Han Lu, Guanlong Zhao, Ignacio Lopez Moreno. "Highly
Efficient Real-Time Streaming and Fully On-Device Speaker Diarization with
Multi-Stage Clustering."
arXiv preprint arXiv:2210.13690 (2022).
https://arxiv.org/abs/2210.13690
Note: This implementation is different from Google's internal implementation
that we used in the paper, but is a best-effort attempt to replicate the
Expand Down Expand Up @@ -54,6 +61,8 @@ class UtteranceMetrics:
cpwer_correct: int = 0
cpwer_total: int = 0

speaker_count_error: int = 0


def merge_cpwer(
wer_metrics: list[UtteranceMetrics], cpwer_metrics: UtteranceMetrics
Expand Down Expand Up @@ -185,6 +194,14 @@ def compute_utterance_metrics(
continue
metrics_to_concat.append(spk_pair_metrics[(r + 1, c + 1)])
merge_cpwer(metrics_to_concat, result)

########################################
# Compute speaker count error.
########################################
hyp_spk_count = len(set(hyp_spk_list))
ref_spk_count = len(set(ref_spk_list))
result.speaker_count_error = hyp_spk_count - ref_spk_count

return result


Expand Down Expand Up @@ -235,6 +252,7 @@ def compute_metrics_on_json_dict(
final_cpwer_sub = 0
final_cpwer_delete = 0
final_cpwer_insert = 0
final_speaker_count_absolute_error_total = 0
for utt in result_dict["utterances"]:
final_wer_total += utt["wer_total"]
final_wer_correct += utt["wer_correct"]
Expand All @@ -250,6 +268,9 @@ def compute_metrics_on_json_dict(
final_cpwer_sub += utt["cpwer_sub"]
final_cpwer_delete += utt["cpwer_delete"]
final_cpwer_insert += utt["cpwer_insert"]
final_speaker_count_absolute_error_total += abs(
utt["speaker_count_error"]
)

result_dict["WER"] = (
final_wer_sub + final_wer_delete + final_wer_insert
Expand All @@ -260,4 +281,7 @@ def compute_metrics_on_json_dict(
result_dict["cpWER"] = (
final_cpwer_sub + final_cpwer_delete + final_cpwer_insert
) / final_cpwer_total
result_dict["SpkCntMAE"] = final_speaker_count_absolute_error_total / len(
result_dict["utterances"]
)
return result_dict
13 changes: 12 additions & 1 deletion DiarizationLM/diarizationlm/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_wder_same_words(self):
self.assertEqual(result.cpwer_sub, 0)
self.assertEqual(result.cpwer_correct, 5)
self.assertEqual(result.cpwer_total, 6)
self.assertEqual(result.speaker_count_error, 0)

def test_wder_diff_words(self):
hyp = "a b c d e f g h"
Expand All @@ -69,14 +70,14 @@ def test_wder_diff_words(self):
self.assertEqual(result.cpwer_sub, 3)
self.assertEqual(result.cpwer_correct, 4)
self.assertEqual(result.cpwer_total, 9)
self.assertEqual(result.speaker_count_error, 0)

def test_cpwer_edge_word_permutation(self):
hyp = "y x"
ref = "x y"
hyp_spk = "2 1"
ref_spk = "1 2"
result = metrics.compute_utterance_metrics(hyp, ref, hyp_spk, ref_spk)
print("result:", result)
self.assertEqual(result.wer_insert, 1)
self.assertEqual(result.wer_delete, 1)
self.assertEqual(result.wer_sub, 0)
Expand All @@ -88,6 +89,14 @@ def test_cpwer_edge_word_permutation(self):
self.assertEqual(result.cpwer_correct, 2)
self.assertEqual(result.cpwer_total, 2)

def test_speaker_count_error(self):
hyp = "a b c d"
ref = "a b c d"
hyp_spk = "1 2 2 2"
ref_spk = "1 2 2 3"
result = metrics.compute_utterance_metrics(hyp, ref, hyp_spk, ref_spk)
self.assertEqual(result.speaker_count_error, -1)

def test_compute_metrics_on_json_dict(self):
json_dict = {
"utterances": [
Expand Down Expand Up @@ -183,6 +192,7 @@ def test_compute_metrics_on_json_file_oracle(self):
self.assertAlmostEqual(result["WER"], 0.2363, delta=0.001)
self.assertAlmostEqual(result["WDER"], 0.0, delta=0.001)
self.assertAlmostEqual(result["cpWER"], 0.2363, delta=0.001)
self.assertAlmostEqual(result["SpkCntMAE"], 0.0, delta=0.001)

def test_compute_metrics_on_json_file_degraded(self):
json_file = os.path.join("testdata", "example_data.json")
Expand All @@ -197,6 +207,7 @@ def test_compute_metrics_on_json_file_degraded(self):
self.assertAlmostEqual(result["WER"], 0.2363, delta=0.001)
self.assertAlmostEqual(result["WDER"], 0.0, delta=0.001)
self.assertAlmostEqual(result["cpWER"], 0.2363, delta=0.001)
self.assertAlmostEqual(result["SpkCntMAE"], 0.0, delta=0.001)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion DiarizationLM/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import setuptools

VERSION = "0.1.3"
VERSION = "0.1.4"

with open("README.md", "r") as file_object:
LONG_DESCRIPTION = file_object.read()
Expand Down

0 comments on commit 22a8b49

Please sign in to comment.