Skip to content

Commit

Permalink
Memoize some dataframes analysis operations. (#31377)
Browse files Browse the repository at this point in the history
ReadFromCsv with an explicit dtype produced graphs that had quadratic
traversal (though the computed results, sets, were always correct).

This fixes #31152 and should help
other deep expressions with common references as well.
  • Loading branch information
robertwb authored May 28, 2024
1 parent fd4368f commit e488f41
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
6 changes: 4 additions & 2 deletions sdks/python/apache_beam/dataframe/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,10 @@ def __init__(
self._preserves_partition_by = preserves_partition_by

def placeholders(self):
return frozenset.union(
frozenset(), *[arg.placeholders() for arg in self.args()])
if not hasattr(self, '_placeholders'):
self._placeholders = frozenset.union(
frozenset(), *[arg.placeholders() for arg in self.args()])
return self._placeholders

def args(self):
return self._args
Expand Down
9 changes: 9 additions & 0 deletions sdks/python/apache_beam/dataframe/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ def test_read_write_csv(self):
self.assertCountEqual(['a,b,c', '1,2,3', '3,4,7'],
set(self.read_all_lines(output + 'out.csv*')))

def test_wide_csv_with_dtypes(self):
# Verify https://github.com/apache/beam/issues/31152 is resolved.
cols = ','.join(f'col{ix}' for ix in range(123))
data = ','.join(str(ix) for ix in range(123))
input = self.temp_dir({'tmp.csv': f'{cols}\n{data}'})
with beam.Pipeline() as p:
pcoll = p | beam.io.ReadFromCsv(f'{input}tmp.csv', dtype=str)
assert_that(pcoll | beam.Map(max), equal_to(['99']))

def test_sharding_parameters(self):
data = pd.DataFrame({'label': ['11a', '37a', '389a'], 'rank': [0, 1, 2]})
output = self.temp_dir()
Expand Down
1 change: 1 addition & 0 deletions sdks/python/apache_beam/dataframe/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def __repr__(self, indent=0):
self.outputs))

# First define some helper functions.
@_memoize
def output_partitioning_in_stage(expr, stage):
"""Return the output partitioning of expr when computed in stage,
or returns None if the expression cannot be computed in this stage.
Expand Down

0 comments on commit e488f41

Please sign in to comment.