Skip to content

Commit

Permalink
#475 done, #476 related
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed May 18, 2023
1 parent 0af94a9 commit 2b606b1
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 79 deletions.
28 changes: 28 additions & 0 deletions arekit/contrib/networks/input/providers/sample.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import collections

from arekit.common.data.input.providers.label.base import LabelProvider
from arekit.common.data.input.providers.rows.samples import BaseSampleRowProvider
from arekit.common.entities.base import Entity
Expand All @@ -16,17 +18,43 @@ def __init__(self,
text_provider,
frames_connotation_provider,
frame_role_label_scaler,
term_embedding_pairs=None,
pos_terms_mapper=None):
""" term_embedding_pairs: dict or None
additional structure, utilized to collect all the embedding pairs during the
rows providing stage.
"""
assert(isinstance(label_provider, LabelProvider))
assert(isinstance(frame_role_label_scaler, SentimentLabelScaler))
assert(isinstance(pos_terms_mapper, PosTermsMapper) or pos_terms_mapper is None)
assert(isinstance(term_embedding_pairs, collections.OrderedDict) or term_embedding_pairs is None)

super(NetworkSampleRowProvider, self).__init__(label_provider=label_provider,
text_provider=text_provider)

self.__frames_connotation_provider = frames_connotation_provider
self.__frame_role_label_scaler = frame_role_label_scaler
self.__pos_terms_mapper = pos_terms_mapper
self.__term_embedding_pairs = term_embedding_pairs

@property
def HasEmbeddingPairs(self):
return self.__term_embedding_pairs is not None

def iter_term_embedding_pairs(self):
""" Provide the contents of the embedded pairs.
"""
return iter(self.__term_embedding_pairs.items())

def clear_embedding_pairs(self):
""" Release the contents of the collected embedding pairs.
"""

# Check whether we actually collect embedding pairs.
if self.__term_embedding_pairs is None:
return

self.__term_embedding_pairs.clear()

def _fill_row_core(self, row, text_opinion_linkage, index_in_linked, etalon_label,
parsed_news, sentence_ind, s_ind, t_ind):
Expand Down
Empty file.
Empty file.
83 changes: 83 additions & 0 deletions arekit/contrib/utils/nn/rows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import collections

from arekit.common.data.input.providers.text.single import BaseSingleTextProvider
from arekit.common.data.input.terms_mapper import OpinionContainingTextTermsMapper
from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.contrib.networks.input.ctx_serialization import NetworkSerializationContext
from arekit.contrib.networks.input.formatters.pos_mapper import PosTermsMapper
from arekit.contrib.networks.input.providers.sample import NetworkSampleRowProvider
from arekit.contrib.networks.input.providers.text import NetworkSingleTextProvider
from arekit.contrib.networks.input.term_types import TermTypes
from arekit.contrib.networks.input.terms_mapping import StringWithEmbeddingNetworkTermMapping
from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper
from arekit.contrib.utils.resources import load_embedding_news_mystem_skipgram_1000_20_2015
from arekit.contrib.utils.vectorizers.bpe import BPEVectorizer
from arekit.contrib.utils.vectorizers.random_norm import RandomNormalVectorizer


def __add_term_embedding(dict_data, term, emb_vector):
if term in dict_data:
return
dict_data[term] = emb_vector


def create_rows_provider(str_entity_fmt, ctx, vectorizers="default"):
""" This method is corresponds to the default initialization of
the rows provider for data sampling pipeline.
vectorizers:
NONE: no need to vectorize, just provide text (using SingleTextProvider).
DEFAULT: we consider an application of stemmer for Russian Language.
DICT: in which for every type there is an assigned Vectorizer
vectorization of term types.
{
TermType.Word: Vectorizer,
TermType.Entity: Vectorizer,
...
}
"""
assert(isinstance(str_entity_fmt, StringEntitiesFormatter))
assert(isinstance(ctx, NetworkSerializationContext))
assert(isinstance(vectorizers, dict) or vectorizers == "default" or vectorizers is None)

term_embedding_pairs = None

if vectorizers is not None:

if vectorizers == "default":
# initialize default vectorizer for Russian language.
embedding = load_embedding_news_mystem_skipgram_1000_20_2015(MystemWrapper())
bpe_vectorizer = BPEVectorizer(embedding=embedding, max_part_size=3)
norm_vectorizer = RandomNormalVectorizer(vector_size=embedding.VectorSize,
token_offset=12345)
vectorizers = {
TermTypes.WORD: bpe_vectorizer,
TermTypes.ENTITY: bpe_vectorizer,
TermTypes.FRAME: bpe_vectorizer,
TermTypes.TOKEN: norm_vectorizer
}

# Setup term-embedding pairs collection instance.
term_embedding_pairs = collections.OrderedDict()

# Use text provider with vectorizers.
text_provider = NetworkSingleTextProvider(
text_terms_mapper=StringWithEmbeddingNetworkTermMapping(
vectorizers=vectorizers,
string_entities_formatter=str_entity_fmt),
pair_handling_func=lambda pair: __add_term_embedding(
dict_data=term_embedding_pairs,
term=pair[0],
emb_vector=pair[1]))
else:
# Create text provider which without vectorizers.
text_provider = BaseSingleTextProvider(
text_terms_mapper=OpinionContainingTextTermsMapper(str_entity_fmt))

return NetworkSampleRowProvider(
label_provider=ctx.LabelProvider,
text_provider=text_provider,
frames_connotation_provider=ctx.FramesConnotationProvider,
frame_role_label_scaler=ctx.FrameRolesLabelScaler,
pos_terms_mapper=PosTermsMapper(ctx.PosTagger) if ctx.PosTagger is not None else None,
term_embedding_pairs=term_embedding_pairs)
Empty file.
65 changes: 9 additions & 56 deletions arekit/contrib/utils/pipelines/items/sampling/networks.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
import collections

from arekit.common.data.input.providers.text.single import BaseSingleTextProvider
from arekit.common.data.input.terms_mapper import OpinionContainingTextTermsMapper
from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.base import BaseDataFolding
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.contrib.networks.input.ctx_serialization import NetworkSerializationContext
from arekit.contrib.networks.input.embedding.matrix import create_term_embedding_matrix
from arekit.contrib.networks.input.embedding.offsets import TermsEmbeddingOffsets
from arekit.contrib.networks.input.formatters.pos_mapper import PosTermsMapper
from arekit.contrib.networks.input.providers.sample import NetworkSampleRowProvider
from arekit.contrib.networks.input.providers.text import NetworkSingleTextProvider
from arekit.contrib.networks.input.terms_mapping import StringWithEmbeddingNetworkTermMapping
from arekit.contrib.networks.embedding import Embedding
from arekit.contrib.networks.input.providers.sample import NetworkSampleRowProvider
from arekit.contrib.utils.io_utils.embedding import NpEmbeddingIO
from arekit.contrib.utils.io_utils.samples import SamplesIO
from arekit.contrib.utils.utils_folding import folding_iter_states
Expand All @@ -24,8 +15,8 @@

class NetworksInputSerializerPipelineItem(BasePipelineItem):

def __init__(self, vectorizers, save_labels_func, str_entity_fmt, ctx,
samples_io, emb_io, balance_func, save_embedding, storage):
def __init__(self, save_labels_func, rows_provider, samples_io,
emb_io, balance_func, storage, save_embedding=True):
""" This pipeline item allows to perform a data preparation for neural network models.
considering a list of the whole data_types with the related pipelines,
Expand All @@ -36,65 +27,27 @@ def __init__(self, vectorizers, save_labels_func, str_entity_fmt, ctx,
balance: bool
declares whethere there is a need to balance Train samples
vectorizers: dict in which for every type there is an assigned Vectorizer
vectorization of term types.
{
TermType.Word: Vectorizer,
TermType.Entity: Vectorizer,
...
}
save_labels_func: function
data_type -> bool
save_embedding: bool
save embedding and all the related information to it.
"""
assert(isinstance(ctx, NetworkSerializationContext))
assert(isinstance(samples_io, SamplesIO))
assert(isinstance(emb_io, NpEmbeddingIO))
assert(isinstance(str_entity_fmt, StringEntitiesFormatter))
assert(isinstance(vectorizers, dict) or vectorizers is None)
assert(isinstance(rows_provider, NetworkSampleRowProvider))
assert(isinstance(save_embedding, bool))
assert(callable(save_labels_func))
assert(callable(balance_func))
super(NetworksInputSerializerPipelineItem, self).__init__()

self.__emb_io = emb_io
self.__samples_io = samples_io
self.__save_embedding = save_embedding and vectorizers is not None
self.__save_embedding = save_embedding
self.__save_labels_func = save_labels_func
self.__balance_func = balance_func
self.__storage = storage

self.__term_embedding_pairs = collections.OrderedDict()

if vectorizers is not None:
text_provider = NetworkSingleTextProvider(
text_terms_mapper=StringWithEmbeddingNetworkTermMapping(
vectorizers=vectorizers,
string_entities_formatter=str_entity_fmt),
pair_handling_func=lambda pair: self.__add_term_embedding(
dict_data=self.__term_embedding_pairs,
term=pair[0],
emb_vector=pair[1]))
else:
# Create text provider which without vectorizers.
text_provider = BaseSingleTextProvider(
text_terms_mapper=OpinionContainingTextTermsMapper(str_entity_fmt))

self.__rows_provider = NetworkSampleRowProvider(
label_provider=ctx.LabelProvider,
text_provider=text_provider,
frames_connotation_provider=ctx.FramesConnotationProvider,
frame_role_label_scaler=ctx.FrameRolesLabelScaler,
pos_terms_mapper=PosTermsMapper(ctx.PosTagger) if ctx.PosTagger is not None else None)

@staticmethod
def __add_term_embedding(dict_data, term, emb_vector):
if term in dict_data:
return
dict_data[term] = emb_vector
self.__rows_provider = rows_provider

def __serialize_iteration(self, data_type, pipeline, rows_provider, data_folding):
assert(isinstance(data_type, DataType))
Expand Down Expand Up @@ -130,19 +83,19 @@ def __handle_iteration(self, data_type_pipelines, data_folding):
assert(isinstance(data_folding, BaseDataFolding))

# Prepare for the present iteration.
self.__term_embedding_pairs.clear()
self.__rows_provider.clear_embedding_pairs()

for data_type, pipeline in data_type_pipelines.items():
self.__serialize_iteration(pipeline=pipeline,
data_type=data_type,
rows_provider=self.__rows_provider,
data_folding=data_folding)

if not self.__save_embedding:
if not (self.__save_embedding and self.__rows_provider.HasEmbeddingPairs):
return

# Save embedding information additionally.
term_embedding = Embedding.from_word_embedding_pairs_iter(iter(self.__term_embedding_pairs.items()))
term_embedding = Embedding.from_word_embedding_pairs_iter(self.__rows_provider.iter_term_embedding_pairs())
embedding_matrix = create_term_embedding_matrix(term_embedding=term_embedding)
vocab = list(TermsEmbeddingOffsets.extract_vocab(words_embedding=term_embedding))

Expand Down
30 changes: 7 additions & 23 deletions tests/tutorials/test_tutorial_pipeline_sampling_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from arekit.common.pipeline.base import BasePipeline
from arekit.common.text.parser import BaseTextParser
from arekit.contrib.networks.input.ctx_serialization import NetworkSerializationContext
from arekit.contrib.networks.input.term_types import TermTypes
from arekit.contrib.source.brat.entities.parser import BratTextEntitiesParser
from arekit.contrib.source.rusentiframes.collection import RuSentiFramesCollection
from arekit.contrib.source.rusentiframes.labels_fmt import RuSentiFramesEffectLabelsFormatter, \
Expand All @@ -24,6 +23,7 @@
from arekit.contrib.utils.entities.formatters.str_simple_uppercase_fmt import SimpleUppercasedEntityFormatter
from arekit.contrib.utils.io_utils.embedding import NpEmbeddingIO
from arekit.contrib.utils.io_utils.samples import SamplesIO
from arekit.contrib.utils.nn.rows import create_rows_provider
from arekit.contrib.utils.pipelines.items.sampling.networks import NetworksInputSerializerPipelineItem
from arekit.contrib.utils.pipelines.items.text.frames_lemmatized import LemmasBasedFrameVariantsParser
from arekit.contrib.utils.pipelines.items.text.tokenizer import DefaultTextTokenizer
Expand All @@ -32,9 +32,6 @@
from arekit.contrib.utils.pipelines.text_opinion.filters.distance_based import DistanceLimitedTextOpinionFilter
from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper
from arekit.contrib.utils.processing.pos.mystem_wrap import POSMystemWrapper
from arekit.contrib.utils.resources import load_embedding_news_mystem_skipgram_1000_20_2015
from arekit.contrib.utils.vectorizers.bpe import BPEVectorizer
from arekit.contrib.utils.vectorizers.random_norm import RandomNormalVectorizer
from tests.tutorials.test_tutorial_pipeline_text_opinion_annotation import FooDocumentOperations


Expand Down Expand Up @@ -70,7 +67,6 @@ class TestSamplingNetwork(unittest.TestCase):
def test(self):

stemmer = MystemWrapper()
embedding = load_embedding_news_mystem_skipgram_1000_20_2015(stemmer)

frames_collection = RuSentiFramesCollection.read(
version=RuSentiFramesVersions.V20,
Expand All @@ -90,29 +86,17 @@ def test(self):
frames_connotation_provider=RuSentiFramesConnotationProvider(frames_collection))

writer = PandasCsvWriter(write_header=True)
samples_io = SamplesIO(self.__output_dir, writer, target_extension=".tsv.gz")

embedding_io = NpEmbeddingIO(target_dir=self.__output_dir)

bpe_vectorizer = BPEVectorizer(embedding=embedding, max_part_size=3)
norm_vectorizer = RandomNormalVectorizer(vector_size=embedding.VectorSize,
token_offset=12345)
vectorizers = {
TermTypes.WORD: bpe_vectorizer,
TermTypes.ENTITY: bpe_vectorizer,
TermTypes.FRAME: bpe_vectorizer,
TermTypes.TOKEN: norm_vectorizer
}
rows_provider = create_rows_provider(
str_entity_fmt=SimpleUppercasedEntityFormatter(),
ctx=ctx)

pipeline_item = NetworksInputSerializerPipelineItem(
vectorizers=vectorizers,
samples_io=samples_io,
emb_io=embedding_io,
ctx=ctx,
str_entity_fmt=SimpleUppercasedEntityFormatter(),
samples_io=SamplesIO(self.__output_dir, writer, target_extension=".tsv.gz"),
emb_io=NpEmbeddingIO(target_dir=self.__output_dir),
rows_provider=rows_provider,
balance_func=lambda data_type: data_type == DataType.Train,
save_labels_func=lambda data_type: data_type != DataType.Test,
save_embedding=True,
storage=PandasBasedRowsStorage())

pipeline = BasePipeline([
Expand Down

0 comments on commit 2b606b1

Please sign in to comment.