diff --git a/scripts/tests/test_train.py b/scripts/tests/test_train.py index 235c8727..cfd2b09e 100644 --- a/scripts/tests/test_train.py +++ b/scripts/tests/test_train.py @@ -84,17 +84,17 @@ def test_standard_setup(self) -> None: '1\tbar\tfoo\n' '-1\tbaz\tqux\n')) # The input matrix X and the target vector Y should look like below now: - # Y X(foo bar baz BIAS) - # 1 1 1 0 1 - # -1 1 0 0 1 - # 1 1 1 1 1 - # 1 1 1 0 1 - # -1 0 0 1 1 + # Y X(foo bar baz) + # 1 1 1 0 + # -1 1 0 0 + # 1 1 1 1 + # 1 1 1 0 + # -1 0 0 1 rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 1) self.assertEqual(features, ['foo', 'bar', 'baz']) self.assertEqual(Y.tolist(), [True, False, True, True, False]) - self.assertEqual(rows.tolist(), [0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4]) - self.assertEqual(cols.tolist(), [0, 1, 3, 0, 3, 0, 1, 2, 3, 1, 0, 3, 2, 3]) + self.assertEqual(rows.tolist(), [0, 0, 1, 2, 2, 2, 3, 3, 4]) + self.assertEqual(cols.tolist(), [0, 1, 0, 0, 1, 2, 1, 0, 2]) def test_skip_invalid_rows(self) -> None: with open(self.ENTRIES_FILE_PATH, 'w') as f: @@ -102,14 +102,14 @@ def test_skip_invalid_rows(self) -> None: '-1\n\n' '-1\tfoo\n\n')) # The input matrix X and the target vector Y should look like below now: - # Y X(foo bar BIAS) - # 1 1 1 1 - # -1 1 0 1 + # Y X(foo bar) + # 1 1 1 + # -1 1 0 rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 0) self.assertEqual(features, ['foo', 'bar']) self.assertEqual(Y.tolist(), [True, False]) - self.assertEqual(rows.tolist(), [0, 0, 0, 1, 1]) - self.assertEqual(cols.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(rows.tolist(), [0, 0, 1]) + self.assertEqual(cols.tolist(), [0, 1, 0]) def tearDown(self) -> None: if (os.path.exists(self.ENTRIES_FILE_PATH)): @@ -179,7 +179,7 @@ def test_standard_setup(self) -> None: self.assertEqual(result.fscore, 2 * p * r / (p + r)) -class TestUpdateWeights(unittest.TestCase): +class TestUpdate(unittest.TestCase): X = np.array([ [1, 0, 1, 0], [0, 1, 0, 0], @@ -194,8 +194,8 @@ def test_standard_setup1(self) -> None: Y = np.array([1, 1, 0, 0, 1], dtype=bool) w = np.array([0.1, 0.3, 0.1, 0.1, 0.4]) scores = jnp.zeros(M) - new_w, new_scores, best_feature_index, added_score = train.update_weights( - w, rows, cols, Y, scores, M) + new_w, new_scores, best_feature_index, added_score = train.update( + w, scores, rows, cols, Y) self.assertFalse(w.argmax() == 0) self.assertTrue(new_w.argmax() == 0) self.assertFalse(scores.argmax() == 1) @@ -254,9 +254,8 @@ def test_fit(self) -> None: for weight in weights: model.setdefault(weight[0], 0) model[weight[0]] += float(weight[1]) - self.assertEqual(scores.shape[0], len(features) + 1) - loaded_scores = [model.get(feature, 0) for feature in features - ] + [model.get('BIAS', 0)] + self.assertEqual(scores.shape[0], len(features)) + loaded_scores = [model.get(feature, 0) for feature in features] self.assertTrue(np.all(np.isclose(scores, loaded_scores))) def tearDown(self) -> None: diff --git a/scripts/train.py b/scripts/train.py index d45b732b..3c548400 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -66,7 +66,6 @@ def preprocess( - features (List[str]): The list of features. """ features_counter: typing.Counter[str] = Counter() - N = 0 X = [] Y = array.array('B') with open(entries_filename) as f: @@ -77,7 +76,6 @@ def preprocess( Y.append(cols[0] == '1') X.append(cols[1:]) features_counter.update(cols[1:]) - N += 1 features = [ item[0] for item in features_counter.most_common() @@ -90,19 +88,17 @@ def preprocess( hit_indices = [feature_index[feat] for feat in x if feat in feature_index] rows.extend(i for _ in range(len(hit_indices))) cols.extend(hit_indices) # type: ignore - rows.append(i) - cols.append(len(features)) # type: ignore return jnp.asarray(rows), jnp.asarray(cols), jnp.asarray( Y, dtype=bool), features def split_data( - rows: npt.NDArray[np.int64], - cols: npt.NDArray[np.int64], + rows: npt.NDArray[np.int32], + cols: npt.NDArray[np.int32], Y: npt.NDArray[np.bool_], split_ratio: float = .9 -) -> typing.Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], - npt.NDArray[np.int64], npt.NDArray[np.int64], +) -> typing.Tuple[npt.NDArray[np.int32], npt.NDArray[np.int32], + npt.NDArray[np.int32], npt.NDArray[np.int32], npt.NDArray[np.bool_], npt.NDArray[np.bool_]]: """Splits a dataset into a training dataset and a test dataset. @@ -129,12 +125,12 @@ def split_data( @partial(jax.jit, static_argnums=[3]) -def pred(phis: npt.NDArray[np.float64], rows: npt.NDArray[np.int64], - cols: npt.NDArray[np.int64], N: int) -> npt.NDArray[np.bool_]: +def pred(scores: npt.NDArray[np.float32], rows: npt.NDArray[np.int32], + cols: npt.NDArray[np.int32], N: int) -> npt.NDArray[np.bool_]: """Predicts the target output from the learned scores and input entries. Args: - phis (numpy.ndarray): Contribution scores of features. + scores (numpy.ndarray): Contribution scores of features. rows (numpy.ndarray): Row indices of True values in the input. cols (numpy.ndarray): Column indices of True values in the input. N (int): The number of input entries. @@ -142,10 +138,10 @@ def pred(phis: npt.NDArray[np.float64], rows: npt.NDArray[np.int64], Returns: res (numpy.ndarray): A prediction of the target. """ - # This is equivalent to phis.dot(2X - 1) = 2phis.dot(X) - phis.sum() but in a - # sparse matrix-friendly way. - r: npt.NDArray[np.float64] = 2 * jax.ops.segment_sum( - phis.take(cols), rows, N) - phis.sum() + # This is equivalent to scores.dot(2X - 1) = 2 * scores.dot(X) - scores.sum() + # but in a sparse matrix-friendly way. + r: npt.NDArray[np.float32] = 2 * jax.ops.segment_sum( + scores.take(cols), rows, N) - scores.sum() return r > 0 @@ -180,20 +176,20 @@ def get_metrics(pred: npt.NDArray[np.bool_], ) -@partial(jax.jit, static_argnums=[5]) -def update_weights(w: npt.NDArray[np.float64], rows: npt.NDArray[np.int64], - cols: npt.NDArray[np.int64], Y: npt.NDArray[np.bool_], - scores: typing.Any, - M: int) -> typing.Tuple[typing.Any, typing.Any, int, float]: - """Calculates the new weight vector from the best feature and its score. +@jax.jit +def update( + w: npt.NDArray[np.float32], scores: typing.Any, rows: npt.NDArray[np.int32], + cols: npt.NDArray[np.int32], Y: npt.NDArray[np.bool_] +) -> typing.Tuple[typing.Any, typing.Any, int, float]: + """Calculates the new weight vector and the contribution scores. Args: w (numpy.ndarray): A weight vector. + scores (JAX array): Contribution scores of features. rows (numpy.ndarray): Row indices of True values in the input data. cols (numpy.ndarray): Column indices of True values in the input data. Y (numpy.ndarray): The target output. - scores (JAX array): Contribution scores of features. - M (int): The number of columns in the input data. + Returns: A tuple of following items: @@ -202,6 +198,8 @@ def update_weights(w: npt.NDArray[np.float64], rows: npt.NDArray[np.int64], - best_feature_index (int): The index of the best feature. - score (float): The newly added score for the best feature. """ + N = w.shape[0] + M = scores.shape[0] # This is quivalent to w.dot(Y[:, None] ^ X). Note that y ^ x = y + x - 2yx, # hence w.dot(y ^ x) = w.dot(y) - w(2y - 1).dot(x). # `segment_sum` is used to implement sparse matrix-friendly dot products. @@ -211,7 +209,6 @@ def update_weights(w: npt.NDArray[np.float64], rows: npt.NDArray[np.int64], positivity: bool = res.at[best_feature_index].get() < 0.5 err_min = err.at[best_feature_index].get() amount: float = jnp.log((1 - err_min) / (err_min + EPS)) - N = Y.shape[0] # This is equivalent to X_best = X[:, best_feature_index] X_best = jnp.zeros( @@ -224,8 +221,8 @@ def update_weights(w: npt.NDArray[np.float64], rows: npt.NDArray[np.int64], return w, scores, best_feature_index, score -def fit(rows_train: npt.NDArray[np.int64], cols_train: npt.NDArray[np.int64], - rows_test: npt.NDArray[np.int64], cols_test: npt.NDArray[np.int64], +def fit(rows_train: npt.NDArray[np.int32], cols_train: npt.NDArray[np.int32], + rows_test: npt.NDArray[np.int32], cols_test: npt.NDArray[np.int32], Y_train: npt.NDArray[np.bool_], Y_test: npt.NDArray[np.bool_], features: typing.List[str], iters: int, weights_filename: str, log_filename: str, out_span: int) -> typing.Any: @@ -255,7 +252,7 @@ def fit(rows_train: npt.NDArray[np.int64], cols_train: npt.NDArray[np.int64], 'test_accuracy\ttest_precision\ttest_recall\ttest_fscore\n') print('Outputting learned weights to %s ...' % (weights_filename)) - M = len(features) + 1 + M = len(features) scores = jnp.zeros(M) feature_score_buffer: typing.List[typing.Tuple[str, float]] = [] N_train = Y_train.shape[0] @@ -296,11 +293,10 @@ def output_progress(t: int) -> None: )) for t in range(iters): - w, scores, best_feature_index, score = update_weights( - w, rows_train, cols_train, Y_train, scores, M) + w, scores, best_feature_index, score = update(w, scores, rows_train, + cols_train, Y_train) w.block_until_ready() - feature = features[best_feature_index] if ( - best_feature_index < len(features)) else 'BIAS' + feature = features[best_feature_index] feature_score_buffer.append((feature, score)) if (t + 1) % out_span == 0: output_progress(t + 1)