From 80341c30f6a7be3298401d74933c3f856297490e Mon Sep 17 00:00:00 2001 From: santialferez Date: Tue, 26 Dec 2023 13:01:49 +0100 Subject: [PATCH] update setup.py to install pyannote.audio==3.1.1, update diarize.py to include num_speakers; to fix Issue #592 --- build/lib/whisperx/SubtitlesProcessor.py | 227 +++++++++++ build/lib/whisperx/__init__.py | 4 + build/lib/whisperx/__main__.py | 4 + build/lib/whisperx/alignment.py | 467 ++++++++++++++++++++++ build/lib/whisperx/asr.py | 350 ++++++++++++++++ build/lib/whisperx/assets/mel_filters.npz | Bin 0 -> 4271 bytes build/lib/whisperx/audio.py | 159 ++++++++ build/lib/whisperx/conjunctions.py | 43 ++ build/lib/whisperx/diarize.py | 74 ++++ build/lib/whisperx/transcribe.py | 229 +++++++++++ build/lib/whisperx/types.py | 58 +++ build/lib/whisperx/utils.py | 437 ++++++++++++++++++++ build/lib/whisperx/vad.py | 311 ++++++++++++++ setup.py | 2 +- whisperx/diarize.py | 4 +- 15 files changed, 2366 insertions(+), 3 deletions(-) create mode 100644 build/lib/whisperx/SubtitlesProcessor.py create mode 100644 build/lib/whisperx/__init__.py create mode 100644 build/lib/whisperx/__main__.py create mode 100644 build/lib/whisperx/alignment.py create mode 100644 build/lib/whisperx/asr.py create mode 100644 build/lib/whisperx/assets/mel_filters.npz create mode 100644 build/lib/whisperx/audio.py create mode 100644 build/lib/whisperx/conjunctions.py create mode 100644 build/lib/whisperx/diarize.py create mode 100644 build/lib/whisperx/transcribe.py create mode 100644 build/lib/whisperx/types.py create mode 100644 build/lib/whisperx/utils.py create mode 100644 build/lib/whisperx/vad.py diff --git a/build/lib/whisperx/SubtitlesProcessor.py b/build/lib/whisperx/SubtitlesProcessor.py new file mode 100644 index 00000000..5ffd1afa --- /dev/null +++ b/build/lib/whisperx/SubtitlesProcessor.py @@ -0,0 +1,227 @@ +import math +from conjunctions import get_conjunctions, get_comma +from typing import TextIO + +def normal_round(n): + if n - math.floor(n) < 0.5: + return math.floor(n) + return math.ceil(n) + + +def format_timestamp(seconds: float, is_vtt: bool = False): + + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + separator = '.' if is_vtt else ',' + + hours_marker = f"{hours:02d}:" + return ( + f"{hours_marker}{minutes:02d}:{seconds:02d}{separator}{milliseconds:03d}" + ) + + + +class SubtitlesProcessor: + def __init__(self, segments, lang, max_line_length = 45, min_char_length_splitter = 30, is_vtt = False): + self.comma = get_comma(lang) + self.conjunctions = set(get_conjunctions(lang)) + self.segments = segments + self.lang = lang + self.max_line_length = max_line_length + self.min_char_length_splitter = min_char_length_splitter + self.is_vtt = is_vtt + complex_script_languages = ['th', 'lo', 'my', 'km', 'am', 'ko', 'ja', 'zh', 'ti', 'ta', 'te', 'kn', 'ml', 'hi', 'ne', 'mr', 'ar', 'fa', 'ur', 'ka'] + if self.lang in complex_script_languages: + self.max_line_length = 30 + self.min_char_length_splitter = 20 + + def estimate_timestamp_for_word(self, words, i, next_segment_start_time=None): + k = 0.25 + has_prev_end = i > 0 and 'end' in words[i - 1] + has_next_start = i < len(words) - 1 and 'start' in words[i + 1] + + if has_prev_end: + words[i]['start'] = words[i - 1]['end'] + if has_next_start: + words[i]['end'] = words[i + 1]['start'] + else: + if next_segment_start_time: + words[i]['end'] = next_segment_start_time if next_segment_start_time - words[i - 1]['end'] <= 1 else next_segment_start_time - 0.5 + else: + words[i]['end'] = words[i]['start'] + len(words[i]['word']) * k + + elif has_next_start: + words[i]['start'] = words[i + 1]['start'] - len(words[i]['word']) * k + words[i]['end'] = words[i + 1]['start'] + + else: + if next_segment_start_time: + words[i]['start'] = next_segment_start_time - 1 + words[i]['end'] = next_segment_start_time - 0.5 + else: + words[i]['start'] = 0 + words[i]['end'] = 0 + + + + def process_segments(self, advanced_splitting=True): + subtitles = [] + for i, segment in enumerate(self.segments): + next_segment_start_time = self.segments[i + 1]['start'] if i + 1 < len(self.segments) else None + + if advanced_splitting: + + split_points = self.determine_advanced_split_points(segment, next_segment_start_time) + subtitles.extend(self.generate_subtitles_from_split_points(segment, split_points, next_segment_start_time)) + else: + words = segment['words'] + for i, word in enumerate(words): + if 'start' not in word or 'end' not in word: + self.estimate_timestamp_for_word(words, i, next_segment_start_time) + + subtitles.append({ + 'start': segment['start'], + 'end': segment['end'], + 'text': segment['text'] + }) + + return subtitles + + def determine_advanced_split_points(self, segment, next_segment_start_time=None): + split_points = [] + last_split_point = 0 + char_count = 0 + + words = segment.get('words', segment['text'].split()) + add_space = 0 if self.lang in ['zh', 'ja'] else 1 + + total_char_count = sum(len(word['word']) if isinstance(word, dict) else len(word) + add_space for word in words) + char_count_after = total_char_count + + for i, word in enumerate(words): + word_text = word['word'] if isinstance(word, dict) else word + word_length = len(word_text) + add_space + char_count += word_length + char_count_after -= word_length + + char_count_before = char_count - word_length + + if isinstance(word, dict) and ('start' not in word or 'end' not in word): + self.estimate_timestamp_for_word(words, i, next_segment_start_time) + + if char_count >= self.max_line_length: + midpoint = normal_round((last_split_point + i) / 2) + if char_count_before >= self.min_char_length_splitter: + split_points.append(midpoint) + last_split_point = midpoint + 1 + char_count = sum(len(words[j]['word']) if isinstance(words[j], dict) else len(words[j]) + add_space for j in range(last_split_point, i + 1)) + + elif word_text.endswith(self.comma) and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter: + split_points.append(i) + last_split_point = i + 1 + char_count = 0 + + elif word_text.lower() in self.conjunctions and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter: + split_points.append(i - 1) + last_split_point = i + char_count = word_length + + return split_points + + + def generate_subtitles_from_split_points(self, segment, split_points, next_start_time=None): + subtitles = [] + + words = segment.get('words', segment['text'].split()) + total_word_count = len(words) + total_time = segment['end'] - segment['start'] + elapsed_time = segment['start'] + prefix = ' ' if self.lang not in ['zh', 'ja'] else '' + start_idx = 0 + for split_point in split_points: + + fragment_words = words[start_idx:split_point + 1] + current_word_count = len(fragment_words) + + + if isinstance(fragment_words[0], dict): + start_time = fragment_words[0]['start'] + end_time = fragment_words[-1]['end'] + next_start_time_for_word = words[split_point + 1]['start'] if split_point + 1 < len(words) else None + if next_start_time_for_word and (next_start_time_for_word - end_time) <= 0.8: + end_time = next_start_time_for_word + else: + fragment = prefix.join(fragment_words).strip() + current_duration = (current_word_count / total_word_count) * total_time + start_time = elapsed_time + end_time = elapsed_time + current_duration + elapsed_time += current_duration + + + subtitles.append({ + 'start': start_time, + 'end': end_time, + 'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words) + }) + + start_idx = split_point + 1 + + # Handle the last fragment + if start_idx < len(words): + fragment_words = words[start_idx:] + current_word_count = len(fragment_words) + + if isinstance(fragment_words[0], dict): + start_time = fragment_words[0]['start'] + end_time = fragment_words[-1]['end'] + else: + fragment = prefix.join(fragment_words).strip() + current_duration = (current_word_count / total_word_count) * total_time + start_time = elapsed_time + end_time = elapsed_time + current_duration + + if next_start_time and (next_start_time - end_time) <= 0.8: + end_time = next_start_time + + subtitles.append({ + 'start': start_time, + 'end': end_time if end_time is not None else segment['end'], + 'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words) + }) + + return subtitles + + + + def save(self, filename="subtitles.srt", advanced_splitting=True): + + subtitles = self.process_segments(advanced_splitting) + + def write_subtitle(file, idx, start_time, end_time, text): + + file.write(f"{idx}\n") + file.write(f"{start_time} --> {end_time}\n") + file.write(text + "\n\n") + + with open(filename, 'w', encoding='utf-8') as file: + if self.is_vtt: + file.write("WEBVTT\n\n") + + if advanced_splitting: + for idx, subtitle in enumerate(subtitles, 1): + start_time = format_timestamp(subtitle['start'], self.is_vtt) + end_time = format_timestamp(subtitle['end'], self.is_vtt) + text = subtitle['text'].strip() + write_subtitle(file, idx, start_time, end_time, text) + + return len(subtitles) \ No newline at end of file diff --git a/build/lib/whisperx/__init__.py b/build/lib/whisperx/__init__.py new file mode 100644 index 00000000..20abaaed --- /dev/null +++ b/build/lib/whisperx/__init__.py @@ -0,0 +1,4 @@ +from .transcribe import load_model +from .alignment import load_align_model, align +from .audio import load_audio +from .diarize import assign_word_speakers, DiarizationPipeline \ No newline at end of file diff --git a/build/lib/whisperx/__main__.py b/build/lib/whisperx/__main__.py new file mode 100644 index 00000000..bc9b04a3 --- /dev/null +++ b/build/lib/whisperx/__main__.py @@ -0,0 +1,4 @@ +from .transcribe import cli + + +cli() diff --git a/build/lib/whisperx/alignment.py b/build/lib/whisperx/alignment.py new file mode 100644 index 00000000..8294983d --- /dev/null +++ b/build/lib/whisperx/alignment.py @@ -0,0 +1,467 @@ +"""" +Forced Alignment with Whisper +C. Max Bain +""" +from dataclasses import dataclass +from typing import Iterable, Union, List + +import numpy as np +import pandas as pd +import torch +import torchaudio +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + +from .audio import SAMPLE_RATE, load_audio +from .utils import interpolate_nans +from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment +import nltk +from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters + +PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof'] + +LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] + +DEFAULT_ALIGN_MODELS_TORCH = { + "en": "WAV2VEC2_ASR_BASE_960H", + "fr": "VOXPOPULI_ASR_BASE_10K_FR", + "de": "VOXPOPULI_ASR_BASE_10K_DE", + "es": "VOXPOPULI_ASR_BASE_10K_ES", + "it": "VOXPOPULI_ASR_BASE_10K_IT", +} + +DEFAULT_ALIGN_MODELS_HF = { + "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", + "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", + "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", + "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", + "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", + "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", + "cs": "comodoro/wav2vec2-xls-r-300m-cs-250", + "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", + "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", + "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", + "fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish", + "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian", + "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", + "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", + "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech", + "he": "imvladikon/wav2vec2-xls-r-300m-hebrew", + "vi": 'nguyenvulebinh/wav2vec2-base-vi', + "ko": "kresnik/wav2vec2-large-xlsr-korean", + "ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu", + "te": "anuragshas/wav2vec2-large-xlsr-53-telugu", + "hi": "theainerd/Wav2Vec2-large-xlsr-hindi", + "ca": "softcatala/wav2vec2-large-xlsr-catala", + "ml": "gvs/wav2vec2-large-xlsr-malayalam", + "no": "NbAiLab/nb-wav2vec2-1b-bokmaal", + "nn": "NbAiLab/nb-wav2vec2-300m-nynorsk", +} + + +def load_align_model(language_code, device, model_name=None, model_dir=None): + if model_name is None: + # use default model + if language_code in DEFAULT_ALIGN_MODELS_TORCH: + model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code] + elif language_code in DEFAULT_ALIGN_MODELS_HF: + model_name = DEFAULT_ALIGN_MODELS_HF[language_code] + else: + print(f"There is no default alignment model set for this language ({language_code}).\ + Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]") + raise ValueError(f"No default align-model for language: {language_code}") + + if model_name in torchaudio.pipelines.__all__: + pipeline_type = "torchaudio" + bundle = torchaudio.pipelines.__dict__[model_name] + align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device) + labels = bundle.get_labels() + align_dictionary = {c.lower(): i for i, c in enumerate(labels)} + else: + try: + processor = Wav2Vec2Processor.from_pretrained(model_name) + align_model = Wav2Vec2ForCTC.from_pretrained(model_name) + except Exception as e: + print(e) + print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models") + raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)') + pipeline_type = "huggingface" + align_model = align_model.to(device) + labels = processor.tokenizer.get_vocab() + align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()} + + align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type} + + return align_model, align_metadata + + +def align( + transcript: Iterable[SingleSegment], + model: torch.nn.Module, + align_model_metadata: dict, + audio: Union[str, np.ndarray, torch.Tensor], + device: str, + interpolate_method: str = "nearest", + return_char_alignments: bool = False, + print_progress: bool = False, + combined_progress: bool = False, +) -> AlignedTranscriptionResult: + """ + Align phoneme recognition predictions to known transcription. + """ + + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) + if len(audio.shape) == 1: + audio = audio.unsqueeze(0) + + MAX_DURATION = audio.shape[1] / SAMPLE_RATE + + model_dictionary = align_model_metadata["dictionary"] + model_lang = align_model_metadata["language"] + model_type = align_model_metadata["type"] + + # 1. Preprocess to keep only characters in dictionary + total_segments = len(transcript) + for sdx, segment in enumerate(transcript): + # strip spaces at beginning / end, but keep track of the amount. + if print_progress: + base_progress = ((sdx + 1) / total_segments) * 100 + percent_complete = (50 + base_progress / 2) if combined_progress else base_progress + print(f"Progress: {percent_complete:.2f}%...") + + num_leading = len(segment["text"]) - len(segment["text"].lstrip()) + num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) + text = segment["text"] + + # split into words + if model_lang not in LANGUAGES_WITHOUT_SPACES: + per_word = text.split(" ") + else: + per_word = text + + clean_char, clean_cdx = [], [] + for cdx, char in enumerate(text): + char_ = char.lower() + # wav2vec2 models use "|" character to represent spaces + if model_lang not in LANGUAGES_WITHOUT_SPACES: + char_ = char_.replace(" ", "|") + + # ignore whitespace at beginning and end of transcript + if cdx < num_leading: + pass + elif cdx > len(text) - num_trailing - 1: + pass + elif char_ in model_dictionary.keys(): + clean_char.append(char_) + clean_cdx.append(cdx) + + clean_wdx = [] + for wdx, wrd in enumerate(per_word): + if any([c in model_dictionary.keys() for c in wrd]): + clean_wdx.append(wdx) + + + punkt_param = PunktParameters() + punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS) + sentence_splitter = PunktSentenceTokenizer(punkt_param) + sentence_spans = list(sentence_splitter.span_tokenize(text)) + + segment["clean_char"] = clean_char + segment["clean_cdx"] = clean_cdx + segment["clean_wdx"] = clean_wdx + segment["sentence_spans"] = sentence_spans + + aligned_segments: List[SingleAlignedSegment] = [] + + # 2. Get prediction matrix from alignment model & align + for sdx, segment in enumerate(transcript): + + t1 = segment["start"] + t2 = segment["end"] + text = segment["text"] + + aligned_seg: SingleAlignedSegment = { + "start": t1, + "end": t2, + "text": text, + "words": [], + } + + if return_char_alignments: + aligned_seg["chars"] = [] + + # check we can align + if len(segment["clean_char"]) == 0: + print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') + aligned_segments.append(aligned_seg) + continue + + if t1 >= MAX_DURATION: + print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...') + aligned_segments.append(aligned_seg) + continue + + text_clean = "".join(segment["clean_char"]) + tokens = [model_dictionary[c] for c in text_clean] + + f1 = int(t1 * SAMPLE_RATE) + f2 = int(t2 * SAMPLE_RATE) + + # TODO: Probably can get some speedup gain with batched inference here + waveform_segment = audio[:, f1:f2] + # Handle the minimum input length for wav2vec2 models + if waveform_segment.shape[-1] < 400: + lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) + waveform_segment = torch.nn.functional.pad( + waveform_segment, (0, 400 - waveform_segment.shape[-1]) + ) + else: + lengths = None + + with torch.inference_mode(): + if model_type == "torchaudio": + emissions, _ = model(waveform_segment.to(device), lengths=lengths) + elif model_type == "huggingface": + emissions = model(waveform_segment.to(device)).logits + else: + raise NotImplementedError(f"Align model of type {model_type} not supported.") + emissions = torch.log_softmax(emissions, dim=-1) + + emission = emissions[0].cpu().detach() + + blank_id = 0 + for char, code in model_dictionary.items(): + if char == '[pad]' or char == '': + blank_id = code + + trellis = get_trellis(emission, tokens, blank_id) + path = backtrack(trellis, emission, tokens, blank_id) + + if path is None: + print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') + aligned_segments.append(aligned_seg) + continue + + char_segments = merge_repeats(path, text_clean) + + duration = t2 -t1 + ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) + + # assign timestamps to aligned characters + char_segments_arr = [] + word_idx = 0 + for cdx, char in enumerate(text): + start, end, score = None, None, None + if cdx in segment["clean_cdx"]: + char_seg = char_segments[segment["clean_cdx"].index(cdx)] + start = round(char_seg.start * ratio + t1, 3) + end = round(char_seg.end * ratio + t1, 3) + score = round(char_seg.score, 3) + + char_segments_arr.append( + { + "char": char, + "start": start, + "end": end, + "score": score, + "word-idx": word_idx, + } + ) + + # increment word_idx, nltk word tokenization would probably be more robust here, but us space for now... + if model_lang in LANGUAGES_WITHOUT_SPACES: + word_idx += 1 + elif cdx == len(text) - 1 or text[cdx+1] == " ": + word_idx += 1 + + char_segments_arr = pd.DataFrame(char_segments_arr) + + aligned_subsegments = [] + # assign sentence_idx to each character index + char_segments_arr["sentence-idx"] = None + for sdx, (sstart, send) in enumerate(segment["sentence_spans"]): + curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)] + char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx + + sentence_text = text[sstart:send] + sentence_start = curr_chars["start"].min() + end_chars = curr_chars[curr_chars["char"] != ' '] + sentence_end = end_chars["end"].max() + sentence_words = [] + + for word_idx in curr_chars["word-idx"].unique(): + word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx] + word_text = "".join(word_chars["char"].tolist()).strip() + if len(word_text) == 0: + continue + + # dont use space character for alignment + word_chars = word_chars[word_chars["char"] != " "] + + word_start = word_chars["start"].min() + word_end = word_chars["end"].max() + word_score = round(word_chars["score"].mean(), 3) + + # -1 indicates unalignable + word_segment = {"word": word_text} + + if not np.isnan(word_start): + word_segment["start"] = word_start + if not np.isnan(word_end): + word_segment["end"] = word_end + if not np.isnan(word_score): + word_segment["score"] = word_score + + sentence_words.append(word_segment) + + aligned_subsegments.append({ + "text": sentence_text, + "start": sentence_start, + "end": sentence_end, + "words": sentence_words, + }) + + if return_char_alignments: + curr_chars = curr_chars[["char", "start", "end", "score"]] + curr_chars.fillna(-1, inplace=True) + curr_chars = curr_chars.to_dict("records") + curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars] + aligned_subsegments[-1]["chars"] = curr_chars + + aligned_subsegments = pd.DataFrame(aligned_subsegments) + aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method) + aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method) + # concatenate sentences with same timestamps + agg_dict = {"text": " ".join, "words": "sum"} + if model_lang in LANGUAGES_WITHOUT_SPACES: + agg_dict["text"] = "".join + if return_char_alignments: + agg_dict["chars"] = "sum" + aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict) + aligned_subsegments = aligned_subsegments.to_dict('records') + aligned_segments += aligned_subsegments + + # create word_segments list + word_segments: List[SingleWordSegment] = [] + for segment in aligned_segments: + word_segments += segment["words"] + + return {"segments": aligned_segments, "word_segments": word_segments} + +""" +source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html +""" +def get_trellis(emission, tokens, blank_id=0): + num_frame = emission.size(0) + num_tokens = len(tokens) + + # Trellis has extra diemsions for both time axis and tokens. + # The extra dim for tokens represents (start-of-sentence) + # The extra dim for time axis is for simplification of the code. + trellis = torch.empty((num_frame + 1, num_tokens + 1)) + trellis[0, 0] = 0 + trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) + trellis[0, -num_tokens:] = -float("inf") + trellis[-num_tokens:, 0] = float("inf") + + for t in range(num_frame): + trellis[t + 1, 1:] = torch.maximum( + # Score for staying at the same token + trellis[t, 1:] + emission[t, blank_id], + # Score for changing to the next token + trellis[t, :-1] + emission[t, tokens], + ) + return trellis + +@dataclass +class Point: + token_index: int + time_index: int + score: float + +def backtrack(trellis, emission, tokens, blank_id=0): + # Note: + # j and t are indices for trellis, which has extra dimensions + # for time and tokens at the beginning. + # When referring to time frame index `T` in trellis, + # the corresponding index in emission is `T-1`. + # Similarly, when referring to token index `J` in trellis, + # the corresponding index in transcript is `J-1`. + j = trellis.size(1) - 1 + t_start = torch.argmax(trellis[:, j]).item() + + path = [] + for t in range(t_start, 0, -1): + # 1. Figure out if the current position was stay or change + # Note (again): + # `emission[J-1]` is the emission at time frame `J` of trellis dimension. + # Score for token staying the same from time frame J-1 to T. + stayed = trellis[t - 1, j] + emission[t - 1, blank_id] + # Score for token changing from C-1 at T-1 to J at T. + changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] + + # 2. Store the path with frame-wise probability. + prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() + # Return token index and time index in non-trellis coordinate. + path.append(Point(j - 1, t - 1, prob)) + + # 3. Update the token + if changed > stayed: + j -= 1 + if j == 0: + break + else: + # failed + return None + return path[::-1] + +# Merge the labels +@dataclass +class Segment: + label: str + start: int + end: int + score: float + + def __repr__(self): + return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" + + @property + def length(self): + return self.end - self.start + +def merge_repeats(path, transcript): + i1, i2 = 0, 0 + segments = [] + while i1 < len(path): + while i2 < len(path) and path[i1].token_index == path[i2].token_index: + i2 += 1 + score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) + segments.append( + Segment( + transcript[path[i1].token_index], + path[i1].time_index, + path[i2 - 1].time_index + 1, + score, + ) + ) + i1 = i2 + return segments + +def merge_words(segments, separator="|"): + words = [] + i1, i2 = 0, 0 + while i1 < len(segments): + if i2 >= len(segments) or segments[i2].label == separator: + if i1 != i2: + segs = segments[i1:i2] + word = "".join([seg.label for seg in segs]) + score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) + words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score)) + i1 = i2 + 1 + i2 = i1 + else: + i2 += 1 + return words diff --git a/build/lib/whisperx/asr.py b/build/lib/whisperx/asr.py new file mode 100644 index 00000000..dba82712 --- /dev/null +++ b/build/lib/whisperx/asr.py @@ -0,0 +1,350 @@ +import os +import warnings +from typing import List, Union, Optional, NamedTuple + +import ctranslate2 +import faster_whisper +import numpy as np +import torch +from transformers import Pipeline +from transformers.pipelines.pt_utils import PipelineIterator + +from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram +from .vad import load_vad_model, merge_chunks +from .types import TranscriptionResult, SingleSegment + +def find_numeral_symbol_tokens(tokenizer): + numeral_symbol_tokens = [] + for i in range(tokenizer.eot): + token = tokenizer.decode([i]).removeprefix(" ") + has_numeral_symbol = any(c in "0123456789%$£" for c in token) + if has_numeral_symbol: + numeral_symbol_tokens.append(i) + return numeral_symbol_tokens + +class WhisperModel(faster_whisper.WhisperModel): + ''' + FasterWhisperModel provides batched inference for faster-whisper. + Currently only works in non-timestamp mode and fixed prompt for all samples in batch. + ''' + + def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None): + batch_size = features.shape[0] + all_tokens = [] + prompt_reset_since = 0 + if options.initial_prompt is not None: + initial_prompt = " " + options.initial_prompt.strip() + initial_prompt_tokens = tokenizer.encode(initial_prompt) + all_tokens.extend(initial_prompt_tokens) + previous_tokens = all_tokens[prompt_reset_since:] + prompt = self.get_prompt( + tokenizer, + previous_tokens, + without_timestamps=options.without_timestamps, + prefix=options.prefix, + ) + + encoder_output = self.encode(features) + + max_initial_timestamp_index = int( + round(options.max_initial_timestamp / self.time_precision) + ) + + result = self.model.generate( + encoder_output, + [prompt] * batch_size, + beam_size=options.beam_size, + patience=options.patience, + length_penalty=options.length_penalty, + max_length=self.max_length, + suppress_blank=options.suppress_blank, + suppress_tokens=options.suppress_tokens, + ) + + tokens_batch = [x.sequences_ids[0] for x in result] + + def decode_batch(tokens: List[List[int]]) -> str: + res = [] + for tk in tokens: + res.append([token for token in tk if token < tokenizer.eot]) + # text_tokens = [token for token in tokens if token < self.eot] + return tokenizer.tokenizer.decode_batch(res) + + text = decode_batch(tokens_batch) + + return text + + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: + # When the model is running on multiple GPUs, the encoder output should be moved + # to the CPU since we don't know which GPU will handle the next job. + to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 + # unsqueeze if batch size = 1 + if len(features.shape) == 2: + features = np.expand_dims(features, 0) + features = faster_whisper.transcribe.get_ctranslate2_storage(features) + + return self.model.encode(features, to_cpu=to_cpu) + +class FasterWhisperPipeline(Pipeline): + """ + Huggingface Pipeline wrapper for FasterWhisperModel. + """ + # TODO: + # - add support for timestamp mode + # - add support for custom inference kwargs + + def __init__( + self, + model, + vad, + vad_params: dict, + options : NamedTuple, + tokenizer=None, + device: Union[int, str, "torch.device"] = -1, + framework = "pt", + language : Optional[str] = None, + suppress_numerals: bool = False, + **kwargs + ): + self.model = model + self.tokenizer = tokenizer + self.options = options + self.preset_language = language + self.suppress_numerals = suppress_numerals + self._batch_size = kwargs.pop("batch_size", None) + self._num_workers = 1 + self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) + self.call_count = 0 + self.framework = framework + if self.framework == "pt": + if isinstance(device, torch.device): + self.device = device + elif isinstance(device, str): + self.device = torch.device(device) + elif device < 0: + self.device = torch.device("cpu") + else: + self.device = torch.device(f"cuda:{device}") + else: + self.device = device + + super(Pipeline, self).__init__() + self.vad_model = vad + self._vad_params = vad_params + + def _sanitize_parameters(self, **kwargs): + preprocess_kwargs = {} + if "tokenizer" in kwargs: + preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] + return preprocess_kwargs, {}, {} + + def preprocess(self, audio): + audio = audio['inputs'] + model_n_mels = self.model.feat_kwargs.get("feature_size") + features = log_mel_spectrogram( + audio, + n_mels=model_n_mels if model_n_mels is not None else 80, + padding=N_SAMPLES - audio.shape[0], + ) + return {'inputs': features} + + def _forward(self, model_inputs): + outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) + return {'text': outputs} + + def postprocess(self, model_outputs): + return model_outputs + + def get_iterator( + self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params + ): + dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) + if "TOKENIZERS_PARALLELISM" not in os.environ: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + # TODO hack by collating feature_extractor and image_processor + + def stack(items): + return {'inputs': torch.stack([x['inputs'] for x in items])} + dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack) + model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) + final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params) + return final_iterator + + def transcribe( + self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False + ) -> TranscriptionResult: + if isinstance(audio, str): + audio = load_audio(audio) + + def data(audio, segments): + for seg in segments: + f1 = int(seg['start'] * SAMPLE_RATE) + f2 = int(seg['end'] * SAMPLE_RATE) + # print(f2-f1) + yield {'inputs': audio[f1:f2]} + + vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) + vad_segments = merge_chunks( + vad_segments, + chunk_size, + onset=self._vad_params["vad_onset"], + offset=self._vad_params["vad_offset"], + ) + if self.tokenizer is None: + language = language or self.detect_language(audio) + task = task or "transcribe" + self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, + self.model.model.is_multilingual, task=task, + language=language) + else: + language = language or self.tokenizer.language_code + task = task or self.tokenizer.task + if task != self.tokenizer.task or language != self.tokenizer.language_code: + self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, + self.model.model.is_multilingual, task=task, + language=language) + + if self.suppress_numerals: + previous_suppress_tokens = self.options.suppress_tokens + numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer) + print(f"Suppressing numeral and symbol tokens: {numeral_symbol_tokens}") + new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens + new_suppressed_tokens = list(set(new_suppressed_tokens)) + self.options = self.options._replace(suppress_tokens=new_suppressed_tokens) + + segments: List[SingleSegment] = [] + batch_size = batch_size or self._batch_size + total_segments = len(vad_segments) + for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): + if print_progress: + base_progress = ((idx + 1) / total_segments) * 100 + percent_complete = base_progress / 2 if combined_progress else base_progress + print(f"Progress: {percent_complete:.2f}%...") + text = out['text'] + if batch_size in [0, 1, None]: + text = text[0] + segments.append( + { + "text": text, + "start": round(vad_segments[idx]['start'], 3), + "end": round(vad_segments[idx]['end'], 3) + } + ) + + # revert the tokenizer if multilingual inference is enabled + if self.preset_language is None: + self.tokenizer = None + + # revert suppressed tokens if suppress_numerals is enabled + if self.suppress_numerals: + self.options = self.options._replace(suppress_tokens=previous_suppress_tokens) + + return {"segments": segments, "language": language} + + + def detect_language(self, audio: np.ndarray): + if audio.shape[0] < N_SAMPLES: + print("Warning: audio is shorter than 30s, language detection may be inaccurate.") + model_n_mels = self.model.feat_kwargs.get("feature_size") + segment = log_mel_spectrogram(audio[: N_SAMPLES], + n_mels=model_n_mels if model_n_mels is not None else 80, + padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) + encoder_output = self.model.encode(segment) + results = self.model.model.detect_language(encoder_output) + language_token, language_probability = results[0][0] + language = language_token[2:-2] + print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...") + return language + +def load_model(whisper_arch, + device, + device_index=0, + compute_type="float16", + asr_options=None, + language : Optional[str] = None, + vad_options=None, + model : Optional[WhisperModel] = None, + task="transcribe", + download_root=None, + threads=4): + '''Load a Whisper model for inference. + Args: + whisper_arch: str - The name of the Whisper model to load. + device: str - The device to load the model on. + compute_type: str - The compute type to use for the model. + options: dict - A dictionary of options to use for the model. + language: str - The language of the model. (use English for now) + model: Optional[WhisperModel] - The WhisperModel instance to use. + download_root: Optional[str] - The root directory to download the model to. + threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. + Returns: + A Whisper pipeline. + ''' + + if whisper_arch.endswith(".en"): + language = "en" + + model = model or WhisperModel(whisper_arch, + device=device, + device_index=device_index, + compute_type=compute_type, + download_root=download_root, + cpu_threads=threads) + if language is not None: + tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) + else: + print("No language specified, language will be first be detected for each audio file (increases inference time).") + tokenizer = None + + default_asr_options = { + "beam_size": 5, + "best_of": 5, + "patience": 1, + "length_penalty": 1, + "repetition_penalty": 1, + "no_repeat_ngram_size": 0, + "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], + "compression_ratio_threshold": 2.4, + "log_prob_threshold": -1.0, + "no_speech_threshold": 0.6, + "condition_on_previous_text": False, + "prompt_reset_on_temperature": 0.5, + "initial_prompt": None, + "prefix": None, + "suppress_blank": True, + "suppress_tokens": [-1], + "without_timestamps": True, + "max_initial_timestamp": 0.0, + "word_timestamps": False, + "prepend_punctuations": "\"'“¿([{-", + "append_punctuations": "\"'.。,,!!??::”)]}、", + "suppress_numerals": False, + } + + if asr_options is not None: + default_asr_options.update(asr_options) + + suppress_numerals = default_asr_options["suppress_numerals"] + del default_asr_options["suppress_numerals"] + + default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) + + default_vad_options = { + "vad_onset": 0.500, + "vad_offset": 0.363 + } + + if vad_options is not None: + default_vad_options.update(vad_options) + + vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options) + + return FasterWhisperPipeline( + model=model, + vad=vad_model, + options=default_asr_options, + tokenizer=tokenizer, + language=language, + suppress_numerals=suppress_numerals, + vad_params=default_vad_options, + ) diff --git a/build/lib/whisperx/assets/mel_filters.npz b/build/lib/whisperx/assets/mel_filters.npz new file mode 100644 index 0000000000000000000000000000000000000000..28ea26909dbdfd608aef67afc4d74d7961ae4bb6 GIT binary patch literal 4271 zcmZ`-cQjmYw;lx1g6JcN7QKe3LG%_Oh!VX=^k~teM-XGQ(Mu4$_Y%?jkm$lFBkB+( z3yfKIgF zxGiAhze`A@t->QRNVV!%P+W=o}VHkB) z%g>qyRHfN1IQ4-=`Y@0T9qE#o+;4E3VQ!epW1Xt=ZG`I3U|62t?<>5h*W|9VvJc`KZ+)ghnA**Z~ET21Tjf_f8oe`vy zZQNtlOx?dDhS71hnOus5cqj)hfyF@H&4y?@9z{I#&cf>A+s2~~(I>TQF}SaR3_tqa z(7&ZdN^vR*t<~?{9DEoI>0PL@Sl?wa?Z{rGX`*eEx9Nh=z*J3HZL1*Py4z$TD#+;m zSSW(kcOTe(4hqgib_W6&xx+j~-u(p)Nn6?>a%wHk=h7Ay$%lcGoo;gAY zmVV7|!Nb;w(PlH@c24{ple2Y3<*9J@jE=sfLzwu_BiAFPE$0Axp`^Nq!H}eG0?r-X zFj@Pwp^al*p>K{@_Cz`q#(N0Y=OpZy^ z{P$KjLJuk_Y%I)$mh`b{uOW5C5Xcmxk!gt_Zg zw>}6fkD4zRK9!#ems~H%U$>V;_wK38Zf-baU$S!#i;7!HWsi}GuC>%@?lMdgkUGC& zh9gC?O-5BlS2#}?7x0?eP#bOL(cqE{M%LJD$CZnplD)CgQR#KCttD=dZK+Ck5R52; z*%5hZ+SXU7)8k%Y^_1U>yI*By(INn&+ir-_4$#dUwTlMNyR@iGQIaZ+eiYqucu)CB z#i{Ru1w+aU#}DHSyzjG_9c?ToB_YjU#f;N=qel98WBIjIc1!#ePwRR+(go&-by#}@ z+M+klVke5b@lWfZ+O&|c??YvRe)&W)qAgtc>t-IZtbRTG#X}49_Q$>P%-)=0W_QY-x%DPep2Vm9#ci zyQcCc4p2&dLtV1@rPe!%>Y^#9W8#ZH&}^@wJKT7N;R9A7cEq&;Y2CYvd@R+Mn&b5O zVyfS^*H#kD74=J5uhD)o`TXoX>>Si$!cT?TXRxj2pB)w_ljjhTby&Je;X|BESZZT= zC%G5!-$BJf&a~U78d_3zBjrvrkJ0CCl@Rfcf7I(`VTNPnI^B#B$zOfPW zG&mEd?R0+W<`l08O1dkcWKS8wB!Z*Cs%I1nMs-EeB-uu5?t@PuD3|z>je8DKi#X(B z{Z=Rz{4X%?-UnxnHQtkELIZ&=J;fK_t}yu8|IxG0(85e&K>H3!!~zlhyJrgti~o1i zzBS*jTgdG~Exp#B-T)6A+PB ztD-e`j^@XAx}|L&JSEFkRvS_%3b%m86z02#Hfn{Y+qIqQ_muywgt?roUA7oiS1xBD zFxmDMsj_cbBcn*^rn^KIMP{AlHM`NiVm*D&`z~7FH#hf<$L3HmJ+=NdiY5>W?nKD? z8Ox6{9dKyI1o8a-j9BtV-|=lm`<`v>tR^Cln&x1dMYzu{@wq5KW!#K14_QMnpH5K%Pavag+g6(i8i-#Eq zguc}rH3?BxH4SOqZW#7m*aT(U9-n#_Xn^Q19(}eH!xG`nI!GYziVQNcA0)`FDHD%~ zz2$HnxW4BQ{#*@u`dssbAa`|fESn$8i8FdxGZh48_Uf~_Q@tv?4in)6fwSed)k&ITqu|){^(WL~J z?Lb|0ro06J^>f>^2}^e-+$u5bU4IZNfO?75v8lstS15%XYw2ac^pkU34{QhDR(umt zPu~`w2?FP|nn3!RWZ3{?=77@teulahD9*S*k5KmY3*adlM)%{SR~bkZYlx1q@fkE= zI$7+kiw5!ha=dYlO>Z5KgxnZEJsaBm%v#nkX0MN-h%n&KA?N}xU3K3o-3Jpk?ANq2n9&Lh%K_CTvfiN ze>6w~NSSl8$#NEZ^t7h9YOxI=zcAG|a+m6AWei`3Jw7K;b;T${pJa^4RwRt%F>?>M zBmoQqm1`<_W7i!5P~THp-II)Ka^u;=z;}d{;SVj{G_4`9^HaEb!=@Pa;Dw)CH^DjsGxFqmb%o$Bkop$KnH8 zDYN)Bh)5=5!-*|f0Gh4)oZG=TEBr()g^DCtSQhmT3!ZN`Qd-E%@1cE}hm8&Vq5B+C zVF2_O)9IiZ(v(xzTwJIg5|}KVuE(;}|7dVIrT`$d=q_OG|3PY}x*URYkMXXJ6PT1$IFkNyvY_(9UglDi6TaeikPS(!Bnij z;Szn+)I_oxnRz7(WTYTp+IHSWQ?Xd~tQn(Q1r)kThM?NM< z?d6LaBG!H}R$zRy!Ij(}1?xe^+o+!;tqWJ3NgjHl1XNxzusxQ0I#6qzM(_00UPMw* zF*GWW_q&fqAN=uimSKgBu_@jD%MX3hpNY|*4r=e=k1lw2r**IyD(hcq?A+HtUgUy4Dqh5D7|G9q{)TsUj{g~c!xy>9wk^(LiXA4VKGz_zMvJMX#AgsR z34T3hhJ)#&sUaQ1+0PML(?YA~{5?=(MT}X^Vib%};uoI{qGW@wgJ&_M+8S8clsNz2 zPQkxMi`#3+Khwtl>>K>wxc{71{&!qGu&Zzz_wU(7TLTyG){PAu?!cXs?Dp-y0Ekcn AQvd(} literal 0 HcmV?d00001 diff --git a/build/lib/whisperx/audio.py b/build/lib/whisperx/audio.py new file mode 100644 index 00000000..db210fb9 --- /dev/null +++ b/build/lib/whisperx/audio.py @@ -0,0 +1,159 @@ +import os +import subprocess +from functools import lru_cache +from typing import Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from .utils import exact_div + +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk +N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input + +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame +TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token + + +def load_audio(file: str, sr: int = SAMPLE_RATE): + """ + Open an audio file and read as mono waveform, resampling as necessary + + Parameters + ---------- + file: str + The audio file to open + + sr: int + The sample rate to resample the audio if necessary + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + try: + # Launches a subprocess to decode audio while down-mixing and resampling as necessary. + # Requires the ffmpeg CLI to be installed. + cmd = [ + "ffmpeg", + "-nostdin", + "-threads", + "0", + "-i", + file, + "-f", + "s16le", + "-ac", + "1", + "-acodec", + "pcm_s16le", + "-ar", + str(sr), + "-", + ] + out = subprocess.run(cmd, capture_output=True, check=True).stdout + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select( + dim=axis, index=torch.arange(length, device=array.device) + ) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}" + with np.load( + os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") + ) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + +def log_mel_spectrogram( + audio: Union[str, np.ndarray, torch.Tensor], + n_mels: int, + padding: int = 0, + device: Optional[Union[str, torch.device]] = None, +): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) + + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec diff --git a/build/lib/whisperx/conjunctions.py b/build/lib/whisperx/conjunctions.py new file mode 100644 index 00000000..a3d35ea6 --- /dev/null +++ b/build/lib/whisperx/conjunctions.py @@ -0,0 +1,43 @@ +# conjunctions.py + +conjunctions_by_language = { + 'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'}, + 'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', 'où', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusqu’à', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'}, + 'de': {'und', 'oder', 'aber', 'weil', 'obwohl', 'während', 'wenn', 'wo', 'wie', 'dass', 'bevor', 'nachdem', 'sobald', 'bis', 'außer', 'trotzdem', 'also', 'sowie', 'indem', 'weder', 'sowohl', 'zwar', 'jedoch'}, + 'es': {'y', 'o', 'pero', 'porque', 'aunque', 'sin', 'mientras', 'cuando', 'donde', 'como', 'si', 'que', 'antes', 'después', 'tan', 'hasta', 'a', 'a', 'por', 'ya', 'ni', 'sino'}, + 'it': {'e', 'o', 'ma', 'perché', 'anche', 'mentre', 'quando', 'dove', 'come', 'se', 'che', 'prima', 'dopo', 'appena', 'fino', 'a', 'nonostante', 'quindi', 'poiché', 'né', 'ossia', 'cioè'}, + 'ja': {'そして', 'または', 'しかし', 'なぜなら', 'もし', 'それとも', 'だから', 'それに', 'なのに', 'そのため', 'かつ', 'それゆえに', 'ならば', 'もしくは', 'ため'}, + 'zh': {'和', '或', '但是', '因为', '任何', '也', '虽然', '而且', '所以', '如果', '除非', '尽管', '既然', '即使', '只要', '直到', '然后', '因此', '不但', '而是', '不过'}, + 'nl': {'en', 'of', 'maar', 'omdat', 'hoewel', 'terwijl', 'wanneer', 'waar', 'zoals', 'als', 'dat', 'voordat', 'nadat', 'zodra', 'totdat', 'tenzij', 'ondanks', 'dus', 'zowel', 'noch', 'echter', 'toch'}, + 'uk': {'та', 'або', 'але', 'тому', 'хоча', 'поки', 'бо', 'коли', 'де', 'як', 'якщо', 'що', 'перш', 'після', 'доки', 'незважаючи', 'тому', 'ані'}, + 'pt': {'e', 'ou', 'mas', 'porque', 'embora', 'enquanto', 'quando', 'onde', 'como', 'se', 'que', 'antes', 'depois', 'assim', 'até', 'a', 'apesar', 'portanto', 'já', 'pois', 'nem', 'senão'}, + 'ar': {'و', 'أو', 'لكن', 'لأن', 'مع', 'بينما', 'عندما', 'حيث', 'كما', 'إذا', 'الذي', 'قبل', 'بعد', 'فور', 'حتى', 'إلا', 'رغم', 'لذلك', 'بما'}, + 'cs': {'a', 'nebo', 'ale', 'protože', 'ačkoli', 'zatímco', 'když', 'kde', 'jako', 'pokud', 'že', 'než', 'poté', 'jakmile', 'dokud', 'pokud ne', 'navzdory', 'tak', 'stejně', 'ani', 'tudíž'}, + 'ru': {'и', 'или', 'но', 'потому', 'хотя', 'пока', 'когда', 'где', 'как', 'если', 'что', 'перед', 'после', 'несмотря', 'таким', 'также', 'ни', 'зато'}, + 'pl': {'i', 'lub', 'ale', 'ponieważ', 'chociaż', 'podczas', 'kiedy', 'gdzie', 'jak', 'jeśli', 'że', 'zanim', 'po', 'jak tylko', 'dopóki', 'chyba', 'pomimo', 'więc', 'tak', 'ani', 'czyli'}, + 'hu': {'és', 'vagy', 'de', 'mert', 'habár', 'míg', 'amikor', 'ahol', 'ahogy', 'ha', 'hogy', 'mielőtt', 'miután', 'amint', 'amíg', 'hacsak', 'ellenére', 'tehát', 'úgy', 'sem', 'vagyis'}, + 'fi': {'ja', 'tai', 'mutta', 'koska', 'vaikka', 'kun', 'missä', 'kuten', 'jos', 'että', 'ennen', 'sen jälkeen', 'heti', 'kunnes', 'ellei', 'huolimatta', 'siis', 'sekä', 'eikä', 'vaan'}, + 'fa': {'و', 'یا', 'اما', 'چون', 'اگرچه', 'در حالی', 'وقتی', 'کجا', 'چگونه', 'اگر', 'که', 'قبل', 'پس', 'به محض', 'تا زمانی', 'مگر', 'با وجود', 'پس', 'همچنین', 'نه'}, + 'el': {'και', 'ή', 'αλλά', 'επειδή', 'αν', 'ενώ', 'όταν', 'όπου', 'όπως', 'αν', 'που', 'προτού', 'αφού', 'μόλις', 'μέχρι', 'εκτός', 'παρά', 'έτσι', 'όπως', 'ούτε', 'δηλαδή'}, + 'tr': {'ve', 'veya', 'ama', 'çünkü', 'her ne', 'iken', 'nerede', 'nasıl', 'eğer', 'ki', 'önce', 'sonra', 'hemen', 'kadar', 'rağmen', 'hem', 'ne', 'yani'}, + 'da': {'og', 'eller', 'men', 'fordi', 'selvom', 'mens', 'når', 'hvor', 'som', 'hvis', 'at', 'før', 'efter', 'indtil', 'medmindre', 'således', 'ligesom', 'hverken', 'altså'}, + 'he': {'ו', 'או', 'אבל', 'כי', 'אף', 'בזמן', 'כאשר', 'היכן', 'כיצד', 'אם', 'ש', 'לפני', 'אחרי', 'ברגע', 'עד', 'אלא', 'למרות', 'לכן', 'כמו', 'לא', 'אז'}, + 'vi': {'và', 'hoặc', 'nhưng', 'bởi', 'mặc', 'trong', 'khi', 'ở', 'như', 'nếu', 'rằng', 'trước', 'sau', 'ngay', 'cho', 'trừ', 'mặc', 'vì', 'giống', 'cũng', 'tức'}, + 'ko': {'그리고', '또는','그런데','그래도', '이나', '결국', '마지막으로', '마찬가지로', '반면에', '아니면', '거나', '또는', '그럼에도', '그렇기', '때문에', '덧붙이자면', '게다가', '그러나', '고', '그래서', '랑', '한다면', '하지만', '무엇', '왜냐하면', '비록', '동안', '언제', '어디서', '어떻게', '만약', '그', '전에', '후에', '즉시', '까지', '아니라면', '불구하고', '따라서', '같은', '도'}, + 'ur': {'اور', 'یا', 'مگر', 'کیونکہ', 'اگرچہ', 'جبکہ', 'جب', 'کہاں', 'کس طرح', 'اگر', 'کہ', 'سے پہلے', 'کے بعد', 'جیسے ہی', 'تک', 'اگر نہیں تو', 'کے باوجود', 'اس لئے', 'جیسے', 'نہ'}, + 'hi': {'और', 'या', 'पर', 'तो', 'न', 'फिर', 'हालांकि', 'चूंकि', 'अगर', 'कैसे', 'वह', 'से', 'जो', 'जहां', 'क्या', 'नजदीक', 'पहले', 'बाद', 'के', 'पार', 'माध्यम', 'तक', 'एक', 'जबकि', 'यहां', 'तक', 'दोनों', 'या', 'न', 'हालांकि'} + +} + +commas_by_language = { + 'ja': '、', + 'zh': ',', + 'fa': '،', + 'ur': '،' +} + +def get_conjunctions(lang_code): + return conjunctions_by_language.get(lang_code, set()) + +def get_comma(lang_code): + return commas_by_language.get(lang_code, ',') \ No newline at end of file diff --git a/build/lib/whisperx/diarize.py b/build/lib/whisperx/diarize.py new file mode 100644 index 00000000..c327c932 --- /dev/null +++ b/build/lib/whisperx/diarize.py @@ -0,0 +1,74 @@ +import numpy as np +import pandas as pd +from pyannote.audio import Pipeline +from typing import Optional, Union +import torch + +from .audio import load_audio, SAMPLE_RATE + + +class DiarizationPipeline: + def __init__( + self, + model_name="pyannote/speaker-diarization-3.1", + use_auth_token=None, + device: Optional[Union[str, torch.device]] = "cpu", + ): + if isinstance(device, str): + device = torch.device(device) + self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) + + def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None): + if isinstance(audio, str): + audio = load_audio(audio) + audio_data = { + 'waveform': torch.from_numpy(audio[None, :]), + 'sample_rate': SAMPLE_RATE + } + segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers) + diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker']) + diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start) + diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end) + return diarize_df + + +def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): + transcript_segments = transcript_result["segments"] + for seg in transcript_segments: + # assign speaker to segment (if any) + diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start']) + diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start']) + # remove no hit, otherwise we look for closest (even negative intersection...) + if not fill_nearest: + dia_tmp = diarize_df[diarize_df['intersection'] > 0] + else: + dia_tmp = diarize_df + if len(dia_tmp) > 0: + # sum over speakers + speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] + seg["speaker"] = speaker + + # assign speaker to words + if 'words' in seg: + for word in seg['words']: + if 'start' in word: + diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start']) + diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start']) + # remove no hit + if not fill_nearest: + dia_tmp = diarize_df[diarize_df['intersection'] > 0] + else: + dia_tmp = diarize_df + if len(dia_tmp) > 0: + # sum over speakers + speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] + word["speaker"] = speaker + + return transcript_result + + +class Segment: + def __init__(self, start, end, speaker=None): + self.start = start + self.end = end + self.speaker = speaker diff --git a/build/lib/whisperx/transcribe.py b/build/lib/whisperx/transcribe.py new file mode 100644 index 00000000..6fff837d --- /dev/null +++ b/build/lib/whisperx/transcribe.py @@ -0,0 +1,229 @@ +import argparse +import gc +import os +import warnings + +import numpy as np +import torch + +from .alignment import align, load_align_model +from .asr import load_model +from .audio import load_audio +from .diarize import DiarizationPipeline, assign_word_speakers +from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float, + optional_int, str2bool) + + +def cli(): + # fmt: off + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") + parser.add_argument("--model", default="small", help="name of the Whisper model to use") + parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference") + parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference") + parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation") + + parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") + parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced") + parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") + + parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") + parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") + + # alignment params + parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment") + parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.") + parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment") + parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file") + + # vad params + parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected") + parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.") + parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.") + + # diarization params + parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word") + parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file") + parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file") + + parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") + parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") + parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") + parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") + parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") + + parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") + parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly") + + parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") + parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") + parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") + + parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") + parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") + parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") + parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") + + parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line") + parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment") + parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt") + parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line") + + parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") + + parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") + + parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.") + # fmt: on + + args = parser.parse_args().__dict__ + model_name: str = args.pop("model") + batch_size: int = args.pop("batch_size") + output_dir: str = args.pop("output_dir") + output_format: str = args.pop("output_format") + device: str = args.pop("device") + device_index: int = args.pop("device_index") + compute_type: str = args.pop("compute_type") + + # model_flush: bool = args.pop("model_flush") + os.makedirs(output_dir, exist_ok=True) + + align_model: str = args.pop("align_model") + interpolate_method: str = args.pop("interpolate_method") + no_align: bool = args.pop("no_align") + task : str = args.pop("task") + if task == "translate": + # translation cannot be aligned + no_align = True + + return_char_alignments: bool = args.pop("return_char_alignments") + + hf_token: str = args.pop("hf_token") + vad_onset: float = args.pop("vad_onset") + vad_offset: float = args.pop("vad_offset") + + chunk_size: int = args.pop("chunk_size") + + diarize: bool = args.pop("diarize") + min_speakers: int = args.pop("min_speakers") + max_speakers: int = args.pop("max_speakers") + print_progress: bool = args.pop("print_progress") + + if args["language"] is not None: + args["language"] = args["language"].lower() + if args["language"] not in LANGUAGES: + if args["language"] in TO_LANGUAGE_CODE: + args["language"] = TO_LANGUAGE_CODE[args["language"]] + else: + raise ValueError(f"Unsupported language: {args['language']}") + + if model_name.endswith(".en") and args["language"] != "en": + if args["language"] is not None: + warnings.warn( + f"{model_name} is an English-only model but received '{args['language']}'; using English instead." + ) + args["language"] = "en" + align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified + + temperature = args.pop("temperature") + if (increment := args.pop("temperature_increment_on_fallback")) is not None: + temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment)) + else: + temperature = [temperature] + + faster_whisper_threads = 4 + if (threads := args.pop("threads")) > 0: + torch.set_num_threads(threads) + faster_whisper_threads = threads + + asr_options = { + "beam_size": args.pop("beam_size"), + "patience": args.pop("patience"), + "length_penalty": args.pop("length_penalty"), + "temperatures": temperature, + "compression_ratio_threshold": args.pop("compression_ratio_threshold"), + "log_prob_threshold": args.pop("logprob_threshold"), + "no_speech_threshold": args.pop("no_speech_threshold"), + "condition_on_previous_text": False, + "initial_prompt": args.pop("initial_prompt"), + "suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")], + "suppress_numerals": args.pop("suppress_numerals"), + } + + writer = get_writer(output_format, output_dir) + word_options = ["highlight_words", "max_line_count", "max_line_width"] + if no_align: + for option in word_options: + if args[option]: + parser.error(f"--{option} not possible with --no_align") + if args["max_line_count"] and not args["max_line_width"]: + warnings.warn("--max_line_count has no effect without --max_line_width") + writer_args = {arg: args.pop(arg) for arg in word_options} + + # Part 1: VAD & ASR Loop + results = [] + tmp_results = [] + # model = load_model(model_name, device=device, download_root=model_dir) + model = load_model(model_name, device=device, device_index=device_index, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads) + + for audio_path in args.pop("audio"): + audio = load_audio(audio_path) + # >> VAD & ASR + print(">>Performing transcription...") + result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size, print_progress=print_progress) + results.append((result, audio_path)) + + # Unload Whisper and VAD + del model + gc.collect() + torch.cuda.empty_cache() + + # Part 2: Align Loop + if not no_align: + tmp_results = results + results = [] + align_model, align_metadata = load_align_model(align_language, device, model_name=align_model) + for result, audio_path in tmp_results: + # >> Align + if len(tmp_results) > 1: + input_audio = audio_path + else: + # lazily load audio from part 1 + input_audio = audio + + if align_model is not None and len(result["segments"]) > 0: + if result.get("language", "en") != align_metadata["language"]: + # load new language + print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...") + align_model, align_metadata = load_align_model(result["language"], device) + print(">>Performing alignment...") + result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments, print_progress=print_progress) + + results.append((result, audio_path)) + + # Unload align model + del align_model + gc.collect() + torch.cuda.empty_cache() + + # >> Diarize + if diarize: + if hf_token is None: + print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...") + tmp_results = results + print(">>Performing diarization...") + results = [] + diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) + for result, input_audio_path in tmp_results: + diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) + result = assign_word_speakers(diarize_segments, result) + results.append((result, input_audio_path)) + # >> Write + for result, audio_path in results: + result["language"] = align_language + writer(result, audio_path, writer_args) + +if __name__ == "__main__": + cli() diff --git a/build/lib/whisperx/types.py b/build/lib/whisperx/types.py new file mode 100644 index 00000000..68f2d783 --- /dev/null +++ b/build/lib/whisperx/types.py @@ -0,0 +1,58 @@ +from typing import TypedDict, Optional, List + + +class SingleWordSegment(TypedDict): + """ + A single word of a speech. + """ + word: str + start: float + end: float + score: float + +class SingleCharSegment(TypedDict): + """ + A single char of a speech. + """ + char: str + start: float + end: float + score: float + + +class SingleSegment(TypedDict): + """ + A single segment (up to multiple sentences) of a speech. + """ + + start: float + end: float + text: str + + +class SingleAlignedSegment(TypedDict): + """ + A single segment (up to multiple sentences) of a speech with word alignment. + """ + + start: float + end: float + text: str + words: List[SingleWordSegment] + chars: Optional[List[SingleCharSegment]] + + +class TranscriptionResult(TypedDict): + """ + A list of segments and word segments of a speech. + """ + segments: List[SingleSegment] + language: str + + +class AlignedTranscriptionResult(TypedDict): + """ + A list of segments and word segments of a speech. + """ + segments: List[SingleAlignedSegment] + word_segments: List[SingleWordSegment] diff --git a/build/lib/whisperx/utils.py b/build/lib/whisperx/utils.py new file mode 100644 index 00000000..16ce116e --- /dev/null +++ b/build/lib/whisperx/utils.py @@ -0,0 +1,437 @@ +import json +import os +import re +import sys +import zlib +from typing import Callable, Optional, TextIO + +LANGUAGES = { + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", + "yue": "cantonese", +} + +# language code lookup by name, with a few language aliases +TO_LANGUAGE_CODE = { + **{language: code for code, language in LANGUAGES.items()}, + "burmese": "my", + "valencian": "ca", + "flemish": "nl", + "haitian": "ht", + "letzeburgesch": "lb", + "pushto": "ps", + "panjabi": "pa", + "moldavian": "ro", + "moldovan": "ro", + "sinhalese": "si", + "castilian": "es", +} + +LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] + +system_encoding = sys.getdefaultencoding() + +if system_encoding != "utf-8": + + def make_safe(string): + # replaces any character not representable using the system default encoding with an '?', + # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). + return string.encode(system_encoding, errors="replace").decode(system_encoding) + +else: + + def make_safe(string): + # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding + return string + + +def exact_div(x, y): + assert x % y == 0 + return x // y + + +def str2bool(string): + str2val = {"True": True, "False": False} + if string in str2val: + return str2val[string] + else: + raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") + + +def optional_int(string): + return None if string == "None" else int(string) + + +def optional_float(string): + return None if string == "None" else float(string) + + +def compression_ratio(text) -> float: + text_bytes = text.encode("utf-8") + return len(text_bytes) / len(zlib.compress(text_bytes)) + + +def format_timestamp( + seconds: float, always_include_hours: bool = False, decimal_marker: str = "." +): + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return ( + f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + ) + + +class ResultWriter: + extension: str + + def __init__(self, output_dir: str): + self.output_dir = output_dir + + def __call__(self, result: dict, audio_path: str, options: dict): + audio_basename = os.path.basename(audio_path) + audio_basename = os.path.splitext(audio_basename)[0] + output_path = os.path.join( + self.output_dir, audio_basename + "." + self.extension + ) + + with open(output_path, "w", encoding="utf-8") as f: + self.write_result(result, file=f, options=options) + + def write_result(self, result: dict, file: TextIO, options: dict): + raise NotImplementedError + + +class WriteTXT(ResultWriter): + extension: str = "txt" + + def write_result(self, result: dict, file: TextIO, options: dict): + for segment in result["segments"]: + print(segment["text"].strip(), file=file, flush=True) + + +class SubtitlesWriter(ResultWriter): + always_include_hours: bool + decimal_marker: str + + def iterate_result(self, result: dict, options: dict): + raw_max_line_width: Optional[int] = options["max_line_width"] + max_line_count: Optional[int] = options["max_line_count"] + highlight_words: bool = options["highlight_words"] + max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width + preserve_segments = max_line_count is None or raw_max_line_width is None + + if len(result["segments"]) == 0: + return + + def iterate_subtitles(): + line_len = 0 + line_count = 1 + # the next subtitle to yield (a list of word timings with whitespace) + subtitle: list[dict] = [] + times = [] + last = result["segments"][0]["start"] + for segment in result["segments"]: + for i, original_timing in enumerate(segment["words"]): + timing = original_timing.copy() + long_pause = not preserve_segments + if "start" in timing: + long_pause = long_pause and timing["start"] - last > 3.0 + else: + long_pause = False + has_room = line_len + len(timing["word"]) <= max_line_width + seg_break = i == 0 and len(subtitle) > 0 and preserve_segments + if line_len > 0 and has_room and not long_pause and not seg_break: + # line continuation + line_len += len(timing["word"]) + else: + # new line + timing["word"] = timing["word"].strip() + if ( + len(subtitle) > 0 + and max_line_count is not None + and (long_pause or line_count >= max_line_count) + or seg_break + ): + # subtitle break + yield subtitle, times + subtitle = [] + times = [] + line_count = 1 + elif line_len > 0: + # line break + line_count += 1 + timing["word"] = "\n" + timing["word"] + line_len = len(timing["word"].strip()) + subtitle.append(timing) + times.append((segment["start"], segment["end"], segment.get("speaker"))) + if "start" in timing: + last = timing["start"] + if len(subtitle) > 0: + yield subtitle, times + + if "words" in result["segments"][0]: + for subtitle, _ in iterate_subtitles(): + sstart, ssend, speaker = _[0] + subtitle_start = self.format_timestamp(sstart) + subtitle_end = self.format_timestamp(ssend) + if result["language"] in LANGUAGES_WITHOUT_SPACES: + subtitle_text = "".join([word["word"] for word in subtitle]) + else: + subtitle_text = " ".join([word["word"] for word in subtitle]) + has_timing = any(["start" in word for word in subtitle]) + + # add [$SPEAKER_ID]: to each subtitle if speaker is available + prefix = "" + if speaker is not None: + prefix = f"[{speaker}]: " + + if highlight_words and has_timing: + last = subtitle_start + all_words = [timing["word"] for timing in subtitle] + for i, this_word in enumerate(subtitle): + if "start" in this_word: + start = self.format_timestamp(this_word["start"]) + end = self.format_timestamp(this_word["end"]) + if last != start: + yield last, start, prefix + subtitle_text + + yield start, end, prefix + " ".join( + [ + re.sub(r"^(\s*)(.*)$", r"\1\2", word) + if j == i + else word + for j, word in enumerate(all_words) + ] + ) + last = end + else: + yield subtitle_start, subtitle_end, prefix + subtitle_text + else: + for segment in result["segments"]: + segment_start = self.format_timestamp(segment["start"]) + segment_end = self.format_timestamp(segment["end"]) + segment_text = segment["text"].strip().replace("-->", "->") + if "speaker" in segment: + segment_text = f"[{segment['speaker']}]: {segment_text}" + yield segment_start, segment_end, segment_text + + def format_timestamp(self, seconds: float): + return format_timestamp( + seconds=seconds, + always_include_hours=self.always_include_hours, + decimal_marker=self.decimal_marker, + ) + + +class WriteVTT(SubtitlesWriter): + extension: str = "vtt" + always_include_hours: bool = False + decimal_marker: str = "." + + def write_result(self, result: dict, file: TextIO, options: dict): + print("WEBVTT\n", file=file) + for start, end, text in self.iterate_result(result, options): + print(f"{start} --> {end}\n{text}\n", file=file, flush=True) + + +class WriteSRT(SubtitlesWriter): + extension: str = "srt" + always_include_hours: bool = True + decimal_marker: str = "," + + def write_result(self, result: dict, file: TextIO, options: dict): + for i, (start, end, text) in enumerate( + self.iterate_result(result, options), start=1 + ): + print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) + + +class WriteTSV(ResultWriter): + """ + Write a transcript to a file in TSV (tab-separated values) format containing lines like: + \t\t + + Using integer milliseconds as start and end times means there's no chance of interference from + an environment setting a language encoding that causes the decimal in a floating point number + to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. + """ + + extension: str = "tsv" + + def write_result(self, result: dict, file: TextIO, options: dict): + print("start", "end", "text", sep="\t", file=file) + for segment in result["segments"]: + print(round(1000 * segment["start"]), file=file, end="\t") + print(round(1000 * segment["end"]), file=file, end="\t") + print(segment["text"].strip().replace("\t", " "), file=file, flush=True) + +class WriteAudacity(ResultWriter): + """ + Write a transcript to a text file that audacity can import as labels. + The extension used is "aud" to distinguish it from the txt file produced by WriteTXT. + Yet this is not an audacity project but only a label file! + + Please note : Audacity uses seconds in timestamps not ms! + Also there is no header expected. + + If speaker is provided it is prepended to the text between double square brackets [[]]. + """ + + extension: str = "aud" + + def write_result(self, result: dict, file: TextIO, options: dict): + ARROW = " " + for segment in result["segments"]: + print(segment["start"], file=file, end=ARROW) + print(segment["end"], file=file, end=ARROW) + print( ( ("[[" + segment["speaker"] + "]]") if "speaker" in segment else "") + segment["text"].strip().replace("\t", " "), file=file, flush=True) + + + +class WriteJSON(ResultWriter): + extension: str = "json" + + def write_result(self, result: dict, file: TextIO, options: dict): + json.dump(result, file, ensure_ascii=False) + + +def get_writer( + output_format: str, output_dir: str +) -> Callable[[dict, TextIO, dict], None]: + writers = { + "txt": WriteTXT, + "vtt": WriteVTT, + "srt": WriteSRT, + "tsv": WriteTSV, + "json": WriteJSON, + } + optional_writers = { + "aud": WriteAudacity, + } + + if output_format == "all": + all_writers = [writer(output_dir) for writer in writers.values()] + + def write_all(result: dict, file: TextIO, options: dict): + for writer in all_writers: + writer(result, file, options) + + return write_all + + if output_format in optional_writers: + return optional_writers[output_format](output_dir) + return writers[output_format](output_dir) + +def interpolate_nans(x, method='nearest'): + if x.notnull().sum() > 1: + return x.interpolate(method=method).ffill().bfill() + else: + return x.ffill().bfill() diff --git a/build/lib/whisperx/vad.py b/build/lib/whisperx/vad.py new file mode 100644 index 00000000..ab2c7bbf --- /dev/null +++ b/build/lib/whisperx/vad.py @@ -0,0 +1,311 @@ +import hashlib +import os +import urllib +from typing import Callable, Optional, Text, Union + +import numpy as np +import pandas as pd +import torch +from pyannote.audio import Model +from pyannote.audio.core.io import AudioFile +from pyannote.audio.pipelines import VoiceActivityDetection +from pyannote.audio.pipelines.utils import PipelineModel +from pyannote.core import Annotation, Segment, SlidingWindowFeature +from tqdm import tqdm + +from .diarize import Segment as SegmentX + +VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin" + +def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None): + model_dir = torch.hub._get_torch_home() + os.makedirs(model_dir, exist_ok = True) + if model_fp is None: + model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin") + if os.path.exists(model_fp) and not os.path.isfile(model_fp): + raise RuntimeError(f"{model_fp} exists and is not a regular file") + + if not os.path.isfile(model_fp): + with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + model_bytes = open(model_fp, "rb").read() + if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]: + raise RuntimeError( + "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." + ) + + vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token) + hyperparameters = {"onset": vad_onset, + "offset": vad_offset, + "min_duration_on": 0.1, + "min_duration_off": 0.1} + vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device)) + vad_pipeline.instantiate(hyperparameters) + + return vad_pipeline + +class Binarize: + """Binarize detection scores using hysteresis thresholding, with min-cut operation + to ensure not segments are longer than max_duration. + + Parameters + ---------- + onset : float, optional + Onset threshold. Defaults to 0.5. + offset : float, optional + Offset threshold. Defaults to `onset`. + min_duration_on : float, optional + Remove active regions shorter than that many seconds. Defaults to 0s. + min_duration_off : float, optional + Fill inactive regions shorter than that many seconds. Defaults to 0s. + pad_onset : float, optional + Extend active regions by moving their start time by that many seconds. + Defaults to 0s. + pad_offset : float, optional + Extend active regions by moving their end time by that many seconds. + Defaults to 0s. + max_duration: float + The maximum length of an active segment, divides segment at timestamp with lowest score. + Reference + --------- + Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of + RNN-based Voice Activity Detection", InterSpeech 2015. + + Modified by Max Bain to include WhisperX's min-cut operation + https://arxiv.org/abs/2303.00747 + + Pyannote-audio + """ + + def __init__( + self, + onset: float = 0.5, + offset: Optional[float] = None, + min_duration_on: float = 0.0, + min_duration_off: float = 0.0, + pad_onset: float = 0.0, + pad_offset: float = 0.0, + max_duration: float = float('inf') + ): + + super().__init__() + + self.onset = onset + self.offset = offset or onset + + self.pad_onset = pad_onset + self.pad_offset = pad_offset + + self.min_duration_on = min_duration_on + self.min_duration_off = min_duration_off + + self.max_duration = max_duration + + def __call__(self, scores: SlidingWindowFeature) -> Annotation: + """Binarize detection scores + Parameters + ---------- + scores : SlidingWindowFeature + Detection scores. + Returns + ------- + active : Annotation + Binarized scores. + """ + + num_frames, num_classes = scores.data.shape + frames = scores.sliding_window + timestamps = [frames[i].middle for i in range(num_frames)] + + # annotation meant to store 'active' regions + active = Annotation() + for k, k_scores in enumerate(scores.data.T): + + label = k if scores.labels is None else scores.labels[k] + + # initial state + start = timestamps[0] + is_active = k_scores[0] > self.onset + curr_scores = [k_scores[0]] + curr_timestamps = [start] + t = start + for t, y in zip(timestamps[1:], k_scores[1:]): + # currently active + if is_active: + curr_duration = t - start + if curr_duration > self.max_duration: + search_after = len(curr_scores) // 2 + # divide segment + min_score_div_idx = search_after + np.argmin(curr_scores[search_after:]) + min_score_t = curr_timestamps[min_score_div_idx] + region = Segment(start - self.pad_onset, min_score_t + self.pad_offset) + active[region, k] = label + start = curr_timestamps[min_score_div_idx] + curr_scores = curr_scores[min_score_div_idx+1:] + curr_timestamps = curr_timestamps[min_score_div_idx+1:] + # switching from active to inactive + elif y < self.offset: + region = Segment(start - self.pad_onset, t + self.pad_offset) + active[region, k] = label + start = t + is_active = False + curr_scores = [] + curr_timestamps = [] + curr_scores.append(y) + curr_timestamps.append(t) + # currently inactive + else: + # switching from inactive to active + if y > self.onset: + start = t + is_active = True + + # if active at the end, add final region + if is_active: + region = Segment(start - self.pad_onset, t + self.pad_offset) + active[region, k] = label + + # because of padding, some active regions might be overlapping: merge them. + # also: fill same speaker gaps shorter than min_duration_off + if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0: + if self.max_duration < float("inf"): + raise NotImplementedError(f"This would break current max_duration param") + active = active.support(collar=self.min_duration_off) + + # remove tracks shorter than min_duration_on + if self.min_duration_on > 0: + for segment, track in list(active.itertracks()): + if segment.duration < self.min_duration_on: + del active[segment, track] + + return active + + +class VoiceActivitySegmentation(VoiceActivityDetection): + def __init__( + self, + segmentation: PipelineModel = "pyannote/segmentation", + fscore: bool = False, + use_auth_token: Union[Text, None] = None, + **inference_kwargs, + ): + + super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs) + + def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation: + """Apply voice activity detection + + Parameters + ---------- + file : AudioFile + Processed file. + hook : callable, optional + Hook called after each major step of the pipeline with the following + signature: hook("step_name", step_artefact, file=file) + + Returns + ------- + speech : Annotation + Speech regions. + """ + + # setup hook (e.g. for debugging purposes) + hook = self.setup_hook(file, hook=hook) + + # apply segmentation model (only if needed) + # output shape is (num_chunks, num_frames, 1) + if self.training: + if self.CACHED_SEGMENTATION in file: + segmentations = file[self.CACHED_SEGMENTATION] + else: + segmentations = self._segmentation(file) + file[self.CACHED_SEGMENTATION] = segmentations + else: + segmentations: SlidingWindowFeature = self._segmentation(file) + + return segmentations + + +def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0): + + active = Annotation() + for k, vad_t in enumerate(vad_arr): + region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset) + active[region, k] = 1 + + + if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0: + active = active.support(collar=min_duration_off) + + # remove tracks shorter than min_duration_on + if min_duration_on > 0: + for segment, track in list(active.itertracks()): + if segment.duration < min_duration_on: + del active[segment, track] + + active = active.for_json() + active_segs = pd.DataFrame([x['segment'] for x in active['content']]) + return active_segs + +def merge_chunks( + segments, + chunk_size, + onset: float = 0.5, + offset: Optional[float] = None, +): + """ + Merge operation described in paper + """ + curr_end = 0 + merged_segments = [] + seg_idxs = [] + speaker_idxs = [] + + assert chunk_size > 0 + binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset) + segments = binarize(segments) + segments_list = [] + for speech_turn in segments.get_timeline(): + segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN")) + + if len(segments_list) == 0: + print("No active speech found in audio") + return [] + # assert segments_list, "segments_list is empty." + # Make sur the starting point is the start of the segment. + curr_start = segments_list[0].start + + for seg in segments_list: + if seg.end - curr_start > chunk_size and curr_end-curr_start > 0: + merged_segments.append({ + "start": curr_start, + "end": curr_end, + "segments": seg_idxs, + }) + curr_start = seg.start + seg_idxs = [] + speaker_idxs = [] + curr_end = seg.end + seg_idxs.append((seg.start, seg.end)) + speaker_idxs.append(seg.speaker) + # add final + merged_segments.append({ + "start": curr_start, + "end": curr_end, + "segments": seg_idxs, + }) + return merged_segments diff --git a/setup.py b/setup.py index 989e0643..40db6cc9 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ open(os.path.join(os.path.dirname(__file__), "requirements.txt")) ) ] - + [f"pyannote.audio==3.1.0"], + + [f"pyannote.audio==3.1.1"], entry_points={ "console_scripts": ["whisperx=whisperx.transcribe:cli"], }, diff --git a/whisperx/diarize.py b/whisperx/diarize.py index c1e30bec..c327c932 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -18,14 +18,14 @@ def __init__( device = torch.device(device) self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) - def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None): + def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None): if isinstance(audio, str): audio = load_audio(audio) audio_data = { 'waveform': torch.from_numpy(audio[None, :]), 'sample_rate': SAMPLE_RATE } - segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers) + segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers) diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker']) diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start) diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)