From fdfc66f7c3a765f63f85dbed6f8fd86a06674b74 Mon Sep 17 00:00:00 2001 From: Paul Tardy Date: Tue, 1 Oct 2019 14:03:39 +0200 Subject: [PATCH] Preparing for pip (#1581) * put requirements in setup and add entry points * moving binaries to onmt.bin.* * add scripts in root for compatibility * remove requirements install from travis * fix test_preprocess * fix scripts * fix scripts path for docs --- .travis.yml | 1 - docs/source/options/preprocess.rst | 2 +- docs/source/options/server.rst | 2 +- docs/source/options/train.rst | 2 +- docs/source/options/translate.rst | 2 +- onmt/bin/__init__.py | 0 onmt/bin/preprocess.py | 287 +++++++++++++++++++++++++++++ onmt/bin/server.py | 133 +++++++++++++ onmt/bin/train.py | 204 ++++++++++++++++++++ onmt/bin/translate.py | 53 ++++++ onmt/tests/test_preprocess.py | 2 +- preprocess.py | 281 +--------------------------- requirements.opt.txt | 3 - requirements.txt | 6 - server.py | 129 +------------ setup.py | 34 +++- train.py | 198 +------------------- translate.py | 47 +---- 18 files changed, 717 insertions(+), 669 deletions(-) create mode 100644 onmt/bin/__init__.py create mode 100755 onmt/bin/preprocess.py create mode 100755 onmt/bin/server.py create mode 100755 onmt/bin/train.py create mode 100755 onmt/bin/translate.py mode change 100755 => 100644 preprocess.py delete mode 100644 requirements.txt mode change 100755 => 100644 server.py mode change 100755 => 100644 train.py mode change 100755 => 100644 translate.py diff --git a/.travis.yml b/.travis.yml index ac41b68cec..f1411e0ef6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,7 +14,6 @@ addons: before_install: # Install CPU version of PyTorch. - if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install torch==1.2.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html; fi - - pip install -r requirements.txt - pip install -r requirements.opt.txt - python setup.py install env: diff --git a/docs/source/options/preprocess.rst b/docs/source/options/preprocess.rst index 09b86f0677..d7ccc66e8b 100644 --- a/docs/source/options/preprocess.rst +++ b/docs/source/options/preprocess.rst @@ -2,6 +2,6 @@ Preprocess ========== .. argparse:: - :filename: ../preprocess.py + :filename: ../onmt/bin/preprocess.py :func: _get_parser :prog: preprocess.py \ No newline at end of file diff --git a/docs/source/options/server.rst b/docs/source/options/server.rst index 31ba6dfe67..63b2676fbe 100644 --- a/docs/source/options/server.rst +++ b/docs/source/options/server.rst @@ -2,6 +2,6 @@ Server ========= .. argparse:: - :filename: ../server.py + :filename: ../onmt/bin/server.py :func: _get_parser :prog: server.py \ No newline at end of file diff --git a/docs/source/options/train.rst b/docs/source/options/train.rst index 9b1f68f6d7..67dc1cb22f 100644 --- a/docs/source/options/train.rst +++ b/docs/source/options/train.rst @@ -2,6 +2,6 @@ Train ===== .. argparse:: - :filename: ../train.py + :filename: ../onmt/bin/train.py :func: _get_parser :prog: train.py \ No newline at end of file diff --git a/docs/source/options/translate.rst b/docs/source/options/translate.rst index 967e0e819c..db0423a43b 100644 --- a/docs/source/options/translate.rst +++ b/docs/source/options/translate.rst @@ -2,6 +2,6 @@ Translate ========= .. argparse:: - :filename: ../translate.py + :filename: ../onmt/bin/translate.py :func: _get_parser :prog: translate.py \ No newline at end of file diff --git a/onmt/bin/__init__.py b/onmt/bin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/onmt/bin/preprocess.py b/onmt/bin/preprocess.py new file mode 100755 index 0000000000..fc47a4bb08 --- /dev/null +++ b/onmt/bin/preprocess.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" + Pre-process Data / features files and build vocabulary +""" +import codecs +import glob +import gc +import torch +from collections import Counter, defaultdict + +from onmt.utils.logging import init_logger, logger +from onmt.utils.misc import split_corpus +import onmt.inputters as inputters +import onmt.opts as opts +from onmt.utils.parse import ArgumentParser +from onmt.inputters.inputter import _build_fields_vocab,\ + _load_vocab + +from functools import partial +from multiprocessing import Pool + + +def check_existing_pt_files(opt, corpus_type, ids, existing_fields): + """ Check if there are existing .pt files to avoid overwriting them """ + existing_shards = [] + for maybe_id in ids: + if maybe_id: + shard_base = corpus_type + "_" + maybe_id + else: + shard_base = corpus_type + pattern = opt.save_data + '.{}.*.pt'.format(shard_base) + if glob.glob(pattern): + if opt.overwrite: + maybe_overwrite = ("will be overwritten because " + "`-overwrite` option is set.") + else: + maybe_overwrite = ("won't be overwritten, pass the " + "`-overwrite` option if you want to.") + logger.warning("Shards for corpus {} already exist, {}" + .format(shard_base, maybe_overwrite)) + existing_shards += [maybe_id] + return existing_shards + + +def process_one_shard(corpus_params, params): + corpus_type, fields, src_reader, tgt_reader, opt, existing_fields,\ + src_vocab, tgt_vocab = corpus_params + i, (src_shard, tgt_shard, maybe_id, filter_pred) = params + # create one counter per shard + sub_sub_counter = defaultdict(Counter) + assert len(src_shard) == len(tgt_shard) + logger.info("Building shard %d." % i) + dataset = inputters.Dataset( + fields, + readers=([src_reader, tgt_reader] + if tgt_reader else [src_reader]), + data=([("src", src_shard), ("tgt", tgt_shard)] + if tgt_reader else [("src", src_shard)]), + dirs=([opt.src_dir, None] + if tgt_reader else [opt.src_dir]), + sort_key=inputters.str2sortkey[opt.data_type], + filter_pred=filter_pred + ) + if corpus_type == "train" and existing_fields is None: + for ex in dataset.examples: + for name, field in fields.items(): + if ((opt.data_type == "audio") and (name == "src")): + continue + try: + f_iter = iter(field) + except TypeError: + f_iter = [(name, field)] + all_data = [getattr(ex, name, None)] + else: + all_data = getattr(ex, name) + for (sub_n, sub_f), fd in zip( + f_iter, all_data): + has_vocab = (sub_n == 'src' and + src_vocab is not None) or \ + (sub_n == 'tgt' and + tgt_vocab is not None) + if (hasattr(sub_f, 'sequential') + and sub_f.sequential and not has_vocab): + val = fd + sub_sub_counter[sub_n].update(val) + if maybe_id: + shard_base = corpus_type + "_" + maybe_id + else: + shard_base = corpus_type + data_path = "{:s}.{:s}.{:d}.pt".\ + format(opt.save_data, shard_base, i) + + logger.info(" * saving %sth %s data shard to %s." + % (i, shard_base, data_path)) + + dataset.save(data_path) + + del dataset.examples + gc.collect() + del dataset + gc.collect() + + return sub_sub_counter + + +def maybe_load_vocab(corpus_type, counters, opt): + src_vocab = None + tgt_vocab = None + existing_fields = None + if corpus_type == "train": + if opt.src_vocab != "": + try: + logger.info("Using existing vocabulary...") + existing_fields = torch.load(opt.src_vocab) + except torch.serialization.pickle.UnpicklingError: + logger.info("Building vocab from text file...") + src_vocab, src_vocab_size = _load_vocab( + opt.src_vocab, "src", counters, + opt.src_words_min_frequency) + if opt.tgt_vocab != "": + tgt_vocab, tgt_vocab_size = _load_vocab( + opt.tgt_vocab, "tgt", counters, + opt.tgt_words_min_frequency) + return src_vocab, tgt_vocab, existing_fields + + +def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt): + assert corpus_type in ['train', 'valid'] + + if corpus_type == 'train': + counters = defaultdict(Counter) + srcs = opt.train_src + tgts = opt.train_tgt + ids = opt.train_ids + elif corpus_type == 'valid': + counters = None + srcs = [opt.valid_src] + tgts = [opt.valid_tgt] + ids = [None] + + src_vocab, tgt_vocab, existing_fields = maybe_load_vocab( + corpus_type, counters, opt) + + existing_shards = check_existing_pt_files( + opt, corpus_type, ids, existing_fields) + + # every corpus has shards, no new one + if existing_shards == ids and not opt.overwrite: + return + + def shard_iterator(srcs, tgts, ids, existing_shards, + existing_fields, corpus_type, opt): + """ + Builds a single iterator yielding every shard of every corpus. + """ + for src, tgt, maybe_id in zip(srcs, tgts, ids): + if maybe_id in existing_shards: + if opt.overwrite: + logger.warning("Overwrite shards for corpus {}" + .format(maybe_id)) + else: + if corpus_type == "train": + assert existing_fields is not None,\ + ("A 'vocab.pt' file should be passed to " + "`-src_vocab` when adding a corpus to " + "a set of already existing shards.") + logger.warning("Ignore corpus {} because " + "shards already exist" + .format(maybe_id)) + continue + if ((corpus_type == "train" or opt.filter_valid) + and tgt is not None): + filter_pred = partial( + inputters.filter_example, + use_src_len=opt.data_type == "text", + max_src_len=opt.src_seq_length, + max_tgt_len=opt.tgt_seq_length) + else: + filter_pred = None + src_shards = split_corpus(src, opt.shard_size) + tgt_shards = split_corpus(tgt, opt.shard_size) + for i, (ss, ts) in enumerate(zip(src_shards, tgt_shards)): + yield (i, (ss, ts, maybe_id, filter_pred)) + + shard_iter = shard_iterator(srcs, tgts, ids, existing_shards, + existing_fields, corpus_type, opt) + + with Pool(opt.num_threads) as p: + dataset_params = (corpus_type, fields, src_reader, tgt_reader, + opt, existing_fields, src_vocab, tgt_vocab) + func = partial(process_one_shard, dataset_params) + for sub_counter in p.imap(func, shard_iter): + if sub_counter is not None: + for key, value in sub_counter.items(): + counters[key].update(value) + + if corpus_type == "train": + vocab_path = opt.save_data + '.vocab.pt' + if existing_fields is None: + fields = _build_fields_vocab( + fields, counters, opt.data_type, + opt.share_vocab, opt.vocab_size_multiple, + opt.src_vocab_size, opt.src_words_min_frequency, + opt.tgt_vocab_size, opt.tgt_words_min_frequency) + else: + fields = existing_fields + torch.save(fields, vocab_path) + + +def build_save_vocab(train_dataset, fields, opt): + fields = inputters.build_vocab( + train_dataset, fields, opt.data_type, opt.share_vocab, + opt.src_vocab, opt.src_vocab_size, opt.src_words_min_frequency, + opt.tgt_vocab, opt.tgt_vocab_size, opt.tgt_words_min_frequency, + vocab_size_multiple=opt.vocab_size_multiple + ) + vocab_path = opt.save_data + '.vocab.pt' + torch.save(fields, vocab_path) + + +def count_features(path): + """ + path: location of a corpus file with whitespace-delimited tokens and + │-delimited features within the token + returns: the number of features in the dataset + """ + with codecs.open(path, "r", "utf-8") as f: + first_tok = f.readline().split(None, 1)[0] + return len(first_tok.split(u"│")) - 1 + + +def preprocess(opt): + ArgumentParser.validate_preprocess_args(opt) + torch.manual_seed(opt.seed) + + init_logger(opt.log_file) + + logger.info("Extracting features...") + + src_nfeats = 0 + tgt_nfeats = 0 + for src, tgt in zip(opt.train_src, opt.train_tgt): + src_nfeats += count_features(src) if opt.data_type == 'text' \ + else 0 + tgt_nfeats += count_features(tgt) # tgt always text so far + logger.info(" * number of source features: %d." % src_nfeats) + logger.info(" * number of target features: %d." % tgt_nfeats) + + logger.info("Building `Fields` object...") + fields = inputters.get_fields( + opt.data_type, + src_nfeats, + tgt_nfeats, + dynamic_dict=opt.dynamic_dict, + src_truncate=opt.src_seq_length_trunc, + tgt_truncate=opt.tgt_seq_length_trunc) + + src_reader = inputters.str2reader[opt.data_type].from_opt(opt) + tgt_reader = inputters.str2reader["text"].from_opt(opt) + + logger.info("Building & saving training data...") + build_save_dataset( + 'train', fields, src_reader, tgt_reader, opt) + + if opt.valid_src and opt.valid_tgt: + logger.info("Building & saving validation data...") + build_save_dataset('valid', fields, src_reader, tgt_reader, opt) + + +def _get_parser(): + parser = ArgumentParser(description='preprocess.py') + + opts.config_opts(parser) + opts.preprocess_opts(parser) + return parser + + +def main(): + parser = _get_parser() + + opt = parser.parse_args() + preprocess(opt) + + +if __name__ == "__main__": + main() diff --git a/onmt/bin/server.py b/onmt/bin/server.py new file mode 100755 index 0000000000..44f8723f83 --- /dev/null +++ b/onmt/bin/server.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +import configargparse + +from flask import Flask, jsonify, request +from onmt.translate import TranslationServer, ServerModelError + +STATUS_OK = "ok" +STATUS_ERROR = "error" + + +def start(config_file, + url_root="./translator", + host="0.0.0.0", + port=5000, + debug=True): + def prefix_route(route_function, prefix='', mask='{0}{1}'): + def newroute(route, *args, **kwargs): + return route_function(mask.format(prefix, route), *args, **kwargs) + return newroute + + app = Flask(__name__) + app.route = prefix_route(app.route, url_root) + translation_server = TranslationServer() + translation_server.start(config_file) + + @app.route('/models', methods=['GET']) + def get_models(): + out = translation_server.list_models() + return jsonify(out) + + @app.route('/health', methods=['GET']) + def health(): + out = {} + out['status'] = STATUS_OK + return jsonify(out) + + @app.route('/clone_model/', methods=['POST']) + def clone_model(model_id): + out = {} + data = request.get_json(force=True) + timeout = -1 + if 'timeout' in data: + timeout = data['timeout'] + del data['timeout'] + + opt = data.get('opt', None) + try: + model_id, load_time = translation_server.clone_model( + model_id, opt, timeout) + except ServerModelError as e: + out['status'] = STATUS_ERROR + out['error'] = str(e) + else: + out['status'] = STATUS_OK + out['model_id'] = model_id + out['load_time'] = load_time + + return jsonify(out) + + @app.route('/unload_model/', methods=['GET']) + def unload_model(model_id): + out = {"model_id": model_id} + + try: + translation_server.unload_model(model_id) + out['status'] = STATUS_OK + except Exception as e: + out['status'] = STATUS_ERROR + out['error'] = str(e) + + return jsonify(out) + + @app.route('/translate', methods=['POST']) + def translate(): + inputs = request.get_json(force=True) + out = {} + try: + translation, scores, n_best, times = translation_server.run(inputs) + assert len(translation) == len(inputs) + assert len(scores) == len(inputs) + + out = [[{"src": inputs[i]['src'], "tgt": translation[i], + "n_best": n_best, + "pred_score": scores[i]} + for i in range(len(translation))]] + except ServerModelError as e: + out['error'] = str(e) + out['status'] = STATUS_ERROR + + return jsonify(out) + + @app.route('/to_cpu/', methods=['GET']) + def to_cpu(model_id): + out = {'model_id': model_id} + translation_server.models[model_id].to_cpu() + + out['status'] = STATUS_OK + return jsonify(out) + + @app.route('/to_gpu/', methods=['GET']) + def to_gpu(model_id): + out = {'model_id': model_id} + translation_server.models[model_id].to_gpu() + + out['status'] = STATUS_OK + return jsonify(out) + + app.run(debug=debug, host=host, port=port, use_reloader=False, + threaded=True) + + +def _get_parser(): + parser = configargparse.ArgumentParser( + config_file_parser_class=configargparse.YAMLConfigFileParser, + description="OpenNMT-py REST Server") + parser.add_argument("--ip", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default="5000") + parser.add_argument("--url_root", type=str, default="/translator") + parser.add_argument("--debug", "-d", action="store_true") + parser.add_argument("--config", "-c", type=str, + default="./available_models/conf.json") + return parser + + +def main(): + parser = _get_parser() + args = parser.parse_args() + start(args.config, url_root=args.url_root, host=args.ip, port=args.port, + debug=args.debug) + + +if __name__ == "__main__": + main() diff --git a/onmt/bin/train.py b/onmt/bin/train.py new file mode 100755 index 0000000000..4f0718e4a0 --- /dev/null +++ b/onmt/bin/train.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python +"""Train models.""" +import os +import signal +import torch + +import onmt.opts as opts +import onmt.utils.distributed + +from onmt.utils.misc import set_random_seed +from onmt.utils.logging import init_logger, logger +from onmt.train_single import main as single_main +from onmt.utils.parse import ArgumentParser +from onmt.inputters.inputter import build_dataset_iter, \ + load_old_vocab, old_style_vocab, build_dataset_iter_multiple + +from itertools import cycle + + +def train(opt): + ArgumentParser.validate_train_opts(opt) + ArgumentParser.update_model_opts(opt) + ArgumentParser.validate_model_opts(opt) + + # Load checkpoint if we resume from a previous training. + if opt.train_from: + logger.info('Loading checkpoint from %s' % opt.train_from) + checkpoint = torch.load(opt.train_from, + map_location=lambda storage, loc: storage) + logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) + vocab = checkpoint['vocab'] + else: + vocab = torch.load(opt.data + '.vocab.pt') + + # check for code where vocab is saved instead of fields + # (in the future this will be done in a smarter way) + if old_style_vocab(vocab): + fields = load_old_vocab( + vocab, opt.model_type, dynamic_dict=opt.copy_attn) + else: + fields = vocab + + if len(opt.data_ids) > 1: + train_shards = [] + for train_id in opt.data_ids: + shard_base = "train_" + train_id + train_shards.append(shard_base) + train_iter = build_dataset_iter_multiple(train_shards, fields, opt) + else: + if opt.data_ids[0] is not None: + shard_base = "train_" + opt.data_ids[0] + else: + shard_base = "train" + train_iter = build_dataset_iter(shard_base, fields, opt) + + nb_gpu = len(opt.gpu_ranks) + + if opt.world_size > 1: + queues = [] + mp = torch.multiprocessing.get_context('spawn') + semaphore = mp.Semaphore(opt.world_size * opt.queue_size) + # Create a thread to listen for errors in the child processes. + error_queue = mp.SimpleQueue() + error_handler = ErrorHandler(error_queue) + # Train with multiprocessing. + procs = [] + for device_id in range(nb_gpu): + q = mp.Queue(opt.queue_size) + queues += [q] + procs.append(mp.Process(target=run, args=( + opt, device_id, error_queue, q, semaphore), daemon=True)) + procs[device_id].start() + logger.info(" Starting process pid: %d " % procs[device_id].pid) + error_handler.add_child(procs[device_id].pid) + producer = mp.Process(target=batch_producer, + args=(train_iter, queues, semaphore, opt,), + daemon=True) + producer.start() + error_handler.add_child(producer.pid) + + for p in procs: + p.join() + producer.terminate() + + elif nb_gpu == 1: # case 1 GPU only + single_main(opt, 0) + else: # case only CPU + single_main(opt, -1) + + +def batch_producer(generator_to_serve, queues, semaphore, opt): + init_logger(opt.log_file) + set_random_seed(opt.seed, False) + # generator_to_serve = iter(generator_to_serve) + + def pred(x): + """ + Filters batches that belong only + to gpu_ranks of current node + """ + for rank in opt.gpu_ranks: + if x[0] % opt.world_size == rank: + return True + + generator_to_serve = filter( + pred, enumerate(generator_to_serve)) + + def next_batch(device_id): + new_batch = next(generator_to_serve) + semaphore.acquire() + return new_batch[1] + + b = next_batch(0) + + for device_id, q in cycle(enumerate(queues)): + b.dataset = None + if isinstance(b.src, tuple): + b.src = tuple([_.to(torch.device(device_id)) + for _ in b.src]) + else: + b.src = b.src.to(torch.device(device_id)) + b.tgt = b.tgt.to(torch.device(device_id)) + b.indices = b.indices.to(torch.device(device_id)) + b.alignment = b.alignment.to(torch.device(device_id)) \ + if hasattr(b, 'alignment') else None + b.src_map = b.src_map.to(torch.device(device_id)) \ + if hasattr(b, 'src_map') else None + + # hack to dodge unpicklable `dict_keys` + b.fields = list(b.fields) + q.put(b) + b = next_batch(device_id) + + +def run(opt, device_id, error_queue, batch_queue, semaphore): + """ run process """ + try: + gpu_rank = onmt.utils.distributed.multi_init(opt, device_id) + if gpu_rank != opt.gpu_ranks[device_id]: + raise AssertionError("An error occurred in \ + Distributed initialization") + single_main(opt, device_id, batch_queue, semaphore) + except KeyboardInterrupt: + pass # killed by parent, do nothing + except Exception: + # propagate exception to parent process, keeping original traceback + import traceback + error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc())) + + +class ErrorHandler(object): + """A class that listens for exceptions in children processes and propagates + the tracebacks to the parent process.""" + + def __init__(self, error_queue): + """ init error handler """ + import signal + import threading + self.error_queue = error_queue + self.children_pids = [] + self.error_thread = threading.Thread( + target=self.error_listener, daemon=True) + self.error_thread.start() + signal.signal(signal.SIGUSR1, self.signal_handler) + + def add_child(self, pid): + """ error handler """ + self.children_pids.append(pid) + + def error_listener(self): + """ error listener """ + (rank, original_trace) = self.error_queue.get() + self.error_queue.put((rank, original_trace)) + os.kill(os.getpid(), signal.SIGUSR1) + + def signal_handler(self, signalnum, stackframe): + """ signal handler """ + for pid in self.children_pids: + os.kill(pid, signal.SIGINT) # kill children processes + (rank, original_trace) = self.error_queue.get() + msg = """\n\n-- Tracebacks above this line can probably + be ignored --\n\n""" + msg += original_trace + raise Exception(msg) + + +def _get_parser(): + parser = ArgumentParser(description='train.py') + + opts.config_opts(parser) + opts.model_opts(parser) + opts.train_opts(parser) + return parser + + +def main(): + parser = _get_parser() + + opt = parser.parse_args() + train(opt) + + +if __name__ == "__main__": + main() diff --git a/onmt/bin/translate.py b/onmt/bin/translate.py new file mode 100755 index 0000000000..b0a7820ac6 --- /dev/null +++ b/onmt/bin/translate.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals +from itertools import repeat + +from onmt.utils.logging import init_logger +from onmt.utils.misc import split_corpus +from onmt.translate.translator import build_translator + +import onmt.opts as opts +from onmt.utils.parse import ArgumentParser + + +def translate(opt): + ArgumentParser.validate_translate_opts(opt) + logger = init_logger(opt.log_file) + + translator = build_translator(opt, report_score=True) + src_shards = split_corpus(opt.src, opt.shard_size) + tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ + if opt.tgt is not None else repeat(None) + shard_pairs = zip(src_shards, tgt_shards) + + for i, (src_shard, tgt_shard) in enumerate(shard_pairs): + logger.info("Translating shard %d." % i) + translator.translate( + src=src_shard, + tgt=tgt_shard, + src_dir=opt.src_dir, + batch_size=opt.batch_size, + batch_type=opt.batch_type, + attn_debug=opt.attn_debug + ) + + +def _get_parser(): + parser = ArgumentParser(description='translate.py') + + opts.config_opts(parser) + opts.translate_opts(parser) + return parser + + +def main(): + parser = _get_parser() + + opt = parser.parse_args() + translate(opt) + + +if __name__ == "__main__": + main() diff --git a/onmt/tests/test_preprocess.py b/onmt/tests/test_preprocess.py index ff63b1d758..18a64162f3 100644 --- a/onmt/tests/test_preprocess.py +++ b/onmt/tests/test_preprocess.py @@ -12,7 +12,7 @@ import onmt import onmt.inputters import onmt.opts -import preprocess +import onmt.bin.preprocess as preprocess parser = configargparse.ArgumentParser(description='preprocess.py') diff --git a/preprocess.py b/preprocess.py old mode 100755 new mode 100644 index e45bd1bbbf..c0c7742fa0 --- a/preprocess.py +++ b/preprocess.py @@ -1,283 +1,6 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -""" - Pre-process Data / features files and build vocabulary -""" -import codecs -import glob -import gc -import torch -from collections import Counter, defaultdict - -from onmt.utils.logging import init_logger, logger -from onmt.utils.misc import split_corpus -import onmt.inputters as inputters -import onmt.opts as opts -from onmt.utils.parse import ArgumentParser -from onmt.inputters.inputter import _build_fields_vocab,\ - _load_vocab - -from functools import partial -from multiprocessing import Pool - - -def check_existing_pt_files(opt, corpus_type, ids, existing_fields): - """ Check if there are existing .pt files to avoid overwriting them """ - existing_shards = [] - for maybe_id in ids: - if maybe_id: - shard_base = corpus_type + "_" + maybe_id - else: - shard_base = corpus_type - pattern = opt.save_data + '.{}.*.pt'.format(shard_base) - if glob.glob(pattern): - if opt.overwrite: - maybe_overwrite = ("will be overwritten because " - "`-overwrite` option is set.") - else: - maybe_overwrite = ("won't be overwritten, pass the " - "`-overwrite` option if you want to.") - logger.warning("Shards for corpus {} already exist, {}" - .format(shard_base, maybe_overwrite)) - existing_shards += [maybe_id] - return existing_shards - - -def process_one_shard(corpus_params, params): - corpus_type, fields, src_reader, tgt_reader, opt, existing_fields,\ - src_vocab, tgt_vocab = corpus_params - i, (src_shard, tgt_shard, maybe_id, filter_pred) = params - # create one counter per shard - sub_sub_counter = defaultdict(Counter) - assert len(src_shard) == len(tgt_shard) - logger.info("Building shard %d." % i) - dataset = inputters.Dataset( - fields, - readers=([src_reader, tgt_reader] - if tgt_reader else [src_reader]), - data=([("src", src_shard), ("tgt", tgt_shard)] - if tgt_reader else [("src", src_shard)]), - dirs=([opt.src_dir, None] - if tgt_reader else [opt.src_dir]), - sort_key=inputters.str2sortkey[opt.data_type], - filter_pred=filter_pred - ) - if corpus_type == "train" and existing_fields is None: - for ex in dataset.examples: - for name, field in fields.items(): - if ((opt.data_type == "audio") and (name == "src")): - continue - try: - f_iter = iter(field) - except TypeError: - f_iter = [(name, field)] - all_data = [getattr(ex, name, None)] - else: - all_data = getattr(ex, name) - for (sub_n, sub_f), fd in zip( - f_iter, all_data): - has_vocab = (sub_n == 'src' and - src_vocab is not None) or \ - (sub_n == 'tgt' and - tgt_vocab is not None) - if (hasattr(sub_f, 'sequential') - and sub_f.sequential and not has_vocab): - val = fd - sub_sub_counter[sub_n].update(val) - if maybe_id: - shard_base = corpus_type + "_" + maybe_id - else: - shard_base = corpus_type - data_path = "{:s}.{:s}.{:d}.pt".\ - format(opt.save_data, shard_base, i) - - logger.info(" * saving %sth %s data shard to %s." - % (i, shard_base, data_path)) - - dataset.save(data_path) - - del dataset.examples - gc.collect() - del dataset - gc.collect() - - return sub_sub_counter - - -def maybe_load_vocab(corpus_type, counters, opt): - src_vocab = None - tgt_vocab = None - existing_fields = None - if corpus_type == "train": - if opt.src_vocab != "": - try: - logger.info("Using existing vocabulary...") - existing_fields = torch.load(opt.src_vocab) - except torch.serialization.pickle.UnpicklingError: - logger.info("Building vocab from text file...") - src_vocab, src_vocab_size = _load_vocab( - opt.src_vocab, "src", counters, - opt.src_words_min_frequency) - if opt.tgt_vocab != "": - tgt_vocab, tgt_vocab_size = _load_vocab( - opt.tgt_vocab, "tgt", counters, - opt.tgt_words_min_frequency) - return src_vocab, tgt_vocab, existing_fields - - -def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt): - assert corpus_type in ['train', 'valid'] - - if corpus_type == 'train': - counters = defaultdict(Counter) - srcs = opt.train_src - tgts = opt.train_tgt - ids = opt.train_ids - elif corpus_type == 'valid': - counters = None - srcs = [opt.valid_src] - tgts = [opt.valid_tgt] - ids = [None] - - src_vocab, tgt_vocab, existing_fields = maybe_load_vocab( - corpus_type, counters, opt) - - existing_shards = check_existing_pt_files( - opt, corpus_type, ids, existing_fields) - - # every corpus has shards, no new one - if existing_shards == ids and not opt.overwrite: - return - - def shard_iterator(srcs, tgts, ids, existing_shards, - existing_fields, corpus_type, opt): - """ - Builds a single iterator yielding every shard of every corpus. - """ - for src, tgt, maybe_id in zip(srcs, tgts, ids): - if maybe_id in existing_shards: - if opt.overwrite: - logger.warning("Overwrite shards for corpus {}" - .format(maybe_id)) - else: - if corpus_type == "train": - assert existing_fields is not None,\ - ("A 'vocab.pt' file should be passed to " - "`-src_vocab` when adding a corpus to " - "a set of already existing shards.") - logger.warning("Ignore corpus {} because " - "shards already exist" - .format(maybe_id)) - continue - if ((corpus_type == "train" or opt.filter_valid) - and tgt is not None): - filter_pred = partial( - inputters.filter_example, - use_src_len=opt.data_type == "text", - max_src_len=opt.src_seq_length, - max_tgt_len=opt.tgt_seq_length) - else: - filter_pred = None - src_shards = split_corpus(src, opt.shard_size) - tgt_shards = split_corpus(tgt, opt.shard_size) - for i, (ss, ts) in enumerate(zip(src_shards, tgt_shards)): - yield (i, (ss, ts, maybe_id, filter_pred)) - - shard_iter = shard_iterator(srcs, tgts, ids, existing_shards, - existing_fields, corpus_type, opt) - - with Pool(opt.num_threads) as p: - dataset_params = (corpus_type, fields, src_reader, tgt_reader, - opt, existing_fields, src_vocab, tgt_vocab) - func = partial(process_one_shard, dataset_params) - for sub_counter in p.imap(func, shard_iter): - if sub_counter is not None: - for key, value in sub_counter.items(): - counters[key].update(value) - - if corpus_type == "train": - vocab_path = opt.save_data + '.vocab.pt' - if existing_fields is None: - fields = _build_fields_vocab( - fields, counters, opt.data_type, - opt.share_vocab, opt.vocab_size_multiple, - opt.src_vocab_size, opt.src_words_min_frequency, - opt.tgt_vocab_size, opt.tgt_words_min_frequency) - else: - fields = existing_fields - torch.save(fields, vocab_path) - - -def build_save_vocab(train_dataset, fields, opt): - fields = inputters.build_vocab( - train_dataset, fields, opt.data_type, opt.share_vocab, - opt.src_vocab, opt.src_vocab_size, opt.src_words_min_frequency, - opt.tgt_vocab, opt.tgt_vocab_size, opt.tgt_words_min_frequency, - vocab_size_multiple=opt.vocab_size_multiple - ) - vocab_path = opt.save_data + '.vocab.pt' - torch.save(fields, vocab_path) - - -def count_features(path): - """ - path: location of a corpus file with whitespace-delimited tokens and - │-delimited features within the token - returns: the number of features in the dataset - """ - with codecs.open(path, "r", "utf-8") as f: - first_tok = f.readline().split(None, 1)[0] - return len(first_tok.split(u"│")) - 1 - - -def main(opt): - ArgumentParser.validate_preprocess_args(opt) - torch.manual_seed(opt.seed) - - init_logger(opt.log_file) - - logger.info("Extracting features...") - - src_nfeats = 0 - tgt_nfeats = 0 - for src, tgt in zip(opt.train_src, opt.train_tgt): - src_nfeats += count_features(src) if opt.data_type == 'text' \ - else 0 - tgt_nfeats += count_features(tgt) # tgt always text so far - logger.info(" * number of source features: %d." % src_nfeats) - logger.info(" * number of target features: %d." % tgt_nfeats) - - logger.info("Building `Fields` object...") - fields = inputters.get_fields( - opt.data_type, - src_nfeats, - tgt_nfeats, - dynamic_dict=opt.dynamic_dict, - src_truncate=opt.src_seq_length_trunc, - tgt_truncate=opt.tgt_seq_length_trunc) - - src_reader = inputters.str2reader[opt.data_type].from_opt(opt) - tgt_reader = inputters.str2reader["text"].from_opt(opt) - - logger.info("Building & saving training data...") - build_save_dataset( - 'train', fields, src_reader, tgt_reader, opt) - - if opt.valid_src and opt.valid_tgt: - logger.info("Building & saving validation data...") - build_save_dataset('valid', fields, src_reader, tgt_reader, opt) - - -def _get_parser(): - parser = ArgumentParser(description='preprocess.py') - - opts.config_opts(parser) - opts.preprocess_opts(parser) - return parser +from onmt.bin.preprocess import main if __name__ == "__main__": - parser = _get_parser() - - opt = parser.parse_args() - main(opt) + main() diff --git a/requirements.opt.txt b/requirements.opt.txt index c3e52ccb84..fdbd2d1ee3 100644 --- a/requirements.opt.txt +++ b/requirements.opt.txt @@ -5,9 +5,6 @@ librosa Pillow git+git://github.com/pytorch/audio.git@d92de5b97fc6204db4b1e3ed20c03ac06f5d53f0 pyrouge -pyonmttok opencv-python git+https://github.com/NVIDIA/apex -flask -tensorboard>=1.14 pretrainedmodels diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 8b835b60b6..0000000000 --- a/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -six -tqdm==4.30.* -torch>=1.2 -git+https://github.com/pytorch/text.git@master#wheel=torchtext -future -configargparse diff --git a/server.py b/server.py old mode 100755 new mode 100644 index 7e852e9c24..2e078ba6ee --- a/server.py +++ b/server.py @@ -1,129 +1,6 @@ #!/usr/bin/env python -import configargparse +from onmt.bin.server import main -from flask import Flask, jsonify, request -from onmt.translate import TranslationServer, ServerModelError -STATUS_OK = "ok" -STATUS_ERROR = "error" - - -def start(config_file, - url_root="./translator", - host="0.0.0.0", - port=5000, - debug=True): - def prefix_route(route_function, prefix='', mask='{0}{1}'): - def newroute(route, *args, **kwargs): - return route_function(mask.format(prefix, route), *args, **kwargs) - return newroute - - app = Flask(__name__) - app.route = prefix_route(app.route, url_root) - translation_server = TranslationServer() - translation_server.start(config_file) - - @app.route('/models', methods=['GET']) - def get_models(): - out = translation_server.list_models() - return jsonify(out) - - @app.route('/health', methods=['GET']) - def health(): - out = {} - out['status'] = STATUS_OK - return jsonify(out) - - @app.route('/clone_model/', methods=['POST']) - def clone_model(model_id): - out = {} - data = request.get_json(force=True) - timeout = -1 - if 'timeout' in data: - timeout = data['timeout'] - del data['timeout'] - - opt = data.get('opt', None) - try: - model_id, load_time = translation_server.clone_model( - model_id, opt, timeout) - except ServerModelError as e: - out['status'] = STATUS_ERROR - out['error'] = str(e) - else: - out['status'] = STATUS_OK - out['model_id'] = model_id - out['load_time'] = load_time - - return jsonify(out) - - @app.route('/unload_model/', methods=['GET']) - def unload_model(model_id): - out = {"model_id": model_id} - - try: - translation_server.unload_model(model_id) - out['status'] = STATUS_OK - except Exception as e: - out['status'] = STATUS_ERROR - out['error'] = str(e) - - return jsonify(out) - - @app.route('/translate', methods=['POST']) - def translate(): - inputs = request.get_json(force=True) - out = {} - try: - translation, scores, n_best, times = translation_server.run(inputs) - assert len(translation) == len(inputs) - assert len(scores) == len(inputs) - - out = [[{"src": inputs[i]['src'], "tgt": translation[i], - "n_best": n_best, - "pred_score": scores[i]} - for i in range(len(translation))]] - except ServerModelError as e: - out['error'] = str(e) - out['status'] = STATUS_ERROR - - return jsonify(out) - - @app.route('/to_cpu/', methods=['GET']) - def to_cpu(model_id): - out = {'model_id': model_id} - translation_server.models[model_id].to_cpu() - - out['status'] = STATUS_OK - return jsonify(out) - - @app.route('/to_gpu/', methods=['GET']) - def to_gpu(model_id): - out = {'model_id': model_id} - translation_server.models[model_id].to_gpu() - - out['status'] = STATUS_OK - return jsonify(out) - - app.run(debug=debug, host=host, port=port, use_reloader=False, - threaded=True) - - -def _get_parser(): - parser = configargparse.ArgumentParser( - config_file_parser_class=configargparse.YAMLConfigFileParser, - description="OpenNMT-py REST Server") - parser.add_argument("--ip", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default="5000") - parser.add_argument("--url_root", type=str, default="/translator") - parser.add_argument("--debug", "-d", action="store_true") - parser.add_argument("--config", "-c", type=str, - default="./available_models/conf.json") - return parser - - -if __name__ == '__main__': - parser = _get_parser() - args = parser.parse_args() - start(args.config, url_root=args.url_root, host=args.ip, port=args.port, - debug=args.debug) +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index d643c5e549..bb9ee57191 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,29 @@ #!/usr/bin/env python +from setuptools import setup, find_packages -from setuptools import setup -setup(name='OpenNMT-py', - description='A python implementation of OpenNMT', - version='0.9.2', - - packages=['onmt', 'onmt.encoders', 'onmt.modules', 'onmt.tests', - 'onmt.translate', 'onmt.decoders', 'onmt.inputters', - 'onmt.models', 'onmt.utils']) +setup( + name='OpenNMT-py', + description='A python implementation of OpenNMT', + version='0.9.2', + packages=find_packages(), + install_requires=[ + "six", + "tqdm~=4.30.0", + "torch>=1.1", + "torchtext==0.4.0", + "future", + "configargparse", + "tensorboard>=1.14", + "flask", + "pyonmttok", + ], + entry_points={ + "console_scripts": [ + "onmt_server=onmt.bin.server:main", + "onmt_train=onmt.bin.train:main", + "onmt_translate=onmt.bin.translate:main", + "onmt_preprocess=onmt.bin.preprocess:main", + ], + } +) diff --git a/train.py b/train.py old mode 100755 new mode 100644 index d00f161a91..1b03c9bcbc --- a/train.py +++ b/train.py @@ -1,200 +1,6 @@ #!/usr/bin/env python -"""Train models.""" -import os -import signal -import torch - -import onmt.opts as opts -import onmt.utils.distributed - -from onmt.utils.misc import set_random_seed -from onmt.utils.logging import init_logger, logger -from onmt.train_single import main as single_main -from onmt.utils.parse import ArgumentParser -from onmt.inputters.inputter import build_dataset_iter, \ - load_old_vocab, old_style_vocab, build_dataset_iter_multiple - -from itertools import cycle - - -def main(opt): - ArgumentParser.validate_train_opts(opt) - ArgumentParser.update_model_opts(opt) - ArgumentParser.validate_model_opts(opt) - - # Load checkpoint if we resume from a previous training. - if opt.train_from: - logger.info('Loading checkpoint from %s' % opt.train_from) - checkpoint = torch.load(opt.train_from, - map_location=lambda storage, loc: storage) - logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) - vocab = checkpoint['vocab'] - else: - vocab = torch.load(opt.data + '.vocab.pt') - - # check for code where vocab is saved instead of fields - # (in the future this will be done in a smarter way) - if old_style_vocab(vocab): - fields = load_old_vocab( - vocab, opt.model_type, dynamic_dict=opt.copy_attn) - else: - fields = vocab - - if len(opt.data_ids) > 1: - train_shards = [] - for train_id in opt.data_ids: - shard_base = "train_" + train_id - train_shards.append(shard_base) - train_iter = build_dataset_iter_multiple(train_shards, fields, opt) - else: - if opt.data_ids[0] is not None: - shard_base = "train_" + opt.data_ids[0] - else: - shard_base = "train" - train_iter = build_dataset_iter(shard_base, fields, opt) - - nb_gpu = len(opt.gpu_ranks) - - if opt.world_size > 1: - queues = [] - mp = torch.multiprocessing.get_context('spawn') - semaphore = mp.Semaphore(opt.world_size * opt.queue_size) - # Create a thread to listen for errors in the child processes. - error_queue = mp.SimpleQueue() - error_handler = ErrorHandler(error_queue) - # Train with multiprocessing. - procs = [] - for device_id in range(nb_gpu): - q = mp.Queue(opt.queue_size) - queues += [q] - procs.append(mp.Process(target=run, args=( - opt, device_id, error_queue, q, semaphore), daemon=True)) - procs[device_id].start() - logger.info(" Starting process pid: %d " % procs[device_id].pid) - error_handler.add_child(procs[device_id].pid) - producer = mp.Process(target=batch_producer, - args=(train_iter, queues, semaphore, opt,), - daemon=True) - producer.start() - error_handler.add_child(producer.pid) - - for p in procs: - p.join() - producer.terminate() - - elif nb_gpu == 1: # case 1 GPU only - single_main(opt, 0) - else: # case only CPU - single_main(opt, -1) - - -def batch_producer(generator_to_serve, queues, semaphore, opt): - init_logger(opt.log_file) - set_random_seed(opt.seed, False) - # generator_to_serve = iter(generator_to_serve) - - def pred(x): - """ - Filters batches that belong only - to gpu_ranks of current node - """ - for rank in opt.gpu_ranks: - if x[0] % opt.world_size == rank: - return True - - generator_to_serve = filter( - pred, enumerate(generator_to_serve)) - - def next_batch(device_id): - new_batch = next(generator_to_serve) - semaphore.acquire() - return new_batch[1] - - b = next_batch(0) - - for device_id, q in cycle(enumerate(queues)): - b.dataset = None - if isinstance(b.src, tuple): - b.src = tuple([_.to(torch.device(device_id)) - for _ in b.src]) - else: - b.src = b.src.to(torch.device(device_id)) - b.tgt = b.tgt.to(torch.device(device_id)) - b.indices = b.indices.to(torch.device(device_id)) - b.alignment = b.alignment.to(torch.device(device_id)) \ - if hasattr(b, 'alignment') else None - b.src_map = b.src_map.to(torch.device(device_id)) \ - if hasattr(b, 'src_map') else None - - # hack to dodge unpicklable `dict_keys` - b.fields = list(b.fields) - q.put(b) - b = next_batch(device_id) - - -def run(opt, device_id, error_queue, batch_queue, semaphore): - """ run process """ - try: - gpu_rank = onmt.utils.distributed.multi_init(opt, device_id) - if gpu_rank != opt.gpu_ranks[device_id]: - raise AssertionError("An error occurred in \ - Distributed initialization") - single_main(opt, device_id, batch_queue, semaphore) - except KeyboardInterrupt: - pass # killed by parent, do nothing - except Exception: - # propagate exception to parent process, keeping original traceback - import traceback - error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc())) - - -class ErrorHandler(object): - """A class that listens for exceptions in children processes and propagates - the tracebacks to the parent process.""" - - def __init__(self, error_queue): - """ init error handler """ - import signal - import threading - self.error_queue = error_queue - self.children_pids = [] - self.error_thread = threading.Thread( - target=self.error_listener, daemon=True) - self.error_thread.start() - signal.signal(signal.SIGUSR1, self.signal_handler) - - def add_child(self, pid): - """ error handler """ - self.children_pids.append(pid) - - def error_listener(self): - """ error listener """ - (rank, original_trace) = self.error_queue.get() - self.error_queue.put((rank, original_trace)) - os.kill(os.getpid(), signal.SIGUSR1) - - def signal_handler(self, signalnum, stackframe): - """ signal handler """ - for pid in self.children_pids: - os.kill(pid, signal.SIGINT) # kill children processes - (rank, original_trace) = self.error_queue.get() - msg = """\n\n-- Tracebacks above this line can probably - be ignored --\n\n""" - msg += original_trace - raise Exception(msg) - - -def _get_parser(): - parser = ArgumentParser(description='train.py') - - opts.config_opts(parser) - opts.model_opts(parser) - opts.train_opts(parser) - return parser +from onmt.bin.train import main if __name__ == "__main__": - parser = _get_parser() - - opt = parser.parse_args() - main(opt) + main() diff --git a/translate.py b/translate.py old mode 100755 new mode 100644 index 9270359083..5ca91336be --- a/translate.py +++ b/translate.py @@ -1,49 +1,6 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- - -from __future__ import unicode_literals -from itertools import repeat - -from onmt.utils.logging import init_logger -from onmt.utils.misc import split_corpus -from onmt.translate.translator import build_translator - -import onmt.opts as opts -from onmt.utils.parse import ArgumentParser - - -def main(opt): - ArgumentParser.validate_translate_opts(opt) - logger = init_logger(opt.log_file) - - translator = build_translator(opt, report_score=True) - src_shards = split_corpus(opt.src, opt.shard_size) - tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ - if opt.tgt is not None else repeat(None) - shard_pairs = zip(src_shards, tgt_shards) - - for i, (src_shard, tgt_shard) in enumerate(shard_pairs): - logger.info("Translating shard %d." % i) - translator.translate( - src=src_shard, - tgt=tgt_shard, - src_dir=opt.src_dir, - batch_size=opt.batch_size, - batch_type=opt.batch_type, - attn_debug=opt.attn_debug - ) - - -def _get_parser(): - parser = ArgumentParser(description='translate.py') - - opts.config_opts(parser) - opts.translate_opts(parser) - return parser +from onmt.bin.translate import main if __name__ == "__main__": - parser = _get_parser() - - opt = parser.parse_args() - main(opt) + main()