Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference code example #8

Open
eschmidbauer opened this issue Jul 18, 2024 · 147 comments
Open

inference code example #8

eschmidbauer opened this issue Jul 18, 2024 · 147 comments

Comments

@eschmidbauer
Copy link

eschmidbauer commented Jul 18, 2024

Is there inference code? I could not find any. but I read through other issues and found this.

          i'll write a inference script next so we can do some quick experiments.

Originally posted by @manmay-nakhashi in #1 (comment)

@manmay-nakhashi
Copy link
Contributor

I'll put together a code this weekend.

@G-structure
Copy link
Contributor

okay so I took at shot at hacking together some inference code. I trained a model for 400k steps on the MushanW/GLOBE dataset; when I test it I get a cacophony which is starting to resemble the tts prompts.. but the intermediate melspec is of very poor quality so something might be wrong with my approach.

image

import os

import torch
import torchaudio
from torchaudio.transforms import GriffinLim, InverseSpectrogram, InverseMelScale, Resample, Speed

from einops import rearrange
from accelerate import Accelerator

from torch.optim import Adam

from e2_tts_pytorch.e2_tts import (
    E2TTS,
    DurationPredictor,
    MelSpec
)

duration_predictor = DurationPredictor(
    transformer = dict(
        dim = 80,
        depth = 2,
    )
)

model = E2TTS(
    duration_predictor = duration_predictor,
    transformer = dict(
        dim = 80,
        depth = 4,
        skip_connect_type = 'concat'
    )
)

n_fft = 1024
sample_rate = 22050
checkpoint_path = "./e2tts.pt"

def exists(v):
    return v is not None

def vocoder(melspec):
    inverse_melscale_transform = InverseMelScale(n_stft=n_fft // 2 + 1, n_mels=80, sample_rate=sample_rate, norm="slaney", f_min=0, f_max=8000)
    spectrogram = inverse_melscale_transform(melspec)
    transform = GriffinLim(n_fft=n_fft, hop_length=256, power=2)
    waveform = transform(spectrogram)
    return waveform

def load_checkpoint(checkpoint_path, model, accelerator, optimizer):
    if not exists(checkpoint_path) or not os.path.exists(checkpoint_path):
        return 0

    checkpoint = torch.load(checkpoint_path)
    accelerator.unwrap_model(model).load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['step']

accelerator = Accelerator(
            log_with="all",
        )

optimizer = Adam(model.parameters(), lr=1e-4)

start_step = load_checkpoint(checkpoint_path=checkpoint_path, model=model, accelerator=accelerator, optimizer=optimizer)

ref_waveform, ref_sample_rate = torchaudio.load("ref.wav", normalize=True)
resampler = Resample(orig_freq=ref_sample_rate, new_freq=sample_rate)
ref_waveform = resampler(ref_waveform)
speed_factor = sample_rate / ref_sample_rate
respeed = Speed(ref_sample_rate, speed_factor)
ref_waveform = respeed(ref_waveform)
ref_waveform_resampled = ref_waveform[0]

mel_model = MelSpec()
mel = mel_model(ref_waveform_resampled)
mel = torch.cat([mel, mel], dim=0)
mel = rearrange(mel, 'b d n -> b n d')

text = ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "Waves crashed against the cliffs, their thunderous applause echoing for miles."]
sample = model.sample(mel[:,:25], text = text, vocoder=vocoder)
sample = sample.to('cpu')

waveform = sample

mono_channel_1 = waveform[0].unsqueeze(0)
mono_channel_2 = waveform[1].unsqueeze(0)

torchaudio.save("output_channel_1.wav", mono_channel_1, sample_rate)
torchaudio.save("output_channel_2.wav", mono_channel_2, sample_rate)

@changjinhan
Copy link

I'm also looking forward to it

@lucasnewman
Copy link
Contributor

I haven't figured out what's up with the text conditioning yet, but here's a rough sample (it doesn't use the duration predictor) of the generation flow in a notebook. I left in some debugging outputs so you can see the flow resolving visually. The voice cloning aspect seems to work fine with different speakers, fwiw, they just say nonsense at the moment 😅

(This is from a quick ~100M param model I trained with ~1/100th the FLOPs used in the paper.)

generate.ipynb

@Coice
Copy link

Coice commented Jul 23, 2024

@lucasnewman does it always output the reference audio, regardless of what you use as input for the reference text? I also left out the duration predictor, I wound up just simply doubling the input duration and doubling the reference text, since if it can't do the doubling, it sure won't work for anything else 🙃

I couldn't get it to generate anything aside from the input reference audio. I was told by the author "train it more", but I put considerable resources into it, and it never improved.

@lucasnewman
Copy link
Contributor

@lucasnewman does it always output the reference audio, regardless of what you use as input for the reference text? I also left out the duration predictor, I wound up just simply doubling the input duration and doubling the reference text, since if it can't do the doubling, it sure won't work for anything else 🙃

Yep, this is exactly what I'm doing, and more or less what you see in the notebook — I just hard-coded the duration to keep it simple.

I couldn't get it to generate anything aside from the input reference audio. I was told by the author "train it more", but I put considerable resources into it, and it never improved.

Make sure you took the duration fix from a few days ago if you explicitly passing it as an int because otherwise it will stop generation after the conditioning. You don't need to retrain your model as it only affects sampling.

@Ryu1845
Copy link

Ryu1845 commented Jul 23, 2024

I haven't figured out what's up with the text conditioning yet, but here's a rough sample (it doesn't use the duration predictor) of the generation flow in a notebook. I left in some debugging outputs so you can see the flow resolving visually. The voice cloning aspect seems to work fine with different speakers, fwiw, they just say nonsense at the moment 😅

(This is from a quick ~100M param model I trained with ~1/100th the FLOPs used in the paper.)

generate.ipynb

The output sample seems to be gibberish after what I assume to be the prompt(?)

Thank you for your work though!

generated.mp4

@Coice
Copy link

Coice commented Jul 23, 2024

@lucasnewman The code I was using was based off of a modified version of voicebox, though I did try training an early version of this repo, but at the time it was giving nan's.

Just to be clear, if you put in any other text, do you still get the exact reference audio? The model I trained always ignored the text embeddings, I'm just wondering if you have the same issue. It looks like it just learns to pass through the input.

Another thing you can try is, pass in a masked region to see if it can do the training objective in inference mode. Can it do the infill? (When I tested this with my model, it was just gibberish, but the unmasked regions were the original audio basically).

@lucasnewman
Copy link
Contributor

@lucasnewman The code I was using was based off of a modified version of voicebox, though I did try training an early version of this repo, but at the time it was giving nan's.

Just to be clear, if you put in any other text, do you still get the exact reference audio? The model I trained always ignored the text embeddings, I'm just wondering if you have the same issue. It looks like it just learns to pass through the input.

I haven't tried, but that would correlate to what I was referencing with the text conditioning.

Another thing you can try is, pass in a masked region to see if it can do the training objective in inference mode. Can it do the infill? (When I tested this with my model, it was just gibberish, but the unmasked regions were the original audio basically).

Yeah, this is effectively the same task with a different mask region, so I would expect similar results for now since the text conditioning doesn't seem to be working right. I don't actually have a ton of extra time to spend on debugging it, but you're welcome to run some experiments! The latest version of the code is almost exactly what I trained.

@juliakorovsky
Copy link

I'm trying to train this model with another repo (I've slightly changed voicebox repo) with around of 300 hours of data. Only at around 400 000 iterations it started to output something sounding like a speech (but the speech was gibberish). I also get random noise a lot, as if model would be unable to fill in the blanks. The voice model uses also doesn not resemble target voice for now. I'm thinking about increasing gradient accumulation to match paper's batch in case model just doesn't see enough of data per iteration.

@HaiFengZeng
Copy link

HaiFengZeng commented Jul 29, 2024

I think the model is hard to train when trying to directly learn the alignment from text to mel-spec, has anyone get some reasonable result? I also get some speech with the same timbre but the speech is not expected for text input,so I think the model doesn't get alignment properly
sorry for this, after train a longer time, some of results seem to become more related to input text(some words missing, some just wrong words...),but much better

@eschmidbauer
Copy link
Author

im trying to test inference again - it is very slow, and it appears the code is using CPU instead of GPU. In the image, GPU is at 0%, VRAM is 9% and CPU is 2718%
Any ideas why this might be happening?
image

@AbrarMahmud
Copy link

is there any pre-trained checkpoint for this model available ?

-thanks in advance

@eschmidbauer
Copy link
Author

i was eventually able to get inference to work by changing this line

sample = e2tts.sample(mel[:,:25].to("cuda"), text = text)

But i only get noise on output. Has anyone else been able to get inference to work?

@eschmidbauer
Copy link
Author

shared checkpoint here

@Coice
Copy link

Coice commented Aug 6, 2024

@eschmidbauer Were you successful in getting it to generate speech from the text input?

@eschmidbauer
Copy link
Author

@Coice no, but maybe my inference script needs work. Maybe someone else is able to generate speech and share the code

@juliakorovsky
Copy link

juliakorovsky commented Aug 6, 2024

i was eventually able to get inference to work by changing this line

sample = e2tts.sample(mel[:,:25].to("cuda"), text = text)

But i only get noise on output. Has anyone else been able to get inference to work?

I'll add what I know in case someone's interested. I't trying to train E2 TTS with another repo on a small dataset. I had to rewrote some code because I used Voicebox repo. I tried to train it for a couple of weeks, but the network only generated noise. I decided to print gradients for all parameters and found out that attention gradients were always zero. After some digging I found out I accidentally turned my attention dropout to 1. When I fixed it I got something resembling speech instead of noise. Model still can speak many sounds properly, but at least I see now that it learns. If your model outputs only noise even at 400 000 iterations (400 is just an example, theoretically at this stage it should be able to generate something), I would recommend to double check gradients: maybe there's some mistake and gradients are None or they might be zero, or you might have vanishing gradients.

@G-structure
Copy link
Contributor

Yeah so as we are finding this model requires an considerable amount of training...

From the paper.

We utilized the Libriheavy dataset [30] to train our models. The Libriheavy dataset comprises 50,000 hours of read English speech from 6,736 speakers, accompanied by transcriptions that preserve case and punctuation marks.

We modeled the 100-dimensional log mel-filterbank features, extracted every 10.7 milliseconds from audio samples with a 24 kHz sampling rate.

All models were trained for 800,000 mini-batch updates with an effective mini-batch size of 307,200 audio frames.

Meaning the model saw 0.9131 hours of audio per mini-batch, 730480 hours total. ~15 epochs over Libriheavy.

From the WER graphs, we observe that the Voicebox models demonstrated a good WER even at the 10% training point, owing to the use of frame-wise phoneme alignment. On the other hand, E2 TTS required significantly more training to converge. Interestingly, E2 TTS achieved a better WER at the end of the training. We speculate this is because the E2 TTS model learned a more effective grapheme-to-phoneme mapping based on the large training data, compared to what was used for Voicebox. From the SIM-o graphs, we also observed that E2 TTS required more training iteration, but it ultimately achieved a better result at the end of the training. We believe this suggests the superiority of E2 TTS, where the audio model and duration model are jointly learned as a single flow-matching Transformer.

image

From this chart we can infer that the model doesn't really start to learn how to speak until the it sees 73048 hours of audio.

Here is the output I am getting after training the a model with the same specs as the paper on 4440.9 hours of audio:
trying to say "son, he would tell him. son, he would tell him". Note the first utterance is the reference audio, the second half is the tts generation.

e2tts_500k.mp4

@lucasnewman
Copy link
Contributor

@cyber-phys FWIW that lines up with my napkin math and your sample sounds similar to my experiments.

I tried a trick where I used a scale factor that ramps from (0, 1] for the random times selection for a few thousand steps, forcing the model to learn stronger conditioning from close-to-noise time steps, which seemed to help a little bit with pronunciation in a low data regime (you could recognize words with ~10k hours of audio training), but nothing close to the quality of Voicebox, which obviously has a big alignment advantage.

It seems like you need a bunch of training over 50k+ hours of audio to make a dent on this one, which is kind of cool because it's possible to just brute force the alignment, but also probably out of reach for most academic/unfunded settings, unfortunately.

@skirdey
Copy link

skirdey commented Aug 7, 2024

I have a training job running that saw around 2,000,000 samples of speech out of 13M total. I am training on multi-lingual datasets so most likely it will take awhile before it can do coherent speech. But it does "speak" a combination of languages now, with no apparent alignment to the text prompt.
output_1089.webm

You can find latest checkpoint here https://drive.google.com/drive/folders/11m6ftmJbxua7-pVkQCA6qbfLMlsfC_Ls?usp=drive_link

# Initialize the duration predictor and TTS model
duration_predictor = DurationPredictor(
    transformer=dict(
        dim=512,
        depth=6,
        heads=2,
        dim_head=64,
        max_seq_len=4000
    )
)

e2tts = E2TTS(
    duration_predictor=duration_predictor,
    num_channels=100,
    transformer=dict(
        dim=1024,
        depth=24,
        skip_connect_type='concat',
        heads=16,
        dim_head=64,
        max_seq_len=4000
    ),
    text_num_embeds=256,
    cond_drop_prob=0.25,
)
optimizer = AdamWScheduleFree(e2tts.parameters(), lr=3e-5)

@HaiFengZeng
Copy link

I would like to share a sample based another modified repo: really needs a lot resource(only get 4X4090 gpus) and train nearly two week and the result seems need more training. I only use two datasets: gigaspeech and libiritts.

yy-2.mp4

text: you are very handsome.

ref_yy-2.mp4

@juliakorovsky
Copy link

@cyber-phys FWIW that lines up with my napkin math and your sample sounds similar to my experiments.

I tried a trick where I used a scale factor that ramps from (0, 1] for the random times selection for a few thousand steps, forcing the model to learn stronger conditioning from close-to-noise time steps, which seemed to help a little bit with pronunciation in a low data regime (you could recognize words with ~10k hours of audio training), but nothing close to the quality of Voicebox, which obviously has a big alignment advantage.

It seems like you need a bunch of training over 50k+ hours of audio to make a dent on this one, which is kind of cool because it's possible to just brute force the alignment, but also probably out of reach for most academic/unfunded settings, unfortunately.

Could you show the code for scale factor trick? Or link to it if it's included in this repo.

@lucasnewman
Copy link
Contributor

lucasnewman commented Aug 7, 2024

It was just an experiment to see if the text conditioning was working at all — I'm not sure it's a great idea in general.

My intuition was that the joint training objective is particularly difficult for alignment because the "fingerprint" of the flow is pretty well established in the first ~2-3% of the ODE steps and at that point the model will primarily use the flow from the previous timestep for prediction. If we force the model to predict from near-noise earlier, we can bias the training objective towards the text conditioning at the start.

(Also I forgot to mention that I used phonemes instead of the raw byte encoding to make it a little easier on the model because I'm using a smaller dataset.)

You can reproduce it with something like:

diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py
index b43a8d5..c83cf05 100644
--- a/e2_tts_pytorch/e2_tts.py
+++ b/e2_tts_pytorch/e2_tts.py
@@ -694,7 +694,8 @@ class E2TTS(Module):
         *,
         text: Int['b nt'] | List[str] | None = None,
         times: Int['b'] | None = None,
-        lens: Int['b'] | None = None,
+        times_scale: Float | None = None,
+        lens: Int['b'] | None = None
     ):
         # handle raw wave
 
@@ -740,6 +741,10 @@ class E2TTS(Module):
         # t is random times from above
 
         times = torch.rand((batch,), dtype = dtype, device = self.device)
+
+        if exists(times_scale):
+            times = times * max(1e-3, times_scale)
+        
         t = rearrange(times, 'b -> b 1 1')
 
         # sample xt (w in the paper)

And then in your trainer class define num_time_scale_steps and do:

if global_step < self.num_time_scale_steps:
    scale_progress = float(global_step) / float(self.num_time_scale_steps)
    times_scale = min(1.0, scale_progress)
else:
    times_scale = None

loss, ... = self.model(mel_spec, text = text, lens = mel_lengths, times_scale = times_scale)

@lucidrains
Copy link
Owner

Screen Shot 2024-08-07 at 1 46 59 PM

@HaiFengZeng
Copy link

It was just an experiment to see if the text conditioning was working at all — I'm not sure it's a great idea in general.

My intuition was that the joint training objective is particularly difficult for alignment because the "fingerprint" of the flow is pretty well established in the first ~2-3% of the ODE steps and at that point the model will primarily use the flow from the previous timestep for prediction. If we force the model to predict from near-noise earlier, we can bias the training objective towards the text conditioning at the start.

(Also I forgot to mention that I used phonemes instead of the raw byte encoding to make it a little easier on the model because I'm using a smaller dataset.)

You can reproduce it with something like:

diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py
index b43a8d5..c83cf05 100644
--- a/e2_tts_pytorch/e2_tts.py
+++ b/e2_tts_pytorch/e2_tts.py
@@ -694,7 +694,8 @@ class E2TTS(Module):
         *,
         text: Int['b nt'] | List[str] | None = None,
         times: Int['b'] | None = None,
-        lens: Int['b'] | None = None,
+        times_scale: Float | None = None,
+        lens: Int['b'] | None = None
     ):
         # handle raw wave
 
@@ -740,6 +741,10 @@ class E2TTS(Module):
         # t is random times from above
 
         times = torch.rand((batch,), dtype = dtype, device = self.device)
+
+        if exists(times_scale):
+            times = times * max(1e-3, times_scale)
+        
         t = rearrange(times, 'b -> b 1 1')
 
         # sample xt (w in the paper)

And then in your trainer class define num_time_scale_steps and do:

if global_step < self.num_time_scale_steps:
    scale_progress = float(global_step) / float(self.num_time_scale_steps)
    times_scale = min(1.0, scale_progress)
else:
    times_scale = None

loss, ... = self.model(mel_spec, text = text, lens = mel_lengths, times_scale = times_scale)

good idea, how to do inference when apply time_scale? will it use less NFE steps?

@acul3
Copy link

acul3 commented Aug 13, 2024

Has anyone got good result already (both text aligment and sound similar/quality)

4xA100 a week,30k hours, still produce incosistent speech with txt

@changjinhan
Copy link

changjinhan commented Aug 13, 2024

@acul3 I'm working with a similar setup, just using GLOBE. Here’s what I have so far—could you share your intermediate results as well?

Text: (separated from the main spacecraft and began its descent to the moon's surface.) Waves crashed against the cliffs, their thunderous applause echoing for miles.

e2tts_680k.mp4

@acul3
Copy link

acul3 commented Aug 13, 2024

@changjinhan how long you train it?

i am training multilingual (indonesia, and malay)

the output its acceptable, but seem hard to follow text

here is my config

# Initialize the duration predictor and TTS model
duration_predictor = DurationPredictor(
    transformer=dict(
        dim=512,
        depth=6,
        heads=2,
        dim_head=64,
        max_seq_len=4000
    )
)

e2tts = E2TTS(
    duration_predictor=duration_predictor,
    num_channels=100,
    transformer=dict(
        dim=1024,
        depth=24,
        skip_connect_type='concat',
        heads=16,
        dim_head=64,
        max_seq_len=4000
    ),
    text_num_embeds=256,
    cond_drop_prob=0.25,
)
optimizer = AdamWScheduleFree(e2tts.parameters(), lr=3e-5)

can you share yours? appreciated

@changjinhan
Copy link

changjinhan commented Aug 14, 2024

@acul3 I trained it for 6 days and my config is as follows:

from itertools import chain

duration_predictor = DurationPredictor(
    transformer = dict(
        dim = 512,
        depth = 6,
    )
)

model = E2TTS(
    num_channels=80,
    transformer = dict(
        dim = 512,
        depth = 12,
        skip_connect_type = 'concat'
    )
)
optimizer = Adam(chain(e2tts.parameters(), duration_predictor.parameters()), lr=7.5e-5)

@manmay-nakhashi
Copy link
Contributor

manmay-nakhashi commented Sep 9, 2024

dataset used
https://huggingface.co/datasets/MushanW/GLOBE
https://huggingface.co/datasets/ShoukanLabs/AniSpeech
https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0
lr = lr=3e-4
epochs: 4
extended phonemes like @lucasnewman did

e2tts = E2TTS(
    tokenizer = 'phoneme_en',
    cond_drop_prob = 0.2,
    interpolated_text=True,
    transformer = dict(
        dim = 384,
        depth = 12,
        heads = 6,
        max_seq_len = 2048,
        skip_connect_type = 'concat'
    ),
    mel_spec_kwargs = dict(
        filter_length = 1024,
        hop_length = 256,
        win_length = 1024,
        n_mel_channels = 100,
        sampling_rate = 24000,
    ),
    immiscible=True,
    frac_lengths_mask = (0.5, 0.9)
)

generated.wav

@ILG2021
Copy link

ILG2021 commented Sep 9, 2024

dataset used https://huggingface.co/datasets/MushanW/GLOBE https://huggingface.co/datasets/ShoukanLabs/AniSpeech https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0 lr = lr=3e-4 epochs: 4 extended phonemes like @lucasnewman did

e2tts = E2TTS(
    tokenizer = 'phoneme_en',
    cond_drop_prob = 0.2,
    interpolated_text=True,
    transformer = dict(
        dim = 384,
        depth = 12,
        heads = 6,
        max_seq_len = 2048,
        skip_connect_type = 'concat'
    ),
    mel_spec_kwargs = dict(
        filter_length = 1024,
        hop_length = 256,
        win_length = 1024,
        n_mel_channels = 100,
        sampling_rate = 24000,
    ),
    immiscible=True,
    frac_lengths_mask = (0.5, 0.9)
)

generated.wav

Hello, can you share the pretrain?

@lucidrains
Copy link
Owner

lucidrains commented Sep 9, 2024

@manmay-nakhashi you are welcome to join us for dinner, but you'll have to fly here 😄 hit me up if you are ever in SF though!

@lucidrains
Copy link
Owner

think i will also just fix the transformer to always use concat unet skips, unless there are any objections

@manmay-nakhashi
Copy link
Contributor

@lucidrains ha ha, i am coming there in next month mostly.

@lucidrains
Copy link
Owner

@manmay-nakhashi that's when we are planning on dinner! when are you arriving? send me an email and we can coordinate

@manmay-nakhashi
Copy link
Contributor

@ILG2021 i'll share one pretrained model tomorrow, not fully trained though, it can at-least save some alignment time.

@ILG2021
Copy link

ILG2021 commented Sep 9, 2024

@ILG2021 i'll share one pretrained model tomorrow, not fully trained though, it can at-least save some alignment time.

Thank you.

@JingRH
Copy link

JingRH commented Sep 11, 2024

Has anyone noticed this repository: https://github.com/feizc/FluxMusic? They also implicitly use a doublestream structure. This kind of multimodal fusion architecture might truly be the direction for future development.

@yangyyt
Copy link

yangyyt commented Sep 11, 2024

@ILG2021 i'll share one pretrained model tomorrow, not fully trained though, it can at-least save some alignment time.

looking forward to your model sharing . @manmay-nakhashi

@manmay-nakhashi
Copy link
Contributor

manmay-nakhashi commented Sep 11, 2024

e2-tts-pytorch-small
@ILG2021 @yangyyt @lucidrains

@manmay-nakhashi
Copy link
Contributor

@lucidrains can we keep interpolation method in ?.
i have trained a model and it works.
this checkpoint is trained with interpolated_text=True.

@lucidrains
Copy link
Owner

@lucidrains can we keep interpolation method in ?. i have trained a model and it works. this checkpoint is trained with interpolated_text=True.

added it back!

@lucidrains
Copy link
Owner

@skirdey are you near SF? do you want to join us for dinner mid next month? you were the one who alerted me to this paper!

@Oktai15
Copy link

Oktai15 commented Sep 12, 2024

@manmay-nakhashi thanks for your checkpoint, but current lucidrains main branch doesn't work with it. Could you share correct commit/branch/link to actual code (especially, inference script)?

@manmay-nakhashi
Copy link
Contributor

manmay-nakhashi commented Sep 12, 2024

@Oktai15 i trained on 0.9.1 , but i'll see over weekend, what's difference between current and 0.9.1 which is not working.

@Oktai15
Copy link

Oktai15 commented Sep 12, 2024

@manmay-nakhashi can you share the inference script?

@lucidrains
Copy link
Owner

@manmay-nakhashi thanks for your checkpoint, but current lucidrains main branch doesn't work with it. Could you share correct commit/branch/link to actual code (especially, inference script)?

you can always just install the specific version

@evelynyhc
Copy link

will the trained duration predictor model be saved? and when inference,duration model is used ?

@fakerybakery
Copy link

@Oktai15 i trained on 0.9.1 , but i'll see over weekend, what's difference between current and 0.9.1 which is not working.

@manmay-nakhashi would you mind sharing what code you used for inference on 0.9.1? thanks!

@manmay-nakhashi
Copy link
Contributor

manmay-nakhashi commented Oct 3, 2024

@fakerybakery

import datetime
from pathlib import Path

import torch
import torchaudio
from torchaudio.transforms import MelSpectrogram
torch.backends.cudnn.benchmark = True

from einops import rearrange

from vocos import Vocos

from e2_tts_pytorch.e2_tts import E2TTS, DurationPredictor

import matplotlib.pyplot as plt
from IPython.display import Audio

vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")

checkpoint_path = ""
audio_path = ""
text = ""

# load the model
duration_predictor = DurationPredictor( 
    transformer = dict(
        dim = 384,
        depth = 8,
        heads = 6)).to('cuda')

e2tts = E2TTS(
    tokenizer = 'phoneme_en',
    cond_drop_prob = 0.2,
    interpolated_text=True,
    transformer = dict(
        dim = 384,
        depth = 12,
        heads = 6,
        max_seq_len = 2048,
        skip_connect_type = 'concat'
    ),
    mel_spec_kwargs = dict(
        filter_length = 1024,
        hop_length = 256,
        win_length = 1024,
        n_mel_channels = 100,
        sampling_rate = 24000,
    ),
    immiscible=True,
    frac_lengths_mask = (0.5, 0.9)
)

checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only = True)
e2tts.load_state_dict(checkpoint['model_state_dict'], strict=False)
e2tts.eval()
duration_predictor.eval()
# load a sample audio file
audio, sr = torchaudio.load(Path(audio_path).expanduser())
if sr != 24000:
    resampler = torchaudio.transforms.Resample(sr, 24000)
    audio = resampler(audio)        
original_mel_spec = e2tts.mel_spec(audio).squeeze(0)

# visualize the mel spectrogram
plt.figure(figsize=(6, 4))
plt.imshow(original_mel_spec.numpy(), origin='lower', aspect='auto')
plt.colorbar()
plt.show()

# mask off the second half

mask_length = original_mel_spec.shape[1] // 2
mel_spec = rearrange(original_mel_spec, 'd n -> 1 n d')[:, :mask_length]

print(f"Text: {text}")

# if you want to use an accelerator, e.g. cuda or mps
device = torch.device('cuda')
e2tts = e2tts.to(device)
mel_spec = mel_spec.to(device)

start_date = datetime.datetime.now()

with torch.inference_mode():
    text = e2tts.tokenizer([text]).to(device)
    print(mel_spec.shape,  original_mel_spec.shape[1])
    duration = duration_predictor(mel_spec, text = text)
    print(duration)
    generated = e2tts.sample(
        cond = mel_spec,
        text = text,
        duration = original_mel_spec.shape[1],
        steps = 32,
        cfg_strength = 1.0
    )
    
    print(f"Generated: {generated.shape} in {datetime.datetime.now() - start_date}")
    generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')

# vocode into audio

wave = vocos.decode(rearrange(rearrange(original_mel_spec, 'd n -> 1 n d').cpu(), '1 n d -> 1 d n'))
print(f"wave: {wave.shape}")

wave2 = vocos.decode(generated_mel_spec.cpu())
print(f"wave2: {wave2.shape}")

# show previews of the original and generated audio

print("Original:")

torchaudio.save("original.wav", wave, 24_000)

print("Generated:")

torchaudio.save("generated.wav", wave2, 24_000)

@BlazJurisic
Copy link

Thank You @manmay-nakhashi for the code and unfortunately, I still get errors with tokenizer on ==0.9.1 when trying to encode text

@manmay-nakhashi
Copy link
Contributor

@BlazJurisic Take new g2p changes from the current master

@fakerybakery
Copy link

Hi @manmay-nakhashi,
Thanks for sharing the code!
I'm getting the following issue when trying to run it:

Traceback (most recent call last):
  File "run.py", line 97, in <module>
    print(f"Generated: {generated.shape} in {datetime.datetime.now() - start_date}")
                        ^^^^^^^^^^^^^^^
AttributeError: 'list' object has no attribute 'shape'

@JohnHerry
Copy link

I may have some miss-understanding about the DurationPredictor? can anyone give me some help?
In the code I see the DurationPredictor can predict a total duration according to the prompt speech, which means that it predicting the prompt speech lens of frames.
Then during the inference process, it give a total length of len1 + len2, where len1 is filled with the prompt speech frames, and len2 (which is predicted by DurationPredictor with prompt speech, and plus 1), what E2-TTS should prediction is the value in the len2.
that means, no matter what is the length of our input text tokens, the total generated speech frames should be fixed at len2, but len2 predicted by DurationPredictor, which is trained mainly on mel frames but not text tokens. is that a resonable result?
Why the DurationPredictor predict a total lens mainly based on input text tokens? instead of prompt speech frames?

@JohnHerry
Copy link

okay so I took at shot at hacking together some inference code. I trained a model for 400k steps on the MushanW/GLOBE dataset; when I test it I get a cacophony which is starting to resemble the tts prompts.. but the intermediate melspec is of very poor quality so something might be wrong with my approach.

image

import os

import torch
import torchaudio
from torchaudio.transforms import GriffinLim, InverseSpectrogram, InverseMelScale, Resample, Speed

from einops import rearrange
from accelerate import Accelerator

from torch.optim import Adam

from e2_tts_pytorch.e2_tts import (
    E2TTS,
    DurationPredictor,
    MelSpec
)

duration_predictor = DurationPredictor(
    transformer = dict(
        dim = 80,
        depth = 2,
    )
)

model = E2TTS(
    duration_predictor = duration_predictor,
    transformer = dict(
        dim = 80,
        depth = 4,
        skip_connect_type = 'concat'
    )
)

n_fft = 1024
sample_rate = 22050
checkpoint_path = "./e2tts.pt"

def exists(v):
    return v is not None

def vocoder(melspec):
    inverse_melscale_transform = InverseMelScale(n_stft=n_fft // 2 + 1, n_mels=80, sample_rate=sample_rate, norm="slaney", f_min=0, f_max=8000)
    spectrogram = inverse_melscale_transform(melspec)
    transform = GriffinLim(n_fft=n_fft, hop_length=256, power=2)
    waveform = transform(spectrogram)
    return waveform

def load_checkpoint(checkpoint_path, model, accelerator, optimizer):
    if not exists(checkpoint_path) or not os.path.exists(checkpoint_path):
        return 0

    checkpoint = torch.load(checkpoint_path)
    accelerator.unwrap_model(model).load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['step']

accelerator = Accelerator(
            log_with="all",
        )

optimizer = Adam(model.parameters(), lr=1e-4)

start_step = load_checkpoint(checkpoint_path=checkpoint_path, model=model, accelerator=accelerator, optimizer=optimizer)

ref_waveform, ref_sample_rate = torchaudio.load("ref.wav", normalize=True)
resampler = Resample(orig_freq=ref_sample_rate, new_freq=sample_rate)
ref_waveform = resampler(ref_waveform)
speed_factor = sample_rate / ref_sample_rate
respeed = Speed(ref_sample_rate, speed_factor)
ref_waveform = respeed(ref_waveform)
ref_waveform_resampled = ref_waveform[0]

mel_model = MelSpec()
mel = mel_model(ref_waveform_resampled)
mel = torch.cat([mel, mel], dim=0)
mel = rearrange(mel, 'b d n -> b n d')

text = ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "Waves crashed against the cliffs, their thunderous applause echoing for miles."]
sample = model.sample(mel[:,:25], text = text, vocoder=vocoder)
sample = sample.to('cpu')

waveform = sample

mono_channel_1 = waveform[0].unsqueeze(0)
mono_channel_2 = waveform[1].unsqueeze(0)

torchaudio.save("output_channel_1.wav", mono_channel_1, sample_rate)
torchaudio.save("output_channel_2.wav", mono_channel_2, sample_rate)

is that means 25 frames of target speaker is enought to syntheize his timbre?

@JohnHerry
Copy link

JohnHerry commented Dec 2, 2024

Trained on more then 100,000 hours of speech, over 500,000 steps, about half an epoch, but still can not get good result.
the training loss seems do not go down anymore.
And the DurationPredictor, trained to 380,000 steps, it seems always predict output the value 1, completely useless.
When test inference, I use a constant duration=1000, means the cond and generated mel totally 1000 frames len. the default sample function will generate a total noise, nothing useful.
and then I changed the diffustion steps from the default 35 to 100, it can produce something out. it is the cond audio, then a puff of noise, and then the generated audio about the input text. the generated audio lost the begin part about the text, and it sounds different with the condition speaker.
and because the constant duration=1000 may be not fit for the text, the generated audio makes very bad prosody, even more robot-like then the troditional fastspeech tts.

@JohnHerry
Copy link

sample的结果里有各种重叠音效,是不是sample过程有什么问题啊?

@ndhuynh02
Copy link

ndhuynh02 commented Dec 17, 2024

Since E2TTS takes Mel-spectrogram as input and its job is to filled out the masked part (this is why there is a crop in input mel), I believe that the result still return the cropped input. Therefore, we need to remove the beginning part of output.

FYI, I am also having difficulties inferring with this model. In the paper, the author said that G2P is removed, so using phoneme_en as tokenizer doesn't seem right. In addition, the model hyperparameters like depth, dim, heads, etc need modification so that it is like what mentioned in the paper

@kunit17
Copy link

kunit17 commented Jan 2, 2025

In the transformer, is the same mask being applied to both the text and audio inputs? It seems like that to me.

def forward(
    self,
    x: Float['b n d'],
    times: Float['b'] | Float[''] | None = None,
    mask: Bool['b n'] | None = None,
    text_embed: Float['b n dt'] | None = None,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests