Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization Suggestion for 'torchaudio.functional.merge_tokens' Function #3614

Closed
eyalcohen308 opened this issue Sep 21, 2023 · 4 comments · Fixed by #3615
Closed

Optimization Suggestion for 'torchaudio.functional.merge_tokens' Function #3614

eyalcohen308 opened this issue Sep 21, 2023 · 4 comments · Fixed by #3615

Comments

@eyalcohen308
Copy link
Contributor

eyalcohen308 commented Sep 21, 2023

🚀 The feature

Optimization of the merge_tokens function in 'torchaudio.functional.merge_tokens' to leverage PyTorch's tensor operations for improved efficiency and reduced runtime.

Proposed Implementation:

def merge_tokens_optimized(tokens: torch.Tensor, scores: torch.Tensor, blank: int = 0) -> List[TokenSpan]:
    """Removes repeated tokens and blank tokens from the given CTC token sequence

    Args:
        tokens (torch.Tensor): Alignment tokens (unbatched)
        scores (torch.Tensor): Alignment scores (unbatched)
        blank (int, optional): Blank token. Defaults to 0.

    Returns:
        List[TokenSpan]: list of TokenSpan
    """

    # Compute the difference between consecutive tokens. prepend and append a -1 to make sure the first and last
    # tokens are not removed
    diff = torch.diff(
        tokens, prepend=torch.tensor([-1], device=tokens.device), append=torch.tensor([-1], device=tokens.device)
    )
    # Compute the change points and mask out the points where the new value is blank
    changes_wo_blank = torch.nonzero((diff != 0)).squeeze().tolist()

    tokens = tokens.tolist()
    # Create a TokenSpan for each change point.
    spans = [
        TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
        for start, end in zip(changes_wo_blank[:-1], changes_wo_blank[1:])
        if (token := tokens[start]) != blank
    ]
    return spans

Motivation, pitch

While working on some CTC token sequence operations, I noticed that the current implementation of the merge_tokens function could benefit from tensor operations, potentially offering significant runtime improvements. Specifically, the current loop-based approach might be streamlined by taking advantage of PyTorch's capabilities. My tests indicate that my proposed optimization runs approximately 7 times faster over 1000 runs while producing equivalent outputs.

Alternatives

Proposed tensor-based optimization: This approach leverages PyTorch's tensor operations, eliminating explicit loops and potentially considerably reducing execution time.

Additional context

To validate the accuracy of the optimized function, a method was crafted to verify the outputs of both the original and proposed functions:

import torchaudio.functional as F
from torchaudio.functional import TokenSpan

def token_span_outputs_equal(span_list_1: List[TokenSpan], span_list_2: List[TokenSpan]) -> bool:
    """Compares two lists of TokenSpan objects for equality."""
    
    if len(span_list_1) != len(span_list_2):
        return False

    for span_1, span_2 in zip(span_list_1, span_list_2):
        if (
            span_1.token != span_2.token or 
            span_1.start != span_2.start or 
            span_1.end != span_2.end or 
            abs(span_1.score - span_2.score) > 1e-6  # Allowing a small tolerance for floating-point comparisons
        ):
            return False

    return True
    
   
output_original = F.merge_tokens(aligned_tokens, alignment_scores)
output_optimized = merge_tokens_optimized(aligned_tokens, alignment_scores)

assert token_span_outputs_equal(output_original, output_optimized), "Outputs are not equivalent!"

Benchmark code:

import timeit

pytorch_time = timeit.timeit(lambda: F.merge_tokens(aligned_tokens, alignment_scores), number=1000)
new_method_time = timeit.timeit(lambda: merge_tokens_optimized(aligned_tokens, alignment_scores), number=1000)

print(f"PyTorch method: {pytorch_time}")
print(f"New method: {new_method_time}")

Output:

Pytorch method: 7.622203521430492
New method: 1.0753349959850311

Integrating this optimization could bring about significant improvements in both performance and maintainability, benefiting a broad spectrum of users.

@mthrok
Copy link
Collaborator

mthrok commented Sep 21, 2023

Hi @eyalcohen308

Thanks for the suggestion. x7 improvement sounds great. Would you like to make a PR? If not I will try to find some time adopting this later.

There are unit tests for this, which you can run as (cd test && pytest torchaudio_unittest/functional/functional_cpu_test.py -k test_merge_repeated_tokens).

@parameterized.expand(
[
([], [], []),
([F.TokenSpan(1, 0, 1, 1.0)], [1], [1.0]),
([F.TokenSpan(1, 0, 2, 0.5)], [1, 1], [0.4, 0.6]),
([F.TokenSpan(1, 0, 3, 0.6)], [1, 1, 1], [0.5, 0.6, 0.7]),
([F.TokenSpan(1, 0, 1, 0.8), F.TokenSpan(2, 1, 2, 0.9)], [1, 2], [0.8, 0.9]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 1, 3, 0.5)], [1, 2, 2], [1.0, 0.4, 0.6]),
([F.TokenSpan(1, 0, 1, 0.8), F.TokenSpan(1, 2, 3, 1.0)], [1, 0, 1], [0.8, 0.9, 1.0]),
([F.TokenSpan(1, 0, 1, 0.8), F.TokenSpan(2, 2, 3, 1.0)], [1, 0, 2], [0.8, 0.9, 1.0]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(1, 2, 4, 0.5)], [1, 0, 1, 1], [1.0, 0.1, 0.4, 0.6]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 2, 4, 0.5)], [1, 0, 2, 2], [1.0, 0.1, 0.4, 0.6]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(1, 3, 4, 0.4)], [1, 0, 0, 1], [1.0, 0.9, 0.7, 0.4]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 3, 4, 0.4)], [1, 0, 0, 2], [1.0, 0.9, 0.7, 0.4]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(1, 3, 5, 0.5)], [1, 0, 0, 1, 1], [1.0, 0.9, 0.8, 0.6, 0.4]),
([F.TokenSpan(1, 0, 1, 1.0), F.TokenSpan(2, 3, 5, 0.5)], [1, 0, 0, 2, 2], [1.0, 0.9, 0.8, 0.6, 0.4]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 2, 3, 0.5)], [1, 1, 2], [1.0, 0.8, 0.5]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 3, 4, 0.7)], [1, 1, 0, 1], [1.0, 0.8, 0.1, 0.7]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 3, 4, 0.7)], [1, 1, 0, 2], [1.0, 0.8, 0.1, 0.7]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 3, 5, 0.4)], [1, 1, 0, 1, 1], [1.0, 0.8, 0.1, 0.5, 0.3]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 3, 5, 0.4)], [1, 1, 0, 2, 2], [1.0, 0.8, 0.1, 0.5, 0.3]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 4, 5, 0.3)], [1, 1, 0, 0, 1], [1.0, 0.8, 0.1, 0.5, 0.3]),
([F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 4, 5, 0.3)], [1, 1, 0, 0, 2], [1.0, 0.8, 0.1, 0.5, 0.3]),
(
[F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(1, 4, 6, 0.2)],
[1, 1, 0, 0, 1, 1],
[1.0, 0.8, 0.6, 0.5, 0.3, 0.1],
),
(
[F.TokenSpan(1, 0, 2, 0.9), F.TokenSpan(2, 4, 6, 0.2)],
[1, 1, 0, 0, 2, 2],
[1.0, 0.8, 0.6, 0.5, 0.3, 0.1],
),
]
)
def test_merge_repeated_tokens(self, expected, tokens, scores):
scores_ = torch.tensor(scores, dtype=torch.float32, device=self.device)
tokens_ = torch.tensor(tokens, dtype=torch.int64, device=self.device)
spans = F.merge_tokens(tokens_, scores_, blank=0)
print(tokens_, scores_)
self._assert_tokens(spans, expected)
# Append blanks at the beginning and at the end.
for num_prefix, num_suffix in itertools.product([0, 1, 2], repeat=2):
tokens_ = ([0] * num_prefix) + tokens + ([0] * num_suffix)
scores_ = ([0.1] * num_prefix) + scores + ([0.1] * num_suffix)
tokens_ = torch.tensor(tokens_, dtype=torch.int64, device=self.device)
scores_ = torch.tensor(scores_, dtype=torch.float32, device=self.device)
expected_ = [F.TokenSpan(s.token, s.start + num_prefix, s.end + num_prefix, s.score) for s in expected]
print(tokens_, scores_)
spans = F.merge_tokens(tokens_, scores_, blank=0)
self._assert_tokens(spans, expected_)

@eyalcohen308
Copy link
Contributor Author

@mthrok I will open a PR, thanks

mthrok pushed a commit that referenced this issue Sep 21, 2023
Optimizes merge_tokens method as discussed in #3614 

Co-authored-by: Eyal Cohen <[email protected]>
@mthrok
Copy link
Collaborator

mthrok commented Sep 21, 2023

@eyalcohen308 Thanks it's been merged. Note that we have a release 2.1 scheduled in a couple of weeks, but the library code for release branch is finalized. So this commit will not be part of the 2.1. It will be part of 2.2. It's cherry-picked to 2.1.

mthrok pushed a commit that referenced this issue Sep 21, 2023
Optimizes merge_tokens method as discussed in #3614 

Co-authored-by: Eyal Cohen <[email protected]>
@mthrok
Copy link
Collaborator

mthrok commented Sep 21, 2023

FYI I confirm that the optimization works for CPU and CUDA

Device CPU CUDA
old 4.31 16.61
new 1.89 4.49

[seconds]

code
import timeit

import torch
import torchaudio.functional as F


def test(device):
    d = torch.device(device)
    tokens = torch.randint(256, (200, ), dtype=torch.int32, device=d)
    scores = torch.randn((200, ), dtype=torch.float32, device=d)
    elapsed = timeit.timeit(lambda: F.merge_tokens(tokens, scores), number=1000)
    print(f"{device}: {elapsed}")


test("cpu")
test("cuda")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants