-
-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_unigram_truecaser.py
88 lines (64 loc) · 3.08 KB
/
test_unigram_truecaser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from typing import Any, Iterable
from machine.corpora.memory_text import MemoryText
from machine.corpora.text import Text
from machine.corpora.text_corpus import TextCorpus
from machine.corpora.text_row import TextRow, TextRowFlags
from machine.translation.unigram_truecaser import UnigramTruecaser, UnigramTruecaserTrainer
training_segments = [
["The", "house", "is", "made", "of", "wood", "."],
["I", "go", "on", "adventures", "."],
["He", "read", "the", "book", "about", "Sherlock", "Holmes", "."],
["John", "and", "I", "agree", "that", "you", "and", "I", "are", "smart", "."],
]
def test_truecase_empty() -> None:
truecaser = create_truecaser()
result = truecaser.truecase([])
assert result == []
def test_truecase_capitialized_name() -> None:
truecaser = create_truecaser()
result = truecaser.truecase(["THE", "ADVENTURES", "OF", "SHERLOCK", "HOLMES"])
assert result == ["the", "adventures", "of", "Sherlock", "Holmes"]
def test_truecase_unknown_word() -> None:
truecaser = create_truecaser()
result = truecaser.truecase(["THE", "EXPLOITS", "OF", "SHERLOCK", "HOLMES"])
assert result == ["the", "EXPLOITS", "of", "Sherlock", "Holmes"]
def test_truecase_multiple_sentences() -> None:
truecaser = create_truecaser()
result = truecaser.truecase(["SHERLOCK", "HOLMES", "IS", "SMART", ".", "YOU", "AGREE", "."])
assert result == ["Sherlock", "Holmes", "is", "smart", ".", "you", "agree", "."]
def test_truecase_ignore_first_word_during_training() -> None:
truecaser = create_truecaser()
result = truecaser.truecase(["HE", "IS", "SMART", "."])
assert result == ["HE", "is", "smart", "."]
def create_truecaser() -> UnigramTruecaser:
truecaser = UnigramTruecaser()
for segment in training_segments:
truecaser.train_segment(segment)
return truecaser
class MemoryTextCorpus(TextCorpus):
def __init__(self, id: str, rows: Iterable[TextRow] = []) -> None:
self._id = id
self._rows = list(rows)
@property
def texts(self) -> Iterable[Text]:
return [MemoryText(self._id, self._rows)]
@property
def is_tokenized(self) -> bool:
return True
def text_row(text_id: str, ref: Any, text: str = "", flags: TextRowFlags = TextRowFlags.SENTENCE_START) -> TextRow:
return TextRow(text_id, ref, [] if len(text) == 0 else text.split(), flags)
def test_compare_with_truecase_trainer() -> None:
text = MemoryTextCorpus(
"text1", [text_row("text1", i, " ".join(segment)) for i, segment in enumerate(training_segments)]
)
trainer = UnigramTruecaserTrainer(text)
trainer.new_truecaser = UnigramTruecaser()
trainer.train()
truecaser = create_truecaser()
assert trainer.new_truecaser._bestTokens == truecaser._bestTokens
for key in trainer.new_truecaser._casing.get_conditions():
for sample in trainer.new_truecaser._casing._freq_dist[key]._sample_counts.keys():
assert (
trainer.new_truecaser._casing._freq_dist[key]._sample_counts[sample]
== truecaser._casing._freq_dist[key]._sample_counts[sample]
)