Skip to content

Commit

Permalink
Use tempfile for unit test (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
tushuhei authored Feb 10, 2023
1 parent 059afc5 commit af59f9f
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions scripts/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import math
import os
import sys
import tempfile
import typing
import unittest

Expand Down Expand Up @@ -73,11 +74,10 @@ def test_cmdargs_full(self) -> None:


class TestPreprocess(unittest.TestCase):
ENTRIES_FILE_PATH = os.path.abspath(
os.path.join(os.path.dirname(__file__), 'entries_test.txt'))

def test_standard_setup(self) -> None:
with open(self.ENTRIES_FILE_PATH, 'w') as f:
entries_file_path = tempfile.NamedTemporaryFile().name
with open(entries_file_path, 'w') as f:
f.write(('1\tfoo\tbar\n'
'-1\tfoo\n'
'1\tfoo\tbar\tbaz\n'
Expand All @@ -90,30 +90,29 @@ def test_standard_setup(self) -> None:
# 1 1 1 1
# 1 1 1 0
# -1 0 0 1
rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 1)
rows, cols, Y, features = train.preprocess(entries_file_path, 1)
self.assertEqual(features, ['foo', 'bar', 'baz'])
self.assertEqual(Y.tolist(), [True, False, True, True, False])
self.assertEqual(rows.tolist(), [0, 0, 1, 2, 2, 2, 3, 3, 4])
self.assertEqual(cols.tolist(), [0, 1, 0, 0, 1, 2, 1, 0, 2])
os.remove(entries_file_path)

def test_skip_invalid_rows(self) -> None:
with open(self.ENTRIES_FILE_PATH, 'w') as f:
entries_file_path = tempfile.NamedTemporaryFile().name
with open(entries_file_path, 'w') as f:
f.write(('\n1\tfoo\tbar\n'
'-1\n\n'
'-1\tfoo\n\n'))
# The input matrix X and the target vector Y should look like below now:
# Y X(foo bar)
# 1 1 1
# -1 1 0
rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 0)
rows, cols, Y, features = train.preprocess(entries_file_path, 0)
self.assertEqual(features, ['foo', 'bar'])
self.assertEqual(Y.tolist(), [True, False])
self.assertEqual(rows.tolist(), [0, 0, 1])
self.assertEqual(cols.tolist(), [0, 1, 0])

def tearDown(self) -> None:
if (os.path.exists(self.ENTRIES_FILE_PATH)):
os.remove(self.ENTRIES_FILE_PATH)
os.remove(entries_file_path)


class TestSplitData(unittest.TestCase):
Expand Down Expand Up @@ -205,12 +204,10 @@ def test_standard_setup1(self) -> None:


class TestFit(unittest.TestCase):
WEIGHTS_FILE_PATH = os.path.abspath(
os.path.join(os.path.dirname(__file__), 'weights_test.txt'))
LOG_FILE_PATH = os.path.abspath(
os.path.join(os.path.dirname(__file__), 'train_test.log'))

def test_fit(self) -> None:
weights_file_path = tempfile.NamedTemporaryFile().name
log_file_path = tempfile.NamedTemporaryFile().name
# Prepare a dataset that the 2nd feature (= the 2nd col in X) perfectly
# correlates with Y in a negative way.
X = np.array([
Expand All @@ -225,8 +222,8 @@ def test_fit(self) -> None:
iters = 5
out_span = 2
scores = train.fit(rows, cols, rows, cols, Y, Y, features, iters,
self.WEIGHTS_FILE_PATH, self.LOG_FILE_PATH, out_span)
with open(self.WEIGHTS_FILE_PATH) as f:
weights_file_path, log_file_path, out_span)
with open(weights_file_path) as f:
weights = [
line.split('\t') for line in f.read().splitlines() if line.strip()
]
Expand All @@ -238,7 +235,7 @@ def test_fit(self) -> None:
iters,
msg='The number of lines should equal to the iteration count.')

with open(self.LOG_FILE_PATH) as f:
with open(log_file_path) as f:
log = [line.split('\t') for line in f.read().splitlines() if line.strip()]
self.assertEqual(
len(log),
Expand All @@ -257,10 +254,8 @@ def test_fit(self) -> None:
self.assertEqual(scores.shape[0], len(features))
loaded_scores = [model.get(feature, 0) for feature in features]
self.assertTrue(np.all(np.isclose(scores, loaded_scores)))

def tearDown(self) -> None:
os.remove(self.WEIGHTS_FILE_PATH)
os.remove(self.LOG_FILE_PATH)
os.remove(weights_file_path)
os.remove(log_file_path)


if __name__ == '__main__':
Expand Down

0 comments on commit af59f9f

Please sign in to comment.