diff --git a/esrally/track/params.py b/esrally/track/params.py index 9c131dab8..c15ce0130 100644 --- a/esrally/track/params.py +++ b/esrally/track/params.py @@ -554,9 +554,11 @@ def __init__(self, corpora, partition_index, total_partitions, batch_size, bulk_ self.ingest_percentage = ingest_percentage self.id_conflicts = id_conflicts self.pipeline = pipeline + # this is only intended for unit-testing + create_reader = original_params.pop("__create_reader", create_default_reader) self.internal_params = bulk_data_based(total_partitions, partition_index, corpora, batch_size, bulk_size, id_conflicts, conflict_probability, on_conflict, recency, - pipeline, original_params) + pipeline, original_params, create_reader) self.current_bulk = 0 all_bulks = number_of_bulks(self.corpora, self.partition_index, self.total_partitions, self.bulk_size) self.total_bulks = math.ceil((all_bulks * self.ingest_percentage) / 100) @@ -566,6 +568,10 @@ def partition(self, partition_index, total_partitions): raise exceptions.RallyError("Cannot partition a PartitionBulkIndexParamSource further") def params(self): + # self.internal_params always reads all files. This is necessary to ensure we terminate early in case + # the user has specified ingest percentage. + if self.current_bulk == self.total_bulks: + raise StopIteration self.current_bulk += 1 return next(self.internal_params) diff --git a/tests/track/params_test.py b/tests/track/params_test.py index ac0a0a133..d7a686617 100644 --- a/tests/track/params_test.py +++ b/tests/track/params_test.py @@ -23,6 +23,28 @@ from esrally.utils import io +class StaticBulkReader: + def __init__(self, index_name, type_name, bulks): + self.index_name = index_name + self.type_name = type_name + self.bulks = iter(bulks) + + def __enter__(self): + return self + + def __iter__(self): + return self + + def __next__(self): + batch = [] + bulk = next(self.bulks) + batch.append((len(bulk), bulk)) + return self.index_name, self.type_name, batch + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + class SliceTests(TestCase): def test_slice_with_source_larger_than_slice(self): source = params.Slice(io.StringAsFileSource, 2, 5) @@ -865,6 +887,27 @@ def test_ingests_all_documents_by_default(self): self.assertEqual(100, partition.total_bulks) def test_restricts_number_of_bulks_if_required(self): + def create_unit_test_reader(*args): + return StaticBulkReader("idx", "doc", bulks=[ + ['{"location" : [-0.1485188, 51.5250666]}'], + ['{"location" : [-0.1479949, 51.5252071]}'], + ['{"location" : [-0.1458559, 51.5289059]}'], + ['{"location" : [-0.1498551, 51.5282564]}'], + ['{"location" : [-0.1487043, 51.5254843]}'], + ['{"location" : [-0.1533367, 51.5261779]}'], + ['{"location" : [-0.1543018, 51.5262398]}'], + ['{"location" : [-0.1522118, 51.5266564]}'], + ['{"location" : [-0.1529092, 51.5263360]}'], + ['{"location" : [-0.1537008, 51.5265365]}'], + ]) + + def schedule(param_source): + while True: + try: + yield param_source.params() + except StopIteration: + return + corpora = [ track.DocumentCorpus(name="default", documents=[ track.Documents(source_format=track.Documents.SOURCE_FORMAT_BULK, @@ -886,12 +929,14 @@ def test_restricts_number_of_bulks_if_required(self): track=track.Track(name="unit-test", corpora=corpora), params={ "bulk-size": 10000, - "ingest-percentage": 2.5 + "ingest-percentage": 2.5, + "__create_reader": create_unit_test_reader }) partition = source.partition(0, 1) # should issue three bulks of size 10.000 self.assertEqual(3, partition.total_bulks) + self.assertEqual(3, len(list(schedule(partition)))) def test_create_with_conflict_probability_zero(self): params.BulkIndexParamSource(track=track.Track(name="unit-test"), params={ @@ -932,31 +977,11 @@ def test_create_with_conflict_probability_not_numeric(self): class BulkDataGeneratorTests(TestCase): - class TestBulkReader: - def __init__(self, index_name, type_name, bulks): - self.index_name = index_name - self.type_name = type_name - self.bulks = iter(bulks) - - def __enter__(self): - return self - - def __iter__(self): - return self - - def __next__(self): - batch = [] - bulk = next(self.bulks) - batch.append((len(bulk), bulk)) - return self.index_name, self.type_name, batch - - def __exit__(self, exc_type, exc_val, exc_tb): - return False @classmethod def create_test_reader(cls, batches): def inner_create_test_reader(docs, *args): - return BulkDataGeneratorTests.TestBulkReader(docs.target_index, docs.target_type, batches) + return StaticBulkReader(docs.target_index, docs.target_type, batches) return inner_create_test_reader