-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,098 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from typing import Generator, Iterable, List | ||
|
||
from ..annotations.range import Range | ||
|
||
|
||
def split(s: str, ranges: Iterable[Range[int]]) -> List[str]: | ||
return [s[range.start : range.end] for range in ranges] | ||
|
||
|
||
def get_ranges(s: str, tokens: Iterable[str]) -> Generator[Range[int], None, None]: | ||
start = 0 | ||
for token in tokens: | ||
index = s.find(token, start) | ||
if index == -1: | ||
raise ValueError(f"The string does not contain the specified token: {token}.") | ||
yield Range.create(index, index + len(token)) | ||
start = index + len(token) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from __future__ import annotations | ||
|
||
from typing import List | ||
|
||
from .edit_operation import EditOperation | ||
|
||
|
||
class EcmScoreInfo: | ||
def __init__(self) -> None: | ||
self._scores: List[float] = [] | ||
self._operations: List[EditOperation] = [] | ||
|
||
@property | ||
def scores(self) -> List[float]: | ||
return self._scores | ||
|
||
@property | ||
def operations(self) -> List[EditOperation]: | ||
return self._operations | ||
|
||
def update_positions(self, prev_esi: EcmScoreInfo, positions: List[int]) -> None: | ||
while len(self.scores) < len(prev_esi.scores): | ||
self.scores.append(0.0) | ||
|
||
while len(self.operations) < len(prev_esi.operations): | ||
self.operations.append(EditOperation.NONE) | ||
|
||
for i in range(len(positions)): | ||
self.scores[positions[i]] = prev_esi.scores[positions[i]] | ||
if len(prev_esi.operations) > i: | ||
self.operations[positions[i]] = prev_esi.operations[positions[i]] | ||
|
||
def remove_last(self) -> None: | ||
if len(self.scores) > 1: | ||
self.scores.pop() | ||
if len(self.operations) > 1: | ||
self.operations.pop() | ||
|
||
def get_last_ins_prefix_word_from_esi(self) -> List[int]: | ||
results = [0] * len(self.operations) | ||
|
||
for j in range(len(self.operations) - 1, -1, -1): | ||
if self.operations[j] == EditOperation.HIT: | ||
results[j] = j - 1 | ||
elif self.operations[j] == EditOperation.INSERT: | ||
tj = j | ||
while tj >= 0 and self.operations[tj] == EditOperation.INSERT: | ||
tj -= 1 | ||
if self.operations[tj] == EditOperation.HIT or self.operations[tj] == EditOperation.SUBSTITUTE: | ||
tj -= 1 | ||
results[j] = tj | ||
elif self.operations[j] == EditOperation.DELETE: | ||
results[j] = j | ||
elif self.operations[j] == EditOperation.SUBSTITUTE: | ||
results[j] = j - 1 | ||
elif self.operations[j] == EditOperation.NONE: | ||
results[j] = 0 | ||
|
||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Generic, Iterable, List, Tuple, TypeVar | ||
|
||
from .edit_operation import EditOperation | ||
|
||
Seq = TypeVar("Seq") | ||
Item = TypeVar("Item") | ||
|
||
|
||
class EditDistance(ABC, Generic[Seq, Item]): | ||
@abstractmethod | ||
def _get_count(self, seq: Seq) -> int: | ||
... | ||
|
||
@abstractmethod | ||
def _get_item(self, seq: Seq, index: int) -> Item: | ||
... | ||
|
||
@abstractmethod | ||
def _get_hit_cost(self, x: Item, y: Item, is_complete: bool) -> float: | ||
... | ||
|
||
@abstractmethod | ||
def _get_substitution_cost(self, x: Item, y: Item, is_complete: bool) -> float: | ||
... | ||
|
||
@abstractmethod | ||
def _get_deletion_cost(self, x: Item) -> float: | ||
... | ||
|
||
@abstractmethod | ||
def _get_insertion_cost(self, y: Item) -> float: | ||
... | ||
|
||
@abstractmethod | ||
def _is_hit(self, x: Item, y: Item, is_complete: bool) -> bool: | ||
... | ||
|
||
def _init_dist_matrix(self, x: Seq, y: Seq) -> List[List[float]]: | ||
x_count = self._get_count(x) | ||
y_count = self._get_count(y) | ||
dim = max(x_count, y_count) | ||
dist_matrix = [[0.0 for _ in range(dim + 1)] for _ in range(dim + 1)] | ||
return dist_matrix | ||
|
||
def _compute_dist_matrix( | ||
self, x: Seq, y: Seq, is_last_item_complete: bool, use_prefix_del_op: bool | ||
) -> Tuple[float, List[List[float]]]: | ||
dist_matrix = self._init_dist_matrix(x, y) | ||
|
||
x_count = self._get_count(x) | ||
y_count = self._get_count(y) | ||
for i in range(x_count + 1): | ||
for j in range(y_count + 1): | ||
dist_matrix[i][j], _, _, _ = self._process_dist_matrix_cell( | ||
x, y, dist_matrix, use_prefix_del_op, j != y_count or is_last_item_complete, i, j | ||
) | ||
|
||
return dist_matrix[x_count][y_count], dist_matrix | ||
|
||
def _process_dist_matrix_cell( | ||
self, x: Seq, y: Seq, dist_matrix: List[List[float]], use_prefix_del_op: bool, is_complete: bool, i: int, j: int | ||
) -> Tuple[float, int, int, EditOperation]: | ||
if i != 0 and j != 0: | ||
x_item = self._get_item(x, i - 1) | ||
y_item = self._get_item(y, j - 1) | ||
if self._is_hit(x_item, y_item, is_complete): | ||
subst_cost = self._get_hit_cost(x_item, y_item, is_complete) | ||
op = EditOperation.HIT | ||
else: | ||
subst_cost = self._get_substitution_cost(x_item, y_item, is_complete) | ||
op = EditOperation.SUBSTITUTE | ||
|
||
cost = dist_matrix[i - 1][j - 1] + subst_cost | ||
min = cost | ||
i_pred = i - 1 | ||
j_pred = j - 1 | ||
|
||
del_cost = 0 if use_prefix_del_op and j == self._get_count(y) else self._get_deletion_cost(x_item) | ||
cost = dist_matrix[i - 1][j] + del_cost | ||
if cost < min: | ||
min = cost | ||
i_pred = i - 1 | ||
j_pred = j | ||
op = EditOperation.PREFIX_DELETE if del_cost == 0 else EditOperation.DELETE | ||
|
||
cost = dist_matrix[i][j - 1] + self._get_insertion_cost(y_item) | ||
if cost < min: | ||
min = cost | ||
i_pred = i | ||
j_pred = j - 1 | ||
op = EditOperation.INSERT | ||
|
||
return (min, i_pred, j_pred, op) | ||
|
||
if i == 0 and j == 0: | ||
return (0.0, 0, 0, EditOperation.NONE) | ||
|
||
if i == 0: | ||
return ( | ||
dist_matrix[0][j - 1] + self._get_insertion_cost(self._get_item(y, j - 1)), | ||
0, | ||
j - 1, | ||
EditOperation.INSERT, | ||
) | ||
|
||
return ( | ||
dist_matrix[i - 1][0] + self._get_deletion_cost(self._get_item(x, i - 1)), | ||
i - 1, | ||
0, | ||
EditOperation.DELETE, | ||
) | ||
|
||
def _get_operations( | ||
self, | ||
x: Seq, | ||
y: Seq, | ||
dist_matrix: List[List[float]], | ||
is_last_item_complete: bool, | ||
use_prefix_del_op: bool, | ||
i: int, | ||
j: int, | ||
) -> Iterable[EditOperation]: | ||
y_count = self._get_count(y) | ||
ops: List[EditOperation] = [] | ||
while i > 0 or j > 0: | ||
_, i, j, op = self._process_dist_matrix_cell( | ||
x, y, dist_matrix, use_prefix_del_op, j != y_count or is_last_item_complete, i, j | ||
) | ||
if op != EditOperation.PREFIX_DELETE: | ||
ops.append(op) | ||
return reversed(ops) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from math import log | ||
from typing import Sequence | ||
|
||
from .ecm_score_info import EcmScoreInfo | ||
from .edit_operation import EditOperation | ||
from .segment_edit_distance import SegmentEditDistance | ||
from .translation_result_builder import TranslationResultBuilder | ||
from .translation_sources import TranslationSources | ||
|
||
|
||
class ErrorCorrectionModel: | ||
def __init__(self) -> None: | ||
self._segment_edit_distance = SegmentEditDistance() | ||
self.set_error_model_parameters(voc_size=128, hit_prob=0.8, ins_factor=1, subst_factor=1, del_factor=1) | ||
|
||
def set_error_model_parameters( | ||
self, voc_size: int, hit_prob: float, ins_factor: float, subst_factor: float, del_factor: float | ||
) -> None: | ||
if voc_size == 0: | ||
e = (1 - hit_prob) / (ins_factor + subst_factor + del_factor) | ||
else: | ||
e = (1 - hit_prob) / ((ins_factor * voc_size) + (subst_factor * (voc_size - 1)) + del_factor) | ||
|
||
ins_prob = e * ins_factor | ||
subst_prob = e * subst_factor | ||
del_prob = e * del_factor | ||
|
||
self._segment_edit_distance.hit_cost = -log(hit_prob) | ||
self._segment_edit_distance.insertion_cost = -log(ins_prob) | ||
self._segment_edit_distance.substitution_cost = -log(subst_prob) | ||
self._segment_edit_distance.deletion_cost = -log(del_prob) | ||
|
||
def setup_initial_esi(self, initial_esi: EcmScoreInfo) -> None: | ||
score = self._segment_edit_distance.compute([], []) | ||
initial_esi.scores.clear() | ||
initial_esi.scores.append(score) | ||
initial_esi.operations.clear() | ||
|
||
def setup_esi(self, esi: EcmScoreInfo, prev_esi: EcmScoreInfo, word: str) -> None: | ||
score = self._segment_edit_distance.compute([word], []) | ||
esi.scores.clear() | ||
esi.scores.append(prev_esi.scores[0] + score) | ||
esi.operations.clear() | ||
esi.operations.append(EditOperation.NONE) | ||
|
||
def extend_initial_esi( | ||
self, initial_esi: EcmScoreInfo, prev_initial_esi: EcmScoreInfo, prefix_diff: Sequence[str] | ||
) -> None: | ||
self._segment_edit_distance.incr_compute_prefix_first_row( | ||
initial_esi.scores, prev_initial_esi.scores, prefix_diff | ||
) | ||
|
||
def extend_esi( | ||
self, | ||
esi: EcmScoreInfo, | ||
prev_esi: EcmScoreInfo, | ||
word: str, | ||
prefix_diff: Sequence[str], | ||
is_last_word_complete: bool, | ||
) -> None: | ||
ops = self._segment_edit_distance.incr_compute_prefix( | ||
esi.scores, prev_esi.scores, word, prefix_diff, is_last_word_complete | ||
) | ||
esi.operations.extend(ops) | ||
|
||
def correct_prefix( | ||
self, | ||
builder: TranslationResultBuilder, | ||
uncorrected_prefix_len: int, | ||
prefix: Sequence[str], | ||
is_last_word_complete: bool, | ||
) -> int: | ||
if uncorrected_prefix_len == 0: | ||
for w in prefix: | ||
builder.append_token(w, TranslationSources.PREFIX, -1) | ||
return len(prefix) | ||
|
||
_, word_ops, char_ops = self._segment_edit_distance.compute_prefix( | ||
builder.target_tokens[uncorrected_prefix_len:], prefix, is_last_word_complete, use_prefix_del_op=False | ||
) | ||
return builder.correct_prefix(word_ops, char_ops, prefix, is_last_word_complete) |
Oops, something went wrong.