Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MusicGen] Add streamer to generate #25320

Merged
merged 9 commits into from
Sep 14, 2023

Conversation

sanchit-gandhi
Copy link
Contributor

What does this PR do?

Adds the streamer to MusicGen's generate, along with an example test for returning chunks of numpy audio arrays on-the-fly as they are generated.

Facilitates using MusicGen with streaming mode as per the Gradio update: gradio-app/gradio#5077

cc @Vaibhavs10 @ylacombe @aliabid94 @abidlabs

@sanchit-gandhi sanchit-gandhi requested a review from gante August 4, 2023 14:51
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 4, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me, given the limitations of our streamer class 👍 (printing audio to the stdout/file feels weird, having the option to return a generator would possibly be valuable here, right?)

self.on_finalized_text(audio_values[self.to_print :], stream_end=True)
self.to_print += len(audio_values)

def on_finalized_text(self, audio: np.ndarray, stream_end: bool = False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Since this is an internal function, perhaps on_finalized_chunk (or another non-text name) :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved in 58de438

@huggingface huggingface deleted a comment from github-actions bot Sep 7, 2023
@sanchit-gandhi
Copy link
Contributor Author

Thanks for the review @gante! We now yield successive audio chunks as suggested

@sanchit-gandhi
Copy link
Contributor Author

Ready for core-maintainer review! What are your thoughts about adding a streamer example to the docs @ArthurZucker @gante? The code is quite involved, so I was thinking that maybe this is better as a standalone gradio example

Streamer code:
from queue import Queue
from threading import Thread
from typing import Optional

import numpy as np
import torch

from transformers import MusicgenForConditionalGeneration, MusicgenProcessor, set_seed
from transformers.generation.streamers import BaseStreamer

import gradio as gr

device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")

if device == "cuda:0":
    model.to(device).half();

class MusicgenStreamer(BaseStreamer):
    def __init__(
        self,
        model: MusicgenForConditionalGeneration,
        device: Optional[str] = None,
        play_steps: Optional[int] = 10,
        stride: Optional[int] = None,
        timeout: Optional[float] = None,
    ):
        """
        Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
        useful for applications that benefit from acessing the generated audio in a non-blocking way (e.g. in an interactive
        Gradio demo).

        Parameters:
            model (`MusicgenForConditionalGeneration`):
                The MusicGen model used to generate the audio waveform.
            device (`str`, *optional*):
                The torch device on which to run the computation. If `None`, will default to the device of the model.
            play_steps (`int`, *optional*, defaults to 10):
                The number of generation steps with which to return the generated audio array. Using fewer steps will 
                mean the first chunk is ready faster, but will require more codec decoding steps overall. This value 
                should be tuned to your device and latency requirements.
            stride (`int`, *optional*):
                The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
                the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to 
                play_steps // 6 in the audio space.
            timeout (`int`, *optional*):
                The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
                in `.generate()`, when it is called in a separate thread.
        """
        self.decoder = model.decoder
        self.audio_encoder = model.audio_encoder
        self.generation_config = model.generation_config
        self.device = device if device is not None else model.device

        # variables used in the streaming process
        self.play_steps = play_steps
        if stride is not None:
            self.stride = stride
        else:
            hop_length = np.prod(self.audio_encoder.config.upsampling_ratios)
            self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
        self.token_cache = None
        self.to_yield = 0

        # varibles used in the thread process
        self.audio_queue = Queue()
        self.stop_signal = None
        self.timeout = timeout

    def apply_delay_pattern_mask(self, input_ids):
        # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
        _, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
            input_ids[:, :1],
            pad_token_id=self.generation_config.decoder_start_token_id,
            max_length=input_ids.shape[-1],
        )
        # apply the pattern mask to the input ids
        input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask)

        # revert the pattern delay mask by filtering the pad token id
        input_ids = input_ids[input_ids != self.generation_config.pad_token_id].reshape(
            1, self.decoder.num_codebooks, -1
        )

        # append the frame dimension back to the audio codes
        input_ids = input_ids[None, ...]

        # send the input_ids to the correct device
        input_ids = input_ids.to(self.audio_encoder.device)

        output_values = self.audio_encoder.decode(
            input_ids,
            audio_scales=[None],
        )
        audio_values = output_values.audio_values[0, 0]
        return audio_values.cpu().float().numpy()

    def put(self, value):
        batch_size = value.shape[0] // self.decoder.num_codebooks
        if batch_size > 1:
            raise ValueError("MusicgenStreamer only supports batch size 1")

        if self.token_cache is None:
            self.token_cache = value
        else:
            self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)

        if self.token_cache.shape[-1] % self.play_steps == 0:
            audio_values = self.apply_delay_pattern_mask(self.token_cache)
            self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
            self.to_yield += len(audio_values) - self.to_yield - self.stride

    def end(self):
        """Flushes any remaining cache and appends the stop symbol."""
        if self.token_cache is not None:
            audio_values = self.apply_delay_pattern_mask(self.token_cache)
        else:
            audio_values = np.zeros(self.to_yield)

        self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)

    def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
        """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
        self.audio_queue.put(audio, timeout=self.timeout)
        if stream_end:
            self.audio_queue.put(self.stop_signal, timeout=self.timeout)

    def __iter__(self):
        return self

    def __next__(self):
        value = self.audio_queue.get(timeout=self.timeout)
        if not isinstance(value, np.ndarray) and value == self.stop_signal:
            raise StopIteration()
        else:
            return value

sampling_rate = model.audio_encoder.config.sampling_rate
frame_rate = model.audio_encoder.config.frame_rate

def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0):
    inputs = processor(
        text=text_prompt,
        padding=True,
        return_tensors="pt",
    )

    max_new_tokens = int(frame_rate * audio_length_in_s)
    play_steps = int(frame_rate * play_steps_in_s)

    streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)

    generation_kwargs = dict(
        **inputs.to(device),
        streamer=streamer,
        max_new_tokens=max_new_tokens,
    )
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    set_seed(0)
    for new_audio in streamer:
        yield (sampling_rate, new_audio)

generator = generate_audio("Techno music with euphoric melodies")

for chunk in generator:
    yield (sampling_rate, chunk)

@sanchit-gandhi
Copy link
Contributor Author

Gently pinging @ArthurZucker for a review!

@ylacombe
Copy link
Contributor

ylacombe commented Sep 13, 2023

This is great @sanchit-gandhi, do you think it's possible to showcase the feature in a Space ?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! As you mentioned the test is a bit too big, so would be better to have it in a space or in examples?
Would be a bit too much work to have a dummy streamer and not worth it IMO 😉

@sanchit-gandhi
Copy link
Contributor Author

Merging as the test showed that the streamer worked, and we'll showcase this directly in a gradio demo once streaming outputs are confirmed as working!

@sanchit-gandhi sanchit-gandhi merged commit 0dd06c3 into huggingface:main Sep 14, 2023
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
* [MusicGen] Add streamer to generate

* add to for cond generation

* add test

* finish

* torch only

* fix type hint

* yield audio chunks

* fix typehint

* remove test
blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
* [MusicGen] Add streamer to generate

* add to for cond generation

* add test

* finish

* torch only

* fix type hint

* yield audio chunks

* fix typehint

* remove test
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
* [MusicGen] Add streamer to generate

* add to for cond generation

* add test

* finish

* torch only

* fix type hint

* yield audio chunks

* fix typehint

* remove test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants