Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci

Signed-off-by: Sasha Meister <[email protected]>
  • Loading branch information
pre-commit-ci[bot] authored and ssh-meister committed Oct 2, 2023
1 parent 5b7510e commit 1f168f9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
11 changes: 7 additions & 4 deletions examples/asr/rate_punctuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argparse import ArgumentParser
import json
import os
from argparse import ArgumentParser

import pandas as pd
from tabulate import tabulate

Expand Down Expand Up @@ -89,7 +90,7 @@ def read_manifest(input_manifest_path: str) -> list[dict]:
Returns:
samples - list of dict
'''

assert os.path.exists(input_manifest_path), f"Input manifest file is not found: {input_manifest_path}"

with open(input_manifest_path, "r") as manifest:
Expand All @@ -99,8 +100,10 @@ def read_manifest(input_manifest_path: str) -> list[dict]:


def write_manifest(output_manifest_path: str, samples: list[dict]) -> None:

assert os.path.exists(os.path.dirname(output_manifest_path)), f"Directory to save output manifest path does not exists: {output_manifest_path}"

assert os.path.exists(
os.path.dirname(output_manifest_path)
), f"Directory to save output manifest path does not exists: {output_manifest_path}"

'''
Writes samples to .json file.
Expand Down
34 changes: 18 additions & 16 deletions nemo/collections/common/metrics/per.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
from collections import namedtuple
from tqdm import tqdm

def per(references: list[str],
hypotheses: list[str],
punctuation_marks: list[str],
punctuation_mask: str = "[PUNCT]",
) -> None:


def per(
references: list[str], hypotheses: list[str], punctuation_marks: list[str], punctuation_mask: str = "[PUNCT]",
) -> None:

"""
Computes Punctuation Error Rate
Expand All @@ -35,16 +34,19 @@ def per(references: list[str],
Return:
per (float) - Punctuation Error Rate
"""

per_data_obj = PERData(references=references,
hypotheses=hypotheses,
punctuation_marks=punctuation_marks,
punctuation_mask=punctuation_mask)


per_data_obj = PERData(
references=references,
hypotheses=hypotheses,
punctuation_marks=punctuation_marks,
punctuation_mask=punctuation_mask,
)

per_data_obj.compute()

return per_data_obj.per


class PER:
"""
Class for computation puncutation-related absolute amounts of operations and thier rates
Expand Down Expand Up @@ -121,9 +123,9 @@ class PER:
"""

def __init__(self, punctuation_marks: list[str], punctuation_mask: str = "[PUNCT]") -> None:

assert len(punctuation_marks) != 0, f"List of punctuation marks is empty"

self.punctuation_marks = punctuation_marks
self.punctuation_mask = punctuation_mask

Expand Down Expand Up @@ -363,7 +365,7 @@ class PERData:
per_data_obj.per - float, total Punctuation Error Rate between provided pairs of
references and hypotheses.
"""

def __init__(
self,
references: list[str],
Expand Down

0 comments on commit 1f168f9

Please sign in to comment.