From 059afc56ef57e3d01b9ee9d74027b549863144c0 Mon Sep 17 00:00:00 2001 From: Shuhei Iitsuka Date: Fri, 10 Feb 2023 09:29:41 +0900 Subject: [PATCH] Refactor feature_extractor (#113) --- budoux/feature_extractor.py | 55 ---------------- budoux/utils.py | 3 - scripts/encode_data.py | 53 ++++++++++++++-- {tests => scripts/tests}/test_encode_data.py | 54 +++++++++++++++- tests/test_feature_extractor.py | 67 -------------------- 5 files changed, 99 insertions(+), 133 deletions(-) delete mode 100644 budoux/feature_extractor.py rename {tests => scripts/tests}/test_encode_data.py (72%) delete mode 100644 tests/test_feature_extractor.py diff --git a/budoux/feature_extractor.py b/budoux/feature_extractor.py deleted file mode 100644 index 2cfd89ce..00000000 --- a/budoux/feature_extractor.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Methods to encode source sentences to features.""" - -import typing - -from .utils import INVALID - - -def get_feature(w1: str, w2: str, w3: str, w4: str, w5: str, - w6: str) -> typing.List[str]: - """Generates a feature from characters around (w1-6). - - Args: - w1 (str): The character 3 characters before the break point. - w2 (str): The character 2 characters before the break point. - w3 (str): The character right before the break point. - w4 (str): The character right after the break point. - w5 (str): The character 2 characters after the break point. - w6 (str): The character 3 characters after the break point. - - Returns: - The feature (list[str]). - - """ - raw_feature = { - 'UW1': w1, - 'UW2': w2, - 'UW3': w3, - 'UW4': w4, - 'UW5': w5, - 'UW6': w6, - 'BW1': w2 + w3, - 'BW2': w3 + w4, - 'BW3': w4 + w5, - 'TW1': w1 + w2 + w3, - 'TW2': w2 + w3 + w4, - 'TW3': w3 + w4 + w5, - 'TW4': w4 + w5 + w6, - } - for key, value in list(raw_feature.items()): - if INVALID in value: - del raw_feature[key] - return [f'{item[0]}:{item[1]}' for item in raw_feature.items()] diff --git a/budoux/utils.py b/budoux/utils.py index fd73d6ce..278fddaa 100644 --- a/budoux/utils.py +++ b/budoux/utils.py @@ -15,6 +15,3 @@ SEP = '▁' """The separator string to specify breakpoints.""" - -INVALID = '▔' -"""The invalid feature string.""" diff --git a/scripts/encode_data.py b/scripts/encode_data.py index f984116e..025c84e9 100644 --- a/scripts/encode_data.py +++ b/scripts/encode_data.py @@ -20,11 +20,51 @@ import sys import typing -from budoux import feature_extractor, utils +from budoux import utils ArgList = typing.Optional[typing.List[str]] DEFAULT_OUTPUT_FILENAME = 'encoded_data.txt' +INVALID = '▔' +"""The invalid feature string.""" + + +def get_feature(w1: str, w2: str, w3: str, w4: str, w5: str, + w6: str) -> typing.List[str]: + """Generates a feature from characters around (w1-6). + + Args: + w1 (str): The character 3 characters before the break point. + w2 (str): The character 2 characters before the break point. + w3 (str): The character right before the break point. + w4 (str): The character right after the break point. + w5 (str): The character 2 characters after the break point. + w6 (str): The character 3 characters after the break point. + + Returns: + The feature (list[str]). + + """ + raw_feature = { + 'UW1': w1, + 'UW2': w2, + 'UW3': w3, + 'UW4': w4, + 'UW5': w5, + 'UW6': w6, + 'BW1': w2 + w3, + 'BW2': w3 + w4, + 'BW3': w4 + w5, + 'TW1': w1 + w2 + w3, + 'TW2': w2 + w3 + w4, + 'TW3': w3 + w4 + w5, + 'TW4': w4 + w5 + w6, + } + for key, value in list(raw_feature.items()): + if INVALID in value: + del raw_feature[key] + return [f'{item[0]}:{item[1]}' for item in raw_feature.items()] + def parse_args(test: ArgList = None) -> argparse.Namespace: """Parses commandline arguments. @@ -65,12 +105,11 @@ def process(i: int, sentence: str, sep_indices: typing.Set[int]) -> str: sentence (str): A sentence sep_indices (typing.Set[int]): A set of separator indices. """ - feature = feature_extractor.get_feature( - sentence[i - 3] if i > 2 else utils.INVALID, - sentence[i - 2] if i > 1 else utils.INVALID, sentence[i - 1], - sentence[i] if i < len(sentence) else utils.INVALID, - sentence[i + 1] if i + 1 < len(sentence) else utils.INVALID, - sentence[i + 2] if i + 2 < len(sentence) else utils.INVALID) + feature = get_feature(sentence[i - 3] if i > 2 else INVALID, + sentence[i - 2] if i > 1 else INVALID, sentence[i - 1], + sentence[i] if i < len(sentence) else INVALID, + sentence[i + 1] if i + 1 < len(sentence) else INVALID, + sentence[i + 2] if i + 2 < len(sentence) else INVALID) positive = i in sep_indices line = '\t'.join(['1' if positive else '-1'] + feature) return line diff --git a/tests/test_encode_data.py b/scripts/tests/test_encode_data.py similarity index 72% rename from tests/test_encode_data.py rename to scripts/tests/test_encode_data.py index 03186d86..25161ded 100644 --- a/tests/test_encode_data.py +++ b/scripts/tests/test_encode_data.py @@ -15,16 +15,64 @@ import os import sys +import typing import unittest from budoux import utils # module hack -LIB_PATH = os.path.join(os.path.dirname(__file__), '..') +LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..') sys.path.insert(0, os.path.abspath(LIB_PATH)) + from scripts import encode_data # type: ignore # noqa (module hack) +class TestGetFeature(unittest.TestCase): + + def test_standard(self) -> None: + feature = encode_data.get_feature('a', 'b', 'c', 'd', 'e', 'f') + self.assertSetEqual( + set(feature), + { + # Unigram of Words (UW) + 'UW1:a', + 'UW2:b', + 'UW3:c', + 'UW4:d', + 'UW5:e', + 'UW6:f', + + # Bigram of Words (BW) + 'BW1:bc', + 'BW2:cd', + 'BW3:de', + + # Trigram of Words (TW) + 'TW1:abc', + 'TW2:bcd', + 'TW3:cde', + 'TW4:def', + }, + 'Features should be extracted.') + + def test_with_invalid(self) -> None: + + def find_by_prefix(prefix: str, feature: typing.List[str]) -> bool: + for item in feature: + if item.startswith(prefix): + return True + return False + + feature = encode_data.get_feature('a', 'a', encode_data.INVALID, 'a', 'a', + 'a') + self.assertFalse( + find_by_prefix('UW3:', feature), + 'Should omit the Unigram feature when the character is invalid.') + self.assertFalse( + find_by_prefix('BW2:', feature), + 'Should omit the Bigram feature that covers an invalid character.') + + class TestArgParse(unittest.TestCase): def test_cmdargs_invalid_option(self) -> None: @@ -107,3 +155,7 @@ def test_doubled_seps(self) -> None: sentence, sep_indices = encode_data.normalize_input(source) self.assertEqual(sentence, 'ABCDEFG') self.assertEqual(sep_indices, {3, 5, 7}) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_feature_extractor.py b/tests/test_feature_extractor.py deleted file mode 100644 index f537c006..00000000 --- a/tests/test_feature_extractor.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests methods for the feature extractor.""" - -import typing -import unittest - -from budoux import feature_extractor, utils - - -class TestFeatureExtractor(unittest.TestCase): - - def test_get_feature(self) -> None: - feature = feature_extractor.get_feature('a', 'b', 'c', 'd', 'e', 'f') - self.assertSetEqual( - set(feature), - { - # Unigram of Words (UW) - 'UW1:a', - 'UW2:b', - 'UW3:c', - 'UW4:d', - 'UW5:e', - 'UW6:f', - - # Bigram of Words (BW) - 'BW1:bc', - 'BW2:cd', - 'BW3:de', - - # Trigram of Words (TW) - 'TW1:abc', - 'TW2:bcd', - 'TW3:cde', - 'TW4:def', - }, - 'Features should be extracted.') - - def find_by_prefix(prefix: str, feature: typing.List[str]) -> bool: - for item in feature: - if item.startswith(prefix): - return True - return False - - feature = feature_extractor.get_feature('a', 'a', utils.INVALID, 'a', 'a', - 'a') - self.assertFalse( - find_by_prefix('UW3:', feature), - 'Should omit the Unigram feature when the character is invalid.') - self.assertFalse( - find_by_prefix('BW2:', feature), - 'Should omit the Bigram feature that covers an invalid character.') - - -if __name__ == '__main__': - unittest.main()