From 76ca20e255e361eb5de7b213f93eb44914f41888 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Tue, 24 Aug 2021 13:19:04 +0200 Subject: [PATCH 01/23] Source features support - Initial commit --- onmt/bin/build_vocab.py | 7 +++- onmt/inputters/corpus.py | 61 ++++++++++++++++++++++-------- onmt/inputters/fields.py | 12 ++++-- onmt/inputters/text_dataset.py | 45 +++++++++++++++------- onmt/opts.py | 5 +++ onmt/train_single.py | 3 ++ onmt/transforms/misc.py | 69 ++++++++++++++++++++++++++++++++++ onmt/utils/parse.py | 23 ++++++++++++ 8 files changed, 191 insertions(+), 34 deletions(-) diff --git a/onmt/bin/build_vocab.py b/onmt/bin/build_vocab.py index e106d92180..f922e38842 100644 --- a/onmt/bin/build_vocab.py +++ b/onmt/bin/build_vocab.py @@ -32,11 +32,13 @@ def build_vocab_main(opts): transforms = make_transforms(opts, transforms_cls, fields) logger.info(f"Counter vocab from {opts.n_sample} samples.") - src_counter, tgt_counter = build_vocab( + src_counter, tgt_counter, src_feats_counter = build_vocab( opts, transforms, n_sample=opts.n_sample) logger.info(f"Counters src:{len(src_counter)}") logger.info(f"Counters tgt:{len(tgt_counter)}") + for k, v in src_feats_counter["src_feats"].items(): + logger.info(f"Counters {k}:{len(v)}") def save_counter(counter, save_path): check_path(save_path, exist_ok=opts.overwrite, log=logger.warning) @@ -52,6 +54,9 @@ def save_counter(counter, save_path): else: save_counter(src_counter, opts.src_vocab) save_counter(tgt_counter, opts.tgt_vocab) + + for k, v in src_feats_counter["src_feats"].items(): + save_counter(v, opts.src_feats_vocab[k]) def _get_parser(): diff --git a/onmt/inputters/corpus.py b/onmt/inputters/corpus.py index c8a559f9f8..1810d5298f 100644 --- a/onmt/inputters/corpus.py +++ b/onmt/inputters/corpus.py @@ -7,7 +7,7 @@ from torchtext.data import Dataset as TorchtextDataset, \ Example as TorchtextExample -from collections import Counter +from collections import Counter, defaultdict from contextlib import contextmanager import multiprocessing as mp @@ -74,6 +74,9 @@ def _process(item, is_train): maybe_example['tgt'] = ' '.join(maybe_example['tgt']) if 'align' in maybe_example: maybe_example['align'] = ' '.join(maybe_example['align']) + if 'src_feats' in maybe_example: + for k in maybe_example['src_feats'].keys(): + maybe_example['src_feats'][k] = ' '.join(maybe_example['src_feats'][k]) return maybe_example def _maybe_add_dynamic_dict(self, example, fields): @@ -107,12 +110,13 @@ def __call__(self, bucket): class ParallelCorpus(object): """A parallel corpus file pair that can be loaded to iterate.""" - def __init__(self, name, src, tgt, align=None): + def __init__(self, name, src, tgt, align=None, src_feats=None): """Initialize src & tgt side file path.""" self.id = name self.src = src self.tgt = tgt self.align = align + self.src_feats = src_feats def load(self, offset=0, stride=1): """ @@ -120,10 +124,16 @@ def load(self, offset=0, stride=1): `offset` and `stride` allow to iterate only on every `stride` example, starting from `offset`. """ + #import pdb + #pdb.set_trace() + if self.src_feats: + features_files = [open(feat_path, mode='rb') for feat_name, feat_path in self.src_feats.items()] + else: + features_files = [] with exfile_open(self.src, mode='rb') as fs,\ exfile_open(self.tgt, mode='rb') as ft,\ exfile_open(self.align, mode='rb') as fa: - for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)): + for i, (sline, tline, align, *features) in enumerate(zip(fs, ft, fa, *features_files)): if (i % stride) == offset: sline = sline.decode('utf-8') tline = tline.decode('utf-8') @@ -133,12 +143,18 @@ def load(self, offset=0, stride=1): } if align is not None: example['align'] = align.decode('utf-8') + if features: + example["src_feats"] = dict() + for j, feat in enumerate(features): + example["src_feats"][list(self.src_feats.keys())[j]] = feat.decode("utf-8") yield example + for f in features_files: + f.close() def __str__(self): cls_name = type(self).__name__ - return '{}({}, {}, align={})'.format( - cls_name, self.src, self.tgt, self.align) + return '{}({}, {}, align={}, src_feats={})'.format( + cls_name, self.src, self.tgt, self.align, self.src_feats) def get_corpora(opts, is_train=False): @@ -150,7 +166,8 @@ def get_corpora(opts, is_train=False): corpus_id, corpus_dict["path_src"], corpus_dict["path_tgt"], - corpus_dict["path_align"]) + corpus_dict["path_align"], + corpus_dict["src_feats"]) else: if CorpusName.VALID in opts.data.keys(): corpora_dict[CorpusName.VALID] = ParallelCorpus( @@ -193,6 +210,9 @@ def _tokenize(self, stream): example['src'], example['tgt'] = src, tgt if 'align' in example: example['align'] = example['align'].strip('\n').split() + if 'src_feats' in example: + for k in example['src_feats'].keys(): + example['src_feats'][k] = example['src_feats'][k].strip('\n').split() yield example def _transform(self, stream): @@ -284,8 +304,11 @@ def write_files_from_queues(sample_path, queues): def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): """Build vocab on (strided) subpart of the data.""" + #import pdb + #pdb.set_trace() sub_counter_src = Counter() sub_counter_tgt = Counter() + sub_counter_src_feats = {'src_feats': defaultdict(Counter)} datasets_iterables = build_corpora_iters( corpora, transforms, opts.data, skip_empty_level=opts.skip_empty_level, @@ -298,6 +321,9 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): build_sub_vocab.queues[c_name][offset].put("blank") continue src_line, tgt_line = maybe_example['src'], maybe_example['tgt'] + if 'src_feats' in maybe_example: + for feat_name, feat_line in maybe_example["src_feats"].items(): + sub_counter_src_feats['src_feats'][feat_name].update(feat_line.split(' ')) sub_counter_src.update(src_line.split(' ')) sub_counter_tgt.update(tgt_line.split(' ')) if opts.dump_samples: @@ -309,7 +335,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): break if opts.dump_samples: build_sub_vocab.queues[c_name][offset].put("break") - return sub_counter_src, sub_counter_tgt + return sub_counter_src, sub_counter_tgt, sub_counter_src_feats def init_pool(queues): @@ -333,6 +359,7 @@ def build_vocab(opts, transforms, n_sample=3): corpora = get_corpora(opts, is_train=True) counter_src = Counter() counter_tgt = Counter() + counter_src_feats = {'src_feats': defaultdict(Counter)} from functools import partial queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size) for i in range(opts.num_threads)] @@ -345,17 +372,19 @@ def build_vocab(opts, transforms, n_sample=3): args=(sample_path, queues), daemon=True) write_process.start() - with mp.Pool(opts.num_threads, init_pool, [queues]) as p: - func = partial( - build_sub_vocab, corpora, transforms, - opts, n_sample, opts.num_threads) - for sub_counter_src, sub_counter_tgt in p.imap( - func, range(0, opts.num_threads)): - counter_src.update(sub_counter_src) - counter_tgt.update(sub_counter_tgt) + #with mp.Pool(opts.num_threads, init_pool, [queues]) as p: + func = partial( + build_sub_vocab, corpora, transforms, + opts, n_sample, opts.num_threads) + sub_counter_src, sub_counter_tgt, sub_counter_src_feats = func(0) + # for sub_counter_src, sub_counter_tgt in p.imap( + # func, range(0, opts.num_threads)): + counter_src.update(sub_counter_src) + counter_tgt.update(sub_counter_tgt) + counter_src_feats.update(sub_counter_src_feats) if opts.dump_samples: write_process.join() - return counter_src, counter_tgt + return counter_src, counter_tgt, counter_src_feats def save_transformed_sample(opts, transforms, n_sample=3): diff --git a/onmt/inputters/fields.py b/onmt/inputters/fields.py index 50c4e6c17f..da53071706 100644 --- a/onmt/inputters/fields.py +++ b/onmt/inputters/fields.py @@ -9,10 +9,10 @@ def _get_dynamic_fields(opts): # NOTE: not support nfeats > 0 yet - src_nfeats = 0 - tgt_nfeats = 0 + #src_nfeats = 0 + tgt_nfeats = None #0 with_align = hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0 - fields = get_fields('text', src_nfeats, tgt_nfeats, + fields = get_fields('text', opts.src_feats_vocab, tgt_nfeats, dynamic_dict=opts.copy_attn, src_truncate=opts.src_seq_length_trunc, tgt_truncate=opts.tgt_seq_length_trunc, @@ -33,6 +33,12 @@ def build_dynamic_fields(opts, src_specials=None, tgt_specials=None): opts.src_vocab, 'src', counters, min_freq=opts.src_words_min_frequency) + if opts.src_feats_vocab: + for feat_name, filepath in opts.src_feats_vocab.items(): + _, _ = _load_vocab( + filepath, feat_name, counters, + min_freq=0) + if opts.tgt_vocab: _tgt_vocab, _tgt_vocab_size = _load_vocab( opts.tgt_vocab, 'tgt', counters, diff --git a/onmt/inputters/text_dataset.py b/onmt/inputters/text_dataset.py index a0621f6407..bcb3794979 100644 --- a/onmt/inputters/text_dataset.py +++ b/onmt/inputters/text_dataset.py @@ -171,20 +171,37 @@ def text_fields(**kwargs): eos = kwargs.get("eos", DefaultTokens.EOS) truncate = kwargs.get("truncate", None) fields_ = [] - feat_delim = u"│" if n_feats > 0 else None - for i in range(n_feats + 1): - name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name - tokenize = partial( - _feature_tokenize, - layer=i, - truncate=truncate, - feat_delim=feat_delim) - use_len = i == 0 and include_lengths - feat = Field( - init_token=bos, eos_token=eos, - pad_token=pad, tokenize=tokenize, - include_lengths=use_len) - fields_.append((name, feat)) + + feat_delim = None #u"│" if n_feats > 0 else None + + # Base field + tokenize = partial( + _feature_tokenize, + layer=None, + truncate=truncate, + feat_delim=feat_delim) + feat = Field( + init_token=bos, eos_token=eos, + pad_token=pad, tokenize=tokenize, + include_lengths=include_lengths) + fields_.append((base_name, feat)) + + # Feats fields + #for i in range(n_feats + 1): + if n_feats: + for feat_name in n_feats.keys(): + #name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name + tokenize = partial( + _feature_tokenize, + layer=None, + truncate=truncate, + feat_delim=feat_delim) + feat = Field( + init_token=bos, eos_token=eos, + pad_token=pad, tokenize=tokenize, + include_lengths=False) + fields_.append((feat_name, feat)) + assert fields_[0][0] == base_name # sanity check field = TextMultiField(fields_[0][0], fields_[0][1], fields_[1:]) return field diff --git a/onmt/opts.py b/onmt/opts.py index ec66f14e95..e5db2947f0 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -132,6 +132,11 @@ def _add_dynamic_fields_opts(parser, build_vocab_only=False): group.add("-share_vocab", "--share_vocab", action="store_true", help="Share source and target vocabulary.") + group.add("-src_feats_vocab", "--src_feats_vocab", + help=("List of paths to save" if build_vocab_only else "List of paths to") + + " src features vocabulary files. " + "Files format: one or \t per line.") + if not build_vocab_only: group.add("-src_vocab_size", "--src_vocab_size", type=int, default=50000, diff --git a/onmt/train_single.py b/onmt/train_single.py index 925c472119..0a4b153c8e 100644 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -58,6 +58,9 @@ def main(opt, fields, transforms_cls, checkpoint, device_id, """Start training on `device_id`.""" # NOTE: It's important that ``opt`` has been validated and updated # at this point. + + #import pdb + #pdb.set_trace() configure_process(opt, device_id) init_logger(opt.log_file) diff --git a/onmt/transforms/misc.py b/onmt/transforms/misc.py index 63df918d0b..8cc9cf0004 100644 --- a/onmt/transforms/misc.py +++ b/onmt/transforms/misc.py @@ -1,6 +1,7 @@ from onmt.utils.logging import logger from onmt.transforms import register_transform from .transform import Transform, ObservableStats +import re class FilterTooLongStats(ObservableStats): @@ -52,6 +53,74 @@ def _repr_args(self): ) +@register_transform(name='inferfeats') +class InferFeatsTransform(Transform): + """Filter out sentence that are too long.""" + + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def add_options(cls, parser): + """Avalilable options relate to this Transform.""" + #group = parser.add_argument_group("Transform/Filter") + #group.add("--src_seq_length", "-src_seq_length", type=int, default=200, + # help="Maximum source sequence length.") + #group.add("--tgt_seq_length", "-tgt_seq_length", type=int, default=200, + # help="Maximum target sequence length.") + pass + + def _parse_opts(self): + pass + #self.src_seq_length = self.opts.src_seq_length + #self.tgt_seq_length = self.opts.tgt_seq_length + + def apply(self, example, is_train=False, stats=None, **kwargs): + """Return None if too long else return as is.""" + + if "src_feats" not in example: + # Do nothing + return example + + inferred_feats = [] + feats_i = 0 + n_feats = len(example["src_feats"]) + inferred_feats = dict() + #import pdb + #pdb.set_trace() + for subword in example["src"]: + none = True + for k, v in example["src_feats"].items(): + # TODO: what about custom placeholders?? + if re.match(r'⦅\w+⦆', subword): + inferred_feat = "N" + elif not re.sub(r'(\W)+', '', subword).strip(): + inferred_feat = "N" + else: + inferred_feat = v[feats_i] + none = False + + if k in inferred_feats: + inferred_feats[k].append(inferred_feat) + else: + inferred_feats[k] = [inferred_feat] + if subword.find('■') < 0 and not none: + feats_i += 1 + #import pdb + #pdb.set_trace() + for k, v in inferred_feats.items(): + example["src_feats"][k] = inferred_feats[k] + return example + + def _repr_args(self): + """Return str represent key arguments for class.""" + #return '{}={}, {}={}'.format( + # 'src_seq_length', self.src_seq_length, + # 'tgt_seq_length', self.tgt_seq_length + #) + return "INFERFEATS" + + @register_transform(name='prefix') class PrefixTransform(Transform): """Add Prefix to src (& tgt) sentence.""" diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 2f4f1e1c45..ddaa9f6e60 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -75,6 +75,17 @@ def _validate_data(cls, opt): logger.warning(f"Corpus {cname}'s weight should be given." " We default it to 1 for you.") corpus['weight'] = 1 + + # Check features + src_feats = corpus.get("src_feats", None) + if src_feats is not None: + for feature_name, feature_file in src_feats.items(): + cls._validate_file(feature_file, info=f'{cname}/path_{feature_name}') + if 'inferfeats' not in corpus["transforms"]: + raise ValueError(f"'inferfeats' transform is required when setting source features") + else: + corpus["src_feats"] = None + logger.info(f"Parsed {len(corpora)} corpora from -data.") opt.data = corpora @@ -107,6 +118,18 @@ def _get_all_transform(cls, opt): @classmethod def _validate_fields_opts(cls, opt, build_vocab_only=False): """Check options relate to vocab and fields.""" + + for cname, corpus in opt.data.items(): + if corpus["src_feats"] is not None: + assert opt.src_feats_vocab, \ + "-src_feats_vocab is required if using source features." + import yaml + opt.src_feats_vocab = yaml.safe_load(opt.src_feats_vocab) + + for feature in corpus["src_feats"].keys(): + assert feature in opt.src_feats_vocab, \ + f"No vocab file set for feature {feature}" + if build_vocab_only: if not opt.share_vocab: assert opt.tgt_vocab, \ From a8190ab91ea998ab400baad4e60f7c2be0b9309d Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Tue, 24 Aug 2021 14:49:06 +0200 Subject: [PATCH 02/23] Improved features transforms --- onmt/transforms/features.py | 101 ++++++++++++++++++++++++++++++++++++ onmt/transforms/misc.py | 69 ------------------------ onmt/utils/parse.py | 2 + 3 files changed, 103 insertions(+), 69 deletions(-) create mode 100644 onmt/transforms/features.py diff --git a/onmt/transforms/features.py b/onmt/transforms/features.py new file mode 100644 index 0000000000..fce01c743f --- /dev/null +++ b/onmt/transforms/features.py @@ -0,0 +1,101 @@ +from onmt.utils.logging import logger +from onmt.transforms import register_transform +from .transform import Transform, ObservableStats +import re +from collections import defaultdict + + +@register_transform(name='filterfeats') +class FilterFeatsTransform(Transform): + """Filter out examples with a mismatch between source and features.""" + + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def add_options(cls, parser): + pass + + def _parse_opts(self): + pass + + def apply(self, example, is_train=False, stats=None, **kwargs): + """Return None if mismatch""" + + if 'src_feats' not in example: + # Do nothing + return example + + for feat_name, feat_values in example['src_feats'].items(): + if len(example['src']) != len(feat_values): + logger.warning(f"Skipping example due to mismatch between source and feature {feat_name}") + return None + return example + + def _repr_args(self): + return '' + + +@register_transform(name='inferfeats') +class InferFeatsTransform(Transform): + """Infer features for subword tokenization.""" + + def __init__(self, opts): + super().__init__(opts) + + @classmethod + def add_options(cls, parser): + pass + + def _parse_opts(self): + pass + + def apply(self, example, is_train=False, stats=None, **kwargs): + + if "src_feats" not in example: + # Do nothing + return example + + import pdb + pdb.set_trace() + + feats_i = 0 + inferred_feats = defaultdict(list) + for subword in example["src"]: + next_ = False + for k, v in example["src_feats"].items(): + # TODO: what about custom placeholders?? + + # Placeholders + if re.match(r'⦅\w+⦆', subword): + inferred_feat = "N" + + # Punctuation only + elif not re.sub(r'(\W)+', '', subword).strip(): + inferred_feat = "N" + + # Joiner annotate + elif re.search("■", subword): + inferred_feat = v[feats_i] + + # Whole word + else: + inferred_feat = v[feats_i] + next_ = True + + inferred_feats[k].append(inferred_feat) + + if next_: + feats_i += 1 + + # Check all features have been consumed + for k, v in example["src_feats"].items(): + assert feats_i == len(v), f'Not all features consumed for {k}' + + for k, v in inferred_feats.items(): + example["src_feats"][k] = inferred_feats[k] + pdb.set_trace() + return example + + def _repr_args(self): + return '' \ No newline at end of file diff --git a/onmt/transforms/misc.py b/onmt/transforms/misc.py index 8cc9cf0004..63df918d0b 100644 --- a/onmt/transforms/misc.py +++ b/onmt/transforms/misc.py @@ -1,7 +1,6 @@ from onmt.utils.logging import logger from onmt.transforms import register_transform from .transform import Transform, ObservableStats -import re class FilterTooLongStats(ObservableStats): @@ -53,74 +52,6 @@ def _repr_args(self): ) -@register_transform(name='inferfeats') -class InferFeatsTransform(Transform): - """Filter out sentence that are too long.""" - - def __init__(self, opts): - super().__init__(opts) - - @classmethod - def add_options(cls, parser): - """Avalilable options relate to this Transform.""" - #group = parser.add_argument_group("Transform/Filter") - #group.add("--src_seq_length", "-src_seq_length", type=int, default=200, - # help="Maximum source sequence length.") - #group.add("--tgt_seq_length", "-tgt_seq_length", type=int, default=200, - # help="Maximum target sequence length.") - pass - - def _parse_opts(self): - pass - #self.src_seq_length = self.opts.src_seq_length - #self.tgt_seq_length = self.opts.tgt_seq_length - - def apply(self, example, is_train=False, stats=None, **kwargs): - """Return None if too long else return as is.""" - - if "src_feats" not in example: - # Do nothing - return example - - inferred_feats = [] - feats_i = 0 - n_feats = len(example["src_feats"]) - inferred_feats = dict() - #import pdb - #pdb.set_trace() - for subword in example["src"]: - none = True - for k, v in example["src_feats"].items(): - # TODO: what about custom placeholders?? - if re.match(r'⦅\w+⦆', subword): - inferred_feat = "N" - elif not re.sub(r'(\W)+', '', subword).strip(): - inferred_feat = "N" - else: - inferred_feat = v[feats_i] - none = False - - if k in inferred_feats: - inferred_feats[k].append(inferred_feat) - else: - inferred_feats[k] = [inferred_feat] - if subword.find('■') < 0 and not none: - feats_i += 1 - #import pdb - #pdb.set_trace() - for k, v in inferred_feats.items(): - example["src_feats"][k] = inferred_feats[k] - return example - - def _repr_args(self): - """Return str represent key arguments for class.""" - #return '{}={}, {}={}'.format( - # 'src_seq_length', self.src_seq_length, - # 'tgt_seq_length', self.tgt_seq_length - #) - return "INFERFEATS" - - @register_transform(name='prefix') class PrefixTransform(Transform): """Add Prefix to src (& tgt) sentence.""" diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index ddaa9f6e60..fbf8cd3ea7 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -83,6 +83,8 @@ def _validate_data(cls, opt): cls._validate_file(feature_file, info=f'{cname}/path_{feature_name}') if 'inferfeats' not in corpus["transforms"]: raise ValueError(f"'inferfeats' transform is required when setting source features") + if 'filterfeats' not in corpus["transforms"]: + raise ValueError(f"'filterfeats' transform is required when setting source features") else: corpus["src_feats"] = None From 80d20f9319e65cd7f1d3a414848a569c8190d2ff Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Tue, 24 Aug 2021 15:02:11 +0200 Subject: [PATCH 03/23] Fixed requirements on VALID data --- onmt/transforms/features.py | 4 ---- onmt/utils/parse.py | 18 +++++++++--------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/onmt/transforms/features.py b/onmt/transforms/features.py index fce01c743f..18c28a39b4 100644 --- a/onmt/transforms/features.py +++ b/onmt/transforms/features.py @@ -56,9 +56,6 @@ def apply(self, example, is_train=False, stats=None, **kwargs): # Do nothing return example - import pdb - pdb.set_trace() - feats_i = 0 inferred_feats = defaultdict(list) for subword in example["src"]: @@ -94,7 +91,6 @@ def apply(self, example, is_train=False, stats=None, **kwargs): for k, v in inferred_feats.items(): example["src_feats"][k] = inferred_feats[k] - pdb.set_trace() return example def _repr_args(self): diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index fbf8cd3ea7..3a49f8841b 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -122,15 +122,15 @@ def _validate_fields_opts(cls, opt, build_vocab_only=False): """Check options relate to vocab and fields.""" for cname, corpus in opt.data.items(): - if corpus["src_feats"] is not None: - assert opt.src_feats_vocab, \ - "-src_feats_vocab is required if using source features." - import yaml - opt.src_feats_vocab = yaml.safe_load(opt.src_feats_vocab) - - for feature in corpus["src_feats"].keys(): - assert feature in opt.src_feats_vocab, \ - f"No vocab file set for feature {feature}" + if cname != CorpusName.VALID and corpus["src_feats"] is not None: + assert opt.src_feats_vocab, \ + "-src_feats_vocab is required if using source features." + import yaml + opt.src_feats_vocab = yaml.safe_load(opt.src_feats_vocab) + + for feature in corpus["src_feats"].keys(): + assert feature in opt.src_feats_vocab, \ + f"No vocab file set for feature {feature}" if build_vocab_only: if not opt.share_vocab: From c722028ca8fd064f546a995bf1770a993a710cde Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Thu, 26 Aug 2021 16:07:02 +0200 Subject: [PATCH 04/23] Solved marked issues + Improved FeatInferTransform --- onmt/bin/build_vocab.py | 6 ++-- onmt/constants.py | 1 + onmt/inputters/corpus.py | 37 ++++++++++----------- onmt/inputters/fields.py | 7 ++-- onmt/inputters/inputter.py | 12 +++---- onmt/inputters/text_dataset.py | 11 +++---- onmt/transforms/features.py | 60 ++++++++++++++++++---------------- onmt/utils/alignment.py | 4 +-- 8 files changed, 70 insertions(+), 68 deletions(-) diff --git a/onmt/bin/build_vocab.py b/onmt/bin/build_vocab.py index f922e38842..ed510f09d2 100644 --- a/onmt/bin/build_vocab.py +++ b/onmt/bin/build_vocab.py @@ -37,8 +37,8 @@ def build_vocab_main(opts): logger.info(f"Counters src:{len(src_counter)}") logger.info(f"Counters tgt:{len(tgt_counter)}") - for k, v in src_feats_counter["src_feats"].items(): - logger.info(f"Counters {k}:{len(v)}") + for feat_name, feat_counter in src_feats_counter.items(): + logger.info(f"Counters {feat_name}:{len(feat_counter)}") def save_counter(counter, save_path): check_path(save_path, exist_ok=opts.overwrite, log=logger.warning) @@ -55,7 +55,7 @@ def save_counter(counter, save_path): save_counter(src_counter, opts.src_vocab) save_counter(tgt_counter, opts.tgt_vocab) - for k, v in src_feats_counter["src_feats"].items(): + for k, v in src_feats_counter.items(): save_counter(v, opts.src_feats_vocab[k]) diff --git a/onmt/constants.py b/onmt/constants.py index fb6afb0252..2d5864137b 100644 --- a/onmt/constants.py +++ b/onmt/constants.py @@ -22,6 +22,7 @@ class CorpusName(object): class SubwordMarker(object): SPACER = '▁' JOINER = '■' + CASE_MARKUP = ["⦅mrk_case_modifier_C⦆", "⦅mrk_begin_case_region_U⦆", "⦅mrk_end_case_region_U⦆"] class ModelTask(object): diff --git a/onmt/inputters/corpus.py b/onmt/inputters/corpus.py index 1810d5298f..18ae726ae1 100644 --- a/onmt/inputters/corpus.py +++ b/onmt/inputters/corpus.py @@ -124,10 +124,12 @@ def load(self, offset=0, stride=1): `offset` and `stride` allow to iterate only on every `stride` example, starting from `offset`. """ - #import pdb - #pdb.set_trace() if self.src_feats: - features_files = [open(feat_path, mode='rb') for feat_name, feat_path in self.src_feats.items()] + features_names = [] + features_files = [] + for feat_name, feat_path in self.src_feats.items(): + features_names.append(feat_name) + features_files.append(open(feat_path, mode='rb')) else: features_files = [] with exfile_open(self.src, mode='rb') as fs,\ @@ -146,7 +148,7 @@ def load(self, offset=0, stride=1): if features: example["src_feats"] = dict() for j, feat in enumerate(features): - example["src_feats"][list(self.src_feats.keys())[j]] = feat.decode("utf-8") + example["src_feats"][features_names[j]] = feat.decode("utf-8") yield example for f in features_files: f.close() @@ -304,11 +306,9 @@ def write_files_from_queues(sample_path, queues): def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): """Build vocab on (strided) subpart of the data.""" - #import pdb - #pdb.set_trace() sub_counter_src = Counter() sub_counter_tgt = Counter() - sub_counter_src_feats = {'src_feats': defaultdict(Counter)} + sub_counter_src_feats = defaultdict(Counter) datasets_iterables = build_corpora_iters( corpora, transforms, opts.data, skip_empty_level=opts.skip_empty_level, @@ -323,7 +323,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): src_line, tgt_line = maybe_example['src'], maybe_example['tgt'] if 'src_feats' in maybe_example: for feat_name, feat_line in maybe_example["src_feats"].items(): - sub_counter_src_feats['src_feats'][feat_name].update(feat_line.split(' ')) + sub_counter_src_feats[feat_name].update(feat_line.split(' ')) sub_counter_src.update(src_line.split(' ')) sub_counter_tgt.update(tgt_line.split(' ')) if opts.dump_samples: @@ -359,7 +359,7 @@ def build_vocab(opts, transforms, n_sample=3): corpora = get_corpora(opts, is_train=True) counter_src = Counter() counter_tgt = Counter() - counter_src_feats = {'src_feats': defaultdict(Counter)} + counter_src_feats = defaultdict(Counter) from functools import partial queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size) for i in range(opts.num_threads)] @@ -372,16 +372,15 @@ def build_vocab(opts, transforms, n_sample=3): args=(sample_path, queues), daemon=True) write_process.start() - #with mp.Pool(opts.num_threads, init_pool, [queues]) as p: - func = partial( - build_sub_vocab, corpora, transforms, - opts, n_sample, opts.num_threads) - sub_counter_src, sub_counter_tgt, sub_counter_src_feats = func(0) - # for sub_counter_src, sub_counter_tgt in p.imap( - # func, range(0, opts.num_threads)): - counter_src.update(sub_counter_src) - counter_tgt.update(sub_counter_tgt) - counter_src_feats.update(sub_counter_src_feats) + with mp.Pool(opts.num_threads, init_pool, [queues]) as p: + func = partial( + build_sub_vocab, corpora, transforms, + opts, n_sample, opts.num_threads) + for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap( + func, range(0, opts.num_threads)): + counter_src.update(sub_counter_src) + counter_tgt.update(sub_counter_tgt) + counter_src_feats.update(sub_counter_src_feats) if opts.dump_samples: write_process.join() return counter_src, counter_tgt, counter_src_feats diff --git a/onmt/inputters/fields.py b/onmt/inputters/fields.py index da53071706..5f41a3a01f 100644 --- a/onmt/inputters/fields.py +++ b/onmt/inputters/fields.py @@ -8,11 +8,10 @@ def _get_dynamic_fields(opts): - # NOTE: not support nfeats > 0 yet - #src_nfeats = 0 - tgt_nfeats = None #0 + # NOTE: not support tgt feats yet + tgt_feats = None with_align = hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0 - fields = get_fields('text', opts.src_feats_vocab, tgt_nfeats, + fields = get_fields('text', opts.src_feats_vocab, tgt_feats, dynamic_dict=opts.copy_attn, src_truncate=opts.src_seq_length_trunc, tgt_truncate=opts.tgt_seq_length_trunc, diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index f6b5c747d0..e385b074bb 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -111,8 +111,8 @@ def get_task_spec_tokens(data_task, pad, bos, eos): def get_fields( src_data_type, - n_src_feats, - n_tgt_feats, + src_feats, + tgt_feats, pad=DefaultTokens.PAD, bos=DefaultTokens.BOS, eos=DefaultTokens.EOS, @@ -125,11 +125,11 @@ def get_fields( """ Args: src_data_type: type of the source input. Options are [text]. - n_src_feats (int): the number of source features (not counting tokens) + src_feats (int): source features dict containing their names to create a :class:`torchtext.data.Field` for. (If ``src_data_type=="text"``, these fields are stored together as a ``TextMultiField``). - n_tgt_feats (int): See above. + tgt_feats (int): See above. pad (str): Special pad symbol. Used on src and tgt side. bos (str): Special beginning of sequence symbol. Only relevant for tgt. @@ -158,7 +158,7 @@ def get_fields( task_spec_tokens = get_task_spec_tokens(data_task, pad, bos, eos) src_field_kwargs = { - "n_feats": n_src_feats, + "feats": src_feats, "include_lengths": True, "pad": task_spec_tokens["src"]["pad"], "bos": task_spec_tokens["src"]["bos"], @@ -169,7 +169,7 @@ def get_fields( fields["src"] = fields_getters[src_data_type](**src_field_kwargs) tgt_field_kwargs = { - "n_feats": n_tgt_feats, + "feats": tgt_feats, "include_lengths": False, "pad": task_spec_tokens["tgt"]["pad"], "bos": task_spec_tokens["tgt"]["bos"], diff --git a/onmt/inputters/text_dataset.py b/onmt/inputters/text_dataset.py index bcb3794979..0fbbeacea3 100644 --- a/onmt/inputters/text_dataset.py +++ b/onmt/inputters/text_dataset.py @@ -152,7 +152,7 @@ def text_fields(**kwargs): Args: base_name (str): Name associated with the field. - n_feats (int): Number of word level feats (not counting the tokens) + feats (int): Word level feats include_lengths (bool): Optionally return the sequence lengths. pad (str, optional): Defaults to ``""``. bos (str or NoneType, optional): Defaults to ``""``. @@ -163,7 +163,7 @@ def text_fields(**kwargs): TextMultiField """ - n_feats = kwargs["n_feats"] + feats = kwargs["feats"] include_lengths = kwargs["include_lengths"] base_name = kwargs["base_name"] pad = kwargs.get("pad", DefaultTokens.PAD) @@ -187,10 +187,9 @@ def text_fields(**kwargs): fields_.append((base_name, feat)) # Feats fields - #for i in range(n_feats + 1): - if n_feats: - for feat_name in n_feats.keys(): - #name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name + if feats: + for feat_name in feats.keys(): + # Legacy function, it is not really necessary tokenize = partial( _feature_tokenize, layer=None, diff --git a/onmt/transforms/features.py b/onmt/transforms/features.py index 18c28a39b4..b2061d1063 100644 --- a/onmt/transforms/features.py +++ b/onmt/transforms/features.py @@ -1,10 +1,13 @@ from onmt.utils.logging import logger from onmt.transforms import register_transform from .transform import Transform, ObservableStats +from onmt.constants import DefaultTokens, SubwordMarker +from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer import re from collections import defaultdict + @register_transform(name='filterfeats') class FilterFeatsTransform(Transform): """Filter out examples with a mismatch between source and features.""" @@ -48,7 +51,10 @@ def add_options(cls, parser): pass def _parse_opts(self): - pass + super()._parse_opts() + logger.info("Parsed pyonmttok kwargs for src: {}".format( + self.opts.src_onmttok_kwargs)) + self.src_onmttok_kwargs = self.opts.src_onmttok_kwargs def apply(self, example, is_train=False, stats=None, **kwargs): @@ -56,41 +62,39 @@ def apply(self, example, is_train=False, stats=None, **kwargs): # Do nothing return example - feats_i = 0 - inferred_feats = defaultdict(list) - for subword in example["src"]: - next_ = False - for k, v in example["src_feats"].items(): - # TODO: what about custom placeholders?? + joiner = self.src_onmttok_kwargs["joiner"] if "joiner" in self.src_onmttok_kwargs else SubwordMarker.JOINER + case_markup = SubwordMarker.CASE_MARKUP if "case_markup" in self.src_onmttok_kwargs else [] + # TODO: support joiner_new or spacer_new options. Consistency not ensured currently - # Placeholders - if re.match(r'⦅\w+⦆', subword): - inferred_feat = "N" + if "joiner_annotate" in self.src_onmttok_kwargs: + word_to_subword_mapping = subword_map_by_joiner(example["src"], marker=joiner, case_markup=case_markup) + elif "spacer_annotate" in self.src_onmttok_kwargs: + # TODO: case markup + word_to_subword_mapping = subword_map_by_spacer(example["src"], marker=joiner) + else: + # TODO: support not reversible tokenization + raise Exception("InferFeats transform does not currently work without either joiner_annotate or spacer_annotate") - # Punctuation only - elif not re.sub(r'(\W)+', '', subword).strip(): - inferred_feat = "N" + inferred_feats = defaultdict(list) + for subword, word_id in zip(example["src"], word_to_subword_mapping): + for feat_name, feat_values in example["src_feats"].items(): - # Joiner annotate - elif re.search("■", subword): - inferred_feat = v[feats_i] + # If case markup placeholder + if subword in case_markup: + inferred_feat = "" + + # Punctuation only (assumes joiner is also some punctuation token) + elif not re.sub(r'(\W)+', '', subword).strip(): + inferred_feat = "" - # Whole word else: - inferred_feat = v[feats_i] - next_ = True + inferred_feat = feat_values[word_id] - inferred_feats[k].append(inferred_feat) - - if next_: - feats_i += 1 + inferred_feats[feat_name].append(inferred_feat) - # Check all features have been consumed - for k, v in example["src_feats"].items(): - assert feats_i == len(v), f'Not all features consumed for {k}' + for feat_name, feat_values in inferred_feats.items(): + example["src_feats"][feat_name] = inferred_feats[feat_name] - for k, v in inferred_feats.items(): - example["src_feats"][k] = inferred_feats[k] return example def _repr_args(self): diff --git a/onmt/utils/alignment.py b/onmt/utils/alignment.py index 0a70edb33e..3f7dd641d1 100644 --- a/onmt/utils/alignment.py +++ b/onmt/utils/alignment.py @@ -120,11 +120,11 @@ def to_word_align(src, tgt, subword_align, m_src='joiner', m_tgt='joiner'): return " ".join(word_align) -def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER): +def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER, case_markup=[]): """Return word id for each subword token (annotate by joiner).""" flags = [0] * len(subwords) for i, tok in enumerate(subwords): - if tok.endswith(marker): + if tok.endswith(marker) or tok in case_markup: flags[i] = 1 if tok.startswith(marker): assert i >= 1 and flags[i-1] != 1, \ From ad94f21196d323d6f6b8fb390ff09dce06860e88 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Fri, 27 Aug 2021 09:34:17 +0200 Subject: [PATCH 05/23] Added some unittests to check features transforms --- onmt/tests/test_subword_marker.py | 15 ++++++++++++++- onmt/tests/test_transform.py | 22 ++++++++++++++++++++++ onmt/transforms/features.py | 15 ++++++--------- onmt/utils/alignment.py | 16 +++++++--------- 4 files changed, 49 insertions(+), 19 deletions(-) diff --git a/onmt/tests/test_subword_marker.py b/onmt/tests/test_subword_marker.py index e827d52ffa..d1cb0b153f 100644 --- a/onmt/tests/test_subword_marker.py +++ b/onmt/tests/test_subword_marker.py @@ -2,6 +2,7 @@ from onmt.transforms.bart import word_start_finder from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer +from onmt.constants import DefaultTokens, SubwordMarker class TestWordStartFinder(unittest.TestCase): @@ -37,7 +38,19 @@ class TestSubwordGroup(unittest.TestCase): def test_subword_group_joiner(self): data_in = ['however', '■,', 'according', 'to', 'the', 'logs', '■,', 'she', 'is', 'hard', '■-■', 'working', '■.'] # noqa: E501 true_out = [0, 0, 1, 2, 3, 4, 4, 5, 6, 7, 7, 7, 7] - out = subword_map_by_joiner(data_in) + out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP) + self.assertEqual(out, true_out) + + def test_subword_group_joiner_with_markup(self): + data_in = ['⦅mrk_case_modifier_C⦆', 'however', '■,', 'according', 'to', 'the', 'logs', '■,', '⦅mrk_begin_case_region_U⦆', 'she', 'is', 'hard', '■-■', 'working', '■.', '⦅mrk_end_case_region_U⦆'] # noqa: E501 + true_out = [0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7] + out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP) + self.assertEqual(out, true_out) + + def test_subword_group_naive(self): + data_in = ['however', ',', 'according', 'to', 'the', 'logs', ',', 'she', 'is', 'hard', '-', 'working', '.'] # noqa: E501 + true_out = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP) self.assertEqual(out, true_out) def test_subword_group_spacer(self): diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py index 4bfa8be3bc..d23169d9a1 100644 --- a/onmt/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -509,3 +509,25 @@ def test_span_infilling(self): # n_masked = math.ceil(n_words * bart_noise.mask_ratio) # print(f"Text Span Infilling: {infillied} / {tokens}") # print(n_words, n_masked) + +class TestFeaturesTransform(unittest.TestCase): + def test_inferfeats(self): + inferfeats_cls = get_transforms_cls(["inferfeats"])["inferfeats"] + opt = Namespace(src_onmttok_kwargs={"mode": "conservative", "joiner": "■", "joiner_annotate": True, "case_markup": True}) + inferfeats_transform = inferfeats_cls(opt) + + ex_in = { + "src": ['however', '■,', 'according', 'to', 'the', 'logs', '■,', 'she', 'is', 'hard', '■-■', 'working', '■.'], + "tgt": ['however', '■,', 'according', 'to', 'the', 'logs', '■,', 'she', 'is', 'hard', '■-■', 'working', '■.'] + } + ex_out = inferfeats_transform.apply(ex_in) + self.assertIs(ex_out, ex_in) + + ex_in["src_feats"] = {"feat_0": ["A", "A", "A", "A", "B", "A", "A", "C"]} + ex_out = inferfeats_transform.apply(ex_in) + self.assertEqual(ex_out["src_feats"]["feat_0"], ["A", "", "A", "A", "A", "B", "", "A", "A", "C", "", "C", ""]) + + ex_in["src"] = ['⦅mrk_case_modifier_C⦆', 'however', '■,', 'according', 'to', 'the', 'logs', '■,', '⦅mrk_begin_case_region_U⦆', 'she', 'is', 'hard', '■-■', 'working', '■.', '⦅mrk_end_case_region_U⦆'] + ex_in["src_feats"] = {"feat_0": ["A", "A", "A", "A", "B", "A", "A", "C"]} + ex_out = inferfeats_transform.apply(ex_in) + self.assertEqual(ex_out["src_feats"]["feat_0"], ["", "A", "", "A", "A", "A", "B", "", "", "A", "A", "C", "", "C", "", ""]) diff --git a/onmt/transforms/features.py b/onmt/transforms/features.py index b2061d1063..9574b37c2b 100644 --- a/onmt/transforms/features.py +++ b/onmt/transforms/features.py @@ -67,26 +67,23 @@ def apply(self, example, is_train=False, stats=None, **kwargs): # TODO: support joiner_new or spacer_new options. Consistency not ensured currently if "joiner_annotate" in self.src_onmttok_kwargs: - word_to_subword_mapping = subword_map_by_joiner(example["src"], marker=joiner, case_markup=case_markup) + word_to_subword_mapping = subword_map_by_joiner(example["src"], marker=joiner, case_markup=case_markup) elif "spacer_annotate" in self.src_onmttok_kwargs: - # TODO: case markup - word_to_subword_mapping = subword_map_by_spacer(example["src"], marker=joiner) - else: - # TODO: support not reversible tokenization - raise Exception("InferFeats transform does not currently work without either joiner_annotate or spacer_annotate") + # TODO: case markup + word_to_subword_mapping = subword_map_by_spacer(example["src"], marker=joiner) + else: + # TODO: support not reversible tokenization + raise Exception("InferFeats transform does not currently work without either joiner_annotate or spacer_annotate") inferred_feats = defaultdict(list) for subword, word_id in zip(example["src"], word_to_subword_mapping): for feat_name, feat_values in example["src_feats"].items(): - # If case markup placeholder if subword in case_markup: inferred_feat = "" - # Punctuation only (assumes joiner is also some punctuation token) elif not re.sub(r'(\W)+', '', subword).strip(): inferred_feat = "" - else: inferred_feat = feat_values[word_id] diff --git a/onmt/utils/alignment.py b/onmt/utils/alignment.py index 3f7dd641d1..a34a08dc7e 100644 --- a/onmt/utils/alignment.py +++ b/onmt/utils/alignment.py @@ -122,17 +122,15 @@ def to_word_align(src, tgt, subword_align, m_src='joiner', m_tgt='joiner'): def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER, case_markup=[]): """Return word id for each subword token (annotate by joiner).""" - flags = [0] * len(subwords) + flags = [1] * len(subwords) for i, tok in enumerate(subwords): - if tok.endswith(marker) or tok in case_markup: - flags[i] = 1 - if tok.startswith(marker): - assert i >= 1 and flags[i-1] != 1, \ + if tok.endswith(marker) or (tok in case_markup and tok.find("end")<0): + flags[i] = 0 + if tok.startswith(marker) or (tok in case_markup and tok.find("end")>=0): + assert i >= 1 and flags[i-1] != 0, \ "Sentence `{}` not correct!".format(" ".join(subwords)) - flags[i-1] = 1 - marker_acc = list(accumulate([0] + flags[:-1])) - word_group = [(i - maker_sofar) for i, maker_sofar - in enumerate(marker_acc)] + flags[i-1] = 0 + word_group = list(accumulate([0] + flags[:-1])) return word_group From e79ec8a28520eccef12844883949d2533f8e8a06 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Fri, 27 Aug 2021 09:49:57 +0200 Subject: [PATCH 06/23] Updated parameters types in docstrings --- onmt/inputters/inputter.py | 4 ++-- onmt/inputters/text_dataset.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index e385b074bb..ffd8c77fb1 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -125,11 +125,11 @@ def get_fields( """ Args: src_data_type: type of the source input. Options are [text]. - src_feats (int): source features dict containing their names + src_feats (Optional[Dict]): source features dict containing their names to create a :class:`torchtext.data.Field` for. (If ``src_data_type=="text"``, these fields are stored together as a ``TextMultiField``). - tgt_feats (int): See above. + tgt_feats (Optional[Dict]): See above. pad (str): Special pad symbol. Used on src and tgt side. bos (str): Special beginning of sequence symbol. Only relevant for tgt. diff --git a/onmt/inputters/text_dataset.py b/onmt/inputters/text_dataset.py index 0fbbeacea3..4cd3453341 100644 --- a/onmt/inputters/text_dataset.py +++ b/onmt/inputters/text_dataset.py @@ -152,7 +152,7 @@ def text_fields(**kwargs): Args: base_name (str): Name associated with the field. - feats (int): Word level feats + feats (Optional[Dict]): Word level feats include_lengths (bool): Optionally return the sequence lengths. pad (str, optional): Defaults to ``""``. bos (str or NoneType, optional): Defaults to ``""``. From 2f15edd2fcd7dfca429196d3c430858c0cf651e6 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Fri, 27 Aug 2021 10:03:59 +0200 Subject: [PATCH 07/23] First version of the FAQ --- docs/source/FAQ.md | 51 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index f40fada251..948aa50c8f 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -477,3 +477,54 @@ Training options to perform vocabulary update are: * `-update_vocab`: set this option * `-reset_optim`: set the value to "states" * `-train_from`: checkpoint path + + +## How can I use source word features? + +Extra information can be added to the words in the source sentences by defining word features. + +Features should be defined in a separate file using blank spaces as a separator and with each row corresponding to a source sentence. An example of the input files: + +data.src +``` +however, according to the logs, she is hard-working. +``` + +feat0.txt +``` +A C C C C A A B +``` + +**Notes** +- Prior tokenization is not necessary, features will be inferred by using the `FeatInferTransform` transform. +- `FilterFeatsTransform` and `FeatInferTransform` are required in order to ensure the functionality. + +Sample config file: + +``` +data: + dummy: + path_src: data/train/data.src + path_tgt: data/train/data.tgt + src_feats: + feat_0: data/train/data.src.feat_0 + feat_1: data/train/data.src.feat_1 + transforms: [filterfeats, onmt_tokenize, inferfeats, filtertoolong] + weight: 1 + valid: + path_src: data/valid/data.src + path_tgt: data/valid/data.tgt + src_feats: + feat_0: data/valid/data.src.feat_0 + feat_1: data/valid/data.src.feat_1 + transforms: [filterfeats, onmt_tokenize, inferfeats] + +# # Vocab opts +src_vocab: exp/data.vocab.src +tgt_vocab: exp/data.vocab.tgt +src_feats_vocab: + feat_0: exp/data.vocab.feat_0 + feat_1: exp/data.vocab.feat_1 +feat_merge: "sum" + +``` \ No newline at end of file From 5a31039cedb2d490a1873e8f3bb355009cc5f337 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Fri, 27 Aug 2021 12:12:33 +0200 Subject: [PATCH 08/23] Improved InferFeatsTransform --- onmt/tests/test_transform.py | 2 +- onmt/transforms/features.py | 24 ++++++++++-------------- onmt/utils/alignment.py | 2 +- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py index d23169d9a1..d99bc607de 100644 --- a/onmt/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -513,7 +513,7 @@ def test_span_infilling(self): class TestFeaturesTransform(unittest.TestCase): def test_inferfeats(self): inferfeats_cls = get_transforms_cls(["inferfeats"])["inferfeats"] - opt = Namespace(src_onmttok_kwargs={"mode": "conservative", "joiner": "■", "joiner_annotate": True, "case_markup": True}) + opt = Namespace(reversible_tokenization="joiner") inferfeats_transform = inferfeats_cls(opt) ex_in = { diff --git a/onmt/transforms/features.py b/onmt/transforms/features.py index 9574b37c2b..01dfcd0251 100644 --- a/onmt/transforms/features.py +++ b/onmt/transforms/features.py @@ -48,13 +48,14 @@ def __init__(self, opts): @classmethod def add_options(cls, parser): - pass + """Avalilable options related to this Transform.""" + group = parser.add_argument_group("Transform/InferFeats") + group.add("--reversible_tokenization", "-reversible_tokenization", default="joiner", + choices=["joiner", "spacer"], help="Type of reversible tokenization applied on the tokenizer.") def _parse_opts(self): super()._parse_opts() - logger.info("Parsed pyonmttok kwargs for src: {}".format( - self.opts.src_onmttok_kwargs)) - self.src_onmttok_kwargs = self.opts.src_onmttok_kwargs + self.reversible_tokenization = self.opts.reversible_tokenization def apply(self, example, is_train=False, stats=None, **kwargs): @@ -62,24 +63,19 @@ def apply(self, example, is_train=False, stats=None, **kwargs): # Do nothing return example - joiner = self.src_onmttok_kwargs["joiner"] if "joiner" in self.src_onmttok_kwargs else SubwordMarker.JOINER - case_markup = SubwordMarker.CASE_MARKUP if "case_markup" in self.src_onmttok_kwargs else [] # TODO: support joiner_new or spacer_new options. Consistency not ensured currently - if "joiner_annotate" in self.src_onmttok_kwargs: - word_to_subword_mapping = subword_map_by_joiner(example["src"], marker=joiner, case_markup=case_markup) - elif "spacer_annotate" in self.src_onmttok_kwargs: + if self.reversible_tokenization == "joiner": + word_to_subword_mapping = subword_map_by_joiner(example["src"]) + else: #Spacer # TODO: case markup - word_to_subword_mapping = subword_map_by_spacer(example["src"], marker=joiner) - else: - # TODO: support not reversible tokenization - raise Exception("InferFeats transform does not currently work without either joiner_annotate or spacer_annotate") + word_to_subword_mapping = subword_map_by_spacer(example["src"]) inferred_feats = defaultdict(list) for subword, word_id in zip(example["src"], word_to_subword_mapping): for feat_name, feat_values in example["src_feats"].items(): # If case markup placeholder - if subword in case_markup: + if subword in SubwordMarker.CASE_MARKUP: inferred_feat = "" # Punctuation only (assumes joiner is also some punctuation token) elif not re.sub(r'(\W)+', '', subword).strip(): diff --git a/onmt/utils/alignment.py b/onmt/utils/alignment.py index a34a08dc7e..d9b1919a8f 100644 --- a/onmt/utils/alignment.py +++ b/onmt/utils/alignment.py @@ -120,7 +120,7 @@ def to_word_align(src, tgt, subword_align, m_src='joiner', m_tgt='joiner'): return " ".join(word_align) -def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER, case_markup=[]): +def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP): """Return word id for each subword token (annotate by joiner).""" flags = [1] * len(subwords) for i, tok in enumerate(subwords): From 1d785b9c95d7860d002abd3f1fb0c0b943ac0354 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Fri, 27 Aug 2021 12:58:19 +0200 Subject: [PATCH 09/23] Added integration test --- data/data_features/src-train.feat0 | 3 +++ data/data_features/src-train.txt | 3 +++ data/data_features/src-val.feat0 | 1 + data/data_features/src-val.txt | 1 + data/data_features/tgt-train.txt | 3 +++ data/data_features/tgt-val.txt | 1 + data/features_data.yaml | 11 +++++++++ onmt/tests/pull_request_chk.sh | 36 ++++++++++++++++++++++++++---- 8 files changed, 55 insertions(+), 4 deletions(-) create mode 100644 data/data_features/src-train.feat0 create mode 100644 data/data_features/src-train.txt create mode 100644 data/data_features/src-val.feat0 create mode 100644 data/data_features/src-val.txt create mode 100644 data/data_features/tgt-train.txt create mode 100644 data/data_features/tgt-val.txt create mode 100644 data/features_data.yaml diff --git a/data/data_features/src-train.feat0 b/data/data_features/src-train.feat0 new file mode 100644 index 0000000000..7e189f2c33 --- /dev/null +++ b/data/data_features/src-train.feat0 @@ -0,0 +1,3 @@ +A A A A B A A A C +A B C D E +C B A B \ No newline at end of file diff --git a/data/data_features/src-train.txt b/data/data_features/src-train.txt new file mode 100644 index 0000000000..8a3ec35c2b --- /dev/null +++ b/data/data_features/src-train.txt @@ -0,0 +1,3 @@ +however, according to the logs, she is a hard-working. +however, according to the logs, +she is a hard-working. \ No newline at end of file diff --git a/data/data_features/src-val.feat0 b/data/data_features/src-val.feat0 new file mode 100644 index 0000000000..4ab4a9e651 --- /dev/null +++ b/data/data_features/src-val.feat0 @@ -0,0 +1 @@ +C B A B \ No newline at end of file diff --git a/data/data_features/src-val.txt b/data/data_features/src-val.txt new file mode 100644 index 0000000000..0cc723ce39 --- /dev/null +++ b/data/data_features/src-val.txt @@ -0,0 +1 @@ +she is a hard-working. \ No newline at end of file diff --git a/data/data_features/tgt-train.txt b/data/data_features/tgt-train.txt new file mode 100644 index 0000000000..8a3ec35c2b --- /dev/null +++ b/data/data_features/tgt-train.txt @@ -0,0 +1,3 @@ +however, according to the logs, she is a hard-working. +however, according to the logs, +she is a hard-working. \ No newline at end of file diff --git a/data/data_features/tgt-val.txt b/data/data_features/tgt-val.txt new file mode 100644 index 0000000000..0cc723ce39 --- /dev/null +++ b/data/data_features/tgt-val.txt @@ -0,0 +1 @@ +she is a hard-working. \ No newline at end of file diff --git a/data/features_data.yaml b/data/features_data.yaml new file mode 100644 index 0000000000..fa9b665f9c --- /dev/null +++ b/data/features_data.yaml @@ -0,0 +1,11 @@ +# Corpus opts: +data: + corpus_1: + path_src: data/data_features/src-train.txt + path_tgt: data/data_features/tgt-train.txt + src_feats: + feat0: data/data_features/src-train.feat0 + transforms: [filterfeats, inferfeats] + valid: + path_src: data/data_features/src-val.txt + path_tgt: data/data_features/tgt-val.txt diff --git a/onmt/tests/pull_request_chk.sh b/onmt/tests/pull_request_chk.sh index b282cc7f1e..4dedf053f2 100755 --- a/onmt/tests/pull_request_chk.sh +++ b/onmt/tests/pull_request_chk.sh @@ -67,10 +67,22 @@ PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH} ${PYTHON} onmt/bin/build_vocab.py \ -save_data $TMP_OUT_DIR/onmt \ -src_vocab $TMP_OUT_DIR/onmt.vocab.src \ -tgt_vocab $TMP_OUT_DIR/onmt.vocab.tgt \ - -n_sample 5000 >> ${LOG_FILE} 2>&1 + -n_sample 5000 -overwrite >> ${LOG_FILE} 2>&1 [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} -rm -r $TMP_OUT_DIR/sample +rm -f -r $TMP_OUT_DIR/sample + +echo -n "[+] Testing vocabulary building with features..." +PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH} ${PYTHON} onmt/bin/build_vocab.py \ + -config ${DATA_DIR}/features_data.yaml \ + -save_data $TMP_OUT_DIR/onmt_feat \ + -src_vocab $TMP_OUT_DIR/onmt_feat.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/onmt_feat.vocab.tgt \ + -src_feats_vocab '{"feat0": "${TMP_OUT_DIR}/onmt_feat.vocab.feat0"}' \ + -n_sample -1 -overwrite>> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} +rm -f -r $TMP_OUT_DIR/sample # # Training test @@ -254,8 +266,24 @@ ${PYTHON} onmt/bin/train.py \ [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} -rm $TMP_OUT_DIR/onmt.vocab* -rm $TMP_OUT_DIR/onmt.model* +echo -n " [+] Testing training with features..." +${PYTHON} onmt/bin/train.py \ + -config ${DATA_DIR}/features_data.yaml \ + -src_vocab $TMP_OUT_DIR/onmt_feat.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/onmt_feat.vocab.tgt \ + -src_feats_vocab '{"feat0": "${TMP_OUT_DIR}/onmt_feat.vocab.feat0"}' \ + -src_vocab_size 1000 -tgt_vocab_size 1000 \ + -rnn_size 2 -batch_size 10 \ + -word_vec_size 5 -rnn_size 10 \ + -report_every 5 -train_steps 10 \ + -save_model $TMP_OUT_DIR/onmt.model \ + -save_checkpoint_steps 10 >> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} + +rm -f $TMP_OUT_DIR/onmt.vocab* +rm -f $TMP_OUT_DIR/onmt.model* +rm -f $TMP_OUT_DIR/onmt_feat.vocab.* # # Translation test From a76c737102282cef4cee062af6c4f8af510da654 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Fri, 27 Aug 2021 15:40:10 +0200 Subject: [PATCH 10/23] Added tests in github CI config --- .github/workflows/push.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 9780df63fc..7ac5d50b86 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -42,6 +42,16 @@ jobs: -src_vocab /tmp/onmt.vocab.src \ -tgt_vocab /tmp/onmt.vocab.tgt \ && rm -rf /tmp/sample + - name: Test vocabulary build with features + run: | + python onmt/bin/build_vocab.py \ + -config data/features_data.yaml \ + -save_data /tmp/onmt_feat \ + -src_vocab /tmp/onmt_feat.vocab.src \ + -tgt_vocab /tmp/onmt_feat.vocab.tgt \ + -src_feats_vocab '{"feat0": "/tmp/onmt_feat.vocab.feat0"}' \ + -n_sample -1 \ + && rm -rf /tmp/sample - name: Test field/transform dump run: | # The dumped fields are used later when testing tools @@ -169,6 +179,19 @@ jobs: -state_dim 256 \ -n_steps 10 \ -n_node 64 + - name: Testing training with features + run: | + python onmt/bin/train.py \ + -config data/features_data.yaml \ + -src_vocab /tmp/onmt_feat.vocab.src \ + -tgt_vocab /tmp/onmt_feat.vocab.tgt \ + -src_feats_vocab '{"feat0": "/tmp/onmt_feat.vocab.feat0"}' \ + -src_vocab_size 1000 -tgt_vocab_size 1000 \ + -rnn_size 2 -batch_size 10 \ + -word_vec_size 5 -rnn_size 10 \ + -report_every 5 -train_steps 10 \ + -save_model /tmp/onmt.model \ + -save_checkpoint_steps 10 - name: Test RNN translation run: | head data/src-test.txt > /tmp/src-test.txt From 0909e31a9fc71ca57c6c12827bd511713827120d Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Mon, 30 Aug 2021 15:15:27 +0200 Subject: [PATCH 11/23] New inference initial commit for review --- onmt/inputters/dataset_base.py | 4 ++-- onmt/inputters/text_dataset.py | 25 ++++++++++++++++++++----- onmt/opts.py | 3 +++ onmt/translate/translator.py | 5 +++-- onmt/utils/parse.py | 2 +- 5 files changed, 29 insertions(+), 10 deletions(-) diff --git a/onmt/inputters/dataset_base.py b/onmt/inputters/dataset_base.py index aeec428aaf..1e761d0261 100644 --- a/onmt/inputters/dataset_base.py +++ b/onmt/inputters/dataset_base.py @@ -116,7 +116,7 @@ def __init__(self, fields, readers, data, sort_key, filter_pred=None): self.sort_key = sort_key can_copy = 'src_map' in fields and 'alignment' in fields - read_iters = [r.read(dat[1], dat[0]) for r, dat in zip(readers, data)] + read_iters = [r.read(dat, name, feats) for r, (name, dat, feats) in zip(readers, data)] # self.src_vocabs is used in collapse_copy_scores and Translator.py self.src_vocabs = [] @@ -162,5 +162,5 @@ def config(fields): for name, field in fields: if field["data"] is not None: readers.append(field["reader"]) - data.append((name, field["data"])) + data.append((name, field["data"], field["features"])) return readers, data diff --git a/onmt/inputters/text_dataset.py b/onmt/inputters/text_dataset.py index 4cd3453341..1f7fc77686 100644 --- a/onmt/inputters/text_dataset.py +++ b/onmt/inputters/text_dataset.py @@ -9,7 +9,7 @@ class TextDataReader(DataReaderBase): - def read(self, sequences, side): + def read(self, sequences, side, features={}): """Read text data from disk. Args: @@ -25,10 +25,26 @@ def read(self, sequences, side): """ if isinstance(sequences, str): sequences = DataReaderBase._read_file(sequences) - for i, seq in enumerate(sequences): + + features_names = [] + features_values = [] + for feat_name, v in features.items(): + features_names.append(feat_name) + if isinstance(v, str): + features_values.append(DataReaderBase._read_file(features)) + else: + features_values.append(v) + for i, (seq, *feats) in enumerate(zip(sequences, *features_values)): + ex_dict = {} if isinstance(seq, bytes): seq = seq.decode("utf-8") - yield {side: seq, "indices": i} + ex_dict[side] = seq + for i, f in enumerate(feats): + if isinstance(f, bytes): + f = f.decode("utf-8") + ex_dict[features_names[i]] = f + ex_dict["indices"] = i + yield {side: ex_dict} def text_sort_key(ex): @@ -140,8 +156,7 @@ def preprocess(self, x): lists of tokens/feature tags for the sentence. The output is ordered like ``self.fields``. """ - - return [f.preprocess(x) for _, f in self.fields] + return [f.preprocess(x[fn]) for fn, f in self.fields] def __getitem__(self, item): return self.fields[item] diff --git a/onmt/opts.py b/onmt/opts.py index e5db2947f0..c6ef415f03 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -760,6 +760,9 @@ def translate_opts(parser): group.add('--src', '-src', required=True, help="Source sequence to decode (one line per " "sequence)") + group.add("-src_feats", "--src_feats", required=False, + help="Source sequence features (one line per " + "sequence)") group.add('--tgt', '-tgt', help='True target sequence (optional)') group.add('--tgt_prefix', '-tgt_prefix', action='store_true', diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 9bdd4ee4c7..d5329e9cc6 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -333,6 +333,7 @@ def _gold_score( def translate( self, src, + src_feats={}, tgt=None, batch_size=None, batch_type="sents", @@ -363,8 +364,8 @@ def translate( if self.tgt_prefix and tgt is None: raise ValueError("Prefix should be feed to tgt if -tgt_prefix.") - src_data = {"reader": self.src_reader, "data": src} - tgt_data = {"reader": self.tgt_reader, "data": tgt} + src_data = {"reader": self.src_reader, "data": src, "features": src_feats} + tgt_data = {"reader": self.tgt_reader, "data": tgt, "features": {}} _readers, _data = inputters.Dataset.config( [("src", src_data), ("tgt", tgt_data)] ) diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 3a49f8841b..4a12a5fe4d 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -320,4 +320,4 @@ def validate_train_opts(cls, opt): @classmethod def validate_translate_opts(cls, opt): - pass + opt.src_feats = eval(opt.src_feats) if opt.src_feats else {} From d68f01060c89d55393d9f5387b26e335f4665ac7 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Mon, 30 Aug 2021 15:34:56 +0200 Subject: [PATCH 12/23] Fixed indices issue --- onmt/inputters/text_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onmt/inputters/text_dataset.py b/onmt/inputters/text_dataset.py index 1f7fc77686..46346b2a82 100644 --- a/onmt/inputters/text_dataset.py +++ b/onmt/inputters/text_dataset.py @@ -43,8 +43,7 @@ def read(self, sequences, side, features={}): if isinstance(f, bytes): f = f.decode("utf-8") ex_dict[features_names[i]] = f - ex_dict["indices"] = i - yield {side: ex_dict} + yield {side: ex_dict, "indices": i} def text_sort_key(ex): From ab8dd7a6c48b6dedbe1e9aed974cbbabc525f17a Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Mon, 30 Aug 2021 18:39:36 +0200 Subject: [PATCH 13/23] Checked it correctly uses features --- onmt/bin/translate.py | 16 +++++++++++++--- onmt/inputters/corpus.py | 28 ++++++++++++++++++---------- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/onmt/bin/translate.py b/onmt/bin/translate.py index 0b5434f89a..4e3e126ae2 100755 --- a/onmt/bin/translate.py +++ b/onmt/bin/translate.py @@ -6,6 +6,7 @@ import onmt.opts as opts from onmt.utils.parse import ArgumentParser +from collections import defaultdict def translate(opt): @@ -15,12 +16,21 @@ def translate(opt): translator = build_translator(opt, logger=logger, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) - shard_pairs = zip(src_shards, tgt_shards) - - for i, (src_shard, tgt_shard) in enumerate(shard_pairs): + features_shards = [] + features_names = [] + for feat_name, feat_path in opt.src_feats.items(): + features_shards.append(split_corpus(feat_path, opt.shard_size)) + features_names.append(feat_name) + shard_pairs = zip(src_shards, tgt_shards, *features_shards) + + for i, (src_shard, tgt_shard, *features_shard) in enumerate(shard_pairs): + features_shard_ = defaultdict(list) + for j, x in enumerate(features_shard): + features_shard_[features_names[j]] = x logger.info("Translating shard %d." % i) translator.translate( src=src_shard, + src_feats=features_shard_, tgt=tgt_shard, batch_size=opt.batch_size, batch_type=opt.batch_type, diff --git a/onmt/inputters/corpus.py b/onmt/inputters/corpus.py index 18ae726ae1..1fb4c33a73 100644 --- a/onmt/inputters/corpus.py +++ b/onmt/inputters/corpus.py @@ -11,6 +11,7 @@ from contextlib import contextmanager import multiprocessing as mp +from collections import defaultdict @contextmanager @@ -70,13 +71,19 @@ def _process(item, is_train): example, is_train=is_train, corpus_name=cid) if maybe_example is None: return None - maybe_example['src'] = ' '.join(maybe_example['src']) - maybe_example['tgt'] = ' '.join(maybe_example['tgt']) + + maybe_example['src'] = {"src": ' '.join(maybe_example['src'])} + + # Make features part of src as in MultiTextField + if 'src_feats' in maybe_example: + for feat_name, feat_value in maybe_example['src_feats'].items(): + maybe_example['src'][feat_name] = ' '.join(feat_value) + del maybe_example["src_feats"] + + maybe_example['tgt'] = {"tgt": ' '.join(maybe_example['tgt'])} if 'align' in maybe_example: maybe_example['align'] = ' '.join(maybe_example['align']) - if 'src_feats' in maybe_example: - for k in maybe_example['src_feats'].keys(): - maybe_example['src_feats'][k] = ' '.join(maybe_example['src_feats'][k]) + return maybe_example def _maybe_add_dynamic_dict(self, example, fields): @@ -176,7 +183,8 @@ def get_corpora(opts, is_train=False): CorpusName.VALID, opts.data[CorpusName.VALID]["path_src"], opts.data[CorpusName.VALID]["path_tgt"], - opts.data[CorpusName.VALID]["path_align"]) + opts.data[CorpusName.VALID]["path_align"], + opts.data[CorpusName.VALID]["src_feats"]) else: return None return corpora_dict @@ -321,11 +329,11 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): build_sub_vocab.queues[c_name][offset].put("blank") continue src_line, tgt_line = maybe_example['src'], maybe_example['tgt'] - if 'src_feats' in maybe_example: - for feat_name, feat_line in maybe_example["src_feats"].items(): + for feat_name, feat_line in maybe_example["src"].items(): + if feat_name != "src": sub_counter_src_feats[feat_name].update(feat_line.split(' ')) - sub_counter_src.update(src_line.split(' ')) - sub_counter_tgt.update(tgt_line.split(' ')) + sub_counter_src.update(src_line["src"].split(' ')) + sub_counter_tgt.update(tgt_line["tgt"].split(' ')) if opts.dump_samples: build_sub_vocab.queues[c_name][offset].put( (i, src_line, tgt_line)) From 1f360e19b5ffb2c28b86a5bddaae80610041cea6 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Tue, 31 Aug 2021 11:22:53 +0200 Subject: [PATCH 14/23] Updated unittests for text dataset --- onmt/tests/test_text_dataset.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/onmt/tests/test_text_dataset.py b/onmt/tests/test_text_dataset.py index e4d22e9c0a..4477bca7fe 100644 --- a/onmt/tests/test_text_dataset.py +++ b/onmt/tests/test_text_dataset.py @@ -79,7 +79,8 @@ def test_preprocess_shape(self): self.INIT_CASES, self.PARAMS): init_case = self.initialize_case(init_case, params) mf = TextMultiField(**init_case) - sample_str = "dummy input here ." + + sample_str = {"base_field": "dummy input here .", "a": "A A B D", "r": "C C C C", "b": "D F E D", "zbase_field": "another dummy input ."} proc = mf.preprocess(sample_str) self.assertEqual(len(proc), len(init_case["feats_fields"]) + 1) @@ -147,7 +148,7 @@ def test_read(self): ] rdr = TextDataReader() for i, ex in enumerate(rdr.read(strings, "src")): - self.assertEqual(ex["src"], strings[i].decode("utf-8")) + self.assertEqual(ex["src"], {"src": strings[i].decode("utf-8")}) class TestTextDataReaderFromFS(unittest.TestCase): @@ -174,4 +175,23 @@ def tearDownClass(cls): def test_read(self): rdr = TextDataReader() for i, ex in enumerate(rdr.read(self.FILE_NAME, "src")): - self.assertEqual(ex["src"], self.STRINGS[i].decode("utf-8")) + self.assertEqual(ex["src"], {"src": self.STRINGS[i].decode("utf-8")}) + +class TestTextDataReaderWithFeatures(unittest.TestCase): + def test_read(self): + strings = [ + "hello world".encode("utf-8"), + "this's a string with punctuation .".encode("utf-8"), + "ThIs Is A sTrInG wItH oDD CapitALIZAtion".encode("utf-8") + ] + features = { + "feat_0": [ + "A A".encode("utf-8"), + "A A B B C".encode("utf-8"), + "A A D D E E".encode("utf-8") + ] + } + + rdr = TextDataReader() + for i, ex in enumerate(rdr.read(strings, "src", features)): + self.assertEqual(ex["src"], {"src": strings[i].decode("utf-8"), "feat_0": features["feat_0"][i].decode("utf-8")}) \ No newline at end of file From 22e0e0dc940b7de5a58fe818d107641d4e56b9a1 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Tue, 31 Aug 2021 11:29:24 +0200 Subject: [PATCH 15/23] Fixed corpus save test --- onmt/inputters/corpus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onmt/inputters/corpus.py b/onmt/inputters/corpus.py index 1fb4c33a73..95237ca96b 100644 --- a/onmt/inputters/corpus.py +++ b/onmt/inputters/corpus.py @@ -423,7 +423,7 @@ def save_transformed_sample(opts, transforms, n_sample=3): maybe_example = DatasetAdapter._process(item, is_train=True) if maybe_example is None: continue - src_line, tgt_line = maybe_example['src'], maybe_example['tgt'] + src_line, tgt_line = maybe_example['src']['src'], maybe_example['tgt']['tgt'] f_src.write(src_line + '\n') f_tgt.write(tgt_line + '\n') if n_sample > 0 and i >= n_sample: From 4e5e53708ec5d660640552dae1d9cb3b90b46a97 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Tue, 31 Aug 2021 11:39:06 +0200 Subject: [PATCH 16/23] Fixed issues with new examples --- onmt/inputters/dataset_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onmt/inputters/dataset_base.py b/onmt/inputters/dataset_base.py index 1e761d0261..65322d9a4c 100644 --- a/onmt/inputters/dataset_base.py +++ b/onmt/inputters/dataset_base.py @@ -41,7 +41,7 @@ def _dynamic_dict(example, src_field, tgt_field): ``example``, changed as described. """ - src = src_field.tokenize(example["src"]) + src = src_field.tokenize(example["src"]["src"]) # make a small vocab containing just the tokens in the source sequence unk = src_field.unk_token pad = src_field.pad_token @@ -60,7 +60,7 @@ def _dynamic_dict(example, src_field, tgt_field): example["src_ex_vocab"] = src_ex_vocab if "tgt" in example: - tgt = tgt_field.tokenize(example["tgt"]) + tgt = tgt_field.tokenize(example["tgt"]["tgt"]) mask = torch.LongTensor( [unk_idx] + [src_ex_vocab.stoi[w] for w in tgt] + [unk_idx]) example["alignment"] = mask From 1e1c5409b8872bcef6acd6b10dcb8192c3927010 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Tue, 31 Aug 2021 11:48:03 +0200 Subject: [PATCH 17/23] Fixed issues with new examples --- onmt/translate/translator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index d5329e9cc6..abb475f70d 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -926,6 +926,7 @@ def _align_forward(self, batch, predictions): def translate( self, src, + src_feats={}, tgt=None, batch_size=None, batch_type="sents", @@ -946,6 +947,7 @@ def translate( return super(GeneratorLM, self).translate( src, + src_feats, tgt, batch_size=1, batch_type=batch_type, From a0bd55f8dbbbd0d60025f5c40c1663d2ef823d12 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Tue, 31 Aug 2021 14:30:13 +0200 Subject: [PATCH 18/23] Added integration tests and updated FAQ --- .github/workflows/push.yml | 7 +++++++ docs/source/FAQ.md | 10 +++++++++- onmt/tests/pull_request_chk.sh | 12 +++++++++++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 7ac5d50b86..66d892efff 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -192,6 +192,13 @@ jobs: -report_every 5 -train_steps 10 \ -save_model /tmp/onmt.model \ -save_checkpoint_steps 10 + - name: Testing translation with features + run: | + python translate.py \ + -model /tmp/onmt.model_step_10.pt \ + -src data/data_features/src-test.txt \ + -src_feats "{'feat0': 'data/data_features/src-test.feat0'}" \ + -verbose - name: Test RNN translation run: | head data/src-test.txt > /tmp/src-test.txt diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index 948aa50c8f..e633140f4b 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -527,4 +527,12 @@ src_feats_vocab: feat_1: exp/data.vocab.feat_1 feat_merge: "sum" -``` \ No newline at end of file +``` + +During inference you can pass features by using the `--src_feats` argument. + +**Important note!** During inference, input sentence is expected to be tokenized. Therefore feature inferring should be handled prior to running the translate command. Example: + +```bash +python translate.py -model model_step_10.pt -src ../data.txt.tok -output ../data.out --src_feats "{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}" +``` diff --git a/onmt/tests/pull_request_chk.sh b/onmt/tests/pull_request_chk.sh index 4dedf053f2..70cd76823a 100755 --- a/onmt/tests/pull_request_chk.sh +++ b/onmt/tests/pull_request_chk.sh @@ -276,7 +276,7 @@ ${PYTHON} onmt/bin/train.py \ -rnn_size 2 -batch_size 10 \ -word_vec_size 5 -rnn_size 10 \ -report_every 5 -train_steps 10 \ - -save_model $TMP_OUT_DIR/onmt.model \ + -save_model $TMP_OUT_DIR/onmt.features.model \ -save_checkpoint_steps 10 >> ${LOG_FILE} 2>&1 [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} @@ -297,6 +297,16 @@ ${PYTHON} translate.py -model ${TEST_DIR}/test_model.pt -src $TMP_OUT_DIR/src-te echo "Succeeded" | tee -a ${LOG_FILE} rm $TMP_OUT_DIR/src-test.txt +echo -n " [+] Testing NMT translation with features..." +${PYTHON} translate.py \ + -model ${TMP_OUT_DIR}/onmt.features.model_step_10.pt \ + -src ${DATA_DIR}/data_features/src-test.txt \ + -src_feats "{'feat0': '${DATA_DIR}/data_features/src-test.feat0'}" \ + -verbose >> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} +rm -f $TMP_OUT_DIR/onmt.features.model* + echo -n " [+] Testing NMT ensemble translation..." head ${DATA_DIR}/src-test.txt > $TMP_OUT_DIR/src-test.txt ${PYTHON} translate.py -model ${TEST_DIR}/test_model.pt ${TEST_DIR}/test_model.pt \ From f5b1eefebd3076e21849dbbb05c816f115ab8019 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Tue, 31 Aug 2021 14:38:09 +0200 Subject: [PATCH 19/23] Added test data files --- data/data_features/src-test.feat0 | 1 + data/data_features/src-test.txt | 1 + 2 files changed, 2 insertions(+) create mode 100644 data/data_features/src-test.feat0 create mode 100644 data/data_features/src-test.txt diff --git a/data/data_features/src-test.feat0 b/data/data_features/src-test.feat0 new file mode 100644 index 0000000000..4ab4a9e651 --- /dev/null +++ b/data/data_features/src-test.feat0 @@ -0,0 +1 @@ +C B A B \ No newline at end of file diff --git a/data/data_features/src-test.txt b/data/data_features/src-test.txt new file mode 100644 index 0000000000..0cc723ce39 --- /dev/null +++ b/data/data_features/src-test.txt @@ -0,0 +1 @@ +she is a hard-working. \ No newline at end of file From 705b94fbf10a4a90e6d2eb8c7ed89122f94c8f57 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Thu, 9 Sep 2021 10:23:42 +0200 Subject: [PATCH 20/23] Fixed some issues --- docs/source/FAQ.md | 13 ++++++++++++- onmt/inputters/corpus.py | 7 ++++--- onmt/inputters/text_dataset.py | 4 ++++ onmt/opts.py | 4 ++-- onmt/tests/test_subword_marker.py | 20 +++++++++++++++++++- onmt/transforms/features.py | 3 --- onmt/translate/translator.py | 1 + onmt/utils/alignment.py | 24 ++++++++++++++++++++++-- 8 files changed, 64 insertions(+), 12 deletions(-) diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index e633140f4b..3e73ec78dc 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -498,6 +498,7 @@ A C C C C A A B **Notes** - Prior tokenization is not necessary, features will be inferred by using the `FeatInferTransform` transform. - `FilterFeatsTransform` and `FeatInferTransform` are required in order to ensure the functionality. +- Not possible to do shared embeddings (at least with `feat_merge: concat` method) Sample config file: @@ -529,10 +530,20 @@ feat_merge: "sum" ``` -During inference you can pass features by using the `--src_feats` argument. +During inference you can pass features by using the `--src_feats` argument. `src_feats` is expected to be a Python like dict, mapping feature name with its data file. + +``` +{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'} +``` **Important note!** During inference, input sentence is expected to be tokenized. Therefore feature inferring should be handled prior to running the translate command. Example: ```bash python translate.py -model model_step_10.pt -src ../data.txt.tok -output ../data.out --src_feats "{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}" ``` + +When using the Transformer arquitechture make sure the following options are appropiately set: + +- `src_word_vec_size` and `tgt_word_vec_size` or `word_vec_size` +- `feat_merge`: how to handle features vecs +- `feat_vec_size` and maybe `feat_vec_exponent` diff --git a/onmt/inputters/corpus.py b/onmt/inputters/corpus.py index 95237ca96b..cc903d3154 100644 --- a/onmt/inputters/corpus.py +++ b/onmt/inputters/corpus.py @@ -75,6 +75,7 @@ def _process(item, is_train): maybe_example['src'] = {"src": ' '.join(maybe_example['src'])} # Make features part of src as in MultiTextField + # {'src': {'src': ..., 'feat1': ...., 'feat2': ....}} if 'src_feats' in maybe_example: for feat_name, feat_value in maybe_example['src_feats'].items(): maybe_example['src'][feat_name] = ' '.join(feat_value) @@ -328,12 +329,12 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset): if opts.dump_samples: build_sub_vocab.queues[c_name][offset].put("blank") continue - src_line, tgt_line = maybe_example['src'], maybe_example['tgt'] + src_line, tgt_line = maybe_example['src']['src'], maybe_example['tgt']['tgt'] for feat_name, feat_line in maybe_example["src"].items(): if feat_name != "src": sub_counter_src_feats[feat_name].update(feat_line.split(' ')) - sub_counter_src.update(src_line["src"].split(' ')) - sub_counter_tgt.update(tgt_line["tgt"].split(' ')) + sub_counter_src.update(src_line.split(' ')) + sub_counter_tgt.update(tgt_line.split(' ')) if opts.dump_samples: build_sub_vocab.queues[c_name][offset].put( (i, src_line, tgt_line)) diff --git a/onmt/inputters/text_dataset.py b/onmt/inputters/text_dataset.py index 46346b2a82..46df8c91c3 100644 --- a/onmt/inputters/text_dataset.py +++ b/onmt/inputters/text_dataset.py @@ -17,6 +17,9 @@ def read(self, sequences, side, features={}): path to text file or iterable of the actual text data. side (str): Prefix used in return dict. Usually ``"src"`` or ``"tgt"``. + features: (Dict[str or Iterable[str]]): + dictionary mapping feature names with th path to feature + file or iterable of the actual feature data. Yields: dictionaries whose keys are the names of fields and whose @@ -53,6 +56,7 @@ def text_sort_key(ex): return len(ex.src[0]) +# Legacy function. Currently it only truncates input if truncate is set. # mix this with partial def _feature_tokenize( string, layer=0, tok_delim=None, feat_delim=None, truncate=None): diff --git a/onmt/opts.py b/onmt/opts.py index c6ef415f03..4c37ab952d 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -761,8 +761,8 @@ def translate_opts(parser): help="Source sequence to decode (one line per " "sequence)") group.add("-src_feats", "--src_feats", required=False, - help="Source sequence features (one line per " - "sequence)") + help="Source sequence features (dict format). " + "Ex: {'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}") group.add('--tgt', '-tgt', help='True target sequence (optional)') group.add('--tgt_prefix', '-tgt_prefix', action='store_true', diff --git a/onmt/tests/test_subword_marker.py b/onmt/tests/test_subword_marker.py index d1cb0b153f..1b8337b56e 100644 --- a/onmt/tests/test_subword_marker.py +++ b/onmt/tests/test_subword_marker.py @@ -41,12 +41,18 @@ def test_subword_group_joiner(self): out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP) self.assertEqual(out, true_out) - def test_subword_group_joiner_with_markup(self): + def test_subword_group_joiner_with_case_markup(self): data_in = ['⦅mrk_case_modifier_C⦆', 'however', '■,', 'according', 'to', 'the', 'logs', '■,', '⦅mrk_begin_case_region_U⦆', 'she', 'is', 'hard', '■-■', 'working', '■.', '⦅mrk_end_case_region_U⦆'] # noqa: E501 true_out = [0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7] out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP) self.assertEqual(out, true_out) + def test_subword_group_joiner_with_new_joiner(self): + data_in = ['⦅mrk_case_modifier_C⦆', 'however', '■', ',', 'according', 'to', 'the', 'logs', '■', ',', '⦅mrk_begin_case_region_U⦆', 'she', 'is', 'hard', '■', '-', '■', 'working', '■', '.', '⦅mrk_end_case_region_U⦆'] # noqa: E501 + true_out = [0, 0, 0, 0, 1, 2, 3, 4, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7] + out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER, case_markup=SubwordMarker.CASE_MARKUP) + self.assertEqual(out, true_out) + def test_subword_group_naive(self): data_in = ['however', ',', 'according', 'to', 'the', 'logs', ',', 'she', 'is', 'hard', '-', 'working', '.'] # noqa: E501 true_out = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] @@ -63,6 +69,18 @@ def test_subword_group_spacer(self): no_dummy_out = subword_map_by_spacer(no_dummy) self.assertEqual(no_dummy_out, true_out) + def test_subword_group_spacer_with_case_markup(self): + data_in = ['⦅mrk_case_modifier_C⦆', '▁however', ',', '▁according', '▁to', '▁the', '▁logs', ',', '▁⦅mrk_begin_case_region_U⦆', '▁she', '▁is', '▁hard', '-', 'working', '.', '▁⦅mrk_end_case_region_U⦆'] # noqa: E501 + true_out = [0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7] + out = subword_map_by_spacer(data_in) + self.assertEqual(out, true_out) + + def test_subword_group_spacer_with_spacer_new(self): + data_in = ['⦅mrk_case_modifier_C⦆', '▁', 'however', ',', '▁', 'according', '▁', 'to', '▁', 'the', '▁', 'logs', ',', '▁', '⦅mrk_begin_case_region_U⦆', '▁', 'she', '▁', 'is', '▁', 'hard', '-', 'working', '.', '▁', '⦅mrk_end_case_region_U⦆'] # noqa: E501 + true_out = [0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7] + out = subword_map_by_spacer(data_in) + self.assertEqual(out, true_out) + if __name__ == '__main__': unittest.main() diff --git a/onmt/transforms/features.py b/onmt/transforms/features.py index 01dfcd0251..7598c6f07f 100644 --- a/onmt/transforms/features.py +++ b/onmt/transforms/features.py @@ -63,12 +63,9 @@ def apply(self, example, is_train=False, stats=None, **kwargs): # Do nothing return example - # TODO: support joiner_new or spacer_new options. Consistency not ensured currently - if self.reversible_tokenization == "joiner": word_to_subword_mapping = subword_map_by_joiner(example["src"]) else: #Spacer - # TODO: case markup word_to_subword_mapping = subword_map_by_spacer(example["src"]) inferred_feats = defaultdict(list) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index abb475f70d..4d37e982fa 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -346,6 +346,7 @@ def translate( Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. + src_feats: See :func`self.src_reader.read()`. batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging align_debug (bool): enables the word alignment logging diff --git a/onmt/utils/alignment.py b/onmt/utils/alignment.py index d9b1919a8f..d775cf920c 100644 --- a/onmt/utils/alignment.py +++ b/onmt/utils/alignment.py @@ -134,9 +134,29 @@ def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER, case_markup=Sub return word_group -def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER): +def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER, case_markup=SubwordMarker.CASE_MARKUP): """Return word id for each subword token (annotate by spacer).""" - word_group = list(accumulate([int(marker in x) for x in subwords])) + flags = [0] * len(subwords) + for i, tok in enumerate(subwords): + if marker in tok: + if tok.replace(marker, "") in case_markup: + if i < len(subwords)-1: + flags[i] = 1 + else: + if i > 0: + previous = subwords[i-1].replace(marker, "") + if previous not in case_markup: + flags[i] = 1 + + # In case there is a final case_markup when new_spacer is on + for i in range(1,len(subwords)-1): + if subwords[-i] in case_markup: + flags[-i] = 0 + elif subwords[-i] == marker: + flags[-i] = 0 + break + + word_group = list(accumulate(flags)) if word_group[0] == 1: # when dummy prefix is set word_group = [item - 1 for item in word_group] return word_group From d18392caeebd38d0ee34ec4921412c35ab4c55e4 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Thu, 9 Sep 2021 10:34:23 +0200 Subject: [PATCH 21/23] Remove pdb traces --- onmt/train_single.py | 3 --- onmt/transforms/features.py | 1 - 2 files changed, 4 deletions(-) diff --git a/onmt/train_single.py b/onmt/train_single.py index 0a4b153c8e..925c472119 100644 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -58,9 +58,6 @@ def main(opt, fields, transforms_cls, checkpoint, device_id, """Start training on `device_id`.""" # NOTE: It's important that ``opt`` has been validated and updated # at this point. - - #import pdb - #pdb.set_trace() configure_process(opt, device_id) init_logger(opt.log_file) diff --git a/onmt/transforms/features.py b/onmt/transforms/features.py index 7598c6f07f..24f02e30fe 100644 --- a/onmt/transforms/features.py +++ b/onmt/transforms/features.py @@ -7,7 +7,6 @@ from collections import defaultdict - @register_transform(name='filterfeats') class FilterFeatsTransform(Transform): """Filter out examples with a mismatch between source and features.""" From a736103400276ce5f8cbdb84c3bc15c218b7a74b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Thu, 9 Sep 2021 09:04:36 +0000 Subject: [PATCH 22/23] fix some typos --- docs/source/FAQ.md | 2 +- onmt/inputters/corpus.py | 2 +- onmt/inputters/text_dataset.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index 3e73ec78dc..b194de3bb9 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -542,7 +542,7 @@ During inference you can pass features by using the `--src_feats` argument. `src python translate.py -model model_step_10.pt -src ../data.txt.tok -output ../data.out --src_feats "{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}" ``` -When using the Transformer arquitechture make sure the following options are appropiately set: +When using the Transformer architechture make sure the following options are appropriately set: - `src_word_vec_size` and `tgt_word_vec_size` or `word_vec_size` - `feat_merge`: how to handle features vecs diff --git a/onmt/inputters/corpus.py b/onmt/inputters/corpus.py index cc903d3154..87da65139b 100644 --- a/onmt/inputters/corpus.py +++ b/onmt/inputters/corpus.py @@ -74,7 +74,7 @@ def _process(item, is_train): maybe_example['src'] = {"src": ' '.join(maybe_example['src'])} - # Make features part of src as in MultiTextField + # Make features part of src as in TextMultiField # {'src': {'src': ..., 'feat1': ...., 'feat2': ....}} if 'src_feats' in maybe_example: for feat_name, feat_value in maybe_example['src_feats'].items(): diff --git a/onmt/inputters/text_dataset.py b/onmt/inputters/text_dataset.py index 46df8c91c3..a55d2593b2 100644 --- a/onmt/inputters/text_dataset.py +++ b/onmt/inputters/text_dataset.py @@ -18,7 +18,7 @@ def read(self, sequences, side, features={}): side (str): Prefix used in return dict. Usually ``"src"`` or ``"tgt"``. features: (Dict[str or Iterable[str]]): - dictionary mapping feature names with th path to feature + dictionary mapping feature names with the path to feature file or iterable of the actual feature data. Yields: From bd4a01da5a73f07a205490a8c76097bf1e02acf4 Mon Sep 17 00:00:00 2001 From: Ander Corral Date: Thu, 9 Sep 2021 11:37:21 +0200 Subject: [PATCH 23/23] Fixed typo --- docs/source/FAQ.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index b194de3bb9..8f618f6c6e 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -542,7 +542,7 @@ During inference you can pass features by using the `--src_feats` argument. `src python translate.py -model model_step_10.pt -src ../data.txt.tok -output ../data.out --src_feats "{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}" ``` -When using the Transformer architechture make sure the following options are appropriately set: +When using the Transformer architecture make sure the following options are appropriately set: - `src_word_vec_size` and `tgt_word_vec_size` or `word_vec_size` - `feat_merge`: how to handle features vecs