Skip to content

Commit

Permalink
Cleanup the training script (#105)
Browse files Browse the repository at this point in the history
* Documentation cleanup

* Use 32bit precision in type hints for JAX compatibility

* Remove the BIAS factor

* Simpler interface for training
  • Loading branch information
tushuhei authored Jan 5, 2023
1 parent 449b150 commit 0533463
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 50 deletions.
37 changes: 18 additions & 19 deletions scripts/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,32 +84,32 @@ def test_standard_setup(self) -> None:
'1\tbar\tfoo\n'
'-1\tbaz\tqux\n'))
# The input matrix X and the target vector Y should look like below now:
# Y X(foo bar baz BIAS)
# 1 1 1 0 1
# -1 1 0 0 1
# 1 1 1 1 1
# 1 1 1 0 1
# -1 0 0 1 1
# Y X(foo bar baz)
# 1 1 1 0
# -1 1 0 0
# 1 1 1 1
# 1 1 1 0
# -1 0 0 1
rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 1)
self.assertEqual(features, ['foo', 'bar', 'baz'])
self.assertEqual(Y.tolist(), [True, False, True, True, False])
self.assertEqual(rows.tolist(), [0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4])
self.assertEqual(cols.tolist(), [0, 1, 3, 0, 3, 0, 1, 2, 3, 1, 0, 3, 2, 3])
self.assertEqual(rows.tolist(), [0, 0, 1, 2, 2, 2, 3, 3, 4])
self.assertEqual(cols.tolist(), [0, 1, 0, 0, 1, 2, 1, 0, 2])

def test_skip_invalid_rows(self) -> None:
with open(self.ENTRIES_FILE_PATH, 'w') as f:
f.write(('\n1\tfoo\tbar\n'
'-1\n\n'
'-1\tfoo\n\n'))
# The input matrix X and the target vector Y should look like below now:
# Y X(foo bar BIAS)
# 1 1 1 1
# -1 1 0 1
# Y X(foo bar)
# 1 1 1
# -1 1 0
rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 0)
self.assertEqual(features, ['foo', 'bar'])
self.assertEqual(Y.tolist(), [True, False])
self.assertEqual(rows.tolist(), [0, 0, 0, 1, 1])
self.assertEqual(cols.tolist(), [0, 1, 2, 0, 2])
self.assertEqual(rows.tolist(), [0, 0, 1])
self.assertEqual(cols.tolist(), [0, 1, 0])

def tearDown(self) -> None:
if (os.path.exists(self.ENTRIES_FILE_PATH)):
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_standard_setup(self) -> None:
self.assertEqual(result.fscore, 2 * p * r / (p + r))


class TestUpdateWeights(unittest.TestCase):
class TestUpdate(unittest.TestCase):
X = np.array([
[1, 0, 1, 0],
[0, 1, 0, 0],
Expand All @@ -194,8 +194,8 @@ def test_standard_setup1(self) -> None:
Y = np.array([1, 1, 0, 0, 1], dtype=bool)
w = np.array([0.1, 0.3, 0.1, 0.1, 0.4])
scores = jnp.zeros(M)
new_w, new_scores, best_feature_index, added_score = train.update_weights(
w, rows, cols, Y, scores, M)
new_w, new_scores, best_feature_index, added_score = train.update(
w, scores, rows, cols, Y)
self.assertFalse(w.argmax() == 0)
self.assertTrue(new_w.argmax() == 0)
self.assertFalse(scores.argmax() == 1)
Expand Down Expand Up @@ -254,9 +254,8 @@ def test_fit(self) -> None:
for weight in weights:
model.setdefault(weight[0], 0)
model[weight[0]] += float(weight[1])
self.assertEqual(scores.shape[0], len(features) + 1)
loaded_scores = [model.get(feature, 0) for feature in features
] + [model.get('BIAS', 0)]
self.assertEqual(scores.shape[0], len(features))
loaded_scores = [model.get(feature, 0) for feature in features]
self.assertTrue(np.all(np.isclose(scores, loaded_scores)))

def tearDown(self) -> None:
Expand Down
58 changes: 27 additions & 31 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def preprocess(
- features (List[str]): The list of features.
"""
features_counter: typing.Counter[str] = Counter()
N = 0
X = []
Y = array.array('B')
with open(entries_filename) as f:
Expand All @@ -77,7 +76,6 @@ def preprocess(
Y.append(cols[0] == '1')
X.append(cols[1:])
features_counter.update(cols[1:])
N += 1
features = [
item[0]
for item in features_counter.most_common()
Expand All @@ -90,19 +88,17 @@ def preprocess(
hit_indices = [feature_index[feat] for feat in x if feat in feature_index]
rows.extend(i for _ in range(len(hit_indices)))
cols.extend(hit_indices) # type: ignore
rows.append(i)
cols.append(len(features)) # type: ignore
return jnp.asarray(rows), jnp.asarray(cols), jnp.asarray(
Y, dtype=bool), features


def split_data(
rows: npt.NDArray[np.int64],
cols: npt.NDArray[np.int64],
rows: npt.NDArray[np.int32],
cols: npt.NDArray[np.int32],
Y: npt.NDArray[np.bool_],
split_ratio: float = .9
) -> typing.Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64],
npt.NDArray[np.int64], npt.NDArray[np.int64],
) -> typing.Tuple[npt.NDArray[np.int32], npt.NDArray[np.int32],
npt.NDArray[np.int32], npt.NDArray[np.int32],
npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:
"""Splits a dataset into a training dataset and a test dataset.
Expand All @@ -129,23 +125,23 @@ def split_data(


@partial(jax.jit, static_argnums=[3])
def pred(phis: npt.NDArray[np.float64], rows: npt.NDArray[np.int64],
cols: npt.NDArray[np.int64], N: int) -> npt.NDArray[np.bool_]:
def pred(scores: npt.NDArray[np.float32], rows: npt.NDArray[np.int32],
cols: npt.NDArray[np.int32], N: int) -> npt.NDArray[np.bool_]:
"""Predicts the target output from the learned scores and input entries.
Args:
phis (numpy.ndarray): Contribution scores of features.
scores (numpy.ndarray): Contribution scores of features.
rows (numpy.ndarray): Row indices of True values in the input.
cols (numpy.ndarray): Column indices of True values in the input.
N (int): The number of input entries.
Returns:
res (numpy.ndarray): A prediction of the target.
"""
# This is equivalent to phis.dot(2X - 1) = 2phis.dot(X) - phis.sum() but in a
# sparse matrix-friendly way.
r: npt.NDArray[np.float64] = 2 * jax.ops.segment_sum(
phis.take(cols), rows, N) - phis.sum()
# This is equivalent to scores.dot(2X - 1) = 2 * scores.dot(X) - scores.sum()
# but in a sparse matrix-friendly way.
r: npt.NDArray[np.float32] = 2 * jax.ops.segment_sum(
scores.take(cols), rows, N) - scores.sum()
return r > 0


Expand Down Expand Up @@ -180,20 +176,20 @@ def get_metrics(pred: npt.NDArray[np.bool_],
)


@partial(jax.jit, static_argnums=[5])
def update_weights(w: npt.NDArray[np.float64], rows: npt.NDArray[np.int64],
cols: npt.NDArray[np.int64], Y: npt.NDArray[np.bool_],
scores: typing.Any,
M: int) -> typing.Tuple[typing.Any, typing.Any, int, float]:
"""Calculates the new weight vector from the best feature and its score.
@jax.jit
def update(
w: npt.NDArray[np.float32], scores: typing.Any, rows: npt.NDArray[np.int32],
cols: npt.NDArray[np.int32], Y: npt.NDArray[np.bool_]
) -> typing.Tuple[typing.Any, typing.Any, int, float]:
"""Calculates the new weight vector and the contribution scores.
Args:
w (numpy.ndarray): A weight vector.
scores (JAX array): Contribution scores of features.
rows (numpy.ndarray): Row indices of True values in the input data.
cols (numpy.ndarray): Column indices of True values in the input data.
Y (numpy.ndarray): The target output.
scores (JAX array): Contribution scores of features.
M (int): The number of columns in the input data.
Returns:
A tuple of following items:
Expand All @@ -202,6 +198,8 @@ def update_weights(w: npt.NDArray[np.float64], rows: npt.NDArray[np.int64],
- best_feature_index (int): The index of the best feature.
- score (float): The newly added score for the best feature.
"""
N = w.shape[0]
M = scores.shape[0]
# This is quivalent to w.dot(Y[:, None] ^ X). Note that y ^ x = y + x - 2yx,
# hence w.dot(y ^ x) = w.dot(y) - w(2y - 1).dot(x).
# `segment_sum` is used to implement sparse matrix-friendly dot products.
Expand All @@ -211,7 +209,6 @@ def update_weights(w: npt.NDArray[np.float64], rows: npt.NDArray[np.int64],
positivity: bool = res.at[best_feature_index].get() < 0.5
err_min = err.at[best_feature_index].get()
amount: float = jnp.log((1 - err_min) / (err_min + EPS))
N = Y.shape[0]

# This is equivalent to X_best = X[:, best_feature_index]
X_best = jnp.zeros(
Expand All @@ -224,8 +221,8 @@ def update_weights(w: npt.NDArray[np.float64], rows: npt.NDArray[np.int64],
return w, scores, best_feature_index, score


def fit(rows_train: npt.NDArray[np.int64], cols_train: npt.NDArray[np.int64],
rows_test: npt.NDArray[np.int64], cols_test: npt.NDArray[np.int64],
def fit(rows_train: npt.NDArray[np.int32], cols_train: npt.NDArray[np.int32],
rows_test: npt.NDArray[np.int32], cols_test: npt.NDArray[np.int32],
Y_train: npt.NDArray[np.bool_], Y_test: npt.NDArray[np.bool_],
features: typing.List[str], iters: int, weights_filename: str,
log_filename: str, out_span: int) -> typing.Any:
Expand Down Expand Up @@ -255,7 +252,7 @@ def fit(rows_train: npt.NDArray[np.int64], cols_train: npt.NDArray[np.int64],
'test_accuracy\ttest_precision\ttest_recall\ttest_fscore\n')
print('Outputting learned weights to %s ...' % (weights_filename))

M = len(features) + 1
M = len(features)
scores = jnp.zeros(M)
feature_score_buffer: typing.List[typing.Tuple[str, float]] = []
N_train = Y_train.shape[0]
Expand Down Expand Up @@ -296,11 +293,10 @@ def output_progress(t: int) -> None:
))

for t in range(iters):
w, scores, best_feature_index, score = update_weights(
w, rows_train, cols_train, Y_train, scores, M)
w, scores, best_feature_index, score = update(w, scores, rows_train,
cols_train, Y_train)
w.block_until_ready()
feature = features[best_feature_index] if (
best_feature_index < len(features)) else 'BIAS'
feature = features[best_feature_index]
feature_score_buffer.append((feature, score))
if (t + 1) % out_span == 0:
output_progress(t + 1)
Expand Down

0 comments on commit 0533463

Please sign in to comment.