Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Source features support for V2.0 #2090

Merged
merged 24 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def build_vocab_main(opts):

logger.info(f"Counters src:{len(src_counter)}")
logger.info(f"Counters tgt:{len(tgt_counter)}")
for k, v in src_feats_counter["src_feats"].items():
logger.info(f"Counters {k}:{len(v)}")
for feat_name, feat_counter in src_feats_counter.items():
logger.info(f"Counters {feat_name}:{len(feat_counter)}")

def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
Expand All @@ -55,7 +55,7 @@ def save_counter(counter, save_path):
save_counter(src_counter, opts.src_vocab)
save_counter(tgt_counter, opts.tgt_vocab)

for k, v in src_feats_counter["src_feats"].items():
for k, v in src_feats_counter.items():
save_counter(v, opts.src_feats_vocab[k])


Expand Down
1 change: 1 addition & 0 deletions onmt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class CorpusName(object):
class SubwordMarker(object):
SPACER = '▁'
JOINER = '■'
CASE_MARKUP = ["⦅mrk_case_modifier_C⦆", "⦅mrk_begin_case_region_U⦆", "⦅mrk_end_case_region_U⦆"]


class ModelTask(object):
Expand Down
37 changes: 18 additions & 19 deletions onmt/inputters/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,12 @@ def load(self, offset=0, stride=1):
`offset` and `stride` allow to iterate only on every
`stride` example, starting from `offset`.
"""
#import pdb
#pdb.set_trace()
if self.src_feats:
features_files = [open(feat_path, mode='rb') for feat_name, feat_path in self.src_feats.items()]
features_names = []
features_files = []
for feat_name, feat_path in self.src_feats.items():
features_names.append(feat_name)
features_files.append(open(feat_path, mode='rb'))
else:
features_files = []
with exfile_open(self.src, mode='rb') as fs,\
Expand All @@ -146,7 +148,7 @@ def load(self, offset=0, stride=1):
if features:
example["src_feats"] = dict()
for j, feat in enumerate(features):
example["src_feats"][list(self.src_feats.keys())[j]] = feat.decode("utf-8")
example["src_feats"][features_names[j]] = feat.decode("utf-8")
yield example
for f in features_files:
f.close()
Expand Down Expand Up @@ -304,11 +306,9 @@ def write_files_from_queues(sample_path, queues):

def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
"""Build vocab on (strided) subpart of the data."""
#import pdb
#pdb.set_trace()
sub_counter_src = Counter()
sub_counter_tgt = Counter()
sub_counter_src_feats = {'src_feats': defaultdict(Counter)}
sub_counter_src_feats = defaultdict(Counter)
datasets_iterables = build_corpora_iters(
corpora, transforms, opts.data,
skip_empty_level=opts.skip_empty_level,
Expand All @@ -323,7 +323,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
src_line, tgt_line = maybe_example['src'], maybe_example['tgt']
if 'src_feats' in maybe_example:
for feat_name, feat_line in maybe_example["src_feats"].items():
sub_counter_src_feats['src_feats'][feat_name].update(feat_line.split(' '))
sub_counter_src_feats[feat_name].update(feat_line.split(' '))
sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))
if opts.dump_samples:
Expand Down Expand Up @@ -359,7 +359,7 @@ def build_vocab(opts, transforms, n_sample=3):
corpora = get_corpora(opts, is_train=True)
counter_src = Counter()
counter_tgt = Counter()
counter_src_feats = {'src_feats': defaultdict(Counter)}
counter_src_feats = defaultdict(Counter)
from functools import partial
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
for i in range(opts.num_threads)]
Expand All @@ -372,16 +372,15 @@ def build_vocab(opts, transforms, n_sample=3):
args=(sample_path, queues),
daemon=True)
write_process.start()
#with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
sub_counter_src, sub_counter_tgt, sub_counter_src_feats = func(0)
# for sub_counter_src, sub_counter_tgt in p.imap(
# func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
counter_src_feats.update(sub_counter_src_feats)
with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
counter_src_feats.update(sub_counter_src_feats)
if opts.dump_samples:
write_process.join()
return counter_src, counter_tgt, counter_src_feats
Expand Down
7 changes: 3 additions & 4 deletions onmt/inputters/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@


def _get_dynamic_fields(opts):
# NOTE: not support nfeats > 0 yet
#src_nfeats = 0
tgt_nfeats = None #0
# NOTE: not support tgt feats yet
tgt_feats = None
with_align = hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0
fields = get_fields('text', opts.src_feats_vocab, tgt_nfeats,
fields = get_fields('text', opts.src_feats_vocab, tgt_feats,
dynamic_dict=opts.copy_attn,
src_truncate=opts.src_seq_length_trunc,
tgt_truncate=opts.tgt_seq_length_trunc,
Expand Down
12 changes: 6 additions & 6 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def get_task_spec_tokens(data_task, pad, bos, eos):

def get_fields(
src_data_type,
n_src_feats,
n_tgt_feats,
src_feats,
tgt_feats,
pad=DefaultTokens.PAD,
bos=DefaultTokens.BOS,
eos=DefaultTokens.EOS,
Expand All @@ -125,11 +125,11 @@ def get_fields(
"""
Args:
src_data_type: type of the source input. Options are [text].
n_src_feats (int): the number of source features (not counting tokens)
src_feats (int): source features dict containing their names
to create a :class:`torchtext.data.Field` for. (If
``src_data_type=="text"``, these fields are stored together
as a ``TextMultiField``).
n_tgt_feats (int): See above.
tgt_feats (int): See above.
anderleich marked this conversation as resolved.
Show resolved Hide resolved
pad (str): Special pad symbol. Used on src and tgt side.
bos (str): Special beginning of sequence symbol. Only relevant
for tgt.
Expand Down Expand Up @@ -158,7 +158,7 @@ def get_fields(
task_spec_tokens = get_task_spec_tokens(data_task, pad, bos, eos)

src_field_kwargs = {
"n_feats": n_src_feats,
"feats": src_feats,
"include_lengths": True,
"pad": task_spec_tokens["src"]["pad"],
"bos": task_spec_tokens["src"]["bos"],
Expand All @@ -169,7 +169,7 @@ def get_fields(
fields["src"] = fields_getters[src_data_type](**src_field_kwargs)

tgt_field_kwargs = {
"n_feats": n_tgt_feats,
"feats": tgt_feats,
"include_lengths": False,
"pad": task_spec_tokens["tgt"]["pad"],
"bos": task_spec_tokens["tgt"]["bos"],
Expand Down
11 changes: 5 additions & 6 deletions onmt/inputters/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def text_fields(**kwargs):

Args:
base_name (str): Name associated with the field.
n_feats (int): Number of word level feats (not counting the tokens)
feats (int): Word level feats
include_lengths (bool): Optionally return the sequence lengths.
pad (str, optional): Defaults to ``"<blank>"``.
bos (str or NoneType, optional): Defaults to ``"<s>"``.
Expand All @@ -163,7 +163,7 @@ def text_fields(**kwargs):
TextMultiField
"""

n_feats = kwargs["n_feats"]
feats = kwargs["feats"]
include_lengths = kwargs["include_lengths"]
base_name = kwargs["base_name"]
pad = kwargs.get("pad", DefaultTokens.PAD)
Expand All @@ -187,10 +187,9 @@ def text_fields(**kwargs):
fields_.append((base_name, feat))

# Feats fields
#for i in range(n_feats + 1):
if n_feats:
for feat_name in n_feats.keys():
#name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name
if feats:
for feat_name in feats.keys():
# Legacy function, it is not really necessary
tokenize = partial(
_feature_tokenize,
layer=None,
Expand Down
60 changes: 32 additions & 28 deletions onmt/transforms/features.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from onmt.utils.logging import logger
from onmt.transforms import register_transform
from .transform import Transform, ObservableStats
from onmt.constants import DefaultTokens, SubwordMarker
from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer
import re
from collections import defaultdict



@register_transform(name='filterfeats')
class FilterFeatsTransform(Transform):
"""Filter out examples with a mismatch between source and features."""
Expand Down Expand Up @@ -48,49 +51,50 @@ def add_options(cls, parser):
pass

def _parse_opts(self):
pass
super()._parse_opts()
logger.info("Parsed pyonmttok kwargs for src: {}".format(
self.opts.src_onmttok_kwargs))
self.src_onmttok_kwargs = self.opts.src_onmttok_kwargs
anderleich marked this conversation as resolved.
Show resolved Hide resolved

def apply(self, example, is_train=False, stats=None, **kwargs):

if "src_feats" not in example:
# Do nothing
return example

feats_i = 0
inferred_feats = defaultdict(list)
for subword in example["src"]:
next_ = False
for k, v in example["src_feats"].items():
# TODO: what about custom placeholders??
joiner = self.src_onmttok_kwargs["joiner"] if "joiner" in self.src_onmttok_kwargs else SubwordMarker.JOINER
anderleich marked this conversation as resolved.
Show resolved Hide resolved
case_markup = SubwordMarker.CASE_MARKUP if "case_markup" in self.src_onmttok_kwargs else []
# TODO: support joiner_new or spacer_new options. Consistency not ensured currently
anderleich marked this conversation as resolved.
Show resolved Hide resolved

# Placeholders
if re.match(r'⦅\w+⦆', subword):
inferred_feat = "N"
if "joiner_annotate" in self.src_onmttok_kwargs:
word_to_subword_mapping = subword_map_by_joiner(example["src"], marker=joiner, case_markup=case_markup)
elif "spacer_annotate" in self.src_onmttok_kwargs:
# TODO: case markup
word_to_subword_mapping = subword_map_by_spacer(example["src"], marker=joiner)
else:
# TODO: support not reversible tokenization
raise Exception("InferFeats transform does not currently work without either joiner_annotate or spacer_annotate")

# Punctuation only
elif not re.sub(r'(\W)+', '', subword).strip():
inferred_feat = "N"
inferred_feats = defaultdict(list)
for subword, word_id in zip(example["src"], word_to_subword_mapping):
for feat_name, feat_values in example["src_feats"].items():

# Joiner annotate
elif re.search("■", subword):
inferred_feat = v[feats_i]
# If case markup placeholder
if subword in case_markup:
inferred_feat = "<null>"

# Punctuation only (assumes joiner is also some punctuation token)
elif not re.sub(r'(\W)+', '', subword).strip():
inferred_feat = "<null>"

# Whole word
else:
inferred_feat = v[feats_i]
next_ = True
inferred_feat = feat_values[word_id]

inferred_feats[k].append(inferred_feat)

if next_:
feats_i += 1
inferred_feats[feat_name].append(inferred_feat)

# Check all features have been consumed
for k, v in example["src_feats"].items():
assert feats_i == len(v), f'Not all features consumed for {k}'
for feat_name, feat_values in inferred_feats.items():
example["src_feats"][feat_name] = inferred_feats[feat_name]

for k, v in inferred_feats.items():
example["src_feats"][k] = inferred_feats[k]
return example

def _repr_args(self):
Expand Down
4 changes: 2 additions & 2 deletions onmt/utils/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def to_word_align(src, tgt, subword_align, m_src='joiner', m_tgt='joiner'):
return " ".join(word_align)


def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER):
def subword_map_by_joiner(subwords, marker=SubwordMarker.JOINER, case_markup=[]):
"""Return word id for each subword token (annotate by joiner)."""
flags = [0] * len(subwords)
for i, tok in enumerate(subwords):
if tok.endswith(marker):
if tok.endswith(marker) or tok in case_markup:
flags[i] = 1
if tok.startswith(marker):
assert i >= 1 and flags[i-1] != 1, \
Expand Down