Skip to content

Commit

Permalink
#399 done! ( #189 and #390 related )
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Nov 3, 2022
1 parent 23089ca commit 5d0f2b3
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 53 deletions.
32 changes: 5 additions & 27 deletions arekit/contrib/utils/cv/doc_stat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,13 @@ def _calc(self, news):

# region public methods

# TODO. depends on io, issue #189
def calculate_and_write_doc_stat(self, filepath, doc_ids_iter):
with open(filepath, 'w') as f:
for doc_id in doc_ids_iter:
doc = self.__doc_reader_func(doc_id)
s_count = self._calc(doc)
f.write("{}: {}\n".format(doc_id, s_count))

@staticmethod
def read_docs_stat(filepath, doc_ids_set):
"""
doc_ids_set: set
set of documents expected to be extracted
return:
list of the following pairs: (doc_id, sentences_count)
"""
assert(isinstance(doc_ids_set, set))

def calculate(self, doc_ids_iter):
docs_info = []
with open(filepath, 'r') as f:
for line in f.readlines():
args = [int(i) for i in line.split(':')]
doc_id, s_count = args

if doc_id not in doc_ids_set:
continue

docs_info.append((doc_id, s_count))
for doc_id in doc_ids_iter:
doc = self.__doc_reader_func(doc_id)
s_count = self._calc(doc)
docs_info.append((doc_id, s_count))

return docs_info

Expand Down
23 changes: 5 additions & 18 deletions arekit/contrib/utils/cv/splitters/statistical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from os import path
import numpy as np
from arekit.contrib.utils.cv.doc_stat.base import BaseDocumentStatGenerator
from arekit.contrib.utils.cv.splitters.base import CrossValidationSplitter
Expand All @@ -8,12 +7,10 @@ class StatBasedCrossValidationSplitter(CrossValidationSplitter):
""" Sentence-based splitter.
"""

def __init__(self, docs_stat, docs_stat_filepath_func):
def __init__(self, docs_stat, doc_ids):
assert(isinstance(docs_stat, BaseDocumentStatGenerator))
assert(callable(docs_stat_filepath_func))
super(StatBasedCrossValidationSplitter, self).__init__()
self.__docs_stat = docs_stat
self.__docs_stat_filepath_func = docs_stat_filepath_func
self.__docs_info = docs_stat.calculate(doc_ids_iter=doc_ids)

# region private methods

Expand All @@ -39,22 +36,12 @@ def __calc_cv_group_delta(cv_group_size, item, g_index_to_add):
# endregion

def items_to_cv_pairs(self, doc_ids, cv_count):
"""
Separation with the specific separation, in terms of cv-classes size difference.
""" Separation with the specific separation, in terms of cv-classes size difference.
"""
assert(isinstance(doc_ids, set))
assert(isinstance(cv_count, int))

filepath = self.__docs_stat_filepath_func()

if not path.exists(filepath):
self.__docs_stat.calculate_and_write_doc_stat(filepath=filepath,
doc_ids_iter=doc_ids)

docs_info = self.__docs_stat.read_docs_stat(filepath=filepath,
doc_ids_set=doc_ids)

sorted_stat = reversed(sorted(docs_info, key=lambda pair: pair[1]))
sorted_stat = reversed(sorted(self.__docs_info, key=lambda pair: pair[1]))
cv_group_docs = [[] for _ in range(cv_count)]
cv_group_sizes = [[] for _ in range(cv_count)]

Expand All @@ -65,6 +52,6 @@ def items_to_cv_pairs(self, doc_ids, cv_count):

for g_index in range(len(cv_group_docs)):
small = cv_group_docs[g_index]
large = [doc_id for doc_id, _ in docs_info if doc_id not in small]
large = [doc_id for doc_id, _ in self.__docs_info if doc_id not in small]

yield large, small
2 changes: 0 additions & 2 deletions tests/tutorials/data/stat.txt

This file was deleted.

11 changes: 5 additions & 6 deletions tests/tutorials/test_tutorial_data_foldings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import unittest
from os.path import dirname, join

from arekit.common.experiment.data_type import DataType
from arekit.common.folding.base import BaseDataFolding
Expand Down Expand Up @@ -43,15 +42,15 @@ def test(self):
splitter_simple = SimpleCrossValidationSplitter(shuffle=True, seed=1)

doc_ops = FooDocumentOperations()
doc_ids = list(range(2))

splitter_statistical = StatBasedCrossValidationSplitter(
docs_stat=SentenceBasedDocumentStatGenerator(
lambda doc_id: doc_ops.get_doc(doc_id)),
docs_stat_filepath_func=lambda: join(dirname(__file__), "data/stat.txt")
)
docs_stat=SentenceBasedDocumentStatGenerator(lambda doc_id: doc_ops.get_doc(doc_id)),
doc_ids=doc_ids)

cv_folding = TwoClassCVFolding(
supported_data_types=[DataType.Train, DataType.Test],
doc_ids_to_fold=list(range(10)),
doc_ids_to_fold=doc_ids,
cv_count=2,
splitter=splitter_statistical)

Expand Down

0 comments on commit 5d0f2b3

Please sign in to comment.