Skip to content

Commit

Permalink
refactor: Add filter-replace-utils for serializing and deserializing … (
Browse files Browse the repository at this point in the history
#154)

* refactor: Add filter-replace-utils for serializing and deserializing filter words replacements

* refactor: Add filter-replace-utils for serializing and deserializing filter words replacements

* refactor: Add filter-replace-utils for serializing and deserializing filter words replacements
  • Loading branch information
royshil authored Sep 9, 2024
1 parent e3c6951 commit ec56c74
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 72 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ target_sources(
src/translation/language_codes.cpp
src/translation/translation.cpp
src/translation/translation-utils.cpp
src/ui/filter-replace-utils.cpp
src/translation/translation-language-utils.cpp
src/ui/filter-replace-dialog.cpp)

Expand All @@ -147,6 +148,7 @@ if(ENABLE_TESTS)
src/whisper-utils/vad-processing.cpp
src/translation/language_codes.cpp
src/translation/translation.cpp
src/ui/filter-replace-utils.cpp
src/translation/translation-language-utils.cpp)

find_libav(${CMAKE_PROJECT_NAME}-tests)
Expand Down
2 changes: 1 addition & 1 deletion src/tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ The JSON config file can look e.g. like
"silero_vad_model_file": ".../obs-localvocal/data/models/silero-vad/silero_vad.onnx",
"ct2_model_folder": ".../obs-localvocal/models/m2m-100-418M",
"fix_utf8": true,
"suppress_sentences": "다음 영상에서 만나요!\nMBC 뉴스 김지경입니다\nMBC 뉴스 김성현입니다\n구독과 좋아요 눌러주세요!\n구독과 좋아요는 저에게 아주 큰\n다음 영상에서 만나요\n끝까지 시청해주셔서 감사합니다\n구독과 좋아요 부탁드립니다!\nMBC 뉴스 이준범입니다\nMBC 뉴스 문재인입니다\nMBC 뉴스 김지연입니다\nMBC 뉴스 안영백입니다.\nMBC 뉴스 이덕영입니다\nMBC 뉴스 김상현입니다\n구독과 좋아요 눌러주세요!\n구독과 좋아요 부탁드",
"filter_words_replace": "[{\"key\": \"다음 영상에서 만나요!\", \"value\":\"\"}]",
"overlap_ms": 150,
"log_level": "debug",
"whisper_sampling_method": 0
Expand Down
120 changes: 83 additions & 37 deletions src/tests/evaluate_output.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,80 @@
import Levenshtein
import argparse
from diff_match_patch import diff_match_patch
import unicodedata
import re
import difflib

def visualize_differences(ref_text, hyp_text):
dmp = diff_match_patch()
diffs = dmp.diff_main(hyp_text, ref_text, checklines=True)
html = dmp.diff_prettyHtml(diffs)
return html
def remove_accents(text):
return ''.join(c for c in unicodedata.normalize('NFD', text)
if unicodedata.category(c) != 'Mn')

def calculate_wer(ref_text, hyp_text):
ref_words = ref_text.split()
hyp_words = hyp_text.split()
def clean_text(text):
# Remove punctuation and special characters
text = re.sub(r'[^\w\s]', '', text)
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text

distance = Levenshtein.distance(ref_words, hyp_words)
wer = distance / len(ref_words)
def normalize_spanish_gender_postfixes(text):
# Normalize
text = re.sub(r'\b(\w+?)(a)\b', r'\1e', text)
return text

def tokenize(text, should_remove_accents=False, remove_punctuation=False):
# Convert to lowercase, remove accents, clean text, and split
if should_remove_accents:
text = remove_accents(text)
text = normalize_spanish_gender_postfixes(text)
if remove_punctuation:
text = clean_text(text)
tokens = text.lower().split()
return tokens

def calculate_wer(ref_text_tokens, hyp_text_tokens):
distance = Levenshtein.distance(ref_text_tokens, hyp_text_tokens, weights=(1, 1, 1))
wer = distance / max(len(ref_text_tokens), len(hyp_text_tokens))
return wer

def calculate_cer(ref_text, hyp_text):
def calculate_cer(ref_text_tokens, hyp_text_tokens):
# Join tokens into a single string
ref_text = ' '.join(ref_text_tokens)
hyp_text = ' '.join(hyp_text_tokens)
distance = Levenshtein.distance(ref_text, hyp_text)
cer = distance / len(ref_text)
return cer

def compare_tokens(ref_tokens, hyp_tokens):
comparisons = []
for ref_token, hyp_token in zip(ref_tokens, hyp_tokens):
distance = Levenshtein.distance(ref_token, hyp_token)
comparison = {'ref_token': ref_token, 'hyp_token': hyp_token, 'error_rate': distance / len(ref_token)}
comparisons.append(comparison)
return comparisons
def print_alignment(ref_words, hyp_words):
d = difflib.Differ()
diff = list(d.compare(ref_words, hyp_words))

print("\nToken-by-token alignment:")
print("Reference | Hypothesis")
print("-" * 30)

ref_token = hyp_token = ""
for token in diff:
if token.startswith(" "): # Common token
if ref_token or hyp_token:
print(f"{ref_token:<10} | {hyp_token:<10}")
ref_token = hyp_token = ""
print(f"{token[2:]:<10} | {token[2:]:<10}")
elif token.startswith("- "): # Token in reference, not in hypothesis
ref_token = token[2:]
elif token.startswith("+ "): # Token in hypothesis, not in reference
hyp_token = token[2:]
if ref_token:
print(f"{ref_token:<10} | {hyp_token:<10} (Substitution)")
ref_token = hyp_token = ""
else:
print(f"{"":10} | {hyp_token:<10} (Insertion)")
hyp_token = ""

# Print any remaining tokens
if ref_token:
print(f"{ref_token:<10} | {"":10} (Deletion)")
elif hyp_token:
print(f"{"":10} | {hyp_token:<10} (Insertion)")


def read_text_from_file(file_path, join_sentences=True):
with open(file_path, 'r', encoding='utf-8', errors='ignore') as file:
Expand All @@ -41,28 +88,27 @@ def read_text_from_file(file_path, join_sentences=True):
parser = argparse.ArgumentParser(description='Evaluate output')
parser.add_argument('ref_file_path', type=str, help='Path to the reference file')
parser.add_argument('hyp_file_path', type=str, help='Path to the hypothesis file')
parser.add_argument('--remove_accents', action='store_true', help='Remove accents from text')
parser.add_argument('--remove_punctuation', action='store_true', help='Remove punctuation from text')
parser.add_argument('--print_alignment', action='store_true', help='Print the alignment to the console')
parser.add_argument('--write_tokens', action='store_true', help='Write the tokens to a file')
args = parser.parse_args()

ref_text = read_text_from_file(args.ref_file_path)
hyp_text = read_text_from_file(args.hyp_file_path)
wer = calculate_wer(ref_text, hyp_text)
cer = calculate_cer(ref_text, hyp_text)
print("Word Error Rate (WER):", wer)
print("Character Error Rate (CER):", cer)

ref_text = '\n'.join(read_text_from_file(args.ref_file_path, join_sentences=False))
hyp_text = '\n'.join(read_text_from_file(args.hyp_file_path, join_sentences=False))
html_diff = visualize_differences(ref_text, hyp_text)
with open("diff_visualization.html", "w", encoding="utf-8") as file:
file.write(html_diff)
ref_text = read_text_from_file(args.ref_file_path, join_sentences=True)
hyp_text = read_text_from_file(args.hyp_file_path, join_sentences=True)
ref_tokens = tokenize(ref_text, should_remove_accents=args.remove_accents, remove_punctuation=args.remove_punctuation)
hyp_tokens = tokenize(hyp_text, should_remove_accents=args.remove_accents, remove_punctuation=args.remove_punctuation)

from Bio.Align import PairwiseAligner
if args.print_alignment:
print_alignment(ref_tokens, hyp_tokens)

aligner = PairwiseAligner()
if args.write_tokens:
with open("ref_tokens.txt", "w", encoding="utf-8") as file:
file.write('\n'.join(ref_tokens))
with open("hyp_tokens.txt", "w", encoding="utf-8") as file:
file.write('\n'.join(hyp_tokens))

alignments = aligner.align(ref_text, hyp_text)
wer = calculate_wer(ref_tokens, hyp_tokens)

# write the first alignment to a file
with open("alignment.txt", "w", encoding="utf-8") as file:
file.write(alignments[0].format())
print(f"\"{args.ref_file_path}\" WER: \"{wer:.2}\"")

7 changes: 7 additions & 0 deletions src/tests/localvocal-offline-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "whisper-utils/vad-processing.h"
#include "audio-file-utils.h"
#include "translation/language_codes.h"
#include "ui/filter-replace-utils.h"

#include <stdio.h>
#include <stdlib.h>
Expand Down Expand Up @@ -430,6 +431,12 @@ int wmain(int argc, wchar_t *argv[])
config["no_context"] ? "true" : "false");
gf->whisper_params.no_context = config["no_context"];
}
if (config.contains("filter_words_replace")) {
obs_log(LOG_INFO, "Setting filter_words_replace to %s",
config["filter_words_replace"]);
gf->filter_words_replace = deserialize_filter_words_replace(
config["filter_words_replace"]);
}
// set log level
if (logLevelStr == "debug") {
gf->log_level = LOG_DEBUG;
Expand Down
1 change: 1 addition & 0 deletions src/transcription-filter-properties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "model-utils/model-downloader-types.h"
#include "translation/language_codes.h"
#include "ui/filter-replace-dialog.h"
#include "ui/filter-replace-utils.h"

#include <string>
#include <vector>
Expand Down
1 change: 1 addition & 0 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "translation/translation.h"
#include "translation/translation-includes.h"
#include "ui/filter-replace-dialog.h"
#include "ui/filter-replace-utils.h"

void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_source)
{
Expand Down
29 changes: 0 additions & 29 deletions src/ui/filter-replace-dialog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,32 +73,3 @@ void FilterReplaceDialog::editFilter(QTableWidgetItem *item)
// use the row number to update the filter_words_replace map
ctx->filter_words_replace[item->row()] = std::make_tuple(key, value);
}

std::string serialize_filter_words_replace(
const std::vector<std::tuple<std::string, std::string>> &filter_words_replace)
{
if (filter_words_replace.empty()) {
return "[]";
}
// use JSON to serialize the filter_words_replace map
nlohmann::json j;
for (const auto &entry : filter_words_replace) {
j.push_back({{"key", std::get<0>(entry)}, {"value", std::get<1>(entry)}});
}
return j.dump();
}

std::vector<std::tuple<std::string, std::string>>
deserialize_filter_words_replace(const std::string &filter_words_replace_str)
{
if (filter_words_replace_str.empty()) {
return {};
}
// use JSON to deserialize the filter_words_replace map
std::vector<std::tuple<std::string, std::string>> filter_words_replace;
nlohmann::json j = nlohmann::json::parse(filter_words_replace_str);
for (const auto &entry : j) {
filter_words_replace.push_back(std::make_tuple(entry["key"], entry["value"]));
}
return filter_words_replace;
}
5 changes: 0 additions & 5 deletions src/ui/filter-replace-dialog.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,4 @@ private slots:
void editFilter(QTableWidgetItem *item);
};

std::string serialize_filter_words_replace(
const std::vector<std::tuple<std::string, std::string>> &filter_words_replace);
std::vector<std::tuple<std::string, std::string>>
deserialize_filter_words_replace(const std::string &filter_words_replace_str);

#endif // FILTERREPLACEDIALOG_H
32 changes: 32 additions & 0 deletions src/ui/filter-replace-utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "filter-replace-utils.h"

#include <nlohmann/json.hpp>

std::string serialize_filter_words_replace(
const std::vector<std::tuple<std::string, std::string>> &filter_words_replace)
{
if (filter_words_replace.empty()) {
return "[]";
}
// use JSON to serialize the filter_words_replace map
nlohmann::json j;
for (const auto &entry : filter_words_replace) {
j.push_back({{"key", std::get<0>(entry)}, {"value", std::get<1>(entry)}});
}
return j.dump();
}

std::vector<std::tuple<std::string, std::string>>
deserialize_filter_words_replace(const std::string &filter_words_replace_str)
{
if (filter_words_replace_str.empty()) {
return {};
}
// use JSON to deserialize the filter_words_replace map
std::vector<std::tuple<std::string, std::string>> filter_words_replace;
nlohmann::json j = nlohmann::json::parse(filter_words_replace_str);
for (const auto &entry : j) {
filter_words_replace.push_back(std::make_tuple(entry["key"], entry["value"]));
}
return filter_words_replace;
}
12 changes: 12 additions & 0 deletions src/ui/filter-replace-utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef FILTER_REPLACE_UTILS_H
#define FILTER_REPLACE_UTILS_H

#include <string>
#include <vector>

std::string serialize_filter_words_replace(
const std::vector<std::tuple<std::string, std::string>> &filter_words_replace);
std::vector<std::tuple<std::string, std::string>>
deserialize_filter_words_replace(const std::string &filter_words_replace_str);

#endif /* FILTER_REPLACE_UTILS_H */

0 comments on commit ec56c74

Please sign in to comment.