diff --git a/scripts/finetune.py b/scripts/finetune.py new file mode 100644 index 00000000..5f01c5de --- /dev/null +++ b/scripts/finetune.py @@ -0,0 +1,288 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Finetunes a BudouX model with the given training dataset. + +Example usage: + +$ python finetune.py train_data.txt base_model.json -o weights.txt --val_data=val_data.txt +""" + +import argparse +import array +import json +import typing +from collections import OrderedDict + +from jax import Array, grad, jit +from jax import numpy as jnp + +EPSILON: float = jnp.finfo(float).eps +DEFAULT_OUTPUT_NAME = 'finetuned-weights.txt' +DEFAULT_NUM_ITERS = 1000 +DEFAULT_LOG_SPAN = 100 +DEFAULT_LEARNING_RATE = 0.01 + + +class NormalizedModel(typing.NamedTuple): + features: typing.List[str] + weights: Array + + +class Dataset(typing.NamedTuple): + X: Array + Y: Array + + +class Metrics(typing.NamedTuple): + tp: int + tn: int + fp: int + fn: int + accuracy: float + precision: float + recall: float + fscore: float + loss: float + + +def load_model(file_path: str) -> NormalizedModel: + """Loads a model as a pair of a features list and a normalized weight vector. + + Args: + file_path: A file path for the model JSON file. + + Returns: + A normalized model, which is a pair of a list of feature identifiers and a + normalized weight vector. + """ + with open(file_path) as f: + model = json.load(f) + model_flat = OrderedDict() + for category in model: + for item in model[category]: + model_flat['%s:%s' % (category, item)] = model[category][item] + weights = jnp.array(list(model_flat.values())) + weights = weights / weights.std() + weights = weights - weights.mean() + keys = list(model_flat.keys()) + return NormalizedModel(keys, weights) + + +def load_dataset(file_path: str, model: NormalizedModel) -> Dataset: + """Loads a dataset from the given file path. + + Args: + file_path: A file path for the encoded data file. + model: A normalized model. + + Returns: + A dataset of inputs (X) and outputs (Y). + """ + xs = [] + ys = array.array('B') + with open(file_path) as f: + for row in f: + cols = row.strip().split('\t') + if len(cols) < 2: + continue + ys.append(cols[0] == '1') + xs.append(tuple(k in set(cols[1:]) for k in model.features)) + X = jnp.array(xs) * 2 - 1 + Y = jnp.array(ys) + return Dataset(X, Y) + + +def cross_entropy_loss(weights: Array, x: Array, y: Array) -> Array: + """Calcurates a cross entropy loss with a prediction by a sigmoid function. + + Args: + weights: A weight vector. + x: An input array. + y: A target output array. + + Returns: + A cross entropy loss. + """ + pred = 1 / (1 + jnp.exp(-x.dot(weights))) + return -jnp.mean(y * jnp.log(pred) + (1 - y) * jnp.log(1 - pred)) + + +def get_metrics(weights: Array, dataset: Dataset) -> Metrics: + """Gets evaluation metrics from the learned weight vector and the dataset. + + Args: + weights: A weight vector. + dataset: A dataset. + + Returns: + result (Metrics): The metrics over the given weights and the dataset. + """ + pred = dataset.X.dot(weights) > 0 + actual = dataset.Y + tp: int = jnp.sum(jnp.logical_and(pred == 1, actual == 1)) # type: ignore + tn: int = jnp.sum(jnp.logical_and(pred == 0, actual == 0)) # type: ignore + fp: int = jnp.sum(jnp.logical_and(pred == 1, actual == 0)) # type: ignore + fn: int = jnp.sum(jnp.logical_and(pred == 0, actual == 1)) # type: ignore + loss: float = cross_entropy_loss(weights, dataset.X, + dataset.Y) # type: ignore + accuracy = (tp + tn) / (tp + tn + fp + fn) + precision = tp / (tp + fp + EPSILON) + recall = tp / (tp + fn + EPSILON) + fscore = 2 * precision * recall / (precision + recall + EPSILON) + return Metrics( + tp=tp, + tn=tn, + fp=fp, + fn=fn, + accuracy=accuracy, + precision=precision, + recall=recall, + fscore=fscore, + loss=loss, + ) + + +def fit(weights: Array, + train_dataset: Dataset, + iters: int, + learning_rate: float, + log_span: int, + val_dataset: typing.Optional[Dataset] = None) -> Array: + """Updates the weights with the given dataset. + + Args: + weights: A weight vector. + train_dataset: A train dataset. + iters: A number of iterations. + learning_rate: A learning rate. + log_span: A span to log metrics. + val_dataset: A validation dataset (optional). + + Returns: + An updated weight vector. + """ + grad_loss = jit(grad(cross_entropy_loss, argnums=0)) + for t in range(iters): + weights = weights - learning_rate * grad_loss(weights, train_dataset.X, + train_dataset.Y) + if (t + 1) % log_span != 0: + continue + metrics_train = jit(get_metrics)(weights, train_dataset) + print() + print('iter:\t%d' % (t + 1)) + print() + print('train accuracy:\t%.5f' % metrics_train.accuracy) + print('train prec.:\t%.5f' % metrics_train.precision) + print('train recall:\t%.5f' % metrics_train.recall) + print('train fscore:\t%.5f' % metrics_train.fscore) + print('train loss:\t%.5f' % metrics_train.loss) + print() + + if val_dataset is None: + continue + metrics_val = jit(get_metrics)(weights, val_dataset) + print('val accuracy:\t%.5f' % metrics_val.accuracy) + print('val prec.:\t%.5f' % metrics_val.precision) + print('val recall:\t%.5f' % metrics_val.recall) + print('val fscore:\t%.5f' % metrics_val.fscore) + print('val loss:\t%.5f' % metrics_val.loss) + print() + return weights + + +def write_weights(file_path: str, weights: Array, + features: typing.List[str]) -> None: + """Writes learned weights and corresponsing features to a file. + + Args: + file_path: A file path for the weights file. + weights: A weight vector. + features: A list of feature identifiers. + """ + with open(file_path, 'w') as f: + f.write('\n'.join([ + '%s\t%.6f' % (feature, weights[i]) for i, feature in enumerate(features) + ])) + + +def parse_args( + test: typing.Optional[typing.List[str]] = None) -> argparse.Namespace: + """Parses commandline arguments. + + Args: + test (typing.Optional[typing.List[str]], optional): Commandline args for + testing. Defaults to None. + + Returns: + Parsed arguments (argparse.Namespace). + """ + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument( + 'train_data', help='File path for the encoded training data.') + parser.add_argument('base_model', help='File path for the base model file.') + parser.add_argument( + '-o', + '--output', + help=f'File path for the output weights. (default: {DEFAULT_OUTPUT_NAME})', + type=str, + default=DEFAULT_OUTPUT_NAME) + parser.add_argument( + '--val-data', help='File path for the encoded validation data.', type=str) + parser.add_argument( + '--iters', + help=f'Number of iterations for training. (default: {DEFAULT_NUM_ITERS})', + type=int, + default=DEFAULT_NUM_ITERS) + parser.add_argument( + '--log-span', + help=f'Iteration span to print metrics. (default: {DEFAULT_LOG_SPAN})', + type=int, + default=DEFAULT_LOG_SPAN) + parser.add_argument( + '--learning-rate', + help=f'Learning rate. (default: {DEFAULT_LEARNING_RATE})', + type=float, + default=DEFAULT_LEARNING_RATE) + if test is None: + return parser.parse_args() + else: + return parser.parse_args(test) + + +def main() -> None: + args = parse_args() + train_data_path: str = args.train_data + base_model_path: str = args.base_model + weights_path: str = args.output + iters: int = args.iters + log_span: int = args.log_span + learning_rate: float = args.learning_rate + val_data_path: typing.Optional[str] = args.val_data + + model = load_model(base_model_path) + train_dataset = load_dataset(train_data_path, model) + val_dataset = load_dataset(val_data_path, model) if val_data_path else None + weights = fit( + model.weights, + train_dataset, + iters=iters, + log_span=log_span, + learning_rate=learning_rate, + val_dataset=val_dataset) + write_weights(weights_path, weights, model.features) + + +if __name__ == '__main__': + main() diff --git a/scripts/tests/test_finetune.py b/scripts/tests/test_finetune.py new file mode 100644 index 00000000..8186871e --- /dev/null +++ b/scripts/tests/test_finetune.py @@ -0,0 +1,148 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests the finetune script.""" + +import os +import sys +import tempfile +import unittest + +from jax import numpy as jnp + +# module hack +LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..') +sys.path.insert(0, os.path.abspath(LIB_PATH)) + +from scripts import finetune # noqa (module hack) + + +class TestArgParse(unittest.TestCase): + + def test_cmdargs_invalid_option(self) -> None: + cmdargs = ['-v'] + with self.assertRaises(SystemExit) as cm: + finetune.parse_args(cmdargs) + self.assertEqual(cm.exception.code, 2) + + def test_cmdargs_help(self) -> None: + cmdargs = ['-h'] + with self.assertRaises(SystemExit) as cm: + finetune.parse_args(cmdargs) + self.assertEqual(cm.exception.code, 0) + + def test_cmdargs_no_data(self) -> None: + with self.assertRaises(SystemExit) as cm: + finetune.parse_args([]) + self.assertEqual(cm.exception.code, 2) + + def test_cmdargs_no_base_model(self) -> None: + with self.assertRaises(SystemExit) as cm: + finetune.parse_args(['encoded.txt']) + self.assertEqual(cm.exception.code, 2) + + def test_cmdargs_default(self) -> None: + cmdargs = ['encoded.txt', 'model.json'] + output = finetune.parse_args(cmdargs) + self.assertEqual(output.train_data, 'encoded.txt') + self.assertEqual(output.base_model, 'model.json') + self.assertEqual(output.iters, finetune.DEFAULT_NUM_ITERS) + self.assertEqual(output.log_span, finetune.DEFAULT_LOG_SPAN) + self.assertEqual(output.learning_rate, finetune.DEFAULT_LEARNING_RATE) + self.assertEqual(output.val_data, None) + + def test_cmdargs_with_values(self) -> None: + cmdargs = [ + 'encoded.txt', 'model.json', '--iters', '50', '--log-span', '10', + '--learning-rate', '0.1', '--val-data', 'val.txt' + ] + output = finetune.parse_args(cmdargs) + self.assertEqual(output.train_data, 'encoded.txt') + self.assertEqual(output.base_model, 'model.json') + self.assertEqual(output.iters, 50) + self.assertEqual(output.log_span, 10) + self.assertEqual(output.learning_rate, 0.1) + self.assertEqual(output.val_data, 'val.txt') + + +class TestLoadModel(unittest.TestCase): + + def setUp(self) -> None: + self.model_file_path = tempfile.NamedTemporaryFile().name + with open(self.model_file_path, 'w') as f: + f.write('{"UW1": {"a": 12, "b": 23}, "TW3": {"xyz": 47}}') + + def test_extracted_keys(self) -> None: + result = finetune.load_model(self.model_file_path).features + self.assertListEqual(result, ['UW1:a', 'UW1:b', 'TW3:xyz']) + + def test_value_variance(self) -> None: + result = finetune.load_model(self.model_file_path).weights.var() + self.assertAlmostEqual(float(result), 1, places=5) + + def test_value_mean(self) -> None: + result = finetune.load_model(self.model_file_path).weights.sum() + self.assertAlmostEqual(float(result), 0, places=5) + + def test_value_order(self) -> None: + result = finetune.load_model(self.model_file_path).weights.tolist() + self.assertGreater(result[1], result[0]) + self.assertGreater(result[2], result[1]) + + +class TestLoadDataset(unittest.TestCase): + + def setUp(self) -> None: + self.entries_file_path = tempfile.NamedTemporaryFile().name + with open(self.entries_file_path, 'w') as f: + f.write(('1\tfoo\tbar\n' + '-1\tfoo\n' + '1\tfoo\tbar\tbaz\n' + '1\tbar\tfoo\n' + '-1\tbaz\tqux\n')) + self.model = finetune.NormalizedModel(['foo', 'bar'], jnp.array([23, -37])) + + def test_y(self) -> None: + result = finetune.load_dataset(self.entries_file_path, self.model) + expected = [True, False, True, True, False] + self.assertListEqual(result.Y.tolist(), expected) + + def test_x(self) -> None: + result = finetune.load_dataset(self.entries_file_path, self.model) + expected = [[1, 1], [1, -1], [1, 1], [1, 1], [-1, -1]] + self.assertListEqual(result.X.tolist(), expected) + + +class TestFit(unittest.TestCase): + + def test_health(self) -> None: + w = jnp.array([.9, .5, -.3]) + X = jnp.array([[-1, 1, 1], [1, -1, 1], [1, 1, -1]]) + # The current result is x.dot(w) = [-0.7, 0.1, 1.1] => [False, True, True] + # It tests if the method can learn a new weight that inverses the result. + Y = jnp.array([True, False, False]) + dataset = finetune.Dataset(X, Y) + w = finetune.fit(w, dataset, iters=1000, learning_rate=.01, log_span=100) + self.assertGreater(X.dot(w).tolist()[0], 0) # x.dot(w) > 0 => True. + + +class TestWriteWeights(unittest.TestCase): + + def test_write_weights(self) -> None: + weights = jnp.array([0.012, 0.238, -0.1237]) + features = ['foo', 'bar', 'baz'] + weights_path = tempfile.NamedTemporaryFile().name + finetune.write_weights(weights_path, weights, features) + with open(weights_path) as f: + result = f.read() + self.assertEqual(result, 'foo\t0.012000\nbar\t0.238000\nbaz\t-0.123700')