diff --git a/.github/workflows/style-check.yml b/.github/workflows/style-check.yml index 263359d8..f7e18406 100644 --- a/.github/workflows/style-check.yml +++ b/.github/workflows/style-check.yml @@ -7,9 +7,28 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: - python-version: '3.7' - - run: python -m pip install --upgrade yapf toml - - run: yapf --diff --recursive budoux tests scripts + python-version: '3.9' + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -r requirements_dev.txt + - name: Run yapf + run: | + yapf --diff --recursive budoux tests scripts + - name: Run mypy + if: ${{ always() }} + uses: sasanquaneuf/mypy-github-action@a0c442aa252655d7736ce6696e06227ccdd62870 + with: + checkName: python-style-check + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Run flake8 + if: ${{ always() }} + uses: suo/flake8-github-action@3e87882219642e01aa8a6bbd03b4b0adb8542c2a + with: + checkName: python-style-check + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} typescript-style-check: runs-on: ubuntu-latest steps: diff --git a/budoux/feature_extractor.py b/budoux/feature_extractor.py index 83f2aede..b3640b19 100644 --- a/budoux/feature_extractor.py +++ b/budoux/feature_extractor.py @@ -19,13 +19,14 @@ import os import sys import typing + from .utils import SEP, Result with open(os.path.join(os.path.dirname(__file__), 'unicode_blocks.json')) as f: block_starts: typing.List[int] = json.load(f) -def unicode_block_index(w: str): +def unicode_block_index(w: str) -> int: """Returns the index of the Unicode block that the character belongs to. Args: @@ -38,7 +39,7 @@ def unicode_block_index(w: str): def get_feature(w1: str, w2: str, w3: str, w4: str, w5: str, w6: str, p1: str, - p2: str, p3: str): + p2: str, p3: str) -> typing.List[str]: """Generates a feature from characters around (w1-6) and past results (p1-3). Args: @@ -129,7 +130,7 @@ def get_feature(w1: str, w2: str, w3: str, w4: str, w5: str, w6: str, p1: str, return [f'{item[0]}:{item[1]}' for item in raw_feature.items()] -def process(source_filename: str, entries_filename: str): +def process(source_filename: str, entries_filename: str) -> None: """Extratcs features from source sentences and outputs as entries. Args: @@ -141,8 +142,8 @@ def process(source_filename: str, entries_filename: str): with open(entries_filename, 'w', encoding=sys.getdefaultencoding()) as f: f.write('') - for row in data: - chunks = row.strip().split(SEP) + for datum in data: + chunks = datum.strip().split(SEP) chunk_lengths = [len(chunk) for chunk in chunks] sep_indices = set(itertools.accumulate(chunk_lengths, lambda x, y: x + y)) sentence = ''.join(chunks) diff --git a/budoux/main.py b/budoux/main.py index 632ec293..0321364d 100644 --- a/budoux/main.py +++ b/budoux/main.py @@ -116,38 +116,34 @@ def parse_args(test: ArgList = None) -> argparse.Namespace: return parser.parse_args() -def _main(test: ArgList = None): +def _main(test: ArgList = None) -> str: args = parse_args(test=test) with open(args.model, "r") as f: model = json.load(f) parser = budoux.Parser(model) - if args.html: if args.text is None: - inputs = sys.stdin.read() + inputs_html = sys.stdin.read() else: - inputs = args.text - res = parser.translate_html_string(inputs, thres=args.thres) + inputs_html = args.text + res = parser.translate_html_string(inputs_html, thres=args.thres) else: if args.text is None: inputs = [v.rstrip() for v in sys.stdin.readlines()] else: inputs = [v.rstrip() for v in args.text.splitlines()] outputs = [parser.parse(sentence, thres=args.thres) for sentence in inputs] - res = ["\n".join(res) for res in outputs] + conbined_output = ["\n".join(output) for output in outputs] ors = "\n" + args.delim + "\n" - res = ors.join(res) + res = ors.join(conbined_output) - if test is not None: - return res - else: - print(res) + return res -def main(test: ArgList = None): +def main(test: ArgList = None) -> None: try: - _main(test) + print(_main(test)) except KeyboardInterrupt: exit(0) diff --git a/budoux/parser.py b/budoux/parser.py index 651b58b5..0b0f7e72 100644 --- a/budoux/parser.py +++ b/budoux/parser.py @@ -17,8 +17,9 @@ import os import typing from html.parser import HTMLParser + from .feature_extractor import get_feature -from .utils import Result, SEP +from .utils import SEP, Result MODEL_DIR = os.path.join(os.path.dirname(__file__), 'models') PARENT_CSS_STYLE = 'word-break: keep-all; overflow-wrap: break-word;' @@ -37,7 +38,7 @@ class TextContentExtractor(HTMLParser): """ output = '' - def handle_data(self, data): + def handle_data(self, data: str) -> None: self.output += data @@ -60,7 +61,7 @@ def __init__(self, chunks: typing.List[str]): self.chunks_joined = SEP.join(chunks) self.to_skip = False - def handle_starttag(self, tag: str, attrs: HTMLAttr): + def handle_starttag(self, tag: str, attrs: HTMLAttr) -> None: attr_pairs = [] for attr in attrs: if attr[1] is None: @@ -71,18 +72,18 @@ def handle_starttag(self, tag: str, attrs: HTMLAttr): self.output += '<%s %s>' % (tag, encoded_attrs) self.to_skip = tag.upper() in SKIP_NODES - def handle_endtag(self, tag: str): + def handle_endtag(self, tag: str) -> None: self.output += '' % (tag) self.to_skip = False - def handle_data(self, data: str): + def handle_data(self, data: str) -> None: if self.to_skip: self.output += data if self.chunks_joined[0] == SEP: self.chunks_joined = self.chunks_joined[1 + len(data):] else: self.chunks_joined = self.chunks_joined[len(data):] - return + return None for char in data: if char == self.chunks_joined[0]: self.chunks_joined = self.chunks_joined[1:] @@ -110,7 +111,9 @@ def __init__(self, model: typing.Dict[str, int]): """ self.model = model - def parse(self, sentence: str, thres: int = DEFAULT_THRES): + def parse(self, + sentence: str, + thres: int = DEFAULT_THRES) -> typing.List[str]: """Parses the input sentence and returns a list of semantic chunks. Args: @@ -134,7 +137,7 @@ def parse(self, sentence: str, thres: int = DEFAULT_THRES): p1, p2, p3) score = 0 for f in feature: - if not f in self.model: + if f not in self.model: continue score += self.model[f] if score > thres: @@ -147,7 +150,7 @@ def parse(self, sentence: str, thres: int = DEFAULT_THRES): p3 = p return chunks - def translate_html_string(self, html: str, thres: int = DEFAULT_THRES): + def translate_html_string(self, html: str, thres: int = DEFAULT_THRES) -> str: """Translates the given HTML string with markups for semantic line breaks. Args: @@ -167,7 +170,7 @@ def translate_html_string(self, html: str, thres: int = DEFAULT_THRES): return '%s' % (PARENT_CSS_STYLE, resolver.output) -def load_default_japanese_parser(): +def load_default_japanese_parser() -> Parser: """Loads a parser equipped with the default Japanese model. Returns: @@ -175,4 +178,4 @@ def load_default_japanese_parser(): """ with open(os.path.join(MODEL_DIR, 'ja-knbc.json')) as f: model = json.load(f) - return Parser(model) \ No newline at end of file + return Parser(model) diff --git a/budoux/utils.py b/budoux/utils.py index 6c7859b7..aaaaf2b7 100644 --- a/budoux/utils.py +++ b/budoux/utils.py @@ -23,4 +23,4 @@ class Result(Enum): """An enum to represent the type of inference result.""" UNKNOWN = 'U' POSITIVE = 'B' - NEGATIVE = 'O' \ No newline at end of file + NEGATIVE = 'O' diff --git a/requirements_dev.txt b/requirements_dev.txt index 04240c19..854b3695 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,6 +1,10 @@ -numpy -html5lib -yapf build -twine +flake8 +html5lib +numpy +mypy toml +twine +types-html5lib +types-setuptools +yapf diff --git a/scripts/build_model.py b/scripts/build_model.py index 0545d493..339befda 100644 --- a/scripts/build_model.py +++ b/scripts/build_model.py @@ -22,7 +22,9 @@ import typing -def rollup(weights_filename: str, model_filename: str, scale: int = 1000): +def rollup(weights_filename: str, + model_filename: str, + scale: int = 1000) -> None: """Rolls up the weights and outputs a model in JSON with integer scores. Args: @@ -43,7 +45,7 @@ def rollup(weights_filename: str, model_filename: str, scale: int = 1000): json.dump(decision_trees_intscore, f) -def main(): +def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( 'weight_file', help='A file path for the learned weights.') diff --git a/scripts/context.py b/scripts/context.py index 3b4a5781..12956356 100644 --- a/scripts/context.py +++ b/scripts/context.py @@ -17,6 +17,4 @@ LIB_PATH = os.path.join(os.path.dirname(__file__), '..') sys.path.insert(0, os.path.abspath(LIB_PATH)) -from budoux import feature_extractor -from budoux import utils -from budoux import Parser +from budoux import Parser, feature_extractor, utils # noqa (unused) diff --git a/scripts/encode_data.py b/scripts/encode_data.py index 031c4344..4168db8d 100644 --- a/scripts/encode_data.py +++ b/scripts/encode_data.py @@ -14,10 +14,11 @@ """Encodes the training data with extracted features.""" import argparse + from context import feature_extractor -def main(): +def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( 'source_data', diff --git a/scripts/load_knbc.py b/scripts/load_knbc.py index a1add91e..cc812be6 100644 --- a/scripts/load_knbc.py +++ b/scripts/load_knbc.py @@ -17,9 +17,10 @@ import os import tarfile import typing -import urllib.request import urllib.error +import urllib.request from html.parser import HTMLParser + from context import utils RESOURCE_URL = ( @@ -29,15 +30,15 @@ class KNBCHTMLParser(HTMLParser): """Parses the HTML files in the KNBC corpus and outputs the chunks.""" - def __init__(self, split_tab: bool = True): + def __init__(self, split_tab: bool = True) -> None: super().__init__() self.chunks = [''] self.n_rows = 0 self.n_cols = 0 - self.current_word = None + self.current_word: typing.Optional[str] = None self.split_tab = split_tab - def handle_starttag(self, tag, _): + def handle_starttag(self, tag: str, _: typing.Any) -> None: if tag == 'tr': self.n_rows += 1 self.n_cols = 0 @@ -45,21 +46,22 @@ def handle_starttag(self, tag, _): if tag == 'td': self.n_cols += 1 - def handle_endtag(self, tag): + def handle_endtag(self, tag: str) -> None: if tag != 'tr': - return - if (self.n_rows > 2 and self.n_cols == 1 and - (self.split_tab or self.current_word == '文節区切り')): + return None + flag1 = self.n_rows > 2 and self.n_cols == 1 + flag2 = self.split_tab or self.current_word == '文節区切り' + if flag1 and flag2: self.chunks.append('') - if self.n_cols == 5: + if self.n_cols == 5 and type(self.current_word) is str: self.chunks[-1] += self.current_word - def handle_data(self, data): + def handle_data(self, data: str) -> None: if self.n_cols == 1: self.current_word = data -def break_before_open_parentheses(chunks: typing.List[str]): +def break_before_open_parentheses(chunks: typing.List[str]) -> typing.List[str]: """Adds chunk breaks before every open parentheses. Args: @@ -80,7 +82,7 @@ def break_before_open_parentheses(chunks: typing.List[str]): return out -def postprocess(chunks: typing.List[str]): +def postprocess(chunks: typing.List[str]) -> typing.List[str]: """Applies some processes to modify the extracted chunks. Args: @@ -93,7 +95,7 @@ def postprocess(chunks: typing.List[str]): return chunks -def download_knbc(target_dir: str): +def download_knbc(target_dir: str) -> None: """Downloads the KNBC corpus and extracts files. Args: @@ -110,7 +112,7 @@ def download_knbc(target_dir: str): t.extractall(path=target_dir) -def main(): +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( '-o', @@ -118,7 +120,11 @@ def main(): help='''File path to output the training data. (default: source.txt)''', default='source.txt') - args = parser.parse_args() + return parser.parse_args() + + +def main() -> None: + args = parse_args() outfile = args.outfile html_dir = 'data/KNBC_v1.0_090925_utf8/html/' if not os.path.isdir(html_dir): diff --git a/scripts/train.py b/scripts/train.py index e62dd111..643073da 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -16,12 +16,17 @@ import argparse import typing from collections import Counter + import numpy as np +import numpy.typing as npt -EPS = np.finfo(float).eps +EPS = np.finfo(float).eps # type: np.floating[typing.Any] -def preprocess(entries_filename: str, feature_thres: int): +def preprocess( + entries_filename: str, feature_thres: int +) -> typing.Tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_], + typing.List[str]]: """Loads entries and translates them into NumPy arrays. Args: @@ -51,8 +56,8 @@ def preprocess(entries_filename: str, feature_thres: int): M = len(features) + 1 N = len(entries) - Y = np.zeros(N, dtype=bool) - X = np.zeros((N, M), dtype=bool) + Y: npt.NDArray[np.bool_] = np.zeros(N, dtype=bool) + X: npt.NDArray[np.bool_] = np.zeros((N, M), dtype=bool) for i, entry in enumerate(entries): Y[i] = entry[0] == '1' @@ -63,7 +68,8 @@ def preprocess(entries_filename: str, feature_thres: int): return X, Y, features -def pred(phis: typing.Dict[int, float], X: np.ndarray) -> np.ndarray: +def pred(phis: typing.Dict[int, float], + X: npt.NDArray[np.bool_]) -> npt.NDArray[np.bool_]: """Predicts the output from the given classifiers and input entries. Args: @@ -74,16 +80,21 @@ def pred(phis: typing.Dict[int, float], X: np.ndarray) -> np.ndarray: Returns: A list of inferred labels. """ + alphas: npt.NDArray[np.float64] + y: npt.NDArray[np.int64] + alphas = np.array(list(phis.values())) - y = 2 * (X[:, list(phis.keys())] == True) - 1 + y = 2 * (X[:, list(phis.keys())] + == True) - 1 # noqa (cannot replace `==` with `is`) return y.dot(alphas) > 0 def split_dataset( - X: np.ndarray, - Y: np.ndarray, - split_ratio=0.9 -) -> typing.Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + X: npt.NDArray[typing.Any], + Y: npt.NDArray[typing.Any], + split_ratio: float = 0.9 +) -> typing.Tuple[npt.NDArray[typing.Any], npt.NDArray[typing.Any], + npt.NDArray[typing.Any], npt.NDArray[typing.Any]]: """Splits given entries and labels to training and testing datasets. Args: @@ -107,8 +118,9 @@ def split_dataset( return X_train, X_test, Y_train, Y_test -def fit(X: np.ndarray, Y: np.ndarray, features: typing.List[str], iters: int, - weights_filename: str, log_filename: str): +def fit(X: npt.NDArray[np.bool_], Y: npt.NDArray[np.bool_], + features: typing.List[str], iters: int, weights_filename: str, + log_filename: str) -> typing.Dict[int, float]: """Trains an AdaBoost classifier. Args: @@ -135,7 +147,7 @@ def fit(X: np.ndarray, Y: np.ndarray, features: typing.List[str], iters: int, for t in range(iters): print('=== %s ===' % (t)) - res: np.ndarray = w.dot(Y_train[:, None] ^ X_train) / w.sum() + res: npt.NDArray[np.float64] = w.dot(Y_train[:, None] ^ X_train) / w.sum() err = 0.5 - np.abs(res - 0.5) m_best = int(err.argmin()) pol_best = res[m_best] < 0.5 @@ -161,7 +173,7 @@ def fit(X: np.ndarray, Y: np.ndarray, features: typing.List[str], iters: int, return phis -def main(): +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( 'encoded_train_data', help='File path for the encoded training data.') @@ -183,7 +195,11 @@ def main(): help='Number of iterations for training. (default: 10000)', default=10000) - args = parser.parse_args() + return parser.parse_args() + + +def main() -> None: + args = parse_args() train_data_filename = args.encoded_train_data weights_filename = args.output log_filename = args.log diff --git a/setup.cfg b/setup.cfg index c131d26b..67c61249 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,3 +27,15 @@ console_scripts = [yapf] based_on_style = yapf + +[flake8] +# E501: line too long +# E126: over-indentation +# BLK100: black formattable +ignore = E126,E501,BLK100 +indent-size = 2 + +[mypy] +python_version = 3.9 +pretty = True +strict = True diff --git a/tests/test_feature_extractor.py b/tests/test_feature_extractor.py index 36653556..b1c10662 100644 --- a/tests/test_feature_extractor.py +++ b/tests/test_feature_extractor.py @@ -13,19 +13,23 @@ # limitations under the License. """Tests methods for the feature extractor.""" -import unittest +import io import os import sys +import unittest from pathlib import Path +# module hack LIB_PATH = os.path.join(os.path.dirname(__file__), '..') sys.path.insert(0, os.path.abspath(LIB_PATH)) -from budoux import feature_extractor -from budoux import utils +from budoux import feature_extractor, utils # noqa (module hack) + +if isinstance(sys.stdin, io.TextIOWrapper) and sys.version_info >= (3, 7): + sys.stdin.reconfigure(encoding='utf-8') -sys.stdin.reconfigure(encoding='utf-8') -sys.stdout.reconfigure(encoding='utf-8') +if isinstance(sys.stdout, io.TextIOWrapper) and sys.version_info >= (3, 7): + sys.stdout.reconfigure(encoding='utf-8') SOURCE_FILE_PATH = os.path.abspath( os.path.join(os.path.dirname(__file__), 'source_test.txt')) @@ -35,22 +39,22 @@ class TestFeatureExtractor(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: Path(ENTRIES_FILE_PATH).touch() self.test_entry = f'これは{utils.SEP}美しい{utils.SEP}ペンです。' with open(SOURCE_FILE_PATH, 'w', encoding=sys.getdefaultencoding()) as f: f.write(self.test_entry) - def test_unicode_block_index(self): + def test_unicode_block_index(self) -> None: - def check(character, block): + def check(character: str, block: int) -> None: self.assertEqual(feature_extractor.unicode_block_index(character), block) check('a', 1) # 'a' falls the 1st block 'Basic Latin'. check('あ', 108) # 'あ' falls the 108th block 'Hiragana'. check('安', 120) # '安' falls the 120th block 'Kanji'. - def test_get_feature(self): + def test_get_feature(self) -> None: feature = feature_extractor.get_feature('a', 'b', 'c', 'd', 'e', 'f', 'x', 'y', 'z') self.assertSetEqual( @@ -134,7 +138,7 @@ def test_get_feature(self): 'BB3:999999', feature, 'BB features that imply the end of line should not be included.') - def test_process(self): + def test_process(self) -> None: feature_extractor.process(SOURCE_FILE_PATH, ENTRIES_FILE_PATH) with open( ENTRIES_FILE_PATH, encoding=sys.getdefaultencoding(), @@ -169,7 +173,7 @@ def test_process(self): self.assertIn('UW3:い', features[3]) self.assertIn('UW3:。', features[-1]) - def tearDown(self): + def tearDown(self) -> None: os.remove(SOURCE_FILE_PATH) os.remove(ENTRIES_FILE_PATH) diff --git a/tests/test_main.py b/tests/test_main.py index c7cdd891..cd2881dd 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -13,37 +13,41 @@ # limitations under the License. """Tests the BudouX CLI.""" -from os.path import * +import io import sys import unittest -from pathlib import Path +from os.path import abspath, dirname, join +# module hack LIB_PATH = join(dirname(__file__), '..') sys.path.insert(0, abspath(LIB_PATH)) -from budoux import main +from budoux import main # noqa (module hack) -sys.stdin.reconfigure(encoding='utf-8') -sys.stdout.reconfigure(encoding='utf-8') +if isinstance(sys.stdin, io.TextIOWrapper) and sys.version_info >= (3, 7): + sys.stdin.reconfigure(encoding='utf-8') + +if isinstance(sys.stdout, io.TextIOWrapper) and sys.version_info >= (3, 7): + sys.stdout.reconfigure(encoding='utf-8') class TestCommonOption(unittest.TestCase): - def test_cmdargs_invalid_option(self): + def test_cmdargs_invalid_option(self) -> None: cmdargs = ['-v'] with self.assertRaises(SystemExit) as cm: main.parse_args(cmdargs) self.assertEqual(cm.exception.code, 2) - def test_cmdargs_help(self): + def test_cmdargs_help(self) -> None: cmdargs = ['-h'] with self.assertRaises(SystemExit) as cm: main.parse_args(cmdargs) self.assertEqual(cm.exception.code, 0) - def test_cmdargs_version(self): + def test_cmdargs_version(self) -> None: cmdargs = ['-V'] with self.assertRaises(SystemExit) as cm: main.parse_args(cmdargs) @@ -53,44 +57,38 @@ def test_cmdargs_version(self): class TestTextArguments(unittest.TestCase): - def test_cmdargs_single_text(self): + def test_cmdargs_single_text(self) -> None: cmdargs = ['これはテストです。'] output = main._main(cmdargs) self.assertEqual(output, "これは\nテストです。") - def test_cmdargs_single_multiline_text(self): + def test_cmdargs_single_multiline_text(self) -> None: cmdargs = ["これはテストです。\n今日は晴天です。"] output = main._main(cmdargs) self.assertEqual(output, "これは\nテストです。\n---\n今日は\n晴天です。") - def test_cmdargs_single_multiline_text_with_delimiter(self): + def test_cmdargs_single_multiline_text_with_delimiter(self) -> None: cmdargs = ["これはテストです。\n今日は晴天です。", "-d", "@"] output = main._main(cmdargs) self.assertEqual(output, "これは\nテストです。\n@\n今日は\n晴天です。") - def test_cmdargs_single_multiline_text_with_empty_delimiter(self): + def test_cmdargs_single_multiline_text_with_empty_delimiter(self) -> None: cmdargs = ["これはテストです。\n今日は晴天です。", "-d", ""] output = main._main(cmdargs) self.assertEqual(output, "これは\nテストです。\n\n今日は\n晴天です。") - def test_cmdargs_single_text(self): - cmdargs = ["これはテストです。\n今日は晴天です。"] - output = main._main(cmdargs) - - self.assertEqual(output, "これは\nテストです。\n---\n今日は\n晴天です。") - - def test_cmdargs_multi_text(self): + def test_cmdargs_multi_text(self) -> None: cmdargs = ['これはテストです。', '今日は晴天です。'] with self.assertRaises(SystemExit) as cm: main.main(cmdargs) self.assertEqual(cm.exception.code, 2) - def test_cmdargs_single_html(self): + def test_cmdargs_single_html(self) -> None: cmdargs = ['-H', '今日はとても天気です。'] output = main._main(cmdargs) @@ -99,14 +97,14 @@ def test_cmdargs_single_html(self): '' '今日はとても天気です。') - def test_cmdargs_multi_html(self): + def test_cmdargs_multi_html(self) -> None: cmdargs = ['-H', '今日はとても天気です。', 'これはテストです。'] with self.assertRaises(SystemExit) as cm: main._main(cmdargs) self.assertEqual(cm.exception.code, 2) - def test_cmdargs_thres(self): + def test_cmdargs_thres(self) -> None: cmdargs = ['--thres', '0', '今日はとても天気です。'] output_granular = main._main(cmdargs) cmdargs = ['--thres', '10000000', '今日はとても天気です。'] @@ -123,7 +121,7 @@ def test_cmdargs_thres(self): class TestStdin(unittest.TestCase): - def test_cmdargs_blank_stdin(self): + def test_cmdargs_blank_stdin(self) -> None: with open( join(abspath(dirname(__file__)), "in/1.in"), "r", @@ -133,7 +131,7 @@ def test_cmdargs_blank_stdin(self): self.assertEqual(output, "") - def test_cmdargs_text_stdin(self): + def test_cmdargs_text_stdin(self) -> None: with open( join(abspath(dirname(__file__)), "in/2.in"), "r", @@ -143,7 +141,7 @@ def test_cmdargs_text_stdin(self): self.assertEqual(output, "これは\nテストです。") - def test_cmdargs_html_stdin(self): + def test_cmdargs_html_stdin(self) -> None: with open( join(abspath(dirname(__file__)), "in/3.in"), "r", diff --git a/tests/test_parser.py b/tests/test_parser.py index 45429171..f5c85bba 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -13,21 +13,23 @@ # limitations under the License. """Tests the BudouX parser.""" -import unittest import os import sys +import unittest import xml.etree.ElementTree as ET + import html5lib +# module hack LIB_PATH = os.path.join(os.path.dirname(__file__), '..') sys.path.insert(0, os.path.abspath(LIB_PATH)) -from budoux import parser +from budoux import parser # noqa (module hack) html_parser = html5lib.HTMLParser() -def compare_html_string(a, b): +def compare_html_string(a: str, b: str) -> bool: a_normalized = ET.tostring(html_parser.parse(a)) b_normalized = ET.tostring(html_parser.parse(b)) return a_normalized == b_normalized @@ -35,7 +37,7 @@ def compare_html_string(a, b): class TestTextContentExtractor(unittest.TestCase): - def test_output(self): + def test_output(self) -> None: input = '

Hello, World

' expected = 'Hello, World' extractor = parser.TextContentExtractor() @@ -47,7 +49,7 @@ def test_output(self): class TestHTMLChunkResolver(unittest.TestCase): - def test_output(self): + def test_output(self) -> None: input = '

abcdef

' expected = '

abcdef

' resolver = parser.HTMLChunkResolver(['abc', 'def']) @@ -60,7 +62,7 @@ def test_output(self): class TestParser(unittest.TestCase): TEST_SENTENCE = 'abcdeabcd' - def test_parse(self): + def test_parse(self) -> None: p = parser.Parser({ 'UW4:a': 10000, # means "should separate right before 'a'". }) @@ -89,7 +91,7 @@ def test_parse(self): self.assertListEqual(chunks, [], 'should return a blank list when the input is blank.') - def test_translate_html_string(self): + def test_translate_html_string(self) -> None: p = parser.Parser({ 'UW4:a': 10000, # means "should separate right before 'a'". }) @@ -142,4 +144,4 @@ def test_translate_html_string(self): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/test_train.py b/tests/test_train.py index ddcd47a4..1028d13c 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -17,12 +17,15 @@ import sys import unittest from pathlib import Path + import numpy as np +import numpy.typing as npt +# module hack LIB_PATH = os.path.join(os.path.dirname(__file__), '..') sys.path.insert(0, os.path.abspath(LIB_PATH)) -from scripts import train +from scripts import train # type: ignore # noqa (module hack) ENTRIES_FILE_PATH = os.path.abspath( os.path.join(os.path.dirname(__file__), 'entries_test.txt')) @@ -34,7 +37,7 @@ class TestTrain(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: Path(WEIGHTS_FILE_PATH).touch() Path(LOG_FILE_PATH).touch() with open(ENTRIES_FILE_PATH, 'w') as f: @@ -45,14 +48,14 @@ def setUp(self): '-1\tA\n' ' 1\tA\tC\n')) - def test_pred(self): - X = np.array([ + def test_pred(self) -> None: + X: npt.NDArray[np.bool_] = np.array([ [True, False, True, False], [False, True, False, True], ]) phis = { - 1: 8, # Weights Feature #1 by 8. - 2: 2, # Weights Feature #2 by 2. + 1: 8.0, # Weights Feature #1 by 8. + 2: 2.0, # Weights Feature #2 by 2. } # Since Feature #1 (= the 2nd col in X) wins, the prediction should be: # [ @@ -62,7 +65,7 @@ def test_pred(self): pred = train.pred(phis, X) self.assertListEqual(pred.tolist(), [False, True]) - def test_preprocess(self): + def test_preprocess(self) -> None: freq_thres = 0 X, Y, features = train.preprocess(ENTRIES_FILE_PATH, freq_thres) self.assertListEqual(features, ['A', 'C', 'B'], @@ -115,7 +118,7 @@ def test_preprocess(self): True, ], 'Y should represent the entry labels even filtered.') - def test_split_dataset(self): + def test_split_dataset(self) -> None: N = 100 X = np.random.rand(N, 2) Y = np.arange(N) @@ -128,16 +131,16 @@ def test_split_dataset(self): self.assertAlmostEqual(Y_train.shape[0], N * split_ratio) self.assertAlmostEqual(Y_test.shape[0], N * (1 - split_ratio)) - def test_fit(self): + def test_fit(self) -> None: # Prepare a dataset that the 2nd feature (= the 2nd col in X) perfectly # correlates with Y in a negative way. - X = np.array([ + X: npt.NDArray[np.bool_] = np.array([ [False, True, True, False], [True, True, False, True], [False, False, True, False], [True, False, False, True], ]) - Y = np.array([ + Y: npt.NDArray[np.bool_] = np.array([ False, False, True, @@ -152,7 +155,7 @@ def test_fit(self): self.assertEqual( top_feature, 'b', msg='The most effective feature should be selected.') - def tearDown(self): + def tearDown(self) -> None: os.remove(WEIGHTS_FILE_PATH) os.remove(LOG_FILE_PATH) os.remove(ENTRIES_FILE_PATH)