Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly handle corrupted SMT models #68

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions machine/translation/thot/thot_smt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def __init__(
else:
self._config_filename = Path(config)
parameters = ThotSmtParameters.load(config)
if not Path(parameters.translation_model_filename_prefix + ".ttable").is_file():
raise FileNotFoundError("The translation model could not be found.")
if not Path(parameters.language_model_filename_prefix).is_file():
raise FileNotFoundError("The language model could not be found.")
self._parameters = parameters
self.source_tokenizer = source_tokenizer
self.target_tokenizer = target_tokenizer
Expand Down
6 changes: 4 additions & 2 deletions machine/translation/thot/thot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def load_smt_model(word_alignment_model_type: ThotWordAlignmentModelType, parame
model_type = ta.AlignmentModelType.IBM4

model = tt.SmtModel(model_type)
model.load_translation_model(parameters.translation_model_filename_prefix)
model.load_language_model(parameters.language_model_filename_prefix)
if not model.load_translation_model(parameters.translation_model_filename_prefix):
raise RuntimeError("Unable to load translation model.")
if not model.load_language_model(parameters.language_model_filename_prefix):
raise RuntimeError("Unable to load language model.")
model.non_monotonicity = parameters.model_non_monotonicity
model.w = parameters.model_w
model.a = parameters.model_a
Expand Down
7 changes: 4 additions & 3 deletions machine/translation/thot/thot_word_alignment_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,16 @@ def load(self, prefix_filename: StrPath) -> None:
prefix_filename = Path(prefix_filename)
if not (prefix_filename.parent / (prefix_filename.name + ".src")).is_file():
raise FileNotFoundError("The word alignment model configuration could not be found.")
self._prefix_filename = prefix_filename
self._model.clear()
self._model.load(str(prefix_filename))
if not self._model.load(str(prefix_filename)):
raise RuntimeError("Unable to load word alignment model.")
self._prefix_filename = prefix_filename

def create_new(self, prefix_filename: StrPath) -> None:
if self._owned:
raise RuntimeError("The word alignment model is owned by an SMT model.")
self._prefix_filename = Path(prefix_filename)
self._model.clear()
self._prefix_filename = Path(prefix_filename)

def save(self) -> None:
if self._prefix_filename is not None:
Expand Down
46 changes: 23 additions & 23 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ charset-normalizer = "^2.1.1"

### extras
sentencepiece = "^0.1.95"
sil-thot = "^3.4.0"
sil-thot = "^3.4.2"
# huggingface extras
transformers = "^4.34.0"
datasets = "^2.4.0"
Expand Down
4 changes: 2 additions & 2 deletions tests/corpora/test_text_file_text_corpus.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest
from pytest import raises
from testutils.corpora_test_helpers import TEXT_TEST_PROJECT_PATH

from machine.corpora import TextFileTextCorpus


def test_does_not_exist() -> None:
with pytest.raises(FileNotFoundError):
with raises(FileNotFoundError):
TextFileTextCorpus("does-not-exist.txt")


Expand Down
4 changes: 2 additions & 2 deletions tests/jobs/test_nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from io import StringIO
from typing import Iterator

import pytest
from decoy import Decoy, matchers
from pytest import raises

from machine.annotations import Range
from machine.corpora import DictionaryTextCorpus
Expand All @@ -27,7 +27,7 @@ def test_run(decoy: Decoy) -> None:
def test_cancel(decoy: Decoy) -> None:
env = _TestEnvironment(decoy)
checker = _CancellationChecker(3)
with pytest.raises(CanceledError):
with raises(CanceledError):
env.job.run(check_canceled=checker.check_canceled)

assert env.target_pretranslations == ""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pytest import approx
from pathlib import Path
from tempfile import TemporaryDirectory

from pytest import approx, raises
from testutils.thot_test_helpers import TOY_CORPUS_FAST_ALIGN_PATH

from machine.translation import WordAlignmentMatrix
Expand Down Expand Up @@ -107,3 +110,11 @@ def test_get_avg_translation_score_symmetrized() -> None:
matrix = model.align(source_segment, target_segment)
score = model.get_avg_translation_score(source_segment, target_segment, matrix)
assert score == approx(0.36, abs=0.01)


def test_constructor_model_corrupted() -> None:
with TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
(temp_dir_path / "src_trg_invswm.src").write_text("corrupted", encoding="utf-8")
with raises(RuntimeError):
ThotFastAlignWordAlignmentModel(temp_dir_path / "src_trg_invswm")
35 changes: 34 additions & 1 deletion tests/translation/thot/test_thot_smt_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from pathlib import Path
from tempfile import TemporaryDirectory

from pytest import raises
from testutils.thot_test_helpers import TOY_CORPUS_FAST_ALIGN_CONFIG_FILENAME, TOY_CORPUS_HMM_CONFIG_FILENAME

from machine.translation.thot import ThotSmtModel, ThotWordAlignmentModelType
from machine.translation.thot import ThotSmtModel, ThotSmtParameters, ThotWordAlignmentModelType


def test_translate_target_segment_hmm() -> None:
Expand Down Expand Up @@ -95,6 +99,35 @@ def test_get_word_graph_empty_segment_fast_align() -> None:
assert word_graph.is_empty


def test_constructor_model_not_found() -> None:
with raises(FileNotFoundError):
ThotSmtModel(
ThotWordAlignmentModelType.HMM,
ThotSmtParameters(
translation_model_filename_prefix="does-not-exist", language_model_filename_prefix="does-not-exist"
),
)


def test_constructor_model_corrupted() -> None:
with TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
tm_dir_path = temp_dir_path / "tm"
tm_dir_path.mkdir()
(tm_dir_path / "src_trg.ttable").write_text("corrupted", encoding="utf-8")
lm_dir_path = temp_dir_path / "lm"
lm_dir_path.mkdir()
(lm_dir_path / "trg.lm").write_text("corrupted", encoding="utf-8")
with raises(RuntimeError):
ThotSmtModel(
ThotWordAlignmentModelType.HMM,
ThotSmtParameters(
translation_model_filename_prefix=str(tm_dir_path / "src_trg"),
language_model_filename_prefix=str(lm_dir_path / "trg.lm"),
),
)


def _create_hmm_model() -> ThotSmtModel:
return ThotSmtModel(ThotWordAlignmentModelType.HMM, TOY_CORPUS_HMM_CONFIG_FILENAME)

Expand Down