From e78281e81a16bed323219d1243fcd1bb0ee8514d Mon Sep 17 00:00:00 2001 From: Shuhei Iitsuka Date: Wed, 21 Dec 2022 16:34:30 +0900 Subject: [PATCH 1/2] Faster training with sparse matrix --- scripts/tests/test_train.py | 322 +++++++++++++++++----------------- scripts/train.py | 332 ++++++++++++++++++------------------ setup.cfg | 1 + 3 files changed, 320 insertions(+), 335 deletions(-) diff --git a/scripts/tests/test_train.py b/scripts/tests/test_train.py index 7a250f44..235c8727 100644 --- a/scripts/tests/test_train.py +++ b/scripts/tests/test_train.py @@ -16,11 +16,11 @@ import math import os import sys +import typing import unittest -from pathlib import Path import numpy as np -import numpy.typing as npt +from jax import numpy as jnp # module hack LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..') @@ -28,13 +28,6 @@ from scripts import train # type: ignore # noqa (module hack) -ENTRIES_FILE_PATH = os.path.abspath( - os.path.join(os.path.dirname(__file__), 'entries_test.txt')) -WEIGHTS_FILE_PATH = os.path.abspath( - os.path.join(os.path.dirname(__file__), 'weights_test.txt')) -LOG_FILE_PATH = os.path.abspath( - os.path.join(os.path.dirname(__file__), 'train_test.log')) - class TestArgParse(unittest.TestCase): @@ -64,12 +57,11 @@ def test_cmdargs_default(self) -> None: self.assertEqual(output.feature_thres, train.DEFAULT_FEATURE_THRES) self.assertEqual(output.iter, train.DEFAULT_ITERATION) self.assertEqual(output.out_span, train.DEFAULT_OUT_SPAN) - self.assertEqual(output.chunk_size, None) def test_cmdargs_full(self) -> None: cmdargs = [ 'encoded.txt', '-o', 'out.txt', '--log', 'foo.log', '--feature-thres', - '100', '--iter', '10', '--chunk-size', '1000', '--out-span', '50' + '100', '--iter', '10', '--out-span', '50' ] output = train.parse_args(cmdargs) self.assertEqual(output.encoded_train_data, 'encoded.txt') @@ -77,173 +69,164 @@ def test_cmdargs_full(self) -> None: self.assertEqual(output.log, 'foo.log') self.assertEqual(output.feature_thres, 100) self.assertEqual(output.iter, 10) - self.assertEqual(output.chunk_size, 1000) self.assertEqual(output.out_span, 50) -class TestTrain(unittest.TestCase): +class TestPreprocess(unittest.TestCase): + ENTRIES_FILE_PATH = os.path.abspath( + os.path.join(os.path.dirname(__file__), 'entries_test.txt')) - def setUp(self) -> None: - Path(WEIGHTS_FILE_PATH).touch() - Path(LOG_FILE_PATH).touch() - with open(ENTRIES_FILE_PATH, 'w') as f: - f.write(( - ' 1\tA\tC\n' # the first column represents the label (-1 / 1). - '-1\tA\tB\n' # the rest columns represents the associated features. - ' 1\tA\tC\n' - '-1\tA\n' - ' 1\tA\tC\n')) + def test_standard_setup(self) -> None: + with open(self.ENTRIES_FILE_PATH, 'w') as f: + f.write(('1\tfoo\tbar\n' + '-1\tfoo\n' + '1\tfoo\tbar\tbaz\n' + '1\tbar\tfoo\n' + '-1\tbaz\tqux\n')) + # The input matrix X and the target vector Y should look like below now: + # Y X(foo bar baz BIAS) + # 1 1 1 0 1 + # -1 1 0 0 1 + # 1 1 1 1 1 + # 1 1 1 0 1 + # -1 0 0 1 1 + rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 1) + self.assertEqual(features, ['foo', 'bar', 'baz']) + self.assertEqual(Y.tolist(), [True, False, True, True, False]) + self.assertEqual(rows.tolist(), [0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4]) + self.assertEqual(cols.tolist(), [0, 1, 3, 0, 3, 0, 1, 2, 3, 1, 0, 3, 2, 3]) - def test_pred(self) -> None: - X: npt.NDArray[np.bool_] = np.array([ - [True, False, True, False], - [False, True, False, True], - ]) - phis = { - 1: 8.0, # Weights Feature #1 by 8. - 2: 2.0, # Weights Feature #2 by 2. - } - # Since Feature #1 (= the 2nd col in X) wins, the prediction should be: - # [ - # False, - # True, - # ] - pred = train.pred(phis, X) - self.assertListEqual(pred.tolist(), [False, True]) - - def test_preprocess(self) -> None: - freq_thres = 0 - X, Y, features = train.preprocess(ENTRIES_FILE_PATH, freq_thres) - self.assertListEqual(features, ['A', 'C', 'B'], - 'Features should be ordered by frequency.') - - self.assertListEqual( - X.tolist(), - [ - # A C B BIAS - [True, True, False, True], - [True, False, True, True], - [True, True, False, True], - [True, False, False, True], - [True, True, False, True], - ], - 'X should represent the entry features with a bias column.') - - self.assertListEqual(Y.tolist(), [ - True, - False, - True, - False, - True, - ], 'Y should represent the entry labels.') - - freq_thres = 4 - X, Y, features = train.preprocess(ENTRIES_FILE_PATH, freq_thres) - self.assertListEqual( - features, ['A'], - 'Features with smaller frequency than the threshold should be filtered.' - ) + def test_skip_invalid_rows(self) -> None: + with open(self.ENTRIES_FILE_PATH, 'w') as f: + f.write(('\n1\tfoo\tbar\n' + '-1\n\n' + '-1\tfoo\n\n')) + # The input matrix X and the target vector Y should look like below now: + # Y X(foo bar BIAS) + # 1 1 1 1 + # -1 1 0 1 + rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 0) + self.assertEqual(features, ['foo', 'bar']) + self.assertEqual(Y.tolist(), [True, False]) + self.assertEqual(rows.tolist(), [0, 0, 0, 1, 1]) + self.assertEqual(cols.tolist(), [0, 1, 2, 0, 2]) - self.assertListEqual( - X.tolist(), - [ - # A BIAS - [True, True], - [True, True], - [True, True], - [True, True], - [True, True], - ], - 'X should represent the filtered entry features with a bias column.') - - self.assertListEqual( - Y.tolist(), [ - True, - False, - True, - False, - True, - ], 'Y should represent the entry labels even some labels are filtered.') - - def test_split_dataset(self) -> None: - N = 100 - X = np.random.rand(N, 2) - Y = np.arange(N) - split_ratio = .8 - X_train, X_test, Y_train, Y_test = train.split_dataset(X, Y, split_ratio) - self.assertAlmostEqual(X_train.shape[0], N * split_ratio) - self.assertAlmostEqual(X_test.shape[0], N * (1 - split_ratio)) - self.assertAlmostEqual(X_train.shape[1], 2) - self.assertAlmostEqual(X_test.shape[1], 2) - self.assertAlmostEqual(Y_train.shape[0], N * split_ratio) - self.assertAlmostEqual(Y_test.shape[0], N * (1 - split_ratio)) + def tearDown(self) -> None: + if (os.path.exists(self.ENTRIES_FILE_PATH)): + os.remove(self.ENTRIES_FILE_PATH) - def test_fit(self) -> None: - # Prepare a dataset that the 2nd feature (= the 2nd col in X) perfectly - # correlates with Y in a negative way. - X: npt.NDArray[np.bool_] = np.array([ - [False, True, True, False], - [True, True, False, True], - [False, False, True, False], - [True, False, False, True], + +class TestSplitData(unittest.TestCase): + + def test_standard_setup(self) -> None: + split_ratio = 0.6 + X = np.array([ + [0, 1, 0], + [1, 0, 0], + [1, 0, 1], + [0, 1, 1], + [0, 1, 0], ]) - Y: npt.NDArray[np.bool_] = np.array([ - False, - False, - True, - True, + Y = np.array([0, 1, 0, 1, 0], dtype=bool) + rows, cols = np.where(X == 1) + rows_train, cols_train, rows_test, cols_test, Y_train, Y_test = train.split_data( + rows, cols, Y, split_ratio) + self.assertEqual(rows_train.tolist(), [0, 1, 2, 2]) + self.assertEqual(cols_train.tolist(), [1, 0, 0, 2]) + self.assertEqual(rows_test.tolist(), [0, 0, 1]) + self.assertEqual(cols_test.tolist(), [1, 2, 1]) + self.assertEqual(Y_train.tolist(), [0, 1, 0]) + self.assertEqual(Y_test.tolist(), [1, 0]) + + +class TestPred(unittest.TestCase): + + def test_standard_setup(self) -> None: + X = np.array([ + [1, 1, 0], + [1, 0, 1], + [0, 1, 0], + [0, 0, 1], ]) - features = ['a', 'b', 'c'] - iters = 5 - out_span = 2 - train.fit(X, Y, X, Y, features, iters, WEIGHTS_FILE_PATH, LOG_FILE_PATH, - out_span) - with open(WEIGHTS_FILE_PATH) as f: - weights = [ - line.split('\t') for line in f.read().splitlines() if line.strip() - ] - top_feature = weights[0][0] - self.assertEqual( - top_feature, 'b', msg='The most effective feature should be selected.') - self.assertEqual( - len(weights), - iters, - msg='The number of lines should equal to the iteration count.') + phis = np.array([0.4, 0.2, -0.3]) + N = X.shape[0] + rows, cols = np.where(X == 1) + res = train.pred(phis, rows, cols, N) + expected = [ + 0.4 + 0.2 - (-0.3) > 0, + 0.4 - 0.2 + (-0.3) > 0, + -0.4 + 0.2 - (-0.3) > 0, + -0.4 - 0.2 + (-0.3) > 0, + ] + self.assertEqual(res.tolist(), expected) - with open(LOG_FILE_PATH) as f: - log = [line.split('\t') for line in f.read().splitlines() if line.strip()] - self.assertEqual( - len(log), - math.ceil(iters / out_span) + 1, - msg='The number of lines should equal to the ceil of iteration / out_span plus one for the header' - ) - self.assertEqual( - len(set(len(line) for line in log)), - 1, - msg='The header and the body should have the same number of columns.') - def test_fit_chunk(self) -> None: +class TestGetMetrics(unittest.TestCase): + + def test_standard_setup(self) -> None: + pred = np.array([0, 0, 1, 0, 0], dtype=bool) + target = np.array([1, 0, 1, 1, 1], dtype=bool) + result = train.get_metrics(pred, target) + self.assertEqual(result.tp, 1) + self.assertEqual(result.tn, 1) + self.assertEqual(result.fp, 0) + self.assertEqual(result.fn, 3) + self.assertEqual(result.accuracy, 2 / 5) + p = 1 / 1 + r = 1 / 4 + self.assertEqual(result.precision, p) + self.assertEqual(result.recall, r) + self.assertEqual(result.fscore, 2 * p * r / (p + r)) + + +class TestUpdateWeights(unittest.TestCase): + X = np.array([ + [1, 0, 1, 0], + [0, 1, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [0, 1, 1, 0], + ]) + + def test_standard_setup1(self) -> None: + rows, cols = np.where(self.X == 1) + M = self.X.shape[-1] + Y = np.array([1, 1, 0, 0, 1], dtype=bool) + w = np.array([0.1, 0.3, 0.1, 0.1, 0.4]) + scores = jnp.zeros(M) + new_w, new_scores, best_feature_index, added_score = train.update_weights( + w, rows, cols, Y, scores, M) + self.assertFalse(w.argmax() == 0) + self.assertTrue(new_w.argmax() == 0) + self.assertFalse(scores.argmax() == 1) + self.assertTrue(new_scores.argmax() == 1) + self.assertEqual(best_feature_index, 1) + self.assertTrue(added_score > 0) + + +class TestFit(unittest.TestCase): + WEIGHTS_FILE_PATH = os.path.abspath( + os.path.join(os.path.dirname(__file__), 'weights_test.txt')) + LOG_FILE_PATH = os.path.abspath( + os.path.join(os.path.dirname(__file__), 'train_test.log')) + + def test_fit(self) -> None: # Prepare a dataset that the 2nd feature (= the 2nd col in X) perfectly # correlates with Y in a negative way. - X: npt.NDArray[np.bool_] = np.array([ - [False, True, True, False], - [True, True, False, True], - [False, False, True, False], - [True, False, False, True], - ]) - Y: npt.NDArray[np.bool_] = np.array([ - False, - False, - True, - True, + X = np.array([ + [0, 1, 1, 1], + [1, 1, 0, 1], + [0, 0, 1, 1], + [1, 0, 0, 1], ]) + Y = np.array([0, 0, 1, 1]) + rows, cols = np.where(X == 1) features = ['a', 'b', 'c'] iters = 5 out_span = 2 - chunk_size = 2 - train.fit(X, Y, X, Y, features, iters, WEIGHTS_FILE_PATH, LOG_FILE_PATH, - out_span, chunk_size) - with open(WEIGHTS_FILE_PATH) as f: + scores = train.fit(rows, cols, rows, cols, Y, Y, features, iters, + self.WEIGHTS_FILE_PATH, self.LOG_FILE_PATH, out_span) + with open(self.WEIGHTS_FILE_PATH) as f: weights = [ line.split('\t') for line in f.read().splitlines() if line.strip() ] @@ -255,7 +238,7 @@ def test_fit_chunk(self) -> None: iters, msg='The number of lines should equal to the iteration count.') - with open(LOG_FILE_PATH) as f: + with open(self.LOG_FILE_PATH) as f: log = [line.split('\t') for line in f.read().splitlines() if line.strip()] self.assertEqual( len(log), @@ -267,13 +250,18 @@ def test_fit_chunk(self) -> None: 1, msg='The header and the body should have the same number of columns.') - train.fit(X, Y, X, Y, features, iters, WEIGHTS_FILE_PATH, LOG_FILE_PATH, - out_span, 2) + model: typing.Dict[str, float] = {} + for weight in weights: + model.setdefault(weight[0], 0) + model[weight[0]] += float(weight[1]) + self.assertEqual(scores.shape[0], len(features) + 1) + loaded_scores = [model.get(feature, 0) for feature in features + ] + [model.get('BIAS', 0)] + self.assertTrue(np.all(np.isclose(scores, loaded_scores))) def tearDown(self) -> None: - os.remove(WEIGHTS_FILE_PATH) - os.remove(LOG_FILE_PATH) - os.remove(ENTRIES_FILE_PATH) + os.remove(self.WEIGHTS_FILE_PATH) + os.remove(self.LOG_FILE_PATH) if __name__ == '__main__': diff --git a/scripts/train.py b/scripts/train.py index 25ad9fcd..73f6c156 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -11,18 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Runs training and exports the learned weights to build a model.""" +"""Runs model training and exports the learned scores to build a model.""" import argparse +import array import typing from collections import Counter from functools import partial from typing import NamedTuple +import jax import jax.numpy as jnp import numpy as np import numpy.typing as npt -from jax import device_put, jit EPS = np.finfo(float).eps # type: np.floating[typing.Any] DEFAULT_OUTPUT_NAME = 'weights.txt' @@ -46,9 +47,11 @@ class Result(NamedTuple): def preprocess( entries_filename: str, feature_thres: int -) -> typing.Tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_], - typing.List[str]]: - """Loads entries and translates them into NumPy arrays. +) -> typing.Tuple[typing.Any, typing.Any, typing.Any, typing.List[str]]: + """Loads entries and translates them into JAX arrays. The boolean matrix of + the input data is represented by row indices and column indices of True values + instead of the matrix itself for memory efficiency, assuming the matrix is + highly sparse. Row and column indices are not guaranteed to be sorted. Args: entries_filename (str): A file path to the entries file. @@ -56,99 +59,112 @@ def preprocess( below the given value. Returns: - X (numpy.ndarray): Input entries. - Y (numpy.ndarray): Output labels. - features (List[str]): Effective features. + A tuple of following items: + - rows (JAX array): Row indices of True values in the input data. + - cols (JAX array): Column indices of True values in the input data. + - Y (JAX array): The target output data. + - features (List[str]): The list of features. """ - with open(entries_filename) as f: - entries = [ - row.strip().split('\t') for row in f.read().splitlines() if row.strip() - ] - print('#entries:\t%d' % (len(entries))) - features_counter: typing.Counter[str] = Counter() - for entry in entries: - features_counter.update(entry[1:]) + N = 0 + X = [] + Y = array.array('B') + with open(entries_filename) as f: + for row in f: + cols = row.strip().split('\t') + if len(cols) < 2: + continue + Y.append(cols[0] == '1') + X.append(cols[1:]) + features_counter.update(cols[1:]) + N += 1 features = [ item[0] for item in features_counter.most_common() if item[1] > feature_thres ] - print('#features:\t%d' % (len(features))) feature_index = dict([(feature, i) for i, feature in enumerate(features)]) - - M = len(features) + 1 - N = len(entries) - Y: npt.NDArray[np.bool_] = np.zeros(N, dtype=bool) - X: npt.NDArray[np.bool_] = np.zeros((N, M), dtype=bool) - - for i, entry in enumerate(entries): - Y[i] = entry[0] == '1' - indices = [feature_index[col] for col in entry[1:] if col in feature_index] - X[i, indices] = True - X[:, -1] = True # add a bias column. - return X, Y, features - - -@jit -def pred(phis: typing.Dict[int, float], - X: npt.NDArray[np.bool_]) -> npt.NDArray[np.bool_]: - """Predicts the output from the given classifiers and input entries. + rows = array.array('I') + cols = array.array('I') # type: ignore + for i, x in enumerate(X): + hit_indices = [feature_index[feat] for feat in x if feat in feature_index] + rows.extend(i for _ in range(len(hit_indices))) + cols.extend(hit_indices) # type: ignore + rows.append(i) + cols.append(len(features)) # type: ignore + return jnp.asarray(rows), jnp.asarray(cols), jnp.asarray( + Y, dtype=bool), features + + +def split_data( + rows: npt.NDArray[np.int64], + cols: npt.NDArray[np.int64], + Y: npt.NDArray[np.bool_], + split_ratio: float = .9 +) -> typing.Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], + npt.NDArray[np.int64], npt.NDArray[np.int64], + npt.NDArray[np.bool_], npt.NDArray[np.bool_]]: + """Splits a dataset into a training dataset and a test dataset. Args: - phis (Dict[int, float]): Classifiers represented as a mapping from the - feature index to its score. - X (numpy.ndarray): Input entries. + rows (numpy.ndarray): Row indices of True values in the input data. + cols (numpy.ndarray): Column indices of True values in the input data. + Y (numpy.ndarray): The target output. + split_ratio (float, optional): The split ratio for the training dataset. + The value should be between 0 and 1. The default value is 0.9 (=90% for + training). Returns: - A list of inferred labels. + A tuple of: + - rows_train (numpy.ndarray): Row indices of True values in the training input data. + - cols_train (numpy.ndarray): Column indices of True values in the training input data. + - rows_test (numpy.ndarray): Row indices of True values in the test input data. + - cols_test (numpy.ndarray): Column indices of True values in the test input data. + - Y_train (numpy.ndarray): The training target output. + - Y_test (numpy.ndarray): The test target output. """ - alphas: npt.NDArray[np.float64] - y: npt.NDArray[np.int64] - - alphas = jnp.array(list(phis.values())) - y = 2 * ( - X[:, list(phis.keys())] == True # noqa (cannot replace `==` with `is`) - ) - 1 - result: npt.NDArray[np.bool_] = y.dot(alphas) > 0 - return result + thres = int(Y.shape[0] * split_ratio) + return (rows[rows < thres], cols[rows < thres], rows[rows >= thres] - thres, + cols[rows >= thres], Y[:thres], Y[thres:]) -def split_dataset( - X: npt.NDArray[typing.Any], - Y: npt.NDArray[typing.Any], - split_ratio: float = 0.9 -) -> typing.Tuple[npt.NDArray[typing.Any], npt.NDArray[typing.Any], - npt.NDArray[typing.Any], npt.NDArray[typing.Any]]: - """Splits given entries and labels to training and testing datasets. +@partial(jax.jit, static_argnums=[3]) +def pred(phis: npt.NDArray[np.float64], rows: npt.NDArray[np.int64], + cols: npt.NDArray[np.int64], N: int) -> npt.NDArray[np.bool_]: + """Predicts the target output from the learned scores and input entries. Args: - X (numpy.ndarray): Entries to split. - Y (numpy.ndarray): Labels to split. - split_ratio (float, optional): The ratio to hold for the training dataset. + phis (numpy.ndarray): Contribution scores of features. + rows (numpy.ndarray): Row indices of True values in the input. + cols (numpy.ndarray): Column indices of True values in the input. + N (int): The number of input entries. Returns: - X_train (numpy.ndarray): Training entries. - X_test (numpy.ndarray): Testing entries. - Y_train (numpy.ndarray): Training labels. - Y_test (numpy.ndarray): Testing labels. + res (numpy.ndarray): A prediction of the target. """ - N, _ = X.shape - np.random.seed(0) - indices = np.random.permutation(N) - X_train = X[indices[:int(N * split_ratio)]] - X_test = X[indices[int(N * split_ratio):]] - Y_train = Y[indices[:int(N * split_ratio)]] - Y_test = Y[indices[int(N * split_ratio):]] - return X_train, X_test, Y_train, Y_test + # This is equivalent to phis.dot(2X - 1) = 2phis.dot(X) - phis.sum() but in a + # sparse matrix-friendly way. + r: npt.NDArray[np.float64] = 2 * jax.ops.segment_sum( + phis.take(cols), rows, N) - phis.sum() + return r > 0 +@jax.jit def get_metrics(pred: npt.NDArray[np.bool_], actual: npt.NDArray[np.bool_]) -> Result: - tp = np.sum(np.logical_and(pred == 1, actual == 1)) - tn = np.sum(np.logical_and(pred == 0, actual == 0)) - fp = np.sum(np.logical_and(pred == 1, actual == 0)) - fn = np.sum(np.logical_and(pred == 0, actual == 1)) + """Gets evaluation metrics from the prediction and the actual target. + + Args: + pred (numpy.ndarray): A prediction of the target. + actual (numpy.ndarray): The actual target. + + Returns: + result (Result): A result. + """ + tp = jnp.sum(jnp.logical_and(pred == 1, actual == 1)) + tn = jnp.sum(jnp.logical_and(pred == 0, actual == 0)) + fp = jnp.sum(jnp.logical_and(pred == 1, actual == 0)) + fn = jnp.sum(jnp.logical_and(pred == 0, actual == 1)) accuracy = (tp + tn) / (tp + tn + fp + fn) precision = tp / (tp + fp) recall = tp / (tp + fn) @@ -164,66 +180,72 @@ def get_metrics(pred: npt.NDArray[np.bool_], ) -@jit -def update_weight( - w: npt.NDArray[np.float64], YX: npt.NDArray[np.bool_] -) -> typing.Tuple[npt.NDArray[np.float64], float, int, bool]: - res = w.dot(YX) - err = 0.5 - jnp.abs(res - 0.5) - m_best = err.argmin() - pol_best = res.at[m_best].get() < 0.5 - err_min = err.at[m_best].get() - alpha = jnp.log((1 - err_min) / (err_min + EPS)) - w = w * jnp.exp(alpha * (YX[:, m_best] == pol_best)) - w = w / w.sum() - return w, alpha, m_best, pol_best +@partial(jax.jit, static_argnums=[5]) +def update_weights(w: npt.NDArray[np.float64], rows: npt.NDArray[np.int64], + cols: npt.NDArray[np.int64], Y: npt.NDArray[np.bool_], + scores: typing.Any, + M: int) -> typing.Tuple[typing.Any, typing.Any, int, float]: + """Calculates the new weight vector from the best feature and its score. + Args: + w (numpy.ndarray): A weight vector. + rows (numpy.ndarray): Row indices of True values in the input data. + cols (numpy.ndarray): Column indices of True values in the input data. + Y (numpy.ndarray): The target output. + scores (JAX array): Contribution scores of features. + M (int): The number of columns in the input data. -@partial(jit, static_argnames=['chunk_size']) -def update_weight_chunk( - w: npt.NDArray[np.float64], YX: npt.NDArray[np.bool_], - chunk_size: int) -> typing.Tuple[npt.NDArray[np.float64], float, int, bool]: - N, M = YX.shape - res = jnp.zeros(M) - for chunk in range(0, N, chunk_size): - res += w[chunk:chunk + chunk_size].dot(YX[chunk:chunk + chunk_size]) + Returns: + A tuple of following items: + - w (numpy.ndarray): The new weight vector. + - scores (JAX array): The new contribution scores. + - best_feature_index (int): The index of the best feature. + - score (float): The newly added score for the best feature. + """ + # This is quivalent to w.dot(Y[:, None] ^ X). Note that y ^ x = y + x - 2yx, + # hence w.dot(y ^ x) = w.dot(y) - w(2y - 1).dot(x). + # `segment_sum` is used to implement sparse matrix-friendly dot products. + res = w.dot(Y) - jax.ops.segment_sum((w * (2 * Y - 1)).take(rows), cols, M) err = 0.5 - jnp.abs(res - 0.5) - m_best = err.argmin() - pol_best = res.at[m_best].get() < 0.5 - err_min = err.at[m_best].get() - alpha = jnp.log((1 - err_min) / (err_min + EPS)) - w = w * jnp.exp(alpha * (YX[:, m_best] == pol_best)) + best_feature_index: int = err.argmin() + positivity: bool = res.at[best_feature_index].get() < 0.5 + err_min = err.at[best_feature_index].get() + amount: float = jnp.log((1 - err_min) / (err_min + EPS)) + N = Y.shape[0] + + # This is equivalent to X_best = X[:, best_feature_index] + X_best = jnp.zeros( + N, dtype=bool).at[jnp.where(cols == best_feature_index, rows, N)].set( + True, mode='drop') + w = w * jnp.exp(amount * (Y ^ X_best == positivity)) w = w / w.sum() - return w, alpha, m_best, pol_best + score = amount * (2 * positivity - 1) + scores = scores.at[best_feature_index].add(score) + return w, scores, best_feature_index, score -def fit(X_train: npt.NDArray[np.bool_], - Y_train: npt.NDArray[np.bool_], - X_test: npt.NDArray[np.bool_], - Y_test: npt.NDArray[np.bool_], - features: typing.List[str], - iters: int, - weights_filename: str, - log_filename: str, - out_span: int, - chunk_size: typing.Optional[int] = None) -> typing.Dict[int, float]: - """Trains an AdaBoost classifier. +def fit(rows_train: npt.NDArray[np.int64], cols_train: npt.NDArray[np.int64], + rows_test: npt.NDArray[np.int64], cols_test: npt.NDArray[np.int64], + Y_train: npt.NDArray[np.bool_], Y_test: npt.NDArray[np.bool_], + features: typing.List[str], iters: int, weights_filename: str, + log_filename: str, out_span: int) -> typing.Any: + """Trains an AdaBoost binary classifier. Args: - X_train (numpy.ndarray): Training entries. - Y_train (numpy.ndarray): Training labels. - X_test (numpy.ndarray): Testing entries. - Y_test (numpy.ndarray): Testing labels. + row_train (numpy.ndarray): Row indices of True values in the training input data. + col_train (numpy.ndarray): Column indices of True values in the training input data. + row_test (numpy.ndarray): Row indices of True values in the test input data. + col_test (numpy.ndarray): Column indices of True values in the test input data. + Y_train (numpy.ndarray): The training target output. + Y_test (numpy.ndarray): The test target output. features (List[str]): Features, which correspond to the columns of entries. iters (int): A number of training iterations. weights_filename (str): A file path to write the learned weights. log_filename (str): A file path to log the accuracy along with training. out_span (int): Iteration span to output metics and weights. - chunk_size (Optional[int]): A chunk size to split training entries for - memory efficiency. Returns: - phi (Dict[int, float]): Learned child classifiers. + scores (Any): The contribution scores. """ with open(weights_filename, 'w') as f: f.write('') @@ -233,35 +255,22 @@ def fit(X_train: npt.NDArray[np.bool_], 'test_accuracy\ttest_precision\ttest_recall\ttest_fscore\n') print('Outputting learned weights to %s ...' % (weights_filename)) - phis: typing.Dict[int, float] = dict() - phi_buffer: typing.List[typing.Tuple[str, float]] = [] - - assert (X_train.shape[1] == X_test.shape[1] - ), 'Training and test entries should have the same number of features.' - assert (X_train.shape[1] - 1 == len(features) - ), 'The training data should have the same number of features + BIAS.' - assert (X_train.shape[0] == Y_train.shape[0] - ), 'Training entries and labels should have the same number of items.' - assert (X_test.shape[0] == Y_test.shape[0] - ), 'Testing entries and labels should have the same number of items.' - - X_train = device_put(X_train) - Y_train = device_put(Y_train) - X_test = device_put(X_test) - Y_test = device_put(Y_test) - N_train, _ = X_train.shape + M = len(features) + 1 + scores = jnp.zeros(M) + feature_score_buffer: typing.List[typing.Tuple[str, float]] = [] + N_train = Y_train.shape[0] + N_test = Y_test.shape[0] w = jnp.ones(N_train) / N_train - YX_train = Y_train[:, None] ^ X_train def output_progress(t: int) -> None: + print('=== %s ===' % t) with open(weights_filename, 'a') as f: - f.write('\n'.join('%s\t%.6f' % p for p in phi_buffer) + '\n') - phi_buffer.clear() - pred_train = pred(phis, X_train) - pred_test = pred(phis, X_test) + f.write('\n'.join('%s\t%.6f' % p for p in feature_score_buffer) + '\n') + feature_score_buffer.clear() + pred_train = pred(scores, rows_train, cols_train, N_train) + pred_test = pred(scores, rows_test, cols_test, N_test) metrics_train = get_metrics(pred_train, Y_train) metrics_test = get_metrics(pred_test, Y_test) - print('=== %s ===' % t) print() print('train accuracy:\t%.5f' % metrics_train.accuracy) print('train prec.:\t%.5f' % metrics_train.precision) @@ -287,23 +296,17 @@ def output_progress(t: int) -> None: )) for t in range(iters): - if chunk_size: - w, alpha, m_best, pol_best = update_weight_chunk(w, YX_train, chunk_size) - else: - w, alpha, m_best, pol_best = update_weight(w, YX_train) + w, scores, best_feature_index, score = update_weights( + w, rows_train, cols_train, Y_train, scores, M) w.block_until_ready() - m_best = int(m_best) - alpha_signed = alpha if pol_best else -alpha - phis.setdefault(m_best, 0) - phis[m_best] += alpha_signed - feature = features[m_best] if m_best < len(features) else 'BIAS' - phi_buffer.append((feature, alpha_signed)) + feature = features[best_feature_index] if ( + best_feature_index < len(features)) else 'BIAS' + feature_score_buffer.append((feature, score)) if (t + 1) % out_span == 0: output_progress(t + 1) - if len(phi_buffer) > 0: + if len(feature_score_buffer) > 0: output_progress(t + 1) - - return phis + return scores def parse_args(test: ArgList = None) -> argparse.Namespace: @@ -345,11 +348,6 @@ def parse_args(test: ArgList = None) -> argparse.Namespace: help=f'Iteration span to output metrics and weights. (default: {DEFAULT_OUT_SPAN})', type=int, default=DEFAULT_OUT_SPAN) - parser.add_argument( - '--chunk-size', - type=int, - help='A chunk size to split training entries for memory efficiency. (default: None)', - default=None) if test is None: return parser.parse_args() else: @@ -358,20 +356,18 @@ def parse_args(test: ArgList = None) -> argparse.Namespace: def main() -> None: args = parse_args() - train_data_filename: str = args.encoded_train_data + data_filename: str = args.encoded_train_data weights_filename: str = args.output log_filename: str = args.log feature_thres = int(args.feature_thres) iterations = int(args.iter) out_span = int(args.out_span) - chunk_size = int(args.chunk_size) if args.chunk_size is not None else None - - X, Y, features = preprocess(train_data_filename, feature_thres) - X_train, X_test, Y_train, Y_test = split_dataset(X, Y) - del X, Y - fit(X_train, Y_train, X_test, Y_test, features, iterations, weights_filename, - log_filename, out_span, chunk_size) + X_rows, X_cols, Y, features = preprocess(data_filename, feature_thres) + X_rows_train, X_cols_train, X_rows_test, X_cols_test, Y_train, Y_test = split_data( + X_rows, X_cols, Y) + fit(X_rows_train, X_cols_train, X_rows_test, X_cols_test, Y_train, Y_test, + features, iterations, weights_filename, log_filename, out_span) print('Training done. Export the model by passing %s to build_model.py' % (weights_filename)) diff --git a/setup.cfg b/setup.cfg index 4b0d2b7f..f3bf4a52 100644 --- a/setup.cfg +++ b/setup.cfg @@ -56,3 +56,4 @@ indent-size = 2 python_version = 3.9 pretty = True strict = True +allow_untyped_calls = True From 94b76be296bd4e7018dcce8e1d6a9b6f51a30904 Mon Sep 17 00:00:00 2001 From: Shuhei Iitsuka Date: Thu, 5 Jan 2023 12:12:23 +0900 Subject: [PATCH 2/2] Style fix --- scripts/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train.py b/scripts/train.py index 73f6c156..d45b732b 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -297,7 +297,7 @@ def output_progress(t: int) -> None: for t in range(iters): w, scores, best_feature_index, score = update_weights( - w, rows_train, cols_train, Y_train, scores, M) + w, rows_train, cols_train, Y_train, scores, M) w.block_until_ready() feature = features[best_feature_index] if ( best_feature_index < len(features)) else 'BIAS'