Skip to content

Commit

Permalink
Take split_dataset out from fit (#42)
Browse files Browse the repository at this point in the history
* Take split_dataset out from fit

Change-Id: Ifc92c4c82208117b760f70f8b179e58cee9e4d3f

* Edit flake8 rule to ignore closing bracket indentation

Change-Id: Icc45432c18b8671751817c58c2e12f615560eb41
  • Loading branch information
tushuhei authored Mar 28, 2022
1 parent 73eb614 commit 5172e0a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
27 changes: 21 additions & 6 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,10 @@ def split_dataset(
return X_train, X_test, Y_train, Y_test


def fit(X: npt.NDArray[np.bool_],
Y: npt.NDArray[np.bool_],
def fit(X_train: npt.NDArray[np.bool_],
Y_train: npt.NDArray[np.bool_],
X_test: npt.NDArray[np.bool_],
Y_test: npt.NDArray[np.bool_],
features: typing.List[str],
iters: int,
weights_filename: str,
Expand All @@ -128,8 +130,10 @@ def fit(X: npt.NDArray[np.bool_],
"""Trains an AdaBoost classifier.
Args:
X (numpy.ndarray): Training entries.
Y (numpy.ndarray): Training labels.
X_train (numpy.ndarray): Training entries.
Y_train (numpy.ndarray): Training labels.
X_test (numpy.ndarray): Testing entries.
Y_test (numpy.ndarray): Testing labels.
features (List[str]): Features, which correspond to the columns of entries.
iters (int): A number of training iterations.
weights_filename (str): A file path to write the learned weights.
Expand All @@ -147,7 +151,16 @@ def fit(X: npt.NDArray[np.bool_],
print('Outputting learned weights to %s ...' % (weights_filename))

phis: typing.Dict[int, float] = dict()
X_train, X_test, Y_train, Y_test = split_dataset(X, Y)

assert (X_train.shape[1] == X_test.shape[1]
), 'Training and test entries should have the same number of features.'
assert (X_train.shape[1] - 1 == len(features)
), 'The training data should have the same number of features + BIAS.'
assert (X_train.shape[0] == Y_train.shape[0]
), 'Training entries and labels should have the same number of items.'
assert (X_test.shape[0] == Y_test.shape[0]
), 'Testing entries and labels should have the same number of items.'

N_train, M_train = X_train.shape
w = np.ones(N_train) / N_train

Expand Down Expand Up @@ -227,7 +240,9 @@ def main() -> None:
chunk_size = int(args.chunk_size) if args.chunk_size is not None else None

X, Y, features = preprocess(train_data_filename, feature_thres)
fit(X, Y, features, iterations, weights_filename, log_filename, chunk_size)
X_train, X_test, Y_train, Y_test = split_dataset(X, Y)
fit(X_train, Y_train, X_test, Y_test, features, iterations, weights_filename,
log_filename, chunk_size)

print('Training done. Export the model by passing %s to build_model.py' %
(weights_filename))
Expand Down
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ console_scripts =
based_on_style = yapf

[flake8]
# E501: line too long
# E124: closing bracket does not match visual indentation
# E126: over-indentation
# E501: line too long
# BLK100: black formattable
ignore = E126,E501,BLK100
ignore = E124,E126,E501,BLK100
indent-size = 2

[mypy]
Expand Down
19 changes: 10 additions & 9 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def setUp(self) -> None:
with open(ENTRIES_FILE_PATH, 'w') as f:
f.write((
' 1\tA\tC\n' # the first column represents the label (-1 / 1).
'-1\tA\tB\n' # the rest cols represents the associated features.
'-1\tA\tB\n' # the rest columns represents the associated features.
' 1\tA\tC\n'
'-1\tA\n'
' 1\tA\tC\n'))
Expand Down Expand Up @@ -110,13 +110,14 @@ def test_preprocess(self) -> None:
],
'X should represent the filtered entry features with a bias column.')

self.assertListEqual(Y.tolist(), [
True,
False,
True,
False,
True,
], 'Y should represent the entry labels even filtered.')
self.assertListEqual(
Y.tolist(), [
True,
False,
True,
False,
True,
], 'Y should represent the entry labels even some labels are filtered.')

def test_split_dataset(self) -> None:
N = 100
Expand Down Expand Up @@ -148,7 +149,7 @@ def test_fit(self) -> None:
])
features = ['a', 'b', 'c']
iters = 1
train.fit(X, Y, features, iters, WEIGHTS_FILE_PATH, LOG_FILE_PATH)
train.fit(X, Y, X, Y, features, iters, WEIGHTS_FILE_PATH, LOG_FILE_PATH)
with open(WEIGHTS_FILE_PATH) as f:
weights = f.read().splitlines()
top_feature = weights[0].split('\t')[0]
Expand Down

0 comments on commit 5172e0a

Please sign in to comment.