Skip to content

Commit

Permalink
Add InteractiveTranslator
Browse files Browse the repository at this point in the history
  • Loading branch information
ddaspit committed Oct 24, 2023
1 parent b6a6ea1 commit 58343ff
Show file tree
Hide file tree
Showing 14 changed files with 1,098 additions and 16 deletions.
17 changes: 17 additions & 0 deletions machine/tokenization/tokenization_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Generator, Iterable, List

from ..annotations.range import Range


def split(s: str, ranges: Iterable[Range[int]]) -> List[str]:
return [s[range.start : range.end] for range in ranges]


def get_ranges(s: str, tokens: Iterable[str]) -> Generator[Range[int], None, None]:
start = 0
for token in tokens:
index = s.find(token, start)
if index == -1:
raise ValueError(f"The string does not contain the specified token: {token}.")
yield Range.create(index, index + len(token))
start = index + len(token)
4 changes: 2 additions & 2 deletions machine/translation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .ibm1_word_alignment_model import Ibm1WordAlignmentModel
from .ibm1_word_confidence_estimator import Ibm1WordConfidenceEstimator
from .ibm2_word_alignment_model import Ibm2WordAlignmentModel
from .interactive_translation_engine import InterativeTranslationEngine
from .interactive_translation_engine import InteractiveTranslationEngine
from .interactive_translation_model import InteractiveTranslationModel
from .null_trainer import NullTrainer
from .phrase import Phrase
Expand Down Expand Up @@ -40,7 +40,7 @@
"Ibm1WordConfidenceEstimator",
"Ibm2WordAlignmentModel",
"InteractiveTranslationModel",
"InterativeTranslationEngine",
"InteractiveTranslationEngine",
"MAX_SEGMENT_LENGTH",
"NullTrainer",
"Phrase",
Expand Down
59 changes: 59 additions & 0 deletions machine/translation/ecm_score_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from typing import List

from .edit_operation import EditOperation


class EcmScoreInfo:
def __init__(self) -> None:
self._scores: List[float] = []
self._operations: List[EditOperation] = []

@property
def scores(self) -> List[float]:
return self._scores

@property
def operations(self) -> List[EditOperation]:
return self._operations

def update_positions(self, prev_esi: EcmScoreInfo, positions: List[int]) -> None:
while len(self.scores) < len(prev_esi.scores):
self.scores.append(0.0)

while len(self.operations) < len(prev_esi.operations):
self.operations.append(EditOperation.NONE)

for i in range(len(positions)):
self.scores[positions[i]] = prev_esi.scores[positions[i]]
if len(prev_esi.operations) > i:
self.operations[positions[i]] = prev_esi.operations[positions[i]]

def remove_last(self) -> None:
if len(self.scores) > 1:
self.scores.pop()
if len(self.operations) > 1:
self.operations.pop()

def get_last_ins_prefix_word_from_esi(self) -> List[int]:
results = [0] * len(self.operations)

for j in range(len(self.operations) - 1, -1, -1):
if self.operations[j] == EditOperation.HIT:
results[j] = j - 1
elif self.operations[j] == EditOperation.INSERT:
tj = j
while tj >= 0 and self.operations[tj] == EditOperation.INSERT:
tj -= 1
if self.operations[tj] == EditOperation.HIT or self.operations[tj] == EditOperation.SUBSTITUTE:
tj -= 1
results[j] = tj
elif self.operations[j] == EditOperation.DELETE:
results[j] = j
elif self.operations[j] == EditOperation.SUBSTITUTE:
results[j] = j - 1
elif self.operations[j] == EditOperation.NONE:
results[j] = 0

return results
132 changes: 132 additions & 0 deletions machine/translation/edit_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from abc import ABC, abstractmethod
from typing import Generic, Iterable, List, Tuple, TypeVar

from .edit_operation import EditOperation

Seq = TypeVar("Seq")
Item = TypeVar("Item")


class EditDistance(ABC, Generic[Seq, Item]):
@abstractmethod
def _get_count(self, seq: Seq) -> int:
...

@abstractmethod
def _get_item(self, seq: Seq, index: int) -> Item:
...

@abstractmethod
def _get_hit_cost(self, x: Item, y: Item, is_complete: bool) -> float:
...

@abstractmethod
def _get_substitution_cost(self, x: Item, y: Item, is_complete: bool) -> float:
...

@abstractmethod
def _get_deletion_cost(self, x: Item) -> float:
...

@abstractmethod
def _get_insertion_cost(self, y: Item) -> float:
...

@abstractmethod
def _is_hit(self, x: Item, y: Item, is_complete: bool) -> bool:
...

def _init_dist_matrix(self, x: Seq, y: Seq) -> List[List[float]]:
x_count = self._get_count(x)
y_count = self._get_count(y)
dim = max(x_count, y_count)
dist_matrix = [[0.0 for _ in range(dim + 1)] for _ in range(dim + 1)]
return dist_matrix

def _compute_dist_matrix(
self, x: Seq, y: Seq, is_last_item_complete: bool, use_prefix_del_op: bool
) -> Tuple[float, List[List[float]]]:
dist_matrix = self._init_dist_matrix(x, y)

x_count = self._get_count(x)
y_count = self._get_count(y)
for i in range(x_count + 1):
for j in range(y_count + 1):
dist_matrix[i][j], _, _, _ = self._process_dist_matrix_cell(
x, y, dist_matrix, use_prefix_del_op, j != y_count or is_last_item_complete, i, j
)

return dist_matrix[x_count][y_count], dist_matrix

def _process_dist_matrix_cell(
self, x: Seq, y: Seq, dist_matrix: List[List[float]], use_prefix_del_op: bool, is_complete: bool, i: int, j: int
) -> Tuple[float, int, int, EditOperation]:
if i != 0 and j != 0:
x_item = self._get_item(x, i - 1)
y_item = self._get_item(y, j - 1)
if self._is_hit(x_item, y_item, is_complete):
subst_cost = self._get_hit_cost(x_item, y_item, is_complete)
op = EditOperation.HIT
else:
subst_cost = self._get_substitution_cost(x_item, y_item, is_complete)
op = EditOperation.SUBSTITUTE

cost = dist_matrix[i - 1][j - 1] + subst_cost
min = cost
i_pred = i - 1
j_pred = j - 1

del_cost = 0 if use_prefix_del_op and j == self._get_count(y) else self._get_deletion_cost(x_item)
cost = dist_matrix[i - 1][j] + del_cost
if cost < min:
min = cost
i_pred = i - 1
j_pred = j
op = EditOperation.PREFIX_DELETE if del_cost == 0 else EditOperation.DELETE

cost = dist_matrix[i][j - 1] + self._get_insertion_cost(y_item)
if cost < min:
min = cost
i_pred = i
j_pred = j - 1
op = EditOperation.INSERT

return (min, i_pred, j_pred, op)

if i == 0 and j == 0:
return (0.0, 0, 0, EditOperation.NONE)

if i == 0:
return (
dist_matrix[0][j - 1] + self._get_insertion_cost(self._get_item(y, j - 1)),
0,
j - 1,
EditOperation.INSERT,
)

return (
dist_matrix[i - 1][0] + self._get_deletion_cost(self._get_item(x, i - 1)),
i - 1,
0,
EditOperation.DELETE,
)

def _get_operations(
self,
x: Seq,
y: Seq,
dist_matrix: List[List[float]],
is_last_item_complete: bool,
use_prefix_del_op: bool,
i: int,
j: int,
) -> Iterable[EditOperation]:
y_count = self._get_count(y)
ops: List[EditOperation] = []
while i > 0 or j > 0:
_, i, j, op = self._process_dist_matrix_cell(
x, y, dist_matrix, use_prefix_del_op, j != y_count or is_last_item_complete, i, j
)
if op != EditOperation.PREFIX_DELETE:
ops.append(op)
return reversed(ops)
81 changes: 81 additions & 0 deletions machine/translation/error_correction_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from math import log
from typing import Sequence

from .ecm_score_info import EcmScoreInfo
from .edit_operation import EditOperation
from .segment_edit_distance import SegmentEditDistance
from .translation_result_builder import TranslationResultBuilder
from .translation_sources import TranslationSources


class ErrorCorrectionModel:
def __init__(self) -> None:
self._segment_edit_distance = SegmentEditDistance()
self.set_error_model_parameters(voc_size=128, hit_prob=0.8, ins_factor=1, subst_factor=1, del_factor=1)

def set_error_model_parameters(
self, voc_size: int, hit_prob: float, ins_factor: float, subst_factor: float, del_factor: float
) -> None:
if voc_size == 0:
e = (1 - hit_prob) / (ins_factor + subst_factor + del_factor)
else:
e = (1 - hit_prob) / ((ins_factor * voc_size) + (subst_factor * (voc_size - 1)) + del_factor)

ins_prob = e * ins_factor
subst_prob = e * subst_factor
del_prob = e * del_factor

self._segment_edit_distance.hit_cost = -log(hit_prob)
self._segment_edit_distance.insertion_cost = -log(ins_prob)
self._segment_edit_distance.substitution_cost = -log(subst_prob)
self._segment_edit_distance.deletion_cost = -log(del_prob)

def setup_initial_esi(self, initial_esi: EcmScoreInfo) -> None:
score = self._segment_edit_distance.compute([], [])
initial_esi.scores.clear()
initial_esi.scores.append(score)
initial_esi.operations.clear()

def setup_esi(self, esi: EcmScoreInfo, prev_esi: EcmScoreInfo, word: str) -> None:
score = self._segment_edit_distance.compute([word], [])
esi.scores.clear()
esi.scores.append(prev_esi.scores[0] + score)
esi.operations.clear()
esi.operations.append(EditOperation.NONE)

def extend_initial_esi(
self, initial_esi: EcmScoreInfo, prev_initial_esi: EcmScoreInfo, prefix_diff: Sequence[str]
) -> None:
self._segment_edit_distance.incr_compute_prefix_first_row(
initial_esi.scores, prev_initial_esi.scores, prefix_diff
)

def extend_esi(
self,
esi: EcmScoreInfo,
prev_esi: EcmScoreInfo,
word: str,
prefix_diff: Sequence[str],
is_last_word_complete: bool,
) -> None:
ops = self._segment_edit_distance.incr_compute_prefix(
esi.scores, prev_esi.scores, word, prefix_diff, is_last_word_complete
)
esi.operations.extend(ops)

def correct_prefix(
self,
builder: TranslationResultBuilder,
uncorrected_prefix_len: int,
prefix: Sequence[str],
is_last_word_complete: bool,
) -> int:
if uncorrected_prefix_len == 0:
for w in prefix:
builder.append_token(w, TranslationSources.PREFIX, -1)
return len(prefix)

_, word_ops, char_ops = self._segment_edit_distance.compute_prefix(
builder.target_tokens[uncorrected_prefix_len:], prefix, is_last_word_complete, use_prefix_del_op=False
)
return builder.correct_prefix(word_ops, char_ops, prefix, is_last_word_complete)
Loading

0 comments on commit 58343ff

Please sign in to comment.