Skip to content

Commit

Permalink
Normalize weights not to overflow (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
tushuhei authored May 16, 2022
1 parent fcc9156 commit 32ba63b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,14 @@ def fit(X_train: npt.NDArray[np.bool_],
for t in range(iters):
print('=== %s ===' % (t))
if chunk_size is None:
res: npt.NDArray[np.float64] = w.dot(Y_train[:, None] ^ X_train) / w.sum()
res: npt.NDArray[np.float64] = w.dot(Y_train[:, None] ^ X_train)
else:
res = np.zeros(M_train)
for i in range(0, N_train, chunk_size):
Y_train_chunk = Y_train[i:i + chunk_size]
X_train_chunk = X_train[i:i + chunk_size]
w_chunk = w[i:i + chunk_size]
res += w_chunk.dot(Y_train_chunk[:, None] ^ X_train_chunk)
res = res / w.sum()
err = 0.5 - np.abs(res - 0.5)
m_best = int(err.argmin())
pol_best = res[m_best] < 0.5
Expand All @@ -188,6 +187,7 @@ def fit(X_train: npt.NDArray[np.bool_],
if not pol_best:
miss = ~(miss)
w = w * np.exp(alpha * miss)
w = w / w.sum()
with open(weights_filename, 'a') as f:
feature = features[m_best] if m_best < len(features) else 'BIAS'
f.write('%s\t%.3f\n' % (feature, alpha if pol_best else -alpha))
Expand Down

0 comments on commit 32ba63b

Please sign in to comment.