Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jordimas committed Jan 1, 2024
1 parent 1f948cb commit c3ea968
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 12 deletions.
5 changes: 1 addition & 4 deletions src/whisper_ctranslate2/diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from faster_whisper.audio import decode_audio



class Diarization:
def __init__(
self,
Expand All @@ -22,9 +21,7 @@ def __init__(
model_name, use_auth_token=use_auth_token
).to(device)

def __call__(
self, audio: str, min_speakers=None, max_speakers=None
):
def __call__(self, audio: str, min_speakers=None, max_speakers=None):
audio = decode_audio(audio)
audio_data = {
"waveform": torch.from_numpy(audio[None, :]),
Expand Down
2 changes: 1 addition & 1 deletion src/whisper_ctranslate2/whisper_ctranslate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def main():
print(f"Time used for inference: {datetime.datetime.now() - start_time}")

start_time = datetime.datetime.now()

if len(hf_token) > 0:
diarize_segments = diarize_model(
audio_path, min_speakers=None, max_speakers=None
Expand Down
11 changes: 4 additions & 7 deletions src/whisper_ctranslate2/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def iterate_subtitles():
subtitle: list[dict] = []
last = result["segments"][0]["words"][0]["start"]
for segment in result["segments"]:

if "speaker" in segment:
speaker = f"[{segment['speaker']}]: "
else:
Expand Down Expand Up @@ -106,13 +105,11 @@ def iterate_subtitles():
if len(subtitle) > 0:
yield subtitle


# if "speaker" in segment:
# segment_text = f"[{segment['speaker']}]: " + segment["text"].strip().replace("-->", "->")
# else:
# segment_text = segment["text"].strip().replace("-->", "->")
# if "speaker" in segment:
# segment_text = f"[{segment['speaker']}]: " + segment["text"].strip().replace("-->", "->")
# else:
# segment_text = segment["text"].strip().replace("-->", "->")


if (
len(result["segments"]) > 0
and "words" in result["segments"][0]
Expand Down

0 comments on commit c3ea968

Please sign in to comment.