From 4d45fb74d9f24eee61b0e226d2a779edb3d5c3ef Mon Sep 17 00:00:00 2001 From: Shuhei Iitsuka Date: Mon, 16 May 2022 13:50:42 +0900 Subject: [PATCH] Normalize weights not to overflow --- scripts/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 5f424c8c..a9bf3f0e 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -166,7 +166,7 @@ def fit(X_train: npt.NDArray[np.bool_], for t in range(iters): print('=== %s ===' % (t)) if chunk_size is None: - res: npt.NDArray[np.float64] = w.dot(Y_train[:, None] ^ X_train) / w.sum() + res: npt.NDArray[np.float64] = w.dot(Y_train[:, None] ^ X_train) else: res = np.zeros(M_train) for i in range(0, N_train, chunk_size): @@ -174,7 +174,6 @@ def fit(X_train: npt.NDArray[np.bool_], X_train_chunk = X_train[i:i + chunk_size] w_chunk = w[i:i + chunk_size] res += w_chunk.dot(Y_train_chunk[:, None] ^ X_train_chunk) - res = res / w.sum() err = 0.5 - np.abs(res - 0.5) m_best = int(err.argmin()) pol_best = res[m_best] < 0.5 @@ -188,6 +187,7 @@ def fit(X_train: npt.NDArray[np.bool_], if not pol_best: miss = ~(miss) w = w * np.exp(alpha * miss) + w = w / w.sum() with open(weights_filename, 'a') as f: feature = features[m_best] if m_best < len(features) else 'BIAS' f.write('%s\t%.3f\n' % (feature, alpha if pol_best else -alpha))