diff --git a/parallel_wavegan/bin/train.py b/parallel_wavegan/bin/train.py index ea28c3f5..3c642997 100755 --- a/parallel_wavegan/bin/train.py +++ b/parallel_wavegan/bin/train.py @@ -597,6 +597,7 @@ def __call__(self, batch): c_batch = [c[start:end] for c, start, end in zip(cs, c_starts, c_ends)] # convert each batch to tensor, asuume that each item in batch has the same length + y_batch, c_batch = np.array(y_batch), np.array(c_batch) y_batch = torch.tensor(y_batch, dtype=torch.float).unsqueeze(1) # (B, 1, T) c_batch = torch.tensor(c_batch, dtype=torch.float).transpose(2, 1) # (B, C, T')