-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MusicGen] Add streamer to generate #25320
[MusicGen] Add streamer to generate #25320
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good to me, given the limitations of our streamer class 👍 (printing audio to the stdout/file feels weird, having the option to return a generator would possibly be valuable here, right?)
self.on_finalized_text(audio_values[self.to_print :], stream_end=True) | ||
self.to_print += len(audio_values) | ||
|
||
def on_finalized_text(self, audio: np.ndarray, stream_end: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Since this is an internal function, perhaps on_finalized_chunk
(or another non-text name) :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved in 58de438
Thanks for the review @gante! We now yield successive audio chunks as suggested |
Ready for core-maintainer review! What are your thoughts about adding a streamer example to the docs @ArthurZucker @gante? The code is quite involved, so I was thinking that maybe this is better as a standalone gradio example Streamer code:from queue import Queue
from threading import Thread
from typing import Optional
import numpy as np
import torch
from transformers import MusicgenForConditionalGeneration, MusicgenProcessor, set_seed
from transformers.generation.streamers import BaseStreamer
import gradio as gr
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")
if device == "cuda:0":
model.to(device).half();
class MusicgenStreamer(BaseStreamer):
def __init__(
self,
model: MusicgenForConditionalGeneration,
device: Optional[str] = None,
play_steps: Optional[int] = 10,
stride: Optional[int] = None,
timeout: Optional[float] = None,
):
"""
Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
useful for applications that benefit from acessing the generated audio in a non-blocking way (e.g. in an interactive
Gradio demo).
Parameters:
model (`MusicgenForConditionalGeneration`):
The MusicGen model used to generate the audio waveform.
device (`str`, *optional*):
The torch device on which to run the computation. If `None`, will default to the device of the model.
play_steps (`int`, *optional*, defaults to 10):
The number of generation steps with which to return the generated audio array. Using fewer steps will
mean the first chunk is ready faster, but will require more codec decoding steps overall. This value
should be tuned to your device and latency requirements.
stride (`int`, *optional*):
The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to
play_steps // 6 in the audio space.
timeout (`int`, *optional*):
The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
in `.generate()`, when it is called in a separate thread.
"""
self.decoder = model.decoder
self.audio_encoder = model.audio_encoder
self.generation_config = model.generation_config
self.device = device if device is not None else model.device
# variables used in the streaming process
self.play_steps = play_steps
if stride is not None:
self.stride = stride
else:
hop_length = np.prod(self.audio_encoder.config.upsampling_ratios)
self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
self.token_cache = None
self.to_yield = 0
# varibles used in the thread process
self.audio_queue = Queue()
self.stop_signal = None
self.timeout = timeout
def apply_delay_pattern_mask(self, input_ids):
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
_, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids[:, :1],
pad_token_id=self.generation_config.decoder_start_token_id,
max_length=input_ids.shape[-1],
)
# apply the pattern mask to the input ids
input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask)
# revert the pattern delay mask by filtering the pad token id
input_ids = input_ids[input_ids != self.generation_config.pad_token_id].reshape(
1, self.decoder.num_codebooks, -1
)
# append the frame dimension back to the audio codes
input_ids = input_ids[None, ...]
# send the input_ids to the correct device
input_ids = input_ids.to(self.audio_encoder.device)
output_values = self.audio_encoder.decode(
input_ids,
audio_scales=[None],
)
audio_values = output_values.audio_values[0, 0]
return audio_values.cpu().float().numpy()
def put(self, value):
batch_size = value.shape[0] // self.decoder.num_codebooks
if batch_size > 1:
raise ValueError("MusicgenStreamer only supports batch size 1")
if self.token_cache is None:
self.token_cache = value
else:
self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
if self.token_cache.shape[-1] % self.play_steps == 0:
audio_values = self.apply_delay_pattern_mask(self.token_cache)
self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
self.to_yield += len(audio_values) - self.to_yield - self.stride
def end(self):
"""Flushes any remaining cache and appends the stop symbol."""
if self.token_cache is not None:
audio_values = self.apply_delay_pattern_mask(self.token_cache)
else:
audio_values = np.zeros(self.to_yield)
self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
"""Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
self.audio_queue.put(audio, timeout=self.timeout)
if stream_end:
self.audio_queue.put(self.stop_signal, timeout=self.timeout)
def __iter__(self):
return self
def __next__(self):
value = self.audio_queue.get(timeout=self.timeout)
if not isinstance(value, np.ndarray) and value == self.stop_signal:
raise StopIteration()
else:
return value
sampling_rate = model.audio_encoder.config.sampling_rate
frame_rate = model.audio_encoder.config.frame_rate
def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0):
inputs = processor(
text=text_prompt,
padding=True,
return_tensors="pt",
)
max_new_tokens = int(frame_rate * audio_length_in_s)
play_steps = int(frame_rate * play_steps_in_s)
streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
generation_kwargs = dict(
**inputs.to(device),
streamer=streamer,
max_new_tokens=max_new_tokens,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
set_seed(0)
for new_audio in streamer:
yield (sampling_rate, new_audio)
generator = generate_audio("Techno music with euphoric melodies")
for chunk in generator:
yield (sampling_rate, chunk) |
Gently pinging @ArthurZucker for a review! |
This is great @sanchit-gandhi, do you think it's possible to showcase the feature in a Space ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! As you mentioned the test is a bit too big, so would be better to have it in a space or in examples?
Would be a bit too much work to have a dummy streamer and not worth it IMO 😉
Merging as the test showed that the streamer worked, and we'll showcase this directly in a gradio demo once streaming outputs are confirmed as working! |
* [MusicGen] Add streamer to generate * add to for cond generation * add test * finish * torch only * fix type hint * yield audio chunks * fix typehint * remove test
* [MusicGen] Add streamer to generate * add to for cond generation * add test * finish * torch only * fix type hint * yield audio chunks * fix typehint * remove test
* [MusicGen] Add streamer to generate * add to for cond generation * add test * finish * torch only * fix type hint * yield audio chunks * fix typehint * remove test
What does this PR do?
Adds the
streamer
to MusicGen's generate, along with an example test for returning chunks of numpy audio arrays on-the-fly as they are generated.Facilitates using MusicGen with streaming mode as per the Gradio update: gradio-app/gradio#5077
cc @Vaibhavs10 @ylacombe @aliabid94 @abidlabs