Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] The Missing Ingredient in Zero-Shot Neural Machine Translation #1714

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
15 changes: 14 additions & 1 deletion onmt/models/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Onmt NMT Model base class definition """
import torch.nn as nn
import torch


class NMTModel(nn.Module):
Expand All @@ -17,7 +18,8 @@ def __init__(self, encoder, decoder):
self.encoder = encoder
self.decoder = decoder

def forward(self, src, tgt, lengths, bptt=False, with_align=False):
def forward(self, src, tgt, lengths, bptt=False,
with_align=False, encode_tgt=False):
"""Forward propagate a `src` and `tgt` pair for training.
Possible initialized with a beginning decoder state.

Expand Down Expand Up @@ -46,9 +48,20 @@ def forward(self, src, tgt, lengths, bptt=False, with_align=False):

if bptt is False:
self.decoder.init_state(src, memory_bank, enc_state)

dec_out, attns = self.decoder(dec_in, memory_bank,
memory_lengths=lengths,
with_align=with_align)

if encode_tgt:
# tgt for zero shot alignment loss
tgt_lengths = torch.Tensor(tgt.size(1))\
.type_as(memory_bank) \
.long() \
.fill_(tgt.size(0))
embs_tgt, memory_bank_tgt, ltgt = self.encoder(tgt, tgt_lengths)
return dec_out, attns, memory_bank, memory_bank_tgt

return dec_out, attns

def update_dropout(self, dropout):
Expand Down
20 changes: 16 additions & 4 deletions onmt/modules/copy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,22 +186,24 @@ def __init__(self, criterion, generator, tgt_vocab, normalize_by_length,
self.tgt_vocab = tgt_vocab
self.normalize_by_length = normalize_by_length

def _make_shard_state(self, batch, output, range_, attns):
def _make_shard_state(self, batch, output, enc_src, enc_tgt,
range_, attns):
"""See base class for args description."""
if getattr(batch, "alignment", None) is None:
raise AssertionError("using -copy_attn you need to pass in "
"-dynamic_dict during preprocess stage.")

shard_state = super(CopyGeneratorLossCompute, self)._make_shard_state(
batch, output, range_, attns)
batch, output, enc_src, enc_tgt, range_, attns)

shard_state.update({
"copy_attn": attns.get("copy"),
"align": batch.alignment[range_[0] + 1: range_[1]]
})
return shard_state

def _compute_loss(self, batch, output, target, copy_attn, align,
def _compute_loss(self, batch, normalization, output, target,
copy_attn, align, enc_src=None, enc_tgt=None,
std_attn=None, coverage_attn=None):
"""Compute the loss.

Expand Down Expand Up @@ -244,8 +246,18 @@ def _compute_loss(self, batch, output, target, copy_attn, align,
offset_align = align[correct_mask] + len(self.tgt_vocab)
target_data[correct_mask] += offset_align

if self.lambda_cosine != 0.0:
cosine_loss, num_ex = self._compute_cosine_loss(enc_src, enc_tgt)
loss += self.lambda_cosine * (cosine_loss / num_ex)
else:
cosine_loss = None
num_ex = 0

# Compute sum of perplexities for stats
stats = self._stats(loss.sum().clone(), scores_data, target_data)
stats = self._stats(loss.sum().clone(),
cosine_loss.clone() if cosine_loss is not None
else cosine_loss,
scores_data, target_data, num_ex)

# this part looks like it belongs in CopyGeneratorLoss
if self.normalize_by_length:
Expand Down
3 changes: 3 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def model_opts(parser):
help='Train a coverage attention layer.')
group.add('--lambda_coverage', '-lambda_coverage', type=float, default=0.0,
help='Lambda value for coverage loss of See et al (2017)')
group.add('--lambda_cosine', '-lambda_cosine', type=float, default=0.0,
help='Lambda value for cosine alignment loss '
'of https://arxiv.org/abs/1903.07091 ')
group.add('--loss_scale', '-loss_scale', type=float, default=0,
help="For FP16 training, the static loss scale to use. If not "
"set, the loss scale is dynamically computed.")
Expand Down
37 changes: 30 additions & 7 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None):
model_dtype=opt.model_dtype,
earlystopper=earlystopper,
dropout=dropout,
dropout_steps=dropout_steps)
dropout_steps=dropout_steps,
encode_tgt=True if opt.lambda_cosine > 0 else False)
francoishernandez marked this conversation as resolved.
Show resolved Hide resolved
return trainer


Expand Down Expand Up @@ -107,7 +108,8 @@ def __init__(self, model, train_loss, valid_loss, optim,
n_gpu=1, gpu_rank=1, gpu_verbose_level=0,
report_manager=None, with_align=False, model_saver=None,
average_decay=0, average_every=1, model_dtype='fp32',
earlystopper=None, dropout=[0.3], dropout_steps=[0]):
earlystopper=None, dropout=[0.3], dropout_steps=[0],
encode_tgt=False):
# Basic attributes.
self.model = model
self.train_loss = train_loss
Expand All @@ -132,6 +134,7 @@ def __init__(self, model, train_loss, valid_loss, optim,
self.earlystopper = earlystopper
self.dropout = dropout
self.dropout_steps = dropout_steps
self.encode_tgt = encode_tgt

for i in range(len(self.accum_count_l)):
assert self.accum_count_l[i] > 0
Expand Down Expand Up @@ -314,11 +317,21 @@ def validate(self, valid_iter, moving_average=None):
tgt = batch.tgt

# F-prop through the model.
outputs, attns = valid_model(src, tgt, src_lengths,
with_align=self.with_align)
if self.encode_tgt:
outputs, attns, enc_src, enc_tgt = valid_model(
src, tgt, src_lengths,
with_align=self.with_align,
encode_tgt=self.encode_tgt)
else:
outputs, attns = valid_model(
src, tgt, src_lengths,
with_align=self.with_align)
enc_src, enc_tgt = None, None

# Compute loss.
_, batch_stats = self.valid_loss(batch, outputs, attns)
_, batch_stats = self.valid_loss(
batch, outputs, attns,
enc_src=enc_src, enc_tgt=enc_tgt)

# Update statistics.
stats.update(batch_stats)
Expand Down Expand Up @@ -361,8 +374,16 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
if self.accum_count == 1:
self.optim.zero_grad()

outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt,
with_align=self.with_align)
if self.encode_tgt:
outputs, attns, enc_src, enc_tgt = self.model(
src, tgt, src_lengths, bptt=bptt,
with_align=self.with_align, encode_tgt=self.encode_tgt)
else:
outputs, attns = self.model(
src, tgt, src_lengths, bptt=bptt,
with_align=self.with_align)
enc_src, enc_tgt = None, None

bptt = True

# 3. Compute loss.
Expand All @@ -371,6 +392,8 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
batch,
outputs,
attns,
enc_src=enc_src,
enc_tgt=enc_tgt,
normalization=normalization,
shard_size=self.shard_size,
trunc_start=j,
Expand Down
68 changes: 53 additions & 15 deletions onmt/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def build_loss_compute(model, tgt_field, opt, train=True):
else:
compute = NMTLossCompute(
criterion, loss_gen, lambda_coverage=opt.lambda_coverage,
lambda_align=opt.lambda_align)
lambda_align=opt.lambda_align, lambda_cosine=opt.lambda_cosine)
compute.to(device)

return compute
Expand Down Expand Up @@ -92,7 +92,8 @@ def __init__(self, criterion, generator):
def padding_idx(self):
return self.criterion.ignore_index

def _make_shard_state(self, batch, output, range_, attns=None):
def _make_shard_state(self, batch, enc_src, enc_tgt,
output, range_, attns=None):
"""
Make shard state dictionary for shards() to return iterable
shards for efficient loss computation. Subclass must define
Expand Down Expand Up @@ -123,6 +124,8 @@ def __call__(self,
batch,
output,
attns,
enc_src=None,
enc_tgt=None,
normalization=1.0,
shard_size=0,
trunc_start=0,
Expand Down Expand Up @@ -157,18 +160,20 @@ def __call__(self,
if trunc_size is None:
trunc_size = batch.tgt.size(0) - trunc_start
trunc_range = (trunc_start, trunc_start + trunc_size)
shard_state = self._make_shard_state(batch, output, trunc_range, attns)
shard_state = self._make_shard_state(
batch, output, enc_src, enc_tgt, trunc_range, attns)
if shard_size == 0:
loss, stats = self._compute_loss(batch, **shard_state)
return loss / float(normalization), stats
loss, stats = self._compute_loss(batch, normalization,
**shard_state)
return loss, stats
batch_stats = onmt.utils.Statistics()
for shard in shards(shard_state, shard_size):
loss, stats = self._compute_loss(batch, **shard)
loss.div(float(normalization)).backward()
loss, stats = self._compute_loss(batch, normalization, **shard)
loss.backward()
batch_stats.update(stats)
return None, batch_stats

def _stats(self, loss, scores, target):
def _stats(self, loss, cosine_loss, scores, target, num_ex):
"""
Args:
loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
Expand All @@ -182,7 +187,9 @@ def _stats(self, loss, scores, target):
non_padding = target.ne(self.padding_idx)
num_correct = pred.eq(target).masked_select(non_padding).sum().item()
num_non_padding = non_padding.sum().item()
return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct)
return onmt.utils.Statistics(
loss.item(), cosine_loss.item() if cosine_loss is not None else 0,
num_non_padding, num_correct, num_ex)

def _bottle(self, _v):
return _v.view(-1, _v.size(2))
Expand Down Expand Up @@ -227,15 +234,17 @@ class NMTLossCompute(LossComputeBase):
"""

def __init__(self, criterion, generator, normalization="sents",
lambda_coverage=0.0, lambda_align=0.0):
lambda_coverage=0.0, lambda_align=0.0, lambda_cosine=0.0):
super(NMTLossCompute, self).__init__(criterion, generator)
self.lambda_coverage = lambda_coverage
self.lambda_align = lambda_align
self.lambda_cosine = lambda_cosine

def _make_shard_state(self, batch, output, range_, attns=None):
def _make_shard_state(self, batch, output, enc_src, enc_tgt,
range_, attns=None):
shard_state = {
"output": output,
"target": batch.tgt[range_[0] + 1: range_[1], :, 0],
"target": batch.tgt[range_[0] + 1: range_[1], :, 0]
}
if self.lambda_coverage != 0.0:
coverage = attns.get("coverage", None)
Expand Down Expand Up @@ -273,9 +282,15 @@ def _make_shard_state(self, batch, output, range_, attns=None):
"align_head": attn_align,
"ref_align": ref_align[:, range_[0] + 1: range_[1], :]
})
if self.lambda_cosine != 0.0:
shard_state.update({
"enc_src": enc_src,
"enc_tgt": enc_tgt
})
return shard_state

def _compute_loss(self, batch, output, target, std_attn=None,
def _compute_loss(self, batch, normalization, output, target,
enc_src=None, enc_tgt=None, std_attn=None,
coverage_attn=None, align_head=None, ref_align=None):

bottled_output = self._bottle(output)
Expand All @@ -284,6 +299,7 @@ def _compute_loss(self, batch, output, target, std_attn=None,
gtruth = target.view(-1)

loss = self.criterion(scores, gtruth)

if self.lambda_coverage != 0.0:
coverage_loss = self._compute_coverage_loss(
std_attn=std_attn, coverage_attn=coverage_attn)
Expand All @@ -296,7 +312,20 @@ def _compute_loss(self, batch, output, target, std_attn=None,
align_loss = self._compute_alignement_loss(
align_head=align_head, ref_align=ref_align)
loss += align_loss
stats = self._stats(loss.clone(), scores, gtruth)

loss = loss/float(normalization)

if self.lambda_cosine != 0.0:
cosine_loss, num_ex = self._compute_cosine_loss(enc_src, enc_tgt)
loss += self.lambda_cosine * (cosine_loss / num_ex)
else:
cosine_loss = None
num_ex = 0

stats = self._stats(loss.clone() * normalization,
cosine_loss.clone() if cosine_loss is not None
else cosine_loss,
scores, gtruth, num_ex)

return loss, stats

Expand All @@ -305,6 +334,15 @@ def _compute_coverage_loss(self, std_attn, coverage_attn):
covloss *= self.lambda_coverage
return covloss

def _compute_cosine_loss(self, enc_src, enc_tgt):
max_src = enc_src.max(axis=0)[0]
max_tgt = enc_tgt.max(axis=0)[0]
cosine_loss = torch.nn.functional.cosine_similarity(
max_src.float(), max_tgt.float(), dim=1)
cosine_loss = 1 - cosine_loss
num_ex = cosine_loss.size(0)
return cosine_loss.sum(), num_ex

def _compute_alignement_loss(self, align_head, ref_align):
"""Compute loss between 2 partial alignment matrix."""
# align_head contains value in [0, 1) presenting attn prob,
Expand Down Expand Up @@ -368,7 +406,7 @@ def shards(state, shard_size, eval_only=False):
# over the shards, not over the keys: therefore, the values need
# to be re-zipped by shard and then each shard can be paired
# with the keys.
for shard_tensors in zip(*values):
for i, shard_tensors in enumerate(zip(*values)):
yield dict(zip(keys, shard_tensors))

# Assumed backprop'd
Expand Down
4 changes: 4 additions & 0 deletions onmt/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def validate_train_opts(cls, opt):
assert len(opt.attention_dropout) == len(opt.dropout_steps), \
"Number of attention_dropout values must match accum_steps values"

assert not(opt.max_generator_batches > 0 and opt.lambda_cosine != 0), \
"-lambda_cosine loss is not implemented " \
"for max_generator_batches > 0."

@classmethod
def validate_translate_opts(cls, opt):
if opt.beam_size != 1 and opt.random_sampling_topk != 1:
Expand Down
17 changes: 15 additions & 2 deletions onmt/utils/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ class Statistics(object):
* elapsed time
"""

def __init__(self, loss=0, n_words=0, n_correct=0):
def __init__(self, loss=0, cosine_loss=0, n_words=0,
n_correct=0, num_ex=0):
self.loss = loss
self.n_words = n_words
self.n_correct = n_correct
self.n_src_words = 0
self.start_time = time.time()
self.cosine_loss = cosine_loss
self.num_ex = num_ex

@staticmethod
def all_gather_stats(stat, max_size=4096):
Expand Down Expand Up @@ -81,6 +84,8 @@ def update(self, stat, update_n_src_words=False):
self.loss += stat.loss
self.n_words += stat.n_words
self.n_correct += stat.n_correct
self.cosine_loss += stat.cosine_loss
self.num_ex += stat.num_ex

if update_n_src_words:
self.n_src_words += stat.n_src_words
Expand All @@ -97,6 +102,10 @@ def ppl(self):
""" compute perplexity """
return math.exp(min(self.loss / self.n_words, 100))

def cos(self):
""" normalize cosine distance per example"""
return self.cosine_loss / self.num_ex

def elapsed_time(self):
""" compute elapsed time """
return time.time() - self.start_time
Expand All @@ -113,8 +122,12 @@ def output(self, step, num_steps, learning_rate, start):
step_fmt = "%2d" % step
if num_steps > 0:
step_fmt = "%s/%5d" % (step_fmt, num_steps)
if self.cosine_loss != 0:
cos_log = "cos: %4.2f; " % (self.cos())
else:
cos_log = ""
logger.info(
("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " +
("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + cos_log +
"lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec")
% (step_fmt,
self.accuracy(),
Expand Down