Skip to content

Commit

Permalink
#237 refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 24, 2021
1 parent 1efa91b commit 7db5340
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 65 deletions.
Empty file.
19 changes: 19 additions & 0 deletions arekit/common/data/pipeline/item_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import collections

from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.item import BasePipelineItem


class HandleIterPipelineItem(BasePipelineItem):

def __init__(self, handle_func=None):
assert(callable(handle_func))
self.__handle_func = handle_func

def apply(self, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
items_iter = pipeline_ctx.provide("src")
assert(isinstance(items_iter, collections.Iterable))

for item in items_iter:
self.__handle_func(item)
17 changes: 17 additions & 0 deletions arekit/common/data/pipeline/item_iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import collections

from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.item import BasePipelineItem


class FilterPipelineItem(BasePipelineItem):

def __init__(self, filter_func=None):
assert(callable(filter_func))
self.__filter_func = filter_func

def apply(self, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
iter_data = pipeline_ctx.provide("src")
assert(isinstance(iter_data, collections.Iterable))
pipeline_ctx.update(param="src", value=filter(self.__filter_func, iter_data))
17 changes: 17 additions & 0 deletions arekit/common/data/pipeline/item_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import collections

from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.item import BasePipelineItem


class MapPipelineItem(BasePipelineItem):

def __init__(self, map_func=None):
assert(callable(map_func))
self.__map_func = map_func

def apply(self, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
iter_data = pipeline_ctx.provide("src")
assert(isinstance(iter_data, collections.Iterable))
pipeline_ctx.update(param="src", value=filter(self.__map_func, iter_data))
49 changes: 14 additions & 35 deletions arekit/common/data/views/ouput_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,6 @@ def __iter_doc_opinion_ids(self, doc_df):
return [self._ids_provider.parse_opinion_in_opinion_id(row_id)
for row_id in doc_df[const.ID]]

def __iter_opinion_linkages(self, doc_df, opinions_view):
assert (isinstance(doc_df, pd.DataFrame))
assert (isinstance(opinions_view, BaseOpinionStorageView))

doc_opin_ids = self.__iter_doc_opinion_ids(doc_df)
doc_opin_id_patterns = self.__iter_id_patterns(doc_opin_ids)
linkages_df = self.__iter_opinion_linkages_df(doc_df=doc_df,
row_ids=doc_opin_id_patterns)

for df_linkage in linkages_df:
assert (isinstance(df_linkage, pd.DataFrame))

opinions_iter = self._iter_by_opinions(linked_df=df_linkage,
opinions_view=opinions_view)

yield OpinionsLinkage(linked_data=opinions_iter)

def __iter_doc_ids(self):
return set(self._storage.iter_column_values(column_name=const.DOC_ID))

Expand Down Expand Up @@ -87,28 +70,24 @@ def _compose_opinion_by_opinion_id(self, sample_id, opinions_view, calc_label_fu

# region public methods

def iter_opinion_collections(self, opinions_view, keep_doc_id_func, to_collection_func):
assert(isinstance(opinions_view, BaseOpinionStorageView))
assert(callable(keep_doc_id_func))
assert(callable(to_collection_func))

# TODO. #237 __iter_doc_ids() should be utilized outside as a part of the pipeline.
for doc_id in self.__iter_doc_ids():
def iter_doc_ids(self):
return self.__iter_doc_ids()

# TODO. #237 keep_doc_id_func(doc_id) should be utilized outside as a part of the pipeline.
if not keep_doc_id_func(doc_id):
continue
def iter_opinion_linkages(self, doc_id, opinions_view):
assert(isinstance(opinions_view, BaseOpinionStorageView))
doc_df = self._storage.find_by_value(column_name=const.DOC_ID, value=doc_id)

# TODO. #237 find_by_value(doc_id) should be utilized outside + the latter should return Storage!
doc_df = self._storage.find_by_value(column_name=const.DOC_ID,
value=doc_id)
doc_opin_ids = self.__iter_doc_opinion_ids(doc_df)
doc_opin_id_patterns = self.__iter_id_patterns(doc_opin_ids)
linkages_df = self.__iter_opinion_linkages_df(doc_df=doc_df,
row_ids=doc_opin_id_patterns)

linkages_iter = self.__iter_opinion_linkages(doc_df=doc_df,
opinions_view=opinions_view)
for df_linkage in linkages_df:
assert (isinstance(df_linkage, pd.DataFrame))

# TODO. #237 This to_collection_func(linkages_iter) should be outside and a part of the pipeline.
collection = to_collection_func(linkages_iter)
opinions_iter = self._iter_by_opinions(linked_df=df_linkage,
opinions_view=opinions_view)

yield doc_id, collection
yield OpinionsLinkage(linked_data=opinions_iter)

# endregion
62 changes: 43 additions & 19 deletions arekit/contrib/bert/run_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
from os.path import exists, join

from arekit.common.data.pipeline.item_handle import HandleIterPipelineItem
from arekit.common.data.pipeline.item_iter import FilterPipelineItem
from arekit.common.data.pipeline.item_map import MapPipelineItem
from arekit.common.data.views.output_multiple import MulticlassOutputView
from arekit.common.experiment.api.ctx_training import TrainingData
from arekit.common.experiment.api.enums import BaseDocumentTag
Expand All @@ -10,6 +13,8 @@
from arekit.common.model.labeling.modes import LabelCalculationMode
from arekit.common.model.labeling.single import SingleLabelsHelper
from arekit.common.opinions.base import Opinion
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.utils import join_dir_with_subfolder_name
from arekit.contrib.bert.callback import Callback
from arekit.contrib.bert.output.eval_helper import EvalHelper
Expand Down Expand Up @@ -54,6 +59,20 @@ def __get_target_dir(self):
original_target_dir = self._experiment.ExperimentIO.get_target_dir()
return self.__eval_helper.get_results_dir(original_target_dir)

def __save_opinion_collection(self, doc_id, collection, epoch_index):

exp_io = self._experiment.ExperimentIO

target = exp_io.create_result_opinion_collection_target(
data_type=self.__data_type,
epoch_index=epoch_index,
doc_id=doc_id)

exp_io.write_opinion_collection(
collection=collection,
labels_formatter=self.__labels_formatter,
target=target)

def _handle_iteration(self, iter_index):
exp_io = self._experiment.ExperimentIO
exp_data = self._experiment.DataIO
Expand Down Expand Up @@ -121,25 +140,30 @@ def _handle_iteration(self, iter_index):
labels_scaler=self.__label_scaler,
storage=storage)

# iterate opinion collections.
collections_iter = output_view.iter_opinion_collections(
opinions_view=exp_io.create_opinions_view(self.__data_type),
keep_doc_id_func=lambda doc_id: doc_id in cmp_doc_ids_set,
to_collection_func=lambda linked_iter: self.__create_opinion_collection(
supported_labels=exp_data.SupportedCollectionLabels,
linked_iter=linked_iter))

for doc_id, collection in collections_iter:

target = exp_io.create_result_opinion_collection_target(
data_type=self.__data_type,
epoch_index=epoch_index,
doc_id=doc_id)

exp_io.write_opinion_collection(
collection=collection,
labels_formatter=self.__labels_formatter,
target=target)
# Opinion collections iterator pipeline.
pipeline_save_collections = BasePipeline([
FilterPipelineItem(filter_func=lambda doc_id: doc_id in cmp_doc_ids_set),
MapPipelineItem(lambda doc_id:
(doc_id, output_view.iter_opinion_linkages(
doc_id=doc_id,
opinions_view=exp_io.create_opinions_view(self.__data_type)))
),
MapPipelineItem(lambda doc_id, linkages_iter:
(doc_id,
self.__create_opinion_collection(
supported_labels=exp_data.SupportedCollectionLabels,
linked_iter=linkages_iter)
)),
HandleIterPipelineItem(lambda doc_id, collection:
self.__save_opinion_collection(
doc_id=doc_id,
collection=collection,
epoch_index=epoch_index))
])

# Executing pipeline.
pipeline_ctx = PipelineContext({"src": output_view.iter_doc_ids()})
pipeline_save_collections.run(pipeline_ctx)

# evaluate
result = self._experiment.evaluate(data_type=self.__data_type,
Expand Down
38 changes: 27 additions & 11 deletions arekit/contrib/networks/core/callback/utils_model_eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging

from arekit.common.data import const
from arekit.common.data.pipeline.item_iter import FilterPipelineItem
from arekit.common.data.pipeline.item_map import MapPipelineItem
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.data.views.output_multiple import MulticlassOutputView
from arekit.common.experiment.api.enums import BaseDocumentTag
Expand All @@ -12,6 +14,8 @@
from arekit.common.model.labeling.modes import LabelCalculationMode
from arekit.common.model.labeling.single import SingleLabelsHelper
from arekit.common.opinions.base import Opinion
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.utils import progress_bar_iter
from arekit.contrib.networks.core.callback.utils_hidden_states import save_minibatch_all_input_dependent_hidden_values
from arekit.contrib.networks.core.ctx_predict_log import NetworkInputDependentVariables
Expand Down Expand Up @@ -77,7 +81,7 @@ def evaluate_model(experiment, label_scaler, data_type, epoch_index, model,
result = experiment.evaluate(data_type=data_type,
epoch_index=epoch_index)

# optionally save input-dependend hidden parameters.
# optionally save input-dependent hidden parameters.
if save_hidden_params:
save_minibatch_all_input_dependent_hidden_values(
predict_log=idhp,
Expand Down Expand Up @@ -106,18 +110,30 @@ def __convert_output_to_opinion_collections(exp_io, opin_ops, doc_ops, labels_sc
output_view = MulticlassOutputView(labels_scaler=labels_scaler,
storage=output_storage)

# Extract iterator.
collections_iter = output_view.iter_opinion_collections(
opinions_view=exp_io.create_opinions_view(data_type),
keep_doc_id_func=lambda doc_id: doc_id in cmp_doc_ids_set,
to_collection_func=lambda linked_iter: __create_opinion_collection(
linked_iter=linked_iter,
supported_labels=supported_collection_labels,
create_opinion_collection=opin_ops.create_opinion_collection,
label_scaler=labels_scaler))
# Opinion collections iterator pipeline.
collections_iter_pipeline = BasePipeline([
FilterPipelineItem(filter_func=lambda doc_id: doc_id in cmp_doc_ids_set),
MapPipelineItem(lambda doc_id:
(doc_id,
output_view.iter_opinion_linkages(
doc_id=doc_id,
opinions_view=exp_io.create_opinions_view(data_type)))
),
MapPipelineItem(lambda doc_id, linkages_iter:
(doc_id,
__create_opinion_collection(
linked_iter=linkages_iter,
supported_labels=supported_collection_labels,
create_opinion_collection=opin_ops.create_opinion_collection,
label_scaler=labels_scaler))),
])

# Executing pipeline.
pipeline_ctx = PipelineContext({"src": output_view.iter_doc_ids()})
collections_iter_pipeline.run(pipeline_ctx)

# Save collection.
for doc_id, collection in __log_wrap_collections_conversion_iter(collections_iter):
for doc_id, collection in pipeline_ctx.provide("src"):

target = exp_io.create_result_opinion_collection_target(
data_type=data_type,
Expand Down

0 comments on commit 7db5340

Please sign in to comment.