Skip to content

PyTorch implementation of batched bi-RNN encoder and attention-decoder.

Notifications You must be signed in to change notification settings

Biubiulity/PyTorch-Batch-Attention-Seq2seq

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 

Repository files navigation

A fast, batched Bi-RNN(GRU) encoder & attention decoder implementation in PyTorch

Usage: Please refer to offical pytorch tutorial on attention-RNN machine translation, except that this implementation handles batched inputs, and that it implements a slightly different attention mechanism.
To find out the formula-level difference of implementation, illustrations below will help a lot.

PyTorch version mechanism illustration, see here:
http://pytorch.org/tutorials/_images/decoder-network.png
PyTorch offical Seq2seq machine translation tutorial:
http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
Bahdanau attention illustration, see here:
http://images2015.cnblogs.com/blog/670089/201610/670089-20161012111504671-910168246.png

PyTorch version attention decoder fed "word_embedding" to compute attention weights, while in the origin paper it is supposed to be "encoder_outputs". In this repository, we implemented the origin attention decoder according to the paper

Update: dynamic encoder added and does not require inputs to be sorted by length in a batch.


Speed up with batched tensor manipulation

PyTorch supports element-wise fetching and assigning tensor values during procedure, but actually it is slow especially when running on GPU. In a tutorial(https://github.com/spro/practical-pytorch), attention values are assigned element-wise; it's absolutely correct(and intuitive from formulas in paper), but slow on our GPU. Thus, we re-implemented a real batched tensor manipulating version, and it achieves more than 10X speed improvement.

This code works well on personal projects.

About

PyTorch implementation of batched bi-RNN encoder and attention-decoder.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%