Skip to content

Commit

Permalink
return encoder representations only if necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez committed Feb 17, 2020
1 parent 306b2e5 commit ff858ac
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
17 changes: 9 additions & 8 deletions onmt/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,23 @@ def forward(self, src, tgt, lengths, bptt=False,

enc_state, memory_bank, lengths = self.encoder(src, lengths)

if bptt is False:
self.decoder.init_state(src, memory_bank, enc_state)

dec_out, attns = self.decoder(dec_in, memory_bank,
memory_lengths=lengths,
with_align=with_align)

if encode_tgt:
# tgt for zero shot alignment loss
tgt_lengths = torch.Tensor(tgt.size(1))\
.type_as(memory_bank) \
.long() \
.fill_(tgt.size(0))
embs_tgt, memory_bank_tgt, ltgt = self.encoder(tgt, tgt_lengths)
else:
memory_bank_tgt = None
return dec_out, attns, memory_bank, memory_bank_tgt

if bptt is False:
self.decoder.init_state(src, memory_bank, enc_state)
dec_out, attns = self.decoder(dec_in, memory_bank,
memory_lengths=lengths,
with_align=with_align)
return dec_out, attns, memory_bank, memory_bank_tgt
return dec_out, attns

def update_dropout(self, dropout):
self.encoder.update_dropout(dropout)
Expand Down
33 changes: 24 additions & 9 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,21 @@ def validate(self, valid_iter, moving_average=None):
tgt = batch.tgt

# F-prop through the model.
outputs, attns, enc_src, enc_tgt = valid_model(
src, tgt, src_lengths,
with_align=self.with_align)
if self.encode_tgt:
outputs, attns, enc_src, enc_tgt = valid_model(
src, tgt, src_lengths,
with_align=self.with_align,
encode_tgt=self.encode_tgt)
else:
output, attns = valid_model(
src, tgt, src_lengths,
with_align=self.with_align)
enc_src, enc_tgt = None, None

# Compute loss.
_, batch_stats = self.valid_loss(
batch, outputs, attns, enc_src, enc_tgt)
batch, outputs, attns,
enc_src=enc_src, enc_tgt=enc_tgt)

# Update statistics.
stats.update(batch_stats)
Expand Down Expand Up @@ -366,9 +374,16 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
if self.accum_count == 1:
self.optim.zero_grad()

outputs, attns, enc_src, enc_tgt = self.model(
src, tgt, src_lengths, bptt=bptt,
with_align=self.with_align, encode_tgt=self.encode_tgt)
is self.encode_tgt:
outputs, attns, enc_src, enc_tgt = self.model(
src, tgt, src_lengths, bptt=bptt,
with_align=self.with_align, encode_tgt=self.encode_tgt)
else:
output, attns = self.model(
src, tgt, src_lengths, bptt=bptt,
with_align=self.with_align)
enc_src, enc_tgt = None, None

bptt = True

# 3. Compute loss.
Expand All @@ -377,8 +392,8 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
batch,
outputs,
attns,
enc_src,
enc_tgt,
enc_src=enc_src,
enc_tgt=enc_tgt,
normalization=normalization,
shard_size=self.shard_size,
trunc_start=j,
Expand Down
4 changes: 2 additions & 2 deletions onmt/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def __call__(self,
batch,
output,
attns,
enc_src,
enc_tgt,
enc_src=None,
enc_tgt=None,
normalization=1.0,
shard_size=0,
trunc_start=0,
Expand Down

0 comments on commit ff858ac

Please sign in to comment.