Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High vram usage for long audio. #6

Open
hykilpikonna opened this issue Aug 11, 2023 · 1 comment
Open

High vram usage for long audio. #6

hykilpikonna opened this issue Aug 11, 2023 · 1 comment

Comments

@hykilpikonna
Copy link

The model seems to process the entire audio at once, which leads to high vram usage for long audio. I was trying to compute MERT on a 9:58 audio with an A100 80GB GPU, and it tried to allocate 90GB of vram.

image

Is it possible to split the audio first, process each segment and obtain the same results? I tried to split the audio into 60s windows using the code below. Even though I managed to make the segmented embedding into the same shape, it seems to give a large mean square error from the original calculation if the entire audio is passed in at once.

window_length = int(self.sr * 60) # 60 seconds
overlap_length = int(self.sr * 4.987) # 4.987 seconds (5s window - 1 * 75Hz framerate)
overlap_frames = int(4.987 * 75) - 1 # 75 Hz frame rate
embeddings = []

print("Audio shape:", audio.shape)
print("Window length:", window_length)
print("Overlap length:", overlap_length)
print("Overlap frames:", overlap_frames)

# Iterate over audio with overlap
for start in range(0, audio.shape[0], window_length - overlap_length):
    end = start + window_length
    segment = audio[start:end]
    print("Segment:", segment.shape)
    # if len(segment) < window_length:
    #     break

    # Process each segment
    inputs = self.processor(segment, sampling_rate=self.sr, return_tensors="pt").to(self.device)
    with torch.no_grad():
        out = self.model(**inputs, output_hidden_states=True)
        out = torch.stack(out.hidden_states).squeeze() # [13 layers, timeframes, 768]
        out = out[11] # [timeframes, 768]
        
        print("Frames before:", out.shape[0])

        # Remove overlap from the end of the segment
        if end < audio.shape[0]:
            out = out[:-overlap_frames, :]

        print("Frames after:", out.shape[0])

    embeddings.append(out)

# Stack embeddings for all segments
out = torch.cat(embeddings, dim=0)

return out
@hykilpikonna
Copy link
Author

hykilpikonna commented Aug 11, 2023

Here is the absolute error between the original and the segmented calculations for a 4-minute audio on a graph... it's weird that the overlapping areas are not the only thing that is affected, but the error seems to bleed to the entire rest of the segment.

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant