diff --git a/arekit/contrib/networks/input/providers/sample.py b/arekit/contrib/networks/input/providers/sample.py index b660eae4..415ce347 100644 --- a/arekit/contrib/networks/input/providers/sample.py +++ b/arekit/contrib/networks/input/providers/sample.py @@ -1,3 +1,5 @@ +import collections + from arekit.common.data.input.providers.label.base import LabelProvider from arekit.common.data.input.providers.rows.samples import BaseSampleRowProvider from arekit.common.entities.base import Entity @@ -16,10 +18,16 @@ def __init__(self, text_provider, frames_connotation_provider, frame_role_label_scaler, + term_embedding_pairs=None, pos_terms_mapper=None): + """ term_embedding_pairs: dict or None + additional structure, utilized to collect all the embedding pairs during the + rows providing stage. + """ assert(isinstance(label_provider, LabelProvider)) assert(isinstance(frame_role_label_scaler, SentimentLabelScaler)) assert(isinstance(pos_terms_mapper, PosTermsMapper) or pos_terms_mapper is None) + assert(isinstance(term_embedding_pairs, collections.OrderedDict) or term_embedding_pairs is None) super(NetworkSampleRowProvider, self).__init__(label_provider=label_provider, text_provider=text_provider) @@ -27,6 +35,26 @@ def __init__(self, self.__frames_connotation_provider = frames_connotation_provider self.__frame_role_label_scaler = frame_role_label_scaler self.__pos_terms_mapper = pos_terms_mapper + self.__term_embedding_pairs = term_embedding_pairs + + @property + def HasEmbeddingPairs(self): + return self.__term_embedding_pairs is not None + + def iter_term_embedding_pairs(self): + """ Provide the contents of the embedded pairs. + """ + return iter(self.__term_embedding_pairs.items()) + + def clear_embedding_pairs(self): + """ Release the contents of the collected embedding pairs. + """ + + # Check whether we actually collect embedding pairs. + if self.__term_embedding_pairs is None: + return + + self.__term_embedding_pairs.clear() def _fill_row_core(self, row, text_opinion_linkage, index_in_linked, etalon_label, parsed_news, sentence_ind, s_ind, t_ind): diff --git a/arekit/contrib/utils/bert/rows.py b/arekit/contrib/utils/bert/rows.py new file mode 100644 index 00000000..e69de29b diff --git a/arekit/contrib/utils/nn/__init__.py b/arekit/contrib/utils/nn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/arekit/contrib/utils/nn/rows.py b/arekit/contrib/utils/nn/rows.py new file mode 100644 index 00000000..bffce85f --- /dev/null +++ b/arekit/contrib/utils/nn/rows.py @@ -0,0 +1,83 @@ +import collections + +from arekit.common.data.input.providers.text.single import BaseSingleTextProvider +from arekit.common.data.input.terms_mapper import OpinionContainingTextTermsMapper +from arekit.common.entities.str_fmt import StringEntitiesFormatter +from arekit.contrib.networks.input.ctx_serialization import NetworkSerializationContext +from arekit.contrib.networks.input.formatters.pos_mapper import PosTermsMapper +from arekit.contrib.networks.input.providers.sample import NetworkSampleRowProvider +from arekit.contrib.networks.input.providers.text import NetworkSingleTextProvider +from arekit.contrib.networks.input.term_types import TermTypes +from arekit.contrib.networks.input.terms_mapping import StringWithEmbeddingNetworkTermMapping +from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper +from arekit.contrib.utils.resources import load_embedding_news_mystem_skipgram_1000_20_2015 +from arekit.contrib.utils.vectorizers.bpe import BPEVectorizer +from arekit.contrib.utils.vectorizers.random_norm import RandomNormalVectorizer + + +def __add_term_embedding(dict_data, term, emb_vector): + if term in dict_data: + return + dict_data[term] = emb_vector + + +def create_rows_provider(str_entity_fmt, ctx, vectorizers="default"): + """ This method is corresponds to the default initialization of + the rows provider for data sampling pipeline. + + vectorizers: + NONE: no need to vectorize, just provide text (using SingleTextProvider). + DEFAULT: we consider an application of stemmer for Russian Language. + DICT: in which for every type there is an assigned Vectorizer + vectorization of term types. + { + TermType.Word: Vectorizer, + TermType.Entity: Vectorizer, + ... + } + """ + assert(isinstance(str_entity_fmt, StringEntitiesFormatter)) + assert(isinstance(ctx, NetworkSerializationContext)) + assert(isinstance(vectorizers, dict) or vectorizers == "default" or vectorizers is None) + + term_embedding_pairs = None + + if vectorizers is not None: + + if vectorizers == "default": + # initialize default vectorizer for Russian language. + embedding = load_embedding_news_mystem_skipgram_1000_20_2015(MystemWrapper()) + bpe_vectorizer = BPEVectorizer(embedding=embedding, max_part_size=3) + norm_vectorizer = RandomNormalVectorizer(vector_size=embedding.VectorSize, + token_offset=12345) + vectorizers = { + TermTypes.WORD: bpe_vectorizer, + TermTypes.ENTITY: bpe_vectorizer, + TermTypes.FRAME: bpe_vectorizer, + TermTypes.TOKEN: norm_vectorizer + } + + # Setup term-embedding pairs collection instance. + term_embedding_pairs = collections.OrderedDict() + + # Use text provider with vectorizers. + text_provider = NetworkSingleTextProvider( + text_terms_mapper=StringWithEmbeddingNetworkTermMapping( + vectorizers=vectorizers, + string_entities_formatter=str_entity_fmt), + pair_handling_func=lambda pair: __add_term_embedding( + dict_data=term_embedding_pairs, + term=pair[0], + emb_vector=pair[1])) + else: + # Create text provider which without vectorizers. + text_provider = BaseSingleTextProvider( + text_terms_mapper=OpinionContainingTextTermsMapper(str_entity_fmt)) + + return NetworkSampleRowProvider( + label_provider=ctx.LabelProvider, + text_provider=text_provider, + frames_connotation_provider=ctx.FramesConnotationProvider, + frame_role_label_scaler=ctx.FrameRolesLabelScaler, + pos_terms_mapper=PosTermsMapper(ctx.PosTagger) if ctx.PosTagger is not None else None, + term_embedding_pairs=term_embedding_pairs) diff --git a/arekit/contrib/utils/pipelines/items/sampling/base.py b/arekit/contrib/utils/pipelines/items/sampling/base.py new file mode 100644 index 00000000..e69de29b diff --git a/arekit/contrib/utils/pipelines/items/sampling/networks.py b/arekit/contrib/utils/pipelines/items/sampling/networks.py index d406f004..9144329a 100644 --- a/arekit/contrib/utils/pipelines/items/sampling/networks.py +++ b/arekit/contrib/utils/pipelines/items/sampling/networks.py @@ -1,21 +1,12 @@ -import collections - -from arekit.common.data.input.providers.text.single import BaseSingleTextProvider -from arekit.common.data.input.terms_mapper import OpinionContainingTextTermsMapper -from arekit.common.entities.str_fmt import StringEntitiesFormatter from arekit.common.experiment.data_type import DataType from arekit.common.folding.base import BaseDataFolding from arekit.common.pipeline.base import BasePipeline from arekit.common.pipeline.context import PipelineContext from arekit.common.pipeline.items.base import BasePipelineItem -from arekit.contrib.networks.input.ctx_serialization import NetworkSerializationContext from arekit.contrib.networks.input.embedding.matrix import create_term_embedding_matrix from arekit.contrib.networks.input.embedding.offsets import TermsEmbeddingOffsets -from arekit.contrib.networks.input.formatters.pos_mapper import PosTermsMapper -from arekit.contrib.networks.input.providers.sample import NetworkSampleRowProvider -from arekit.contrib.networks.input.providers.text import NetworkSingleTextProvider -from arekit.contrib.networks.input.terms_mapping import StringWithEmbeddingNetworkTermMapping from arekit.contrib.networks.embedding import Embedding +from arekit.contrib.networks.input.providers.sample import NetworkSampleRowProvider from arekit.contrib.utils.io_utils.embedding import NpEmbeddingIO from arekit.contrib.utils.io_utils.samples import SamplesIO from arekit.contrib.utils.utils_folding import folding_iter_states @@ -24,8 +15,8 @@ class NetworksInputSerializerPipelineItem(BasePipelineItem): - def __init__(self, vectorizers, save_labels_func, str_entity_fmt, ctx, - samples_io, emb_io, balance_func, save_embedding, storage): + def __init__(self, save_labels_func, rows_provider, samples_io, + emb_io, balance_func, storage, save_embedding=True): """ This pipeline item allows to perform a data preparation for neural network models. considering a list of the whole data_types with the related pipelines, @@ -36,25 +27,15 @@ def __init__(self, vectorizers, save_labels_func, str_entity_fmt, ctx, balance: bool declares whethere there is a need to balance Train samples - vectorizers: dict in which for every type there is an assigned Vectorizer - vectorization of term types. - { - TermType.Word: Vectorizer, - TermType.Entity: Vectorizer, - ... - } - save_labels_func: function data_type -> bool save_embedding: bool save embedding and all the related information to it. """ - assert(isinstance(ctx, NetworkSerializationContext)) assert(isinstance(samples_io, SamplesIO)) assert(isinstance(emb_io, NpEmbeddingIO)) - assert(isinstance(str_entity_fmt, StringEntitiesFormatter)) - assert(isinstance(vectorizers, dict) or vectorizers is None) + assert(isinstance(rows_provider, NetworkSampleRowProvider)) assert(isinstance(save_embedding, bool)) assert(callable(save_labels_func)) assert(callable(balance_func)) @@ -62,39 +43,11 @@ def __init__(self, vectorizers, save_labels_func, str_entity_fmt, ctx, self.__emb_io = emb_io self.__samples_io = samples_io - self.__save_embedding = save_embedding and vectorizers is not None + self.__save_embedding = save_embedding self.__save_labels_func = save_labels_func self.__balance_func = balance_func self.__storage = storage - - self.__term_embedding_pairs = collections.OrderedDict() - - if vectorizers is not None: - text_provider = NetworkSingleTextProvider( - text_terms_mapper=StringWithEmbeddingNetworkTermMapping( - vectorizers=vectorizers, - string_entities_formatter=str_entity_fmt), - pair_handling_func=lambda pair: self.__add_term_embedding( - dict_data=self.__term_embedding_pairs, - term=pair[0], - emb_vector=pair[1])) - else: - # Create text provider which without vectorizers. - text_provider = BaseSingleTextProvider( - text_terms_mapper=OpinionContainingTextTermsMapper(str_entity_fmt)) - - self.__rows_provider = NetworkSampleRowProvider( - label_provider=ctx.LabelProvider, - text_provider=text_provider, - frames_connotation_provider=ctx.FramesConnotationProvider, - frame_role_label_scaler=ctx.FrameRolesLabelScaler, - pos_terms_mapper=PosTermsMapper(ctx.PosTagger) if ctx.PosTagger is not None else None) - - @staticmethod - def __add_term_embedding(dict_data, term, emb_vector): - if term in dict_data: - return - dict_data[term] = emb_vector + self.__rows_provider = rows_provider def __serialize_iteration(self, data_type, pipeline, rows_provider, data_folding): assert(isinstance(data_type, DataType)) @@ -130,7 +83,7 @@ def __handle_iteration(self, data_type_pipelines, data_folding): assert(isinstance(data_folding, BaseDataFolding)) # Prepare for the present iteration. - self.__term_embedding_pairs.clear() + self.__rows_provider.clear_embedding_pairs() for data_type, pipeline in data_type_pipelines.items(): self.__serialize_iteration(pipeline=pipeline, @@ -138,11 +91,11 @@ def __handle_iteration(self, data_type_pipelines, data_folding): rows_provider=self.__rows_provider, data_folding=data_folding) - if not self.__save_embedding: + if not (self.__save_embedding and self.__rows_provider.HasEmbeddingPairs): return # Save embedding information additionally. - term_embedding = Embedding.from_word_embedding_pairs_iter(iter(self.__term_embedding_pairs.items())) + term_embedding = Embedding.from_word_embedding_pairs_iter(self.__rows_provider.iter_term_embedding_pairs()) embedding_matrix = create_term_embedding_matrix(term_embedding=term_embedding) vocab = list(TermsEmbeddingOffsets.extract_vocab(words_embedding=term_embedding)) diff --git a/tests/tutorials/test_tutorial_pipeline_sampling_network.py b/tests/tutorials/test_tutorial_pipeline_sampling_network.py index 338b0415..678b0285 100644 --- a/tests/tutorials/test_tutorial_pipeline_sampling_network.py +++ b/tests/tutorials/test_tutorial_pipeline_sampling_network.py @@ -11,7 +11,6 @@ from arekit.common.pipeline.base import BasePipeline from arekit.common.text.parser import BaseTextParser from arekit.contrib.networks.input.ctx_serialization import NetworkSerializationContext -from arekit.contrib.networks.input.term_types import TermTypes from arekit.contrib.source.brat.entities.parser import BratTextEntitiesParser from arekit.contrib.source.rusentiframes.collection import RuSentiFramesCollection from arekit.contrib.source.rusentiframes.labels_fmt import RuSentiFramesEffectLabelsFormatter, \ @@ -24,6 +23,7 @@ from arekit.contrib.utils.entities.formatters.str_simple_uppercase_fmt import SimpleUppercasedEntityFormatter from arekit.contrib.utils.io_utils.embedding import NpEmbeddingIO from arekit.contrib.utils.io_utils.samples import SamplesIO +from arekit.contrib.utils.nn.rows import create_rows_provider from arekit.contrib.utils.pipelines.items.sampling.networks import NetworksInputSerializerPipelineItem from arekit.contrib.utils.pipelines.items.text.frames_lemmatized import LemmasBasedFrameVariantsParser from arekit.contrib.utils.pipelines.items.text.tokenizer import DefaultTextTokenizer @@ -32,9 +32,6 @@ from arekit.contrib.utils.pipelines.text_opinion.filters.distance_based import DistanceLimitedTextOpinionFilter from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper from arekit.contrib.utils.processing.pos.mystem_wrap import POSMystemWrapper -from arekit.contrib.utils.resources import load_embedding_news_mystem_skipgram_1000_20_2015 -from arekit.contrib.utils.vectorizers.bpe import BPEVectorizer -from arekit.contrib.utils.vectorizers.random_norm import RandomNormalVectorizer from tests.tutorials.test_tutorial_pipeline_text_opinion_annotation import FooDocumentOperations @@ -70,7 +67,6 @@ class TestSamplingNetwork(unittest.TestCase): def test(self): stemmer = MystemWrapper() - embedding = load_embedding_news_mystem_skipgram_1000_20_2015(stemmer) frames_collection = RuSentiFramesCollection.read( version=RuSentiFramesVersions.V20, @@ -90,29 +86,17 @@ def test(self): frames_connotation_provider=RuSentiFramesConnotationProvider(frames_collection)) writer = PandasCsvWriter(write_header=True) - samples_io = SamplesIO(self.__output_dir, writer, target_extension=".tsv.gz") - embedding_io = NpEmbeddingIO(target_dir=self.__output_dir) - - bpe_vectorizer = BPEVectorizer(embedding=embedding, max_part_size=3) - norm_vectorizer = RandomNormalVectorizer(vector_size=embedding.VectorSize, - token_offset=12345) - vectorizers = { - TermTypes.WORD: bpe_vectorizer, - TermTypes.ENTITY: bpe_vectorizer, - TermTypes.FRAME: bpe_vectorizer, - TermTypes.TOKEN: norm_vectorizer - } + rows_provider = create_rows_provider( + str_entity_fmt=SimpleUppercasedEntityFormatter(), + ctx=ctx) pipeline_item = NetworksInputSerializerPipelineItem( - vectorizers=vectorizers, - samples_io=samples_io, - emb_io=embedding_io, - ctx=ctx, - str_entity_fmt=SimpleUppercasedEntityFormatter(), + samples_io=SamplesIO(self.__output_dir, writer, target_extension=".tsv.gz"), + emb_io=NpEmbeddingIO(target_dir=self.__output_dir), + rows_provider=rows_provider, balance_func=lambda data_type: data_type == DataType.Train, save_labels_func=lambda data_type: data_type != DataType.Test, - save_embedding=True, storage=PandasBasedRowsStorage()) pipeline = BasePipeline([