diff --git a/scripts/tests/test_train.py b/scripts/tests/test_train.py index cfd2b09e..8046f62c 100644 --- a/scripts/tests/test_train.py +++ b/scripts/tests/test_train.py @@ -16,6 +16,7 @@ import math import os import sys +import tempfile import typing import unittest @@ -73,11 +74,10 @@ def test_cmdargs_full(self) -> None: class TestPreprocess(unittest.TestCase): - ENTRIES_FILE_PATH = os.path.abspath( - os.path.join(os.path.dirname(__file__), 'entries_test.txt')) def test_standard_setup(self) -> None: - with open(self.ENTRIES_FILE_PATH, 'w') as f: + entries_file_path = tempfile.NamedTemporaryFile().name + with open(entries_file_path, 'w') as f: f.write(('1\tfoo\tbar\n' '-1\tfoo\n' '1\tfoo\tbar\tbaz\n' @@ -90,14 +90,16 @@ def test_standard_setup(self) -> None: # 1 1 1 1 # 1 1 1 0 # -1 0 0 1 - rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 1) + rows, cols, Y, features = train.preprocess(entries_file_path, 1) self.assertEqual(features, ['foo', 'bar', 'baz']) self.assertEqual(Y.tolist(), [True, False, True, True, False]) self.assertEqual(rows.tolist(), [0, 0, 1, 2, 2, 2, 3, 3, 4]) self.assertEqual(cols.tolist(), [0, 1, 0, 0, 1, 2, 1, 0, 2]) + os.remove(entries_file_path) def test_skip_invalid_rows(self) -> None: - with open(self.ENTRIES_FILE_PATH, 'w') as f: + entries_file_path = tempfile.NamedTemporaryFile().name + with open(entries_file_path, 'w') as f: f.write(('\n1\tfoo\tbar\n' '-1\n\n' '-1\tfoo\n\n')) @@ -105,15 +107,12 @@ def test_skip_invalid_rows(self) -> None: # Y X(foo bar) # 1 1 1 # -1 1 0 - rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 0) + rows, cols, Y, features = train.preprocess(entries_file_path, 0) self.assertEqual(features, ['foo', 'bar']) self.assertEqual(Y.tolist(), [True, False]) self.assertEqual(rows.tolist(), [0, 0, 1]) self.assertEqual(cols.tolist(), [0, 1, 0]) - - def tearDown(self) -> None: - if (os.path.exists(self.ENTRIES_FILE_PATH)): - os.remove(self.ENTRIES_FILE_PATH) + os.remove(entries_file_path) class TestSplitData(unittest.TestCase): @@ -205,12 +204,10 @@ def test_standard_setup1(self) -> None: class TestFit(unittest.TestCase): - WEIGHTS_FILE_PATH = os.path.abspath( - os.path.join(os.path.dirname(__file__), 'weights_test.txt')) - LOG_FILE_PATH = os.path.abspath( - os.path.join(os.path.dirname(__file__), 'train_test.log')) def test_fit(self) -> None: + weights_file_path = tempfile.NamedTemporaryFile().name + log_file_path = tempfile.NamedTemporaryFile().name # Prepare a dataset that the 2nd feature (= the 2nd col in X) perfectly # correlates with Y in a negative way. X = np.array([ @@ -225,8 +222,8 @@ def test_fit(self) -> None: iters = 5 out_span = 2 scores = train.fit(rows, cols, rows, cols, Y, Y, features, iters, - self.WEIGHTS_FILE_PATH, self.LOG_FILE_PATH, out_span) - with open(self.WEIGHTS_FILE_PATH) as f: + weights_file_path, log_file_path, out_span) + with open(weights_file_path) as f: weights = [ line.split('\t') for line in f.read().splitlines() if line.strip() ] @@ -238,7 +235,7 @@ def test_fit(self) -> None: iters, msg='The number of lines should equal to the iteration count.') - with open(self.LOG_FILE_PATH) as f: + with open(log_file_path) as f: log = [line.split('\t') for line in f.read().splitlines() if line.strip()] self.assertEqual( len(log), @@ -257,10 +254,8 @@ def test_fit(self) -> None: self.assertEqual(scores.shape[0], len(features)) loaded_scores = [model.get(feature, 0) for feature in features] self.assertTrue(np.all(np.isclose(scores, loaded_scores))) - - def tearDown(self) -> None: - os.remove(self.WEIGHTS_FILE_PATH) - os.remove(self.LOG_FILE_PATH) + os.remove(weights_file_path) + os.remove(log_file_path) if __name__ == '__main__':