Skip to content

Commit

Permalink
use jiwer instead of evaluate in benchmarks (#1159)
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoudAshraf97 authored Nov 20, 2024
1 parent 491852e commit 9c8ef76
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 32 deletions.
43 changes: 20 additions & 23 deletions benchmark/evaluate_yt_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from io import BytesIO

from datasets import load_dataset
from evaluate import load
from jiwer import wer
from pytubefix import YouTube
from torch.utils.data import DataLoader
from pytubefix.exceptions import VideoUnavailable
from tqdm import tqdm
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer

Expand All @@ -17,15 +17,19 @@
def url_to_audio(row):
buffer = BytesIO()
yt = YouTube(row["link"])
video = (
yt.streams.filter(only_audio=True, mime_type="audio/mp4")
.order_by("bitrate")
.desc()
.first()
)
video.stream_to_buffer(buffer)
buffer.seek(0)
row["audio"] = decode_audio(buffer)
try:
video = (
yt.streams.filter(only_audio=True, mime_type="audio/mp4")
.order_by("bitrate")
.desc()
.last()
)
video.stream_to_buffer(buffer)
buffer.seek(0)
row["audio"] = decode_audio(buffer)
except VideoUnavailable:
print(f'Failed to download: {row["link"]}')
row["audio"] = []
return row


Expand All @@ -39,27 +43,22 @@ def url_to_audio(row):
)
args = parser.parse_args()

# define the evaluation metric
wer_metric = load("wer")

with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
normalizer = EnglishTextNormalizer(json.load(f))

dataset = load_dataset("mobiuslabsgmbh/youtube-commons-asr-eval", streaming=True).map(
url_to_audio
)
dataset = iter(
DataLoader(dataset["test"], batch_size=1, prefetch_factor=4, num_workers=2)
)

model = WhisperModel("large-v3", device="cuda")
pipeline = BatchedInferencePipeline(model, device="cuda")


all_transcriptions = []
all_references = []
# iterate over the dataset and run inference
for i, row in tqdm(enumerate(dataset), desc="Evaluating..."):
for i, row in tqdm(enumerate(dataset["test"]), desc="Evaluating..."):
if not row["audio"]:
continue
result, info = pipeline.transcribe(
row["audio"][0],
batch_size=8,
Expand All @@ -77,7 +76,5 @@ def url_to_audio(row):
all_references = [normalizer(reference) for reference in all_references]

# compute the WER metric
wer = 100 * wer_metric.compute(
predictions=all_transcriptions, references=all_references
)
print("WER: %.3f" % wer)
word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
print("WER: %.3f" % word_error_rate)
1 change: 0 additions & 1 deletion benchmark/requirements.benchmark.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
transformers
jiwer
evaluate
datasets
memory_profiler
py3nvml
Expand Down
11 changes: 3 additions & 8 deletions benchmark/wer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os

from datasets import load_dataset
from evaluate import load
from jiwer import wer
from tqdm import tqdm
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer

Expand All @@ -25,9 +25,6 @@
# load the dataset with streaming mode
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)

# define the evaluation metric
wer_metric = load("wer")

with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
normalizer = EnglishTextNormalizer(json.load(f))

Expand Down Expand Up @@ -58,7 +55,5 @@ def inference(batch):
all_references = [normalizer(reference) for reference in all_references]

# compute the WER metric
wer = 100 * wer_metric.compute(
predictions=all_transcriptions, references=all_references
)
print("WER: %.3f" % wer)
word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
print("WER: %.3f" % word_error_rate)

0 comments on commit 9c8ef76

Please sign in to comment.