Skip to content

Commit

Permalink
#476 same implementation of most parts for different items
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed May 18, 2023
1 parent 2b606b1 commit e844ec6
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 15 deletions.
9 changes: 5 additions & 4 deletions arekit/contrib/utils/pipelines/items/sampling/bert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.base import BaseDataFolding
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.contrib.utils.utils_folding import folding_iter_states
Expand All @@ -9,7 +10,7 @@

class BertExperimentInputSerializerPipelineItem(BasePipelineItem):

def __init__(self, sample_rows_provider, samples_io, save_labels_func, balance_func, storage):
def __init__(self, rows_provider, samples_io, save_labels_func, balance_func, storage):
""" sample_rows_formatter:
how we format input texts for a BERT model, for example:
- single text
Expand All @@ -21,7 +22,7 @@ def __init__(self, sample_rows_provider, samples_io, save_labels_func, balance_f
assert(isinstance(samples_io, BaseSamplesIO))
super(BertExperimentInputSerializerPipelineItem, self).__init__()

self.__sample_rows_provider = sample_rows_provider
self.__rows_provider = rows_provider
self.__balance_func = balance_func
self.__samples_io = samples_io
self.__save_labels_func = save_labels_func
Expand All @@ -31,11 +32,12 @@ def __init__(self, sample_rows_provider, samples_io, save_labels_func, balance_f

def __serialize_iteration(self, data_type, pipeline, data_folding):
assert(isinstance(data_type, DataType))
assert(isinstance(pipeline, BasePipeline))

repos = {
"sample": InputDataSerializationHelper.create_samples_repo(
keep_labels=self.__save_labels_func(data_type),
rows_provider=self.__sample_rows_provider,
rows_provider=self.__rows_provider,
storage=self.__storage),
}

Expand All @@ -46,7 +48,6 @@ def __serialize_iteration(self, data_type, pipeline, data_folding):
}

for description, repo in repos.items():

InputDataSerializationHelper.fill_and_write(
repo=repo,
pipeline=pipeline,
Expand Down
13 changes: 5 additions & 8 deletions arekit/contrib/utils/pipelines/items/sampling/networks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.base import BaseDataFolding
from arekit.common.pipeline.base import BasePipeline
Expand All @@ -8,7 +9,6 @@
from arekit.contrib.networks.embedding import Embedding
from arekit.contrib.networks.input.providers.sample import NetworkSampleRowProvider
from arekit.contrib.utils.io_utils.embedding import NpEmbeddingIO
from arekit.contrib.utils.io_utils.samples import SamplesIO
from arekit.contrib.utils.utils_folding import folding_iter_states
from arekit.contrib.utils.serializer import InputDataSerializationHelper

Expand All @@ -33,7 +33,7 @@ def __init__(self, save_labels_func, rows_provider, samples_io,
save_embedding: bool
save embedding and all the related information to it.
"""
assert(isinstance(samples_io, SamplesIO))
assert(isinstance(samples_io, BaseSamplesIO))
assert(isinstance(emb_io, NpEmbeddingIO))
assert(isinstance(rows_provider, NetworkSampleRowProvider))
assert(isinstance(save_embedding, bool))
Expand All @@ -49,14 +49,14 @@ def __init__(self, save_labels_func, rows_provider, samples_io,
self.__storage = storage
self.__rows_provider = rows_provider

def __serialize_iteration(self, data_type, pipeline, rows_provider, data_folding):
def __serialize_iteration(self, data_type, pipeline, data_folding):
assert(isinstance(data_type, DataType))
assert(isinstance(pipeline, BasePipeline))

repos = {
"sample": InputDataSerializationHelper.create_samples_repo(
keep_labels=self.__save_labels_func(data_type),
rows_provider=rows_provider,
rows_provider=self.__rows_provider,
storage=self.__storage),
}

Expand Down Expand Up @@ -86,10 +86,7 @@ def __handle_iteration(self, data_type_pipelines, data_folding):
self.__rows_provider.clear_embedding_pairs()

for data_type, pipeline in data_type_pipelines.items():
self.__serialize_iteration(pipeline=pipeline,
data_type=data_type,
rows_provider=self.__rows_provider,
data_folding=data_folding)
self.__serialize_iteration(data_type=data_type, pipeline=pipeline, data_folding=data_folding)

if not (self.__save_embedding and self.__rows_provider.HasEmbeddingPairs):
return
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/utils/test_csv_stream_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __launch(self, writer, target_extention):
samples_io = SamplesIO(self.__output_dir, writer, target_extension=target_extention)

pipeline_item = BertExperimentInputSerializerPipelineItem(
sample_rows_provider=sample_rows_provider,
rows_provider=sample_rows_provider,
samples_io=samples_io,
save_labels_func=lambda data_type: True,
balance_func=lambda _: False,
Expand Down
4 changes: 2 additions & 2 deletions tests/tutorials/test_tutorial_pipeline_sampling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ def test(self):
if text_b_template is None else \
PairTextProvider(text_b_template, terms_mapper)

sample_rows_provider = BaseSampleRowProvider(
rows_provider = BaseSampleRowProvider(
label_provider=MultipleLabelProvider(SentimentLabelScaler()),
text_provider=text_provider)

writer = PandasCsvWriter(write_header=True)
samples_io = SamplesIO(self.__output_dir, writer, target_extension=".tsv.gz")

pipeline_item = BertExperimentInputSerializerPipelineItem(
sample_rows_provider=sample_rows_provider,
rows_provider=rows_provider,
samples_io=samples_io,
save_labels_func=lambda data_type: True,
balance_func=lambda data_type: data_type == DataType.Train,
Expand Down

0 comments on commit e844ec6

Please sign in to comment.