Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash Attention 2 support to ParlerTTS #59

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ if torch.xpu.is_available():
torch_dtype = torch.float16 if device != "cpu" else torch.float32

model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)

# # Use with flash attention
# model = ParlerTTSForConditionalGeneration.from_pretrained(
# repo_id, attn_implementation="flash_attention_2", torch_dtype=torch.float16
# ).to(device, dtype=torch_dtype)


model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)


tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")

prompt = "Hey, how are you doing today?"
Expand Down
103 changes: 103 additions & 0 deletions helpers/benchmark/benchmark_attention_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
from tqdm import tqdm
from dataset import PROMPTS, DESCRIPTIONS
import time

model_eager = ParlerTTSForConditionalGeneration.from_pretrained(
"parler-tts/parler-tts-mini-expresso",
attn_implementation="eager",
torch_dtype=torch.float16
).to("cuda:0")

model_flash = ParlerTTSForConditionalGeneration.from_pretrained(
"parler-tts/parler-tts-mini-expresso",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16
).to("cuda:0")


model_sdpa = ParlerTTSForConditionalGeneration.from_pretrained(
"parler-tts/parler-tts-mini-expresso",
attn_implementation="sdpa",
torch_dtype=torch.float16
).to("cuda:0")

tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")


for i in range(3):
print(f"Wramming up decoder")
z = torch.empty(1, 1024, 8).uniform_(-10,10).to(model_eager.device).to(model_eager.dtype)
model_eager.audio_encoder.model.decode(z)
model_flash.audio_encoder.model.decode(z)
model_sdpa.audio_encoder.model.decode(z)





def generate_speech(model, prompt, description):
input_ids = tokenizer(description, return_tensors="pt").input_ids.to("cuda:0")
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda:0")

generation_config = model.generation_config

# Generate first second
generation_config.max_length = 86 # default 2580. WTF

_ = model.generate(input_ids=input_ids,
prompt_input_ids=prompt_input_ids,
generation_config=generation_config,
use_cache=True,
past_key_values = None,
)


def benchmark(model):
device = "cuda:0"
# define Events that measure start and end of the generate pass
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

# reset cuda memory stats and empty cache
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.empty_cache()
torch.cuda.synchronize()

start_event.record()


NUM_SAMPLE = 50

latencies = []

for i in tqdm(range(len(PROMPTS[:NUM_SAMPLE]))):
prompt = PROMPTS[i]
description = DESCRIPTIONS[i]

start = time.perf_counter()

_ = generate_speech(model, prompt, description)

latencies.append(time.perf_counter() - start)

# get the end time
end_event.record()
torch.cuda.synchronize()

# measure memory footprint and elapsed time
max_memory = torch.cuda.max_memory_allocated(device)
elapsed_time = start_event.elapsed_time(end_event) * 1.0e-3

print('Execution time:', elapsed_time/NUM_SAMPLE, 'seconds')
print('Max memory footprint', max_memory*1e-9, ' GB')

if __name__ == "__main__":
print("Benchmark model with Eager Attention")
benchmark(model_eager)
print("Benchmark model with Flash Attention 2")
benchmark(model_flash)
print("Benchmark model with SDPA Attention")
benchmark(model_sdpa)
7 changes: 7 additions & 0 deletions helpers/benchmark/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from datasets import load_dataset

dataset = load_dataset("parler-tts/libritts_r_tags_tagged_10k_generated", 'clean')

PROMPTS = dataset['test.clean']['text']
DESCRIPTIONS = dataset['test.clean']['text_description']

18 changes: 18 additions & 0 deletions parler_tts/configuration_parler_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,21 @@ def from_sub_models_config(
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate

# Copy from musicgen
@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"

@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value
self.decoder._attn_implementation = value
Loading