Skip to content

Commit

Permalink
Fix error
Browse files Browse the repository at this point in the history
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
hsiehjackson committed Apr 14, 2023
1 parent c183ac4 commit 98beb1c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions nemo/collections/tts/modules/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,13 @@ def output_types(self):
}

def get_speaker_embedding(self, speaker, reference_spec, reference_spec_lens):
"""spk_emb: BxD"""
"""spk_emb: Bx1xD"""
if self.speaker_encoder is not None:
spk_emb = self.speaker_encoder(speaker, reference_spec, reference_spec_lens)
spk_emb = self.speaker_encoder(speaker, reference_spec, reference_spec_lens).unsqueeze(1)
elif self.speaker_emb is not None:
if speaker is None:
raise ValueError('Please give speaker id to get lookup speaker embedding.')
spk_emb = self.speaker_emb(speaker)
spk_emb = self.speaker_emb(speaker).unsqueeze(1)
else:
spk_emb = None

Expand Down Expand Up @@ -281,7 +281,7 @@ def forward(
# Calculate speaker embedding
spk_emb = self.get_speaker_embedding(
speaker=speaker, reference_spec=reference_spec, reference_spec_lens=reference_spec_lens,
).unsqueeze(1)
)

# Input FFT
enc_out, enc_mask = self.encoder(input=text, conditioning=spk_emb)
Expand Down

0 comments on commit 98beb1c

Please sign in to comment.