Skip to content

Commit

Permalink
#282 related fix. #252 related in terms of simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Aug 1, 2022
1 parent 5db2391 commit b9e2d84
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 21 deletions.
2 changes: 1 addition & 1 deletion arekit/common/experiment/api/base_samples_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ class BaseSamplesIO(object):
results -- evaluation of experiments.
"""

def create_view(self, data_type, data_folding):
def create_view(self, target):
""" For viewing/reading
"""
raise NotImplementedError()
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/networks/pipelines/items/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from arekit.contrib.networks.embedding import Embedding
from arekit.contrib.utils.io_utils.embedding import NpzEmbeddingIOUtils

from arekit.contrib.utils.io_utils.samples import SamplesIOUtils
from arekit.contrib.utils.io_utils.samples import SamplesIO
from arekit.contrib.utils.utils_folding import folding_iter_states
from arekit.contrib.utils.serializer import InputDataSerializationHelper

Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self, vectorizers, save_labels_func, str_entity_fmt, exp_ctx,
save embedding and all the related information to it.
"""
assert(isinstance(exp_ctx, NetworkSerializationContext))
assert(isinstance(samples_io, SamplesIOUtils))
assert(isinstance(samples_io, SamplesIO))
assert(isinstance(emb_io, NpzEmbeddingIOUtils))
assert(isinstance(str_entity_fmt, StringEntitiesFormatter))
assert(isinstance(vectorizers, dict))
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/networks/pipelines/items/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from arekit.contrib.networks.shapes import NetworkInputShapes
from arekit.contrib.networks.utils import rm_dir_contents
from arekit.contrib.utils.io_utils.embedding import NpzEmbeddingIOUtils
from arekit.contrib.utils.io_utils.samples import SamplesIOUtils
from arekit.contrib.utils.io_utils.samples import SamplesIO
from arekit.contrib.utils.utils_folding import folding_iter_states


Expand All @@ -29,7 +29,7 @@ def __init__(self, bags_collection_type, model_io, samples_io, emb_io,
load_model, config, create_network_func, training_epochs,
labels_count, network_callbacks, prepare_model_root=True, seed=None):
assert(callable(create_network_func))
assert(isinstance(samples_io, SamplesIOUtils))
assert(isinstance(samples_io, SamplesIO))
assert(isinstance(emb_io, NpzEmbeddingIOUtils))
assert(isinstance(config, DefaultNetworkConfig))
assert(issubclass(bags_collection_type, BagsCollection))
Expand Down
10 changes: 6 additions & 4 deletions arekit/contrib/utils/io_utils/opinions.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
from os.path import join

from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.contrib.utils.data.views.opinions import BaseOpinionStorageView
from arekit.contrib.utils.io_utils.utils import filename_template


class OpinionsIOUtils(object):
class OpinionsIO(BaseSamplesIO):

def __init__(self, target_dir, target_extension=".tsv.gz"):
def __init__(self, target_dir, prefix="opinion", target_extension=".tsv.gz"):
self.__target_dir = target_dir
self.__prefix = prefix
self.__target_extension = target_extension

def create_view(self, target):
storage = BaseRowsStorage.from_tsv(filepath=target)
return BaseOpinionStorageView(storage)

def create_writer_target(self, data_type, data_folding):
def create_target(self, data_type, data_folding):
return self.__get_input_opinions_target(data_type, data_folding=data_folding)

def __get_input_opinions_target(self, data_type, data_folding):
template = filename_template(data_type=data_type, data_folding=data_folding)
return self.__get_filepath(out_dir=self.__target_dir,
template=template,
prefix="opinion",
prefix=self.__prefix,
extension=self.__target_extension)

@staticmethod
Expand Down
11 changes: 6 additions & 5 deletions arekit/contrib/utils/io_utils/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,27 @@
from arekit.common.data.row_ids.multiple import MultipleIDProvider
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.data.views.samples import BaseSampleStorageView
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.contrib.utils.io_utils.utils import filename_template, check_targets_existence

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class SamplesIOUtils(BaseSamplesIO):
class SamplesIO(BaseSamplesIO):
""" Samples default IO utils for samples.
Sample is a text part which include pair of attitude participants.
This class allows to provide saver and loader for such entries, bubbed as samples.
Samples required for machine learning training/inferring.
"""

def __init__(self, target_dir,
samples_writer=TsvWriter(write_header=True),
writer=TsvWriter(write_header=True),
prefix="sample",
target_extension=".tsv.gz"):
assert(isinstance(samples_writer, BaseWriter))
assert(isinstance(writer, BaseWriter))
self.__target_dir = target_dir
self.__samples_writer = samples_writer
self.__writer = writer
self.__target_extension = target_extension
self.__prefix = prefix

Expand All @@ -36,7 +37,7 @@ def create_view(self, target):
row_ids_provider=MultipleIDProvider())

def create_writer(self):
return self.__samples_writer
return self.__writer

def create_target(self, data_type, data_folding):
return self.__get_input_sample_target(data_type, data_folding=data_folding)
Expand Down
14 changes: 7 additions & 7 deletions arekit/contrib/utils/pipelines/items/to_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,27 @@
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.common.pipeline.items.handle import HandleIterPipelineItem
from arekit.contrib.utils.data.views.linkages.multilabel import MultilableOpinionLinkagesView
from arekit.contrib.utils.io_utils.opinions import OpinionsIOUtils
from arekit.contrib.utils.io_utils.opinions import OpinionsIO
from arekit.contrib.utils.utils_folding import folding_iter_states, experiment_iter_index
from arekit.contrib.utils.pipelines.opinion_collections import \
text_opinion_linkages_to_opinion_collections_pipeline_part


class TextOpinionLinkagesToOpinionConverterPipelineItem(BasePipelineItem):

def __init__(self, opinion_samples_io, create_opinion_collection_func,
def __init__(self, opinions_io, create_opinion_collection_func,
opinion_collection_writer, label_scaler, labels_formatter):
""" create_opinion_collection_func: func
func () -> OpinionCollection (empty)
"""
assert(isinstance(opinion_samples_io, OpinionsIOUtils))
assert(isinstance(opinions_io, OpinionsIO))
assert(callable(create_opinion_collection_func))
assert(isinstance(label_scaler, BaseLabelScaler))
assert(isinstance(labels_formatter, StringLabelsFormatter))
assert(isinstance(opinion_collection_writer, OpinionCollectionWriter))
super(TextOpinionLinkagesToOpinionConverterPipelineItem, self).__init__()

self.__opinion_samples_io = opinion_samples_io
self.__opinions_io = opinions_io
self.__labels_formatter = labels_formatter
self.__label_scaler = label_scaler
self.__create_opinion_collection_func = create_opinion_collection_func
Expand All @@ -52,12 +52,12 @@ def __convert(self, data_folding, output_storage, target_func, data_type):
linkages_view = MultilableOpinionLinkagesView(labels_scaler=self.__label_scaler,
storage=output_storage)

target = self.__opinion_samples_io.create_writer_target(data_type=data_type,
data_folding=data_folding)
target = self.__opinions_io.create_target(data_type=data_type,
data_folding=data_folding)

converter_part = text_opinion_linkages_to_opinion_collections_pipeline_part(
iter_opinion_linkages_func=lambda doc_id: linkages_view.iter_opinion_linkages(
doc_id=doc_id, opinions_view=self.__opinion_samples_io.create_view(target)),
doc_id=doc_id, opinions_view=self.__opinions_io.create_view(target)),
doc_ids_set=set(data_folding.fold_doc_ids_set()[data_type]),
create_opinion_collection_func=self.__create_opinion_collection_func,
labels_scaler=self.__label_scaler,
Expand Down

0 comments on commit b9e2d84

Please sign in to comment.