Skip to content

Commit

Permalink
Optimize merge_tokens method (pytorch#3615)
Browse files Browse the repository at this point in the history
Optimizes merge_tokens method as discussed in pytorch#3614 

Co-authored-by: Eyal Cohen <[email protected]>
  • Loading branch information
2 people authored and mthrok committed Sep 21, 2023
1 parent 420d9ac commit 6ea1133
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions torchaudio/functional/_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,14 @@ def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSp
if len(tokens) != len(scores):
raise ValueError("`tokens` and `scores` must be the same length.")

t_prev = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != t_prev:
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
t_prev = token
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, len(tokens), scores[start:].mean().item()))
diff = torch.diff(
tokens, prepend=torch.tensor([-1], device=tokens.device), append=torch.tensor([-1], device=tokens.device)
)
changes_wo_blank = torch.nonzero((diff != 0)).squeeze().tolist()
tokens = tokens.tolist()
spans = [
TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
for start, end in zip(changes_wo_blank[:-1], changes_wo_blank[1:])
if (token := tokens[start]) != blank
]
return spans

0 comments on commit 6ea1133

Please sign in to comment.