diff --git a/arekit/contrib/utils/pipelines/items/sampling/bert.py b/arekit/contrib/utils/pipelines/items/sampling/bert.py index 272691b4..58c31283 100644 --- a/arekit/contrib/utils/pipelines/items/sampling/bert.py +++ b/arekit/contrib/utils/pipelines/items/sampling/bert.py @@ -1,6 +1,7 @@ from arekit.common.experiment.api.base_samples_io import BaseSamplesIO from arekit.common.experiment.data_type import DataType from arekit.common.folding.base import BaseDataFolding +from arekit.common.pipeline.base import BasePipeline from arekit.common.pipeline.context import PipelineContext from arekit.common.pipeline.items.base import BasePipelineItem from arekit.contrib.utils.utils_folding import folding_iter_states @@ -9,7 +10,7 @@ class BertExperimentInputSerializerPipelineItem(BasePipelineItem): - def __init__(self, sample_rows_provider, samples_io, save_labels_func, balance_func, storage): + def __init__(self, rows_provider, samples_io, save_labels_func, balance_func, storage): """ sample_rows_formatter: how we format input texts for a BERT model, for example: - single text @@ -21,7 +22,7 @@ def __init__(self, sample_rows_provider, samples_io, save_labels_func, balance_f assert(isinstance(samples_io, BaseSamplesIO)) super(BertExperimentInputSerializerPipelineItem, self).__init__() - self.__sample_rows_provider = sample_rows_provider + self.__rows_provider = rows_provider self.__balance_func = balance_func self.__samples_io = samples_io self.__save_labels_func = save_labels_func @@ -31,11 +32,12 @@ def __init__(self, sample_rows_provider, samples_io, save_labels_func, balance_f def __serialize_iteration(self, data_type, pipeline, data_folding): assert(isinstance(data_type, DataType)) + assert(isinstance(pipeline, BasePipeline)) repos = { "sample": InputDataSerializationHelper.create_samples_repo( keep_labels=self.__save_labels_func(data_type), - rows_provider=self.__sample_rows_provider, + rows_provider=self.__rows_provider, storage=self.__storage), } @@ -46,7 +48,6 @@ def __serialize_iteration(self, data_type, pipeline, data_folding): } for description, repo in repos.items(): - InputDataSerializationHelper.fill_and_write( repo=repo, pipeline=pipeline, diff --git a/arekit/contrib/utils/pipelines/items/sampling/networks.py b/arekit/contrib/utils/pipelines/items/sampling/networks.py index 9144329a..dacabdc4 100644 --- a/arekit/contrib/utils/pipelines/items/sampling/networks.py +++ b/arekit/contrib/utils/pipelines/items/sampling/networks.py @@ -1,3 +1,4 @@ +from arekit.common.experiment.api.base_samples_io import BaseSamplesIO from arekit.common.experiment.data_type import DataType from arekit.common.folding.base import BaseDataFolding from arekit.common.pipeline.base import BasePipeline @@ -8,7 +9,6 @@ from arekit.contrib.networks.embedding import Embedding from arekit.contrib.networks.input.providers.sample import NetworkSampleRowProvider from arekit.contrib.utils.io_utils.embedding import NpEmbeddingIO -from arekit.contrib.utils.io_utils.samples import SamplesIO from arekit.contrib.utils.utils_folding import folding_iter_states from arekit.contrib.utils.serializer import InputDataSerializationHelper @@ -33,7 +33,7 @@ def __init__(self, save_labels_func, rows_provider, samples_io, save_embedding: bool save embedding and all the related information to it. """ - assert(isinstance(samples_io, SamplesIO)) + assert(isinstance(samples_io, BaseSamplesIO)) assert(isinstance(emb_io, NpEmbeddingIO)) assert(isinstance(rows_provider, NetworkSampleRowProvider)) assert(isinstance(save_embedding, bool)) @@ -49,14 +49,14 @@ def __init__(self, save_labels_func, rows_provider, samples_io, self.__storage = storage self.__rows_provider = rows_provider - def __serialize_iteration(self, data_type, pipeline, rows_provider, data_folding): + def __serialize_iteration(self, data_type, pipeline, data_folding): assert(isinstance(data_type, DataType)) assert(isinstance(pipeline, BasePipeline)) repos = { "sample": InputDataSerializationHelper.create_samples_repo( keep_labels=self.__save_labels_func(data_type), - rows_provider=rows_provider, + rows_provider=self.__rows_provider, storage=self.__storage), } @@ -86,10 +86,7 @@ def __handle_iteration(self, data_type_pipelines, data_folding): self.__rows_provider.clear_embedding_pairs() for data_type, pipeline in data_type_pipelines.items(): - self.__serialize_iteration(pipeline=pipeline, - data_type=data_type, - rows_provider=self.__rows_provider, - data_folding=data_folding) + self.__serialize_iteration(data_type=data_type, pipeline=pipeline, data_folding=data_folding) if not (self.__save_embedding and self.__rows_provider.HasEmbeddingPairs): return diff --git a/tests/contrib/utils/test_csv_stream_write.py b/tests/contrib/utils/test_csv_stream_write.py index 1a0f99f3..834ee867 100644 --- a/tests/contrib/utils/test_csv_stream_write.py +++ b/tests/contrib/utils/test_csv_stream_write.py @@ -53,7 +53,7 @@ def __launch(self, writer, target_extention): samples_io = SamplesIO(self.__output_dir, writer, target_extension=target_extention) pipeline_item = BertExperimentInputSerializerPipelineItem( - sample_rows_provider=sample_rows_provider, + rows_provider=sample_rows_provider, samples_io=samples_io, save_labels_func=lambda data_type: True, balance_func=lambda _: False, diff --git a/tests/tutorials/test_tutorial_pipeline_sampling_bert.py b/tests/tutorials/test_tutorial_pipeline_sampling_bert.py index 48b445eb..9b910c33 100644 --- a/tests/tutorials/test_tutorial_pipeline_sampling_bert.py +++ b/tests/tutorials/test_tutorial_pipeline_sampling_bert.py @@ -84,7 +84,7 @@ def test(self): if text_b_template is None else \ PairTextProvider(text_b_template, terms_mapper) - sample_rows_provider = BaseSampleRowProvider( + rows_provider = BaseSampleRowProvider( label_provider=MultipleLabelProvider(SentimentLabelScaler()), text_provider=text_provider) @@ -92,7 +92,7 @@ def test(self): samples_io = SamplesIO(self.__output_dir, writer, target_extension=".tsv.gz") pipeline_item = BertExperimentInputSerializerPipelineItem( - sample_rows_provider=sample_rows_provider, + rows_provider=rows_provider, samples_io=samples_io, save_labels_func=lambda data_type: True, balance_func=lambda data_type: data_type == DataType.Train,