Skip to content

Commit

Permalink
Refactor feature_extractor (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
tushuhei authored Feb 10, 2023
1 parent 4be9bcc commit 059afc5
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 133 deletions.
55 changes: 0 additions & 55 deletions budoux/feature_extractor.py

This file was deleted.

3 changes: 0 additions & 3 deletions budoux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,3 @@

SEP = '▁'
"""The separator string to specify breakpoints."""

INVALID = '▔'
"""The invalid feature string."""
53 changes: 46 additions & 7 deletions scripts/encode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,51 @@
import sys
import typing

from budoux import feature_extractor, utils
from budoux import utils

ArgList = typing.Optional[typing.List[str]]
DEFAULT_OUTPUT_FILENAME = 'encoded_data.txt'

INVALID = '▔'
"""The invalid feature string."""


def get_feature(w1: str, w2: str, w3: str, w4: str, w5: str,
w6: str) -> typing.List[str]:
"""Generates a feature from characters around (w1-6).
Args:
w1 (str): The character 3 characters before the break point.
w2 (str): The character 2 characters before the break point.
w3 (str): The character right before the break point.
w4 (str): The character right after the break point.
w5 (str): The character 2 characters after the break point.
w6 (str): The character 3 characters after the break point.
Returns:
The feature (list[str]).
"""
raw_feature = {
'UW1': w1,
'UW2': w2,
'UW3': w3,
'UW4': w4,
'UW5': w5,
'UW6': w6,
'BW1': w2 + w3,
'BW2': w3 + w4,
'BW3': w4 + w5,
'TW1': w1 + w2 + w3,
'TW2': w2 + w3 + w4,
'TW3': w3 + w4 + w5,
'TW4': w4 + w5 + w6,
}
for key, value in list(raw_feature.items()):
if INVALID in value:
del raw_feature[key]
return [f'{item[0]}:{item[1]}' for item in raw_feature.items()]


def parse_args(test: ArgList = None) -> argparse.Namespace:
"""Parses commandline arguments.
Expand Down Expand Up @@ -65,12 +105,11 @@ def process(i: int, sentence: str, sep_indices: typing.Set[int]) -> str:
sentence (str): A sentence
sep_indices (typing.Set[int]): A set of separator indices.
"""
feature = feature_extractor.get_feature(
sentence[i - 3] if i > 2 else utils.INVALID,
sentence[i - 2] if i > 1 else utils.INVALID, sentence[i - 1],
sentence[i] if i < len(sentence) else utils.INVALID,
sentence[i + 1] if i + 1 < len(sentence) else utils.INVALID,
sentence[i + 2] if i + 2 < len(sentence) else utils.INVALID)
feature = get_feature(sentence[i - 3] if i > 2 else INVALID,
sentence[i - 2] if i > 1 else INVALID, sentence[i - 1],
sentence[i] if i < len(sentence) else INVALID,
sentence[i + 1] if i + 1 < len(sentence) else INVALID,
sentence[i + 2] if i + 2 < len(sentence) else INVALID)
positive = i in sep_indices
line = '\t'.join(['1' if positive else '-1'] + feature)
return line
Expand Down
54 changes: 53 additions & 1 deletion tests/test_encode_data.py → scripts/tests/test_encode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,64 @@

import os
import sys
import typing
import unittest

from budoux import utils

# module hack
LIB_PATH = os.path.join(os.path.dirname(__file__), '..')
LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.insert(0, os.path.abspath(LIB_PATH))

from scripts import encode_data # type: ignore # noqa (module hack)


class TestGetFeature(unittest.TestCase):

def test_standard(self) -> None:
feature = encode_data.get_feature('a', 'b', 'c', 'd', 'e', 'f')
self.assertSetEqual(
set(feature),
{
# Unigram of Words (UW)
'UW1:a',
'UW2:b',
'UW3:c',
'UW4:d',
'UW5:e',
'UW6:f',

# Bigram of Words (BW)
'BW1:bc',
'BW2:cd',
'BW3:de',

# Trigram of Words (TW)
'TW1:abc',
'TW2:bcd',
'TW3:cde',
'TW4:def',
},
'Features should be extracted.')

def test_with_invalid(self) -> None:

def find_by_prefix(prefix: str, feature: typing.List[str]) -> bool:
for item in feature:
if item.startswith(prefix):
return True
return False

feature = encode_data.get_feature('a', 'a', encode_data.INVALID, 'a', 'a',
'a')
self.assertFalse(
find_by_prefix('UW3:', feature),
'Should omit the Unigram feature when the character is invalid.')
self.assertFalse(
find_by_prefix('BW2:', feature),
'Should omit the Bigram feature that covers an invalid character.')


class TestArgParse(unittest.TestCase):

def test_cmdargs_invalid_option(self) -> None:
Expand Down Expand Up @@ -107,3 +155,7 @@ def test_doubled_seps(self) -> None:
sentence, sep_indices = encode_data.normalize_input(source)
self.assertEqual(sentence, 'ABCDEFG')
self.assertEqual(sep_indices, {3, 5, 7})


if __name__ == '__main__':
unittest.main()
67 changes: 0 additions & 67 deletions tests/test_feature_extractor.py

This file was deleted.

0 comments on commit 059afc5

Please sign in to comment.