diff --git a/examples/seq2seq/train.py b/examples/seq2seq/train.py index 4b3cf01801..6b8b3bde0e 100644 --- a/examples/seq2seq/train.py +++ b/examples/seq2seq/train.py @@ -98,7 +98,7 @@ def encode(self, inputs): def decode(self, inputs): """Decode from list of integers to string.""" chars = [] - for elem in inputs: + for elem in inputs.tolist(): if elem == self.eos_id: break chars.append(self._indices_char[elem])