forked from m-bain/whisperX
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update setup.py to install pyannote.audio==3.1.1, update diarize.py t…
…o include num_speakers; to fix Issue m-bain#592
- Loading branch information
1 parent
3171bb5
commit 80341c3
Showing
15 changed files
with
2,366 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
import math | ||
from conjunctions import get_conjunctions, get_comma | ||
from typing import TextIO | ||
|
||
def normal_round(n): | ||
if n - math.floor(n) < 0.5: | ||
return math.floor(n) | ||
return math.ceil(n) | ||
|
||
|
||
def format_timestamp(seconds: float, is_vtt: bool = False): | ||
|
||
assert seconds >= 0, "non-negative timestamp expected" | ||
milliseconds = round(seconds * 1000.0) | ||
|
||
hours = milliseconds // 3_600_000 | ||
milliseconds -= hours * 3_600_000 | ||
|
||
minutes = milliseconds // 60_000 | ||
milliseconds -= minutes * 60_000 | ||
|
||
seconds = milliseconds // 1_000 | ||
milliseconds -= seconds * 1_000 | ||
|
||
separator = '.' if is_vtt else ',' | ||
|
||
hours_marker = f"{hours:02d}:" | ||
return ( | ||
f"{hours_marker}{minutes:02d}:{seconds:02d}{separator}{milliseconds:03d}" | ||
) | ||
|
||
|
||
|
||
class SubtitlesProcessor: | ||
def __init__(self, segments, lang, max_line_length = 45, min_char_length_splitter = 30, is_vtt = False): | ||
self.comma = get_comma(lang) | ||
self.conjunctions = set(get_conjunctions(lang)) | ||
self.segments = segments | ||
self.lang = lang | ||
self.max_line_length = max_line_length | ||
self.min_char_length_splitter = min_char_length_splitter | ||
self.is_vtt = is_vtt | ||
complex_script_languages = ['th', 'lo', 'my', 'km', 'am', 'ko', 'ja', 'zh', 'ti', 'ta', 'te', 'kn', 'ml', 'hi', 'ne', 'mr', 'ar', 'fa', 'ur', 'ka'] | ||
if self.lang in complex_script_languages: | ||
self.max_line_length = 30 | ||
self.min_char_length_splitter = 20 | ||
|
||
def estimate_timestamp_for_word(self, words, i, next_segment_start_time=None): | ||
k = 0.25 | ||
has_prev_end = i > 0 and 'end' in words[i - 1] | ||
has_next_start = i < len(words) - 1 and 'start' in words[i + 1] | ||
|
||
if has_prev_end: | ||
words[i]['start'] = words[i - 1]['end'] | ||
if has_next_start: | ||
words[i]['end'] = words[i + 1]['start'] | ||
else: | ||
if next_segment_start_time: | ||
words[i]['end'] = next_segment_start_time if next_segment_start_time - words[i - 1]['end'] <= 1 else next_segment_start_time - 0.5 | ||
else: | ||
words[i]['end'] = words[i]['start'] + len(words[i]['word']) * k | ||
|
||
elif has_next_start: | ||
words[i]['start'] = words[i + 1]['start'] - len(words[i]['word']) * k | ||
words[i]['end'] = words[i + 1]['start'] | ||
|
||
else: | ||
if next_segment_start_time: | ||
words[i]['start'] = next_segment_start_time - 1 | ||
words[i]['end'] = next_segment_start_time - 0.5 | ||
else: | ||
words[i]['start'] = 0 | ||
words[i]['end'] = 0 | ||
|
||
|
||
|
||
def process_segments(self, advanced_splitting=True): | ||
subtitles = [] | ||
for i, segment in enumerate(self.segments): | ||
next_segment_start_time = self.segments[i + 1]['start'] if i + 1 < len(self.segments) else None | ||
|
||
if advanced_splitting: | ||
|
||
split_points = self.determine_advanced_split_points(segment, next_segment_start_time) | ||
subtitles.extend(self.generate_subtitles_from_split_points(segment, split_points, next_segment_start_time)) | ||
else: | ||
words = segment['words'] | ||
for i, word in enumerate(words): | ||
if 'start' not in word or 'end' not in word: | ||
self.estimate_timestamp_for_word(words, i, next_segment_start_time) | ||
|
||
subtitles.append({ | ||
'start': segment['start'], | ||
'end': segment['end'], | ||
'text': segment['text'] | ||
}) | ||
|
||
return subtitles | ||
|
||
def determine_advanced_split_points(self, segment, next_segment_start_time=None): | ||
split_points = [] | ||
last_split_point = 0 | ||
char_count = 0 | ||
|
||
words = segment.get('words', segment['text'].split()) | ||
add_space = 0 if self.lang in ['zh', 'ja'] else 1 | ||
|
||
total_char_count = sum(len(word['word']) if isinstance(word, dict) else len(word) + add_space for word in words) | ||
char_count_after = total_char_count | ||
|
||
for i, word in enumerate(words): | ||
word_text = word['word'] if isinstance(word, dict) else word | ||
word_length = len(word_text) + add_space | ||
char_count += word_length | ||
char_count_after -= word_length | ||
|
||
char_count_before = char_count - word_length | ||
|
||
if isinstance(word, dict) and ('start' not in word or 'end' not in word): | ||
self.estimate_timestamp_for_word(words, i, next_segment_start_time) | ||
|
||
if char_count >= self.max_line_length: | ||
midpoint = normal_round((last_split_point + i) / 2) | ||
if char_count_before >= self.min_char_length_splitter: | ||
split_points.append(midpoint) | ||
last_split_point = midpoint + 1 | ||
char_count = sum(len(words[j]['word']) if isinstance(words[j], dict) else len(words[j]) + add_space for j in range(last_split_point, i + 1)) | ||
|
||
elif word_text.endswith(self.comma) and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter: | ||
split_points.append(i) | ||
last_split_point = i + 1 | ||
char_count = 0 | ||
|
||
elif word_text.lower() in self.conjunctions and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter: | ||
split_points.append(i - 1) | ||
last_split_point = i | ||
char_count = word_length | ||
|
||
return split_points | ||
|
||
|
||
def generate_subtitles_from_split_points(self, segment, split_points, next_start_time=None): | ||
subtitles = [] | ||
|
||
words = segment.get('words', segment['text'].split()) | ||
total_word_count = len(words) | ||
total_time = segment['end'] - segment['start'] | ||
elapsed_time = segment['start'] | ||
prefix = ' ' if self.lang not in ['zh', 'ja'] else '' | ||
start_idx = 0 | ||
for split_point in split_points: | ||
|
||
fragment_words = words[start_idx:split_point + 1] | ||
current_word_count = len(fragment_words) | ||
|
||
|
||
if isinstance(fragment_words[0], dict): | ||
start_time = fragment_words[0]['start'] | ||
end_time = fragment_words[-1]['end'] | ||
next_start_time_for_word = words[split_point + 1]['start'] if split_point + 1 < len(words) else None | ||
if next_start_time_for_word and (next_start_time_for_word - end_time) <= 0.8: | ||
end_time = next_start_time_for_word | ||
else: | ||
fragment = prefix.join(fragment_words).strip() | ||
current_duration = (current_word_count / total_word_count) * total_time | ||
start_time = elapsed_time | ||
end_time = elapsed_time + current_duration | ||
elapsed_time += current_duration | ||
|
||
|
||
subtitles.append({ | ||
'start': start_time, | ||
'end': end_time, | ||
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words) | ||
}) | ||
|
||
start_idx = split_point + 1 | ||
|
||
# Handle the last fragment | ||
if start_idx < len(words): | ||
fragment_words = words[start_idx:] | ||
current_word_count = len(fragment_words) | ||
|
||
if isinstance(fragment_words[0], dict): | ||
start_time = fragment_words[0]['start'] | ||
end_time = fragment_words[-1]['end'] | ||
else: | ||
fragment = prefix.join(fragment_words).strip() | ||
current_duration = (current_word_count / total_word_count) * total_time | ||
start_time = elapsed_time | ||
end_time = elapsed_time + current_duration | ||
|
||
if next_start_time and (next_start_time - end_time) <= 0.8: | ||
end_time = next_start_time | ||
|
||
subtitles.append({ | ||
'start': start_time, | ||
'end': end_time if end_time is not None else segment['end'], | ||
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words) | ||
}) | ||
|
||
return subtitles | ||
|
||
|
||
|
||
def save(self, filename="subtitles.srt", advanced_splitting=True): | ||
|
||
subtitles = self.process_segments(advanced_splitting) | ||
|
||
def write_subtitle(file, idx, start_time, end_time, text): | ||
|
||
file.write(f"{idx}\n") | ||
file.write(f"{start_time} --> {end_time}\n") | ||
file.write(text + "\n\n") | ||
|
||
with open(filename, 'w', encoding='utf-8') as file: | ||
if self.is_vtt: | ||
file.write("WEBVTT\n\n") | ||
|
||
if advanced_splitting: | ||
for idx, subtitle in enumerate(subtitles, 1): | ||
start_time = format_timestamp(subtitle['start'], self.is_vtt) | ||
end_time = format_timestamp(subtitle['end'], self.is_vtt) | ||
text = subtitle['text'].strip() | ||
write_subtitle(file, idx, start_time, end_time, text) | ||
|
||
return len(subtitles) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .transcribe import load_model | ||
from .alignment import load_align_model, align | ||
from .audio import load_audio | ||
from .diarize import assign_word_speakers, DiarizationPipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .transcribe import cli | ||
|
||
|
||
cli() |
Oops, something went wrong.