Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support sharding in WriteToFiles (tested for to_csv)
Browse files Browse the repository at this point in the history
...
langner committed Jan 15, 2025
1 parent 98122f3 commit fae6ea8
Showing 3 changed files with 93 additions and 33 deletions.
21 changes: 18 additions & 3 deletions sdks/python/apache_beam/dataframe/io_test.py
Original file line number Diff line number Diff line change
@@ -126,16 +126,31 @@ def test_wide_csv_with_dtypes(self):
pcoll = p | beam.io.ReadFromCsv(f'{input}tmp.csv', dtype=str)
assert_that(pcoll | beam.Map(max), equal_to(['99']))

def test_sharding_parameters(self):
def test_sharding_parameters_for_single_file_naming(self):
data = pd.DataFrame({'label': ['11a', '37a', '389a'], 'rank': [0, 1, 2]})
output = self.temp_dir()
with beam.Pipeline() as p:
df = convert.to_dataframe(p | beam.Create([data]), proxy=data[:0])
df = convert.to_dataframe(
p | beam.Create([data, data, data, data]), proxy=data[:0])
df.to_csv(
output,
num_shards=1,
file_naming=fileio.single_file_naming('out.csv'))
self.assertEqual(glob.glob(output + '*'), [output + 'out.csv'])
self.assertCountEqual(glob.glob(output + '*'), [output + 'out.csv'])

def test_sharding_parameters_multiple_shards(self):
data = pd.DataFrame({'label': ['11a', '37a', '389a'], 'rank': [0, 1, 2]})
output = self.temp_dir()
with beam.Pipeline() as p:
df = convert.to_dataframe(
p | beam.Create([data, data, data]), proxy=data[:0])
df.to_csv(
output,
num_shards=3,
file_naming=fileio.default_file_naming('out', suffix='.csv'))
self.assertCountEqual(
glob.glob(output + '*'),
[f'{output}out-0000{i}-of-00003.csv' for i in range(3)])

@pytest.mark.uses_pyarrow
@unittest.skipIf(
77 changes: 47 additions & 30 deletions sdks/python/apache_beam/io/fileio.py
Original file line number Diff line number Diff line change
@@ -522,7 +522,7 @@ class WriteToFiles(beam.PTransform):
# Too many files will add memory pressure to the worker, so we let it be 20.
MAX_NUM_WRITERS_PER_BUNDLE = 20

DEFAULT_SHARDING = 5
DEFAULT_SHARDING = 1

def __init__(
self,
@@ -567,6 +567,7 @@ class signature or an instance of FileSink to this parameter. If none is
self.sink_fn = self._get_sink_fn(sink)
self.shards = shards or WriteToFiles.DEFAULT_SHARDING
self.output_fn = output_fn or (lambda x: x)
self.only_sharding = self.shards > 1 and destination is None

self._max_num_writers_per_bundle = max_writers_per_bundle

@@ -603,35 +604,51 @@ def expand(self, pcoll):
str, filesystems.FileSystems.join(temp_location, '.temp%s' % dir_uid))
_LOGGER.info('Added temporary directory %s', self._temp_directory.get())

output = (
pcoll
| beam.ParDo(
_WriteUnshardedRecordsFn(
base_path=self._temp_directory,
destination_fn=self.destination_fn,
sink_fn=self.sink_fn,
max_writers_per_bundle=self._max_num_writers_per_bundle)).
with_outputs(
_WriteUnshardedRecordsFn.SPILLED_RECORDS,
_WriteUnshardedRecordsFn.WRITTEN_FILES))

written_files_pc = output[_WriteUnshardedRecordsFn.WRITTEN_FILES]
spilled_records_pc = output[_WriteUnshardedRecordsFn.SPILLED_RECORDS]

more_written_files_pc = (
spilled_records_pc
| beam.ParDo(
_AppendShardedDestination(self.destination_fn, self.shards))
| "GroupRecordsByDestinationAndShard" >> beam.GroupByKey()
| beam.ParDo(
_WriteShardedRecordsFn(
self._temp_directory, self.sink_fn, self.shards)))

files_by_destination_pc = (
(written_files_pc, more_written_files_pc)
| beam.Flatten()
| beam.Map(lambda file_result: (file_result.destination, file_result))
| "GroupTempFilesByDestination" >> beam.GroupByKey())
if self.only_sharding:
written_files_pc = (
pcoll
| beam.ParDo(
_AppendShardedDestination(self.destination_fn, self.shards))
| "GroupRecordsByDestinationAndShard" >> beam.GroupByKey()
| beam.ParDo(
_WriteShardedRecordsFn(
self._temp_directory, self.sink_fn, self.shards)))

files_by_destination_pc = (
written_files_pc
|
beam.Map(lambda file_result: (file_result.destination, file_result))
| "GroupTempFilesByDestination" >> beam.GroupByKey())
else:
output = (
pcoll
| beam.ParDo(
_WriteUnshardedRecordsFn(
base_path=self._temp_directory,
destination_fn=self.destination_fn,
sink_fn=self.sink_fn,
max_writers_per_bundle=self._max_num_writers_per_bundle)).
with_outputs(
_WriteUnshardedRecordsFn.SPILLED_RECORDS,
_WriteUnshardedRecordsFn.WRITTEN_FILES))

written_files_pc = output[_WriteUnshardedRecordsFn.WRITTEN_FILES]
spilled_records_pc = output[_WriteUnshardedRecordsFn.SPILLED_RECORDS]
more_written_files_pc = (
spilled_records_pc
| beam.ParDo(
_AppendShardedDestination(self.destination_fn, self.shards))
| "GroupRecordsByDestinationAndShard" >> beam.GroupByKey()
| beam.ParDo(
_WriteShardedRecordsFn(
self._temp_directory, self.sink_fn, self.shards)))

files_by_destination_pc = (
(written_files_pc, more_written_files_pc)
| beam.Flatten()
|
beam.Map(lambda file_result: (file_result.destination, file_result))
| "GroupTempFilesByDestination" >> beam.GroupByKey())

# Now we should take the temporary files, and write them to the final
# destination, with their proper names.
28 changes: 28 additions & 0 deletions sdks/python/apache_beam/io/fileio_test.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
# pytype: skip-file

import csv
import glob
import io
import json
import logging
@@ -484,6 +485,31 @@ def test_write_to_single_file_batch(self):
| "Serialize" >> beam.Map(json.dumps)
| beam.io.fileio.WriteToFiles(path=dir))

assert len(glob.glob(os.path.join(dir, '*'))) == 1

with TestPipeline() as p:
result = (
p
| fileio.MatchFiles(FileSystems.join(dir, '*'))
| fileio.ReadMatches()
| beam.FlatMap(lambda f: f.read_utf8().strip().split('\n'))
| beam.Map(json.loads))

assert_that(result, equal_to([row for row in self.SIMPLE_COLLECTION]))

def test_write_to_multiple_shards(self):

dir = self._new_tempdir()

with TestPipeline() as p:
_ = (
p
| beam.Create(WriteFilesTest.SIMPLE_COLLECTION)
| "Serialize" >> beam.Map(json.dumps)
| beam.io.fileio.WriteToFiles(path=dir, shards=3))

assert len(glob.glob(os.path.join(dir, '*'))) == 3

with TestPipeline() as p:
result = (
p
@@ -515,6 +541,8 @@ def test_write_to_dynamic_destination(self):
sink=sink,
file_naming=fileio.destination_prefix_naming("test")))

assert len(glob.glob(os.path.join(dir, '*'))) == 2

with TestPipeline() as p:
result = (
p

0 comments on commit fae6ea8

Please sign in to comment.