From f3be44db1f0307c9b92158b00ddf5df7016f4fbd Mon Sep 17 00:00:00 2001 From: Nicolay Rusnachenko Date: Sun, 22 Dec 2024 15:45:46 +0000 Subject: [PATCH] #10 and #12 --- README.md | 3 +- bulk_translate/api.py | 14 ++---- bulk_translate/src/pipeline/context.py | 34 +++++++++++++ bulk_translate/src/pipeline/items/__init__.py | 0 bulk_translate/src/pipeline/items/base.py | 49 +++++++++++++++++++ bulk_translate/src/pipeline/items/map.py | 12 +++++ bulk_translate/src/pipeline/launcher.py | 28 +++++++++++ bulk_translate/src/pipeline/translator.py | 5 +- bulk_translate/src/pipeline/utils.py | 32 ++++++++++++ bulk_translate/src/spans_parser.py | 3 +- dependencies.txt | 1 - 11 files changed, 164 insertions(+), 17 deletions(-) create mode 100644 bulk_translate/src/pipeline/context.py create mode 100644 bulk_translate/src/pipeline/items/__init__.py create mode 100644 bulk_translate/src/pipeline/items/base.py create mode 100644 bulk_translate/src/pipeline/items/map.py create mode 100644 bulk_translate/src/pipeline/launcher.py create mode 100644 bulk_translate/src/pipeline/utils.py diff --git a/README.md b/README.md index 0e6dcb4..60d2d8e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ # bulk-translate 0.25.0 ![](https://img.shields.io/badge/Python-3.9-brightgreen.svg) -![](https://img.shields.io/badge/AREkit-0.25.1-orange.svg) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nicolay-r/bulk-translate/blob/master/bulk_translate_demo.ipynb) [![PyPI downloads](https://img.shields.io/pypi/dm/bulk-translate.svg)](https://pypistats.org/packages/bulk-translate) @@ -78,7 +77,7 @@ python -m bulk_translate.translate \ ## Powered by -* AREkit [[github]](https://github.com/nicolay-r/AREkit) +The pipeline construction components were taken from AREkit [[github]](https://github.com/nicolay-r/AREkit)

diff --git a/bulk_translate/api.py b/bulk_translate/api.py index a42182f..37a73ac 100644 --- a/bulk_translate/api.py +++ b/bulk_translate/api.py @@ -1,10 +1,9 @@ -from arekit.common.pipeline.batching import BatchingPipelineLauncher -from arekit.common.pipeline.context import PipelineContext -from arekit.common.pipeline.items.base import BasePipelineItem -from arekit.common.pipeline.items.map import MapPipelineItem -from arekit.common.pipeline.utils import BatchIterator - +from bulk_translate.src.pipeline.context import PipelineContext +from bulk_translate.src.pipeline.items.base import BasePipelineItem +from bulk_translate.src.pipeline.items.map import MapPipelineItem +from bulk_translate.src.pipeline.launcher import BatchingPipelineLauncher from bulk_translate.src.pipeline.translator import MLTextTranslatorPipelineItem +from bulk_translate.src.pipeline.utils import BatchIterator from bulk_translate.src.service_prompt import DataService from bulk_translate.src.span import Span from bulk_translate.src.spans_parser import TextSpansParser @@ -41,8 +40,5 @@ def iter_translated_data(self, data_dict_it, prompt, batch_size=1): # Target. d = ctx._d - for m in ['parent_ctx']: - del d[m] - for batch_ind in range(len(d["input"])): yield {k: v[batch_ind] for k, v in d.items()} diff --git a/bulk_translate/src/pipeline/context.py b/bulk_translate/src/pipeline/context.py new file mode 100644 index 0000000..39bbdb6 --- /dev/null +++ b/bulk_translate/src/pipeline/context.py @@ -0,0 +1,34 @@ +class PipelineContext(object): + """ Context of parameters utilized in pipeline + """ + + def __init__(self, d): + assert(isinstance(d, dict)) + self._d = d + + def __provide(self, param): + if param not in self._d: + raise Exception(f"Key `{param}` is not in dictionary.\n{self._d}") + return self._d[param] + + # region public + + def provide(self, param): + return self.__provide(param) + + def provide_or_none(self, param): + return self.__provide(param) if param in self._d else None + + def update(self, param, value, is_new_key=False): + if is_new_key and param in self._d: + raise Exception(f"Key `{param}` is already presented in pipeline context dictionary.") + self._d[param] = value + + # endregion + + # region base methods + + def __contains__(self, item): + return item in self._d + + # endregion diff --git a/bulk_translate/src/pipeline/items/__init__.py b/bulk_translate/src/pipeline/items/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bulk_translate/src/pipeline/items/base.py b/bulk_translate/src/pipeline/items/base.py new file mode 100644 index 0000000..ce4580d --- /dev/null +++ b/bulk_translate/src/pipeline/items/base.py @@ -0,0 +1,49 @@ +from bulk_translate.src.pipeline.context import PipelineContext + + +class BasePipelineItem(object): + """ Single pipeline item that might be instantiated and embedded into pipeline. + """ + + def __init__(self, src_key="result", result_key="result", src_func=None): + assert(isinstance(src_key, str) or src_key is None) + assert(callable(src_func) or src_func is None) + self.__src_key = src_key + self._src_func = src_func + self.__result_key = result_key + + @property + def ResultKey(self): + return self.__result_key + + @property + def SupportBatching(self): + """ By default pipeline item is not designed for batching. + """ + return False + + def get_source(self, src_ctx, call_func=True, force_key=None): + """ Extract input element for processing. + """ + assert(isinstance(src_ctx, PipelineContext)) + + # If there is no information about key, then we consider absence of the source. + if self.__src_key is None: + return None + + # Extracting actual source. + src_data = src_ctx.provide(self.__src_key if force_key is None else force_key) + if self._src_func is not None and call_func: + src_data = self._src_func(src_data) + + return src_data + + def apply_core(self, input_data, pipeline_ctx): + """By default we do nothing.""" + pass + + def apply(self, input_data, pipeline_ctx=None): + """ Performs input processing an update it for a further pipeline items. + """ + output_data = self.apply_core(input_data=input_data, pipeline_ctx=pipeline_ctx) + return output_data if output_data is not None else input_data diff --git a/bulk_translate/src/pipeline/items/map.py b/bulk_translate/src/pipeline/items/map.py new file mode 100644 index 0000000..3778a8e --- /dev/null +++ b/bulk_translate/src/pipeline/items/map.py @@ -0,0 +1,12 @@ +from bulk_translate.src.pipeline.items.base import BasePipelineItem + + +class MapPipelineItem(BasePipelineItem): + + def __init__(self, map_func=None, **kwargs): + assert(callable(map_func)) + super(MapPipelineItem, self).__init__(**kwargs) + self._map_func = map_func + + def apply_core(self, input_data, pipeline_ctx): + return map(self._map_func, input_data) diff --git a/bulk_translate/src/pipeline/launcher.py b/bulk_translate/src/pipeline/launcher.py new file mode 100644 index 0000000..18fa102 --- /dev/null +++ b/bulk_translate/src/pipeline/launcher.py @@ -0,0 +1,28 @@ +from bulk_translate.src.pipeline.context import PipelineContext +from bulk_translate.src.pipeline.items.base import BasePipelineItem + + +class BatchingPipelineLauncher: + + @staticmethod + def run(pipeline, pipeline_ctx, src_key=None): + assert(isinstance(pipeline, list)) + assert(isinstance(pipeline_ctx, PipelineContext)) + assert(isinstance(src_key, str) or src_key is None) + + for ind, item in enumerate(filter(lambda itm: itm is not None, pipeline)): + assert (isinstance(item, BasePipelineItem)) + + # Handle the content of the batch or batch itself. + content = item.get_source(pipeline_ctx, call_func=False, force_key=src_key if ind == 0 else None) + handled_batch = [item._src_func(i) if item._src_func is not None else i for i in content] + + if item.SupportBatching: + batch_result = list(item.apply(input_data=handled_batch, pipeline_ctx=pipeline_ctx)) + else: + batch_result = [item.apply(input_data=input_data, pipeline_ctx=pipeline_ctx) + for input_data in handled_batch] + + pipeline_ctx.update(param=item.ResultKey, value=batch_result, is_new_key=False) + + return pipeline_ctx diff --git a/bulk_translate/src/pipeline/translator.py b/bulk_translate/src/pipeline/translator.py index a2a22bb..cd53905 100644 --- a/bulk_translate/src/pipeline/translator.py +++ b/bulk_translate/src/pipeline/translator.py @@ -1,6 +1,5 @@ -from arekit.common.pipeline.context import PipelineContext -from arekit.common.pipeline.items.base import BasePipelineItem - +from bulk_translate.src.pipeline.context import PipelineContext +from bulk_translate.src.pipeline.items.base import BasePipelineItem from bulk_translate.src.span import Span diff --git a/bulk_translate/src/pipeline/utils.py b/bulk_translate/src/pipeline/utils.py new file mode 100644 index 0000000..bd451bd --- /dev/null +++ b/bulk_translate/src/pipeline/utils.py @@ -0,0 +1,32 @@ +class BatchIterator: + + def __init__(self, data_iter, batch_size, end_value=None): + assert(isinstance(batch_size, int) and batch_size > 0) + assert(callable(end_value) or end_value is None) + self.__data_iter = data_iter + self.__index = 0 + self.__batch_size = batch_size + self.__end_value = end_value + + def __iter__(self): + return self + + def __next__(self): + buffer = [] + while True: + try: + data = next(self.__data_iter) + except StopIteration: + break + buffer.append(data) + if len(buffer) == self.__batch_size: + break + + if len(buffer) > 0: + self.__index += 1 + return buffer + + if self.__end_value is None: + raise StopIteration + else: + return self.__end_value() diff --git a/bulk_translate/src/spans_parser.py b/bulk_translate/src/spans_parser.py index 8b42641..8c409a1 100644 --- a/bulk_translate/src/spans_parser.py +++ b/bulk_translate/src/spans_parser.py @@ -1,5 +1,4 @@ -from arekit.common.pipeline.items.base import BasePipelineItem - +from bulk_translate.src.pipeline.items.base import BasePipelineItem from bulk_translate.src.span import Span diff --git a/dependencies.txt b/dependencies.txt index fa2872d..7a8d986 100644 --- a/dependencies.txt +++ b/dependencies.txt @@ -1,2 +1 @@ -arekit>=0.25.1 source_iter>=0.24.2 \ No newline at end of file