Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a script to finetune models. #145

Merged
merged 6 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Brush up
Change-Id: I5932a30d5f4f935a2e764f309fe4277adbc4e7d1
  • Loading branch information
tushuhei committed May 31, 2023
commit 701e097ae06becabc0b15de9f0e33edd25ff35ab
120 changes: 89 additions & 31 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetunes the model with the given training data."""
"""Finetunes a BudouX model with the given training dataset.

Example usage:

$ python finetune.py train_data.txt base_model.json -o weights.txt --val_data=val_data.txt
"""

import argparse
import array
import typing
import json
import typing
from collections import OrderedDict

from jax import Array, grad, jit
from jax import numpy as jnp
from jax import Array, jit, grad

EPSILON: float = jnp.finfo(float).eps
DEFAULT_OUTPUT_NAME = 'finetuned-weights.txt'
DEFAULT_NUM_ITERS = 1000
DEFAULT_LOG_SPAN = 100
DEFAULT_LEARNING_RATE = 0.01


class NormalizedModel(typing.NamedTuple):
Expand All @@ -31,6 +43,7 @@ class Dataset(typing.NamedTuple):
X: Array
Y: Array


class Metrics(typing.NamedTuple):
tp: int
tn: int
Expand All @@ -40,11 +53,12 @@ class Metrics(typing.NamedTuple):
precision: float
recall: float
fscore: float
loss: float


def load_model(file_path: str) -> NormalizedModel:
"""Loads a model as a pair of a features list and a normalized weight vector.

Args:
file_path: A file path for the model JSON file.

Expand All @@ -67,7 +81,7 @@ def load_model(file_path: str) -> NormalizedModel:

def load_dataset(file_path: str, model: NormalizedModel) -> Dataset:
"""Loads a dataset from the given file path.

Args:
file_path: A file path for the encoded data file.
model: A normalized model.
Expand All @@ -91,7 +105,7 @@ def load_dataset(file_path: str, model: NormalizedModel) -> Dataset:

def cross_entropy_loss(weights: Array, x: Array, y: Array) -> Array:
"""Calcurates a cross entropy loss with a prediction by a sigmoid function.

Args:
weights: A weight vector.
x: An input array.
Expand All @@ -106,6 +120,7 @@ def cross_entropy_loss(weights: Array, x: Array, y: Array) -> Array:

def get_metrics(weights: Array, dataset: Dataset) -> Metrics:
"""Gets evaluation metrics from the learned weight vector and the dataset.

Args:
weights: A weight vector.
dataset: A dataset.
Expand All @@ -119,9 +134,12 @@ def get_metrics(weights: Array, dataset: Dataset) -> Metrics:
tn: int = jnp.sum(jnp.logical_and(pred == 0, actual == 0)) # type: ignore
fp: int = jnp.sum(jnp.logical_and(pred == 1, actual == 0)) # type: ignore
fn: int = jnp.sum(jnp.logical_and(pred == 0, actual == 1)) # type: ignore
loss: float = cross_entropy_loss(weights, dataset.X,
dataset.Y) # type: ignore
accuracy = (tp + tn) / (tp + tn + fp + fn)
precision = tp / (tp + fp)
recall = tp / (tp + fn)
precision = tp / (tp + fp + EPSILON)
recall = tp / (tp + fn + EPSILON)
fscore = 2 * precision * recall / (precision + recall + EPSILON)
return Metrics(
tp=tp,
tn=tn,
Expand All @@ -130,53 +148,61 @@ def get_metrics(weights: Array, dataset: Dataset) -> Metrics:
accuracy=accuracy,
precision=precision,
recall=recall,
fscore=2 * precision * recall / (precision + recall),
fscore=fscore,
loss=loss,
)


def fit(weights: Array,
train_dataset: Dataset,
val_dataset: typing.Optional[Dataset] = None,
iter: int = 1000,
learning_rate: float = 0.1,
log_span: int = 100) -> Array:
iters: int,
learning_rate: float,
log_span: int,
val_dataset: typing.Optional[Dataset] = None) -> Array:
"""Updates the weights with the given dataset.

Args:
weights: A weight vector.
train_dataset: A train dataset.
iter: A number of iterations.
learning_rate: A learning rate.
log_span: A span to log metrics.
val_dataset: A validation dataset (optional).
iter: A number of iterations (default: 1000).
learning_rate: A learning rate (default: 0.1).
log_span: A span to log metrics (default: 100).

Returns:
An updated weight vector.
"""
grad_loss = jit(grad(cross_entropy_loss, argnums=0))
for t in range(iter):
for t in range(iters):
weights = weights - learning_rate * grad_loss(weights, train_dataset.X,
train_dataset.Y)
if (t + 1) % log_span != 0: continue
if (t + 1) % log_span != 0:
continue
metrics_train = jit(get_metrics)(weights, train_dataset)
print()
print('iter:\t%d' % (t + 1))
print()
print('train accuracy:\t%.5f' % metrics_train.accuracy)
print('train prec.:\t%.5f' % metrics_train.precision)
print('train recall:\t%.5f' % metrics_train.recall)
print('train fscore:\t%.5f' % metrics_train.fscore)
print('train loss:\t%.5f' % metrics_train.loss)
print()

if val_dataset is None: continue
metrics_val = jit(get_metrics(weights, val_dataset))
print('val accuracy:\t%.5f' % metrics_test.accuracy)
print('val prec.:\t%.5f' % metrics_test.precision)
print('val recall:\t%.5f' % metrics_test.recall)
print('val fscore:\t%.5f' % metrics_test.fscore)
if val_dataset is None:
continue
metrics_val = jit(get_metrics)(weights, val_dataset)
print('val accuracy:\t%.5f' % metrics_val.accuracy)
print('val prec.:\t%.5f' % metrics_val.precision)
print('val recall:\t%.5f' % metrics_val.recall)
print('val fscore:\t%.5f' % metrics_val.fscore)
print('val loss:\t%.5f' % metrics_val.loss)
print()
return weights


def write_weights(file_path: str, weights: Array, features: typing.List[str]):
def write_weights(file_path: str, weights: Array,
features: typing.List[str]) -> None:
"""Writes learned weights and corresponsing features to a file.

Args:
Expand All @@ -185,7 +211,9 @@ def write_weights(file_path: str, weights: Array, features: typing.List[str]):
features: A list of feature identifiers.
"""
with open(file_path, 'w') as f:
f.write('\n'.join(['%s\t%.6f' % (feature, weights[i]) for i, feature in enumerate(features)]))
f.write('\n'.join([
'%s\t%.6f' % (feature, weights[i]) for i, feature in enumerate(features)
]))


def parse_args(
Expand All @@ -199,16 +227,34 @@ def parse_args(
Returns:
Parsed arguments (argparse.Namespace).
"""
parser = argparse.ArgumentParser(description=__doc__)
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument(
'train_data', help='File path for the encoded training data.')
parser.add_argument('base_model', help='File path for the base model file.')
parser.add_argument(
'-o',
'--output',
help=f'Output file path for the learned weights. (default: finetuned-weights.txt)',
help=f'File path for the output weights. (default: {DEFAULT_OUTPUT_NAME})',
type=str,
default='finetuned-weights.txt')
default=DEFAULT_OUTPUT_NAME)
parser.add_argument(
'--val-data', help='File path for the encoded validation data.', type=str)
parser.add_argument(
'--iters',
help=f'Number of iterations for training. (default: {DEFAULT_NUM_ITERS})',
type=int,
default=DEFAULT_NUM_ITERS)
parser.add_argument(
'--log-span',
help=f'Iteration span to print metrics. (default: {DEFAULT_LOG_SPAN})',
type=int,
default=DEFAULT_LOG_SPAN)
parser.add_argument(
'--learning-rate',
help=f'Learning rate. (default: {DEFAULT_LEARNING_RATE})',
type=float,
default=DEFAULT_LEARNING_RATE)
if test is None:
return parser.parse_args()
else:
Expand All @@ -220,11 +266,23 @@ def main() -> None:
train_data_path: str = args.train_data
base_model_path: str = args.base_model
weights_path: str = args.output
iters: int = args.iters
log_span: int = args.log_span
learning_rate: float = args.learning_rate
val_data_path: typing.Optional[str] = args.val_data

model = load_model(base_model_path)
train_dataset = load_dataset(train_data_path, model)
weights = fit(model.weights, train_dataset.X, train_dataset.Y)
val_dataset = load_dataset(val_data_path, model) if val_data_path else None
weights = fit(
model.weights,
train_dataset,
iters=iters,
log_span=log_span,
learning_rate=learning_rate,
val_dataset=val_dataset)
write_weights(weights_path, weights, model.features)


if __name__ == '__main__':
main()
24 changes: 22 additions & 2 deletions scripts/tests/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

import os
import sys
import unittest
import tempfile
import unittest
from collections import OrderedDict

from jax import numpy as jnp

# module hack
Expand Down Expand Up @@ -56,6 +57,23 @@ def test_cmdargs_default(self) -> None:
output = finetune.parse_args(cmdargs)
self.assertEqual(output.train_data, 'encoded.txt')
self.assertEqual(output.base_model, 'model.json')
self.assertEqual(output.iters, finetune.DEFAULT_NUM_ITERS)
self.assertEqual(output.log_span, finetune.DEFAULT_LOG_SPAN)
self.assertEqual(output.learning_rate, finetune.DEFAULT_LEARNING_RATE)
self.assertEqual(output.val_data, None)

def test_cmdargs_with_values(self) -> None:
cmdargs = [
'encoded.txt', 'model.json', '--iters', '50', '--log-span', '10',
'--learning-rate', '0.1', '--val-data', 'val.txt'
]
output = finetune.parse_args(cmdargs)
self.assertEqual(output.train_data, 'encoded.txt')
self.assertEqual(output.base_model, 'model.json')
self.assertEqual(output.iters, 50)
self.assertEqual(output.log_span, 10)
self.assertEqual(output.learning_rate, 0.1)
self.assertEqual(output.val_data, 'val.txt')


class TestLoadModel(unittest.TestCase):
Expand Down Expand Up @@ -105,15 +123,17 @@ def test_x(self) -> None:
expected = [[1, 1], [1, -1], [1, 1], [1, 1], [-1, -1]]
self.assertListEqual(result.X.tolist(), expected)


class TestFit(unittest.TestCase):

def test_health(self) -> None:
w = jnp.array([.9, .5, -.3])
X = jnp.array([[-1, 1, 1], [1, -1, 1], [1, 1, -1]])
# The current result is x.dot(w) = [-0.7, 0.1, 1.1] => [False, True, True]
# It tests if the method can learn a new weight that inverses the result.
Y = jnp.array([True, False, False])
dataset = finetune.Dataset(X, Y)
w = finetune.fit(w, dataset, iter=1000, learning_rate=.01, log_span=100)
w = finetune.fit(w, dataset, iters=1000, learning_rate=.01, log_span=100)
self.assertGreater(X.dot(w).tolist()[0], 0) # x.dot(w) > 0 => True.


Expand Down