Skip to content
This repository has been archived by the owner on Jan 2, 2025. It is now read-only.

Commit

Permalink
Fix BatchesFromExecutions (#111)
Browse files Browse the repository at this point in the history
* Fix BatchesFromExecutions: size estimation overflow + filter by destination type in process func
* Explanation about estimate_size
  • Loading branch information
diogoaihara authored Oct 20, 2022
1 parent ae533d1 commit 105e6ee
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 8 deletions.
34 changes: 26 additions & 8 deletions megalista_dataflow/sources/batches_from_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,9 @@

_LOGGER_NAME = 'megalista.BatchesFromExecutions'


def _convert_row_to_dict(row):
dict = {}
for key, value in row.items():
dict[key] = value
return dict
# max int size.
# used for avoiding overflow when casting from str to int (underlying C code)
_INT_MAX = 2147483647

class ExecutionCoder(coders.Coder):
"""A custom coder for the Execution class."""
Expand Down Expand Up @@ -78,6 +75,23 @@ def decode(self, s):
def is_deterministic(self):
return True

def estimate_size(self, o):
"""Estimation of P-Collection size (in bytes).
- Called from Dataflow / Apache Beam
- Estimated size had to be truncated into _INT_MAX for
avoiding overflow when casting from str to int
(in C underlying code)."""
amount_of_rows = len(o.rows)
row_size = 0
if amount_of_rows > 0:
row_size = len(json.dumps(o.rows[0]).encode('utf-8'))
estimate = amount_of_rows * row_size
# there is an overflow error if estimated size > _INT_MAX
if estimate > _INT_MAX:
estimate = _INT_MAX
return estimate



class BatchesFromExecutions(beam.PTransform):
"""
Expand Down Expand Up @@ -117,9 +131,13 @@ def process(self, grouped_elements):
yield Batch(execution, batch, iteration)

class _BreakIntoExecutions(beam.DoFn):
def __init__(self, destination_type: DestinationType):
self._destination_type = destination_type

def process(self, el):
for item in el:
yield item
if item[0].destination.destination_type == self._destination_type:
yield item

def __init__(
self,
Expand Down Expand Up @@ -149,6 +167,6 @@ def expand(self, executions):
)
| beam.ParDo(self._ReadDataSource(self._transactional_type, self._dataflow_options, self._error_handler))
| beam.Map(lambda el: [(execution, el.rows) for execution in iter(el.executions.executions)])
| beam.ParDo(self._BreakIntoExecutions())
| beam.ParDo(self._BreakIntoExecutions(self._destination_type))
| beam.ParDo(self._BatchElements(self._batch_size, self._error_handler))
)
75 changes: 75 additions & 0 deletions megalista_dataflow/sources/batches_from_executions_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sources.batches_from_executions import BatchesFromExecutions, DataRowsGroupedBySourceCoder, _INT_MAX
from models.execution import AccountConfig, DataRow, DataRowsGroupedBySource, SourceType, DestinationType, TransactionalType, Execution, Source, Destination, ExecutionsGroupedBySource

from typing import List
import pytest

def _get_execution() -> Execution:
return Execution(
AccountConfig('', False, '', '', ''),
Source(
'source_name',
SourceType.BIG_QUERY,
[]
),
Destination(
'destination_name',
DestinationType.ADS_CUSTOMER_MATCH_CONTACT_INFO_UPLOAD,
[]
)
)

@pytest.fixture
def execution() -> Execution:
return _get_execution()

@pytest.fixture
def executions_grouped_by_source() -> ExecutionsGroupedBySource:
return ExecutionsGroupedBySource(
'source_name',
[_get_execution()]
)

@pytest.fixture
def data_rows_grouped_by_source_coder() -> DataRowsGroupedBySourceCoder:
return DataRowsGroupedBySourceCoder()

def test_data_rows_grouped_by_source_estimate_size_zero(mocker, data_rows_grouped_by_source_coder: DataRowsGroupedBySourceCoder, executions_grouped_by_source: ExecutionsGroupedBySource):
data_rows: List[DataRow] = []
o = DataRowsGroupedBySource(executions_grouped_by_source, data_rows)
assert data_rows_grouped_by_source_coder.estimate_size(o) == 0

def test_data_rows_grouped_by_source_estimate_size_overflow(mocker, data_rows_grouped_by_source_coder: DataRowsGroupedBySourceCoder, executions_grouped_by_source: ExecutionsGroupedBySource):
item: DataRow = DataRow({
'phone': '5ecdb1fcdba73c56fc682fceb87166537e7d3990cbefcadb31ee23fe0add6322'
})
data_rows: List[DataRow] = [item for _ in range(100000000)]

o = DataRowsGroupedBySource(executions_grouped_by_source, data_rows)
assert data_rows_grouped_by_source_coder.estimate_size(o) == _INT_MAX

def test_batch_elements(mocker, execution):
item: DataRow = DataRow({
'phone': '5ecdb1fcdba73c56fc682fceb87166537e7d3990cbefcadb31ee23fe0add6322'
})
data_rows: List[DataRow] = [item for _ in range(11)]
batch_elements = BatchesFromExecutions._BatchElements(2, None)
grouped_elements = (execution, data_rows)
amount_of_batches = 0
for _ in batch_elements.process(grouped_elements):
amount_of_batches = amount_of_batches + 1
assert amount_of_batches == 6

0 comments on commit 105e6ee

Please sign in to comment.