Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add target features #2315

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions data/data_features/tgt-train-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
however,│A according│A to│A the│A logs,│B she│A is│A a│A hard-working.│C
however,│A according│B to│C the│D logs,│E
she│C is│B a│A hard-working.│B
1 change: 1 addition & 0 deletions data/data_features/tgt-val-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she│C is│B a│A hard-working.│B
35 changes: 27 additions & 8 deletions onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
sub_counter_src = Counter()
sub_counter_tgt = Counter()
sub_counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
sub_counter_tgt_feats = [Counter() for _ in range(opts.n_tgt_feats)]
datasets_iterables = build_corpora_iters(
corpora, transforms, opts.data,
skip_empty_level=opts.skip_empty_level,
Expand All @@ -63,26 +64,36 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))

src_feats_lines = []
if 'feats' in maybe_example['src']:
src_feats_lines = maybe_example['src']['feats']
for i in range(opts.n_src_feats):
sub_counter_src_feats[i].update(
src_feats_lines[i].split(' '))
else:
src_feats_lines = []

tgt_feats_lines = []
if maybe_example["tgt"] is not None:
if 'feats' in maybe_example['tgt']:
tgt_feats_lines = maybe_example['tgt']['feats']
for i in range(opts.n_tgt_feats):
sub_counter_tgt_feats[i].update(
tgt_feats_lines[i].split(' '))

if opts.dump_samples:
src_pretty_line = append_features_to_text(
src_line, src_feats_lines)
tgt_pretty_line = append_features_to_text(
tgt_line, tgt_feats_lines)
build_sub_vocab.queues[c_name][offset].put(
(i, src_pretty_line, tgt_line))
(i, src_pretty_line, tgt_pretty_line))
if n_sample > 0 and ((i+1) * stride + offset) >= n_sample:
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
break
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
return sub_counter_src, sub_counter_tgt, sub_counter_src_feats
return (sub_counter_src, sub_counter_tgt,
sub_counter_src_feats, sub_counter_tgt_feats)


def init_pool(queues):
Expand All @@ -107,6 +118,7 @@ def build_vocab(opts, transforms, n_sample=3):
counter_src = Counter()
counter_tgt = Counter()
counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
counter_tgt_feats = [Counter() for _ in range(opts.n_tgt_feats)]
from functools import partial
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
for i in range(opts.num_threads)]
Expand All @@ -123,15 +135,18 @@ def build_vocab(opts, transforms, n_sample=3):
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
for (sub_counter_src, sub_counter_tgt,
sub_counter_src_feats, sub_counter_tgt_feats) in p.imap(
func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
for i in range(opts.n_src_feats):
counter_src_feats[i].update(sub_counter_src_feats[i])
for i in range(opts.n_tgt_feats):
counter_tgt_feats[i].update(sub_counter_tgt_feats[i])
if opts.dump_samples:
write_process.join()
return counter_src, counter_tgt, counter_src_feats
return counter_src, counter_tgt, counter_src_feats, counter_tgt_feats


def build_vocab_main(opts):
Expand All @@ -157,13 +172,15 @@ def build_vocab_main(opts):
transforms = make_transforms(opts, transforms_cls, None)

logger.info(f"Counter vocab from {opts.n_sample} samples.")
src_counter, tgt_counter, src_feats_counter = build_vocab(
opts, transforms, n_sample=opts.n_sample)
src_counter, tgt_counter, src_feats_counter, tgt_feats_counter = \
build_vocab(opts, transforms, n_sample=opts.n_sample)

logger.info(f"Counters src: {len(src_counter)}")
logger.info(f"Counters tgt: {len(tgt_counter)}")
for i, feat_counter in enumerate(src_feats_counter):
logger.info(f"Counters src feat_{i}: {len(feat_counter)}")
for i, feat_counter in enumerate(tgt_feats_counter):
logger.info(f"Counters tgt feat_{i}: {len(feat_counter)}")

def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
Expand All @@ -182,6 +199,8 @@ def save_counter(counter, save_path):

for i, c in enumerate(src_feats_counter):
save_counter(c, f"{opts.src_vocab}_feat{i}")
for i, c in enumerate(tgt_feats_counter):
save_counter(c, f"{opts.tgt_vocab}_feat{i}")


def _get_parser():
Expand Down
27 changes: 21 additions & 6 deletions onmt/decoders/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,29 @@ def forward(self, hidden, attn=None, src_map=None):
by averaging distributions from models in the ensemble.
All models in the ensemble must share a target vocabulary.
"""
distributions = torch.stack(
[mg(h) if attn is None else mg(h, attn, src_map)
for h, mg in zip(hidden, self.model_generators)]
)
distributions, feats_distributions = [], []
n_feats = len(self.model_generators[0].feats_generators)
for h, mg in zip(hidden, self.model_generators):
scores, feats_scores = \
(mg(h) if attn is None else mg(h, attn, src_map))
distributions.append(scores)
feats_distributions.append(feats_scores)

distributions = torch.stack(distributions)

stacked_feats_distributions = []
for i in range(n_feats):
stacked_feats_distributions.append(
torch.stack([feat_distribution[i]
for feat_distribution in feats_distributions
for i in range(n_feats)]))
if self._raw_probs:
return torch.log(torch.exp(distributions).mean(0))
return (torch.log(torch.exp(distributions).mean(0)),
[torch.log(torch.exp(d).mean(0))
for d in stacked_feats_distributions])
else:
return distributions.mean(0)
return (distributions.mean(0),
[d.mean(0) for d in stacked_feats_distributions])


class EnsembleModel(NMTModel):
Expand Down
30 changes: 29 additions & 1 deletion onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ def build_vocab(opt, specials):
""" Build vocabs dict to be stored in the checkpoint
based on vocab files having each line [token, count]
Args:
opt: src_vocab, tgt_vocab, n_src_feats
opt: src_vocab, tgt_vocab, n_src_feats, n_tgt_feats
Return:
vocabs: {'src': pyonmttok.Vocab, 'tgt': pyonmttok.Vocab,
'src_feats' : [pyonmttok.Vocab, ...]},
'tgt_feats' : [pyonmttok.Vocab, ...]},
'data_task': seq2seq or lm
}
"""
Expand Down Expand Up @@ -103,6 +104,25 @@ def _pad_vocab_to_multiple(vocab, multiple):
src_feats_vocabs.append(src_f_vocab)
vocabs["src_feats"] = src_feats_vocabs

if opt.n_tgt_feats > 0:
tgt_feats_vocabs = []
for i in range(opt.n_tgt_feats):
tgt_f_vocab = _read_vocab_file(f"{opt.tgt_vocab}_feat{i}", 1)
tgt_f_vocab = pyonmttok.build_vocab_from_tokens(
tgt_f_vocab,
maximum_size=0,
minimum_frequency=1,
special_tokens=[DefaultTokens.UNK,
DefaultTokens.PAD,
DefaultTokens.BOS,
DefaultTokens.EOS])
tgt_f_vocab.default_id = tgt_f_vocab[DefaultTokens.UNK]
if opt.vocab_size_multiple > 1:
tgt_f_vocab = _pad_vocab_to_multiple(tgt_f_vocab,
opt.vocab_size_multiple)
tgt_feats_vocabs.append(tgt_f_vocab)
vocabs["tgt_feats"] = tgt_feats_vocabs

vocabs['data_task'] = opt.data_task

return vocabs
Expand Down Expand Up @@ -147,6 +167,9 @@ def vocabs_to_dict(vocabs):
if 'src_feats' in vocabs.keys():
vocabs_dict['src_feats'] = [feat_vocab.ids_to_tokens
for feat_vocab in vocabs['src_feats']]
if 'tgt_feats' in vocabs.keys():
vocabs_dict['tgt_feats'] = [feat_vocab.ids_to_tokens
for feat_vocab in vocabs['tgt_feats']]
vocabs_dict['data_task'] = vocabs['data_task']
return vocabs_dict

Expand All @@ -168,4 +191,9 @@ def dict_to_vocabs(vocabs_dict):
for feat_vocab in vocabs_dict['src_feats']:
vocabs['src_feats'].append(
pyonmttok.build_vocab_from_tokens(feat_vocab))
if 'tgt_feats' in vocabs_dict.keys():
vocabs['tgt_feats'] = []
for feat_vocab in vocabs_dict['tgt_feats']:
vocabs['tgt_feats'].append(
pyonmttok.build_vocab_from_tokens(feat_vocab))
return vocabs
29 changes: 23 additions & 6 deletions onmt/inputters/text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,17 @@ class ParallelCorpus(object):
"""A parallel corpus file pair that can be loaded to iterate."""

def __init__(self, name, src, tgt, align=None,
n_src_feats=0, src_feats_defaults=None):
n_src_feats=0, src_feats_defaults=None,
n_tgt_feats=0, tgt_feats_defaults=None):
"""Initialize src & tgt side file path."""
self.id = name
self.src = src
self.tgt = tgt
self.align = align
self.n_src_feats = n_src_feats
self.src_feats_defaults = src_feats_defaults
self.n_tgt_feats = n_tgt_feats
self.tgt_feats_defaults = tgt_feats_defaults

def load(self, offset=0, stride=1):
"""
Expand All @@ -68,6 +71,10 @@ def load(self, offset=0, stride=1):
defaults=self.src_feats_defaults)
if tline is not None:
tline = tline.decode('utf-8')
tline, tfeats = parse_features(
tline,
n_feats=self.n_tgt_feats,
defaults=self.tgt_feats_defaults)
# 'src_original' and 'tgt_original' store the
# original line before tokenization. These
# fields are used later on in the feature
Expand All @@ -82,15 +89,19 @@ def load(self, offset=0, stride=1):
example['align'] = align.decode('utf-8')

if sfeats is not None:
example['src_feats'] = [f for f in sfeats]
example['src_feats'] = sfeats
if tline is not None and tfeats is not None:
example['tgt_feats'] = tfeats
yield example

def __str__(self):
cls_name = type(self).__name__
return f'{cls_name}({self.id}, {self.src}, {self.tgt}, ' \
f'align={self.align}, ' \
f'n_src_feats={self.n_src_feats}, ' \
f'src_feats_defaults="{self.src_feats_defaults}")'
f'src_feats_defaults="{self.src_feats_defaults}", ' \
f'n_tgt_feats={self.n_tgt_feats}, ' \
f'tgt_feats_defaults="{self.tgt_feats_defaults}")'


def get_corpora(opts, task=CorpusTask.TRAIN):
Expand All @@ -104,7 +115,9 @@ def get_corpora(opts, task=CorpusTask.TRAIN):
corpus_dict["path_tgt"],
corpus_dict["path_align"],
n_src_feats=opts.n_src_feats,
src_feats_defaults=opts.src_feats_defaults)
src_feats_defaults=opts.src_feats_defaults,
n_tgt_feats=opts.n_tgt_feats,
tgt_feats_defaults=opts.tgt_feats_defaults)
elif task == CorpusTask.VALID:
if CorpusName.VALID in opts.data.keys():
corpora_dict[CorpusName.VALID] = ParallelCorpus(
Expand All @@ -113,7 +126,9 @@ def get_corpora(opts, task=CorpusTask.TRAIN):
opts.data[CorpusName.VALID]["path_tgt"],
opts.data[CorpusName.VALID]["path_align"],
n_src_feats=opts.n_src_feats,
src_feats_defaults=opts.src_feats_defaults)
src_feats_defaults=opts.src_feats_defaults,
n_tgt_feats=opts.n_tgt_feats,
tgt_feats_defaults=opts.tgt_feats_defaults)
else:
return None
else:
Expand All @@ -122,7 +137,9 @@ def get_corpora(opts, task=CorpusTask.TRAIN):
opts.src,
opts.tgt,
n_src_feats=opts.n_src_feats,
src_feats_defaults=opts.src_feats_defaults)
src_feats_defaults=opts.src_feats_defaults,
n_tgt_feats=opts.n_tgt_feats,
tgt_feats_defaults=opts.tgt_feats_defaults)
return corpora_dict


Expand Down
60 changes: 40 additions & 20 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from onmt.encoders import str2enc
from onmt.decoders import str2dec
from onmt.inputters.inputter import dict_to_vocabs
from onmt.modules import Embeddings, CopyGenerator
from onmt.modules import Embeddings, Generator
from onmt.utils.misc import use_gpu
from onmt.utils.logging import logger
from onmt.utils.parse import ArgumentParser
Expand Down Expand Up @@ -193,20 +193,49 @@ def use_embeddings_from_checkpoint(vocabs, model, generator, checkpoint):
emb_name
][old_i]
if side == 'tgt':
generator.state_dict()['weight'][i] = checkpoint[
'generator'
]['weight'][old_i]
generator.state_dict()['bias'][i] = checkpoint[
'generator'
]['bias'][old_i]
# TODO: check feats generators
generator.state_dict()['tgt_generator.weight'][i] = \
checkpoint['generator']['tgt_generator.weight'][old_i]
generator.state_dict()['tgt_generator.bias'][i] = \
checkpoint['generator']['tgt_generator.bias'][old_i]
else:
# Just for debugging purposes
new_tokens.append(tok)
logger.info("%s: %d new tokens" % (side, len(new_tokens)))

# Remove old vocabulary associated embeddings
del checkpoint['model'][emb_name]
del checkpoint['generator']['weight'], checkpoint['generator']['bias']
del checkpoint['generator']['tgt_generator.weight']
del checkpoint['generator']['tgt_generator.bias']


def build_generator(model_opt, vocabs, decoder):
gen_sizes = [len(vocabs['tgt'])]
if 'tgt_feats' in vocabs:
gen_sizes += [len(feat_vocab) for feat_vocab in vocabs['tgt_feats']]

if model_opt.share_decoder_embeddings:
hid_sizes = ([model_opt.dec_hid_size -
(model_opt.feat_vec_size * (len(gen_sizes) - 1))]
+ [model_opt.feat_vec_size] * (len(gen_sizes) - 1))
else:
hid_sizes = [model_opt.dec_hid_size] * len(gen_sizes)

pad_idx = vocabs['tgt'][DefaultTokens.PAD]
generator = Generator(hid_sizes, gen_sizes,
shared=model_opt.share_decoder_embeddings,
copy_attn=model_opt.copy_attn,
pad_idx=pad_idx)

if model_opt.share_decoder_embeddings:
if not model_opt.share_decoder_embeddings:
generator.generators[0].weight = \
decoder.embeddings.word_lut.weight
else:
generator.generators[0].linear.weight = \
decoder.embeddings.word_lut.weight

return generator


def build_base_model(model_opt, vocabs, gpu, checkpoint=None, gpu_id=None):
Expand Down Expand Up @@ -243,18 +272,9 @@ def build_base_model(model_opt, vocabs, gpu, checkpoint=None, gpu_id=None):

model = build_task_specific_model(model_opt, vocabs)

# Build Generator.
if not model_opt.copy_attn:
generator = nn.Linear(model_opt.dec_hid_size,
len(vocabs['tgt']))
if model_opt.share_decoder_embeddings:
generator.weight = model.decoder.embeddings.word_lut.weight
else:
vocab_size = len(vocabs['tgt'])
pad_idx = vocabs['tgt'][DefaultTokens.PAD]
generator = CopyGenerator(model_opt.dec_hid_size, vocab_size, pad_idx)
if model_opt.share_decoder_embeddings:
generator.linear.weight = model.decoder.embeddings.word_lut.weight
# Build Generators
# Next token prediction and possibly target features generators
generator = build_generator(model_opt, vocabs, model.decoder)

# Load the model states from checkpoint or initialize them.
if checkpoint is None or model_opt.update_vocab:
Expand Down
2 changes: 1 addition & 1 deletion onmt/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(self, src, tgt, src_len, bptt=False, with_align=False):
* enc_out + enc_final_hs in the case of CNNs
* src in the case of Transformer
"""
dec_in = tgt[:, :-1, :]
dec_in = tgt[:, :-1, :1]
enc_out, enc_final_hs, src_len = self.encoder(src, src_len)
if not bptt:
self.decoder.init_state(src, enc_out, enc_final_hs)
Expand Down
Loading