Skip to content

Commit

Permalink
#16 done
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 29, 2024
1 parent ea237de commit 34b9233
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
3 changes: 2 additions & 1 deletion bulk_translate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(self, translate_spans, translation_model, **custom_args_dict):
TextSpansParser(src_func=lambda text: [text] if isinstance(text, str) else text),
MLTextTranslatorPipelineItem(
batch_translate_model=translation_model.get_func(**custom_args_dict),
do_translate_entity=translate_spans),
do_translate_entity=translate_spans,
is_span_func=lambda term: isinstance(term, Span)),
MapPipelineItem(map_func=lambda term:
([term.DisplayValue] + ([term.content] if term.content is not None else []))
if isinstance(term, Span) else term),
Expand Down
10 changes: 6 additions & 4 deletions bulk_translate/src/pipeline/translator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from bulk_translate.src.pipeline.context import PipelineContext
from bulk_translate.src.pipeline.items.base import BasePipelineItem
from bulk_translate.src.span import Span


class MLTextTranslatorPipelineItem(BasePipelineItem):
""" Machine learning based translator pipeline item.
"""

def __init__(self, batch_translate_model, do_translate_entity=True, **kwargs):
def __init__(self, batch_translate_model, is_span_func, do_translate_entity=True, **kwargs):
""" Model, which is based on translation of the text,
represented as a list of words.
"""
super(MLTextTranslatorPipelineItem, self).__init__(**kwargs)
self.__do_translate_entity = do_translate_entity
self.__translate = batch_translate_model
self.__is_span = is_span_func

def fast_most_accurate_approach(self, input_data, entity_placeholder_template="<entityTag={}/>"):
""" This approach assumes that the translation won't corrupt the original
Expand All @@ -31,7 +32,7 @@ def __optionally_register(prts):
for part in input_data:
if isinstance(part, str) and part.strip():
parts_to_join.append(part)
elif isinstance(part, Span):
elif self.__is_span(part):
entity_index = len(origin_entities)
parts_to_join.append(entity_placeholder_template.format(entity_index))
# Register entities information for further restoration.
Expand Down Expand Up @@ -93,7 +94,7 @@ def __optionally_register(prts):
for _, part in enumerate(input_data):
if isinstance(part, str) and part.strip():
parts_to_join.append(part)
elif isinstance(part, Span):
elif self.__is_span(part):
# Register first the prior parts were merged.
__optionally_register(parts_to_join)
# Register entities information for further restoration.
Expand All @@ -115,6 +116,7 @@ def __optionally_register(prts):
return translated_parts

def apply_core(self, input_data, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
assert(isinstance(input_data, list))

fast_accurate = self.fast_most_accurate_approach(input_data)
Expand Down

0 comments on commit 34b9233

Please sign in to comment.