Skip to content

Commit

Permalink
Improve existing Python multi-lang SchemaTransform examples (#33361)
Browse files Browse the repository at this point in the history
* improve python multi-lang examples

* minor adjustments
  • Loading branch information
ahmedabu98 authored Dec 19, 2024
1 parent fe6b7aa commit 8fee3ca
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 87 deletions.
52 changes: 25 additions & 27 deletions examples/multi-language/python/wordcount_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import logging

import apache_beam as beam
from apache_beam.io import ReadFromText
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.transforms.external import BeamJarExpansionService
from apache_beam.transforms.external_transform_provider import ExternalTransformProvider
from apache_beam.typehints.row_type import RowTypeConstraint
"""A Python multi-language pipeline that counts words using multiple Java SchemaTransforms.
Expand Down Expand Up @@ -60,39 +60,35 @@
--expansion_service_port <PORT>
"""

# Original Java transform is in ExtractWordsProvider.java
EXTRACT_IDENTIFIER = "beam:schematransform:org.apache.beam:extract_words:v1"
# Original Java transform is in JavaCountProvider.java
COUNT_IDENTIFIER = "beam:schematransform:org.apache.beam:count:v1"
# Original Java transform is in WriteWordsProvider.java
WRITE_IDENTIFIER = "beam:schematransform:org.apache.beam:write_words:v1"


def run(input_path, output_path, expansion_service_port, pipeline_args):
pipeline_options = PipelineOptions(pipeline_args)

# Discover and get external transforms from this expansion service
provider = ExternalTransformProvider("localhost:" + expansion_service_port)
# Get transforms with identifiers, then use them as you would a regular
# native PTransform
Extract = provider.get_urn(EXTRACT_IDENTIFIER)
Count = provider.get_urn(COUNT_IDENTIFIER)
Write = provider.get_urn(WRITE_IDENTIFIER)

with beam.Pipeline(options=pipeline_options) as p:
lines = p | 'Read' >> ReadFromText(input_path)

words = (lines
| 'Prepare Rows' >> beam.Map(lambda line: beam.Row(line=line))
| 'Extract Words' >> Extract())
word_counts = words | 'Count Words' >> Count()
formatted_words = (
word_counts
| 'Format Text' >> beam.Map(lambda row: beam.Row(line="%s: %s" % (
row.word, row.count))).with_output_types(
RowTypeConstraint.from_fields([('line', str)])))

formatted_words | 'Write' >> Write(file_path_prefix=output_path)
expansion_service = BeamJarExpansionService(
"examples:multi-language:shadowJar")
if expansion_service_port:
expansion_service = "localhost:" + expansion_service_port

provider = ExternalTransformProvider(expansion_service)
# Retrieve portable transforms
Extract = provider.get_urn(EXTRACT_IDENTIFIER)
Count = provider.get_urn(COUNT_IDENTIFIER)
Write = provider.get_urn(WRITE_IDENTIFIER)

_ = (p
| 'Read' >> beam.io.ReadFromText(input_path)
| 'Prepare Rows' >> beam.Map(lambda line: beam.Row(line=line))
| 'Extract Words' >> Extract(drop=["king", "palace"])
| 'Count Words' >> Count()
| 'Format Text' >> beam.Map(lambda row: beam.Row(line="%s: %s" % (
row.word, row.count))).with_output_types(
RowTypeConstraint.from_fields([('line', str)]))
| 'Write' >> Write(file_path_prefix=output_path))


if __name__ == '__main__':
Expand All @@ -110,8 +106,10 @@ def run(input_path, output_path, expansion_service_port, pipeline_args):
help='Output file')
parser.add_argument('--expansion_service_port',
dest='expansion_service_port',
required=True,
help='Expansion service port')
required=False,
help='Expansion service port. If left empty, the '
'existing multi-language examples service will '
'be used by default.')
known_args, pipeline_args = parser.parse_known_args()

run(known_args.input, known_args.output, known_args.expansion_service_port,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import java.util.Arrays;
import java.util.List;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
Expand All @@ -36,7 +39,6 @@
/** Splits a line into separate words and returns each word. */
@AutoService(SchemaTransformProvider.class)
public class ExtractWordsProvider extends TypedSchemaTransformProvider<Configuration> {
public static final Schema OUTPUT_SCHEMA = Schema.builder().addStringField("word").build();

@Override
public String identifier() {
Expand All @@ -45,32 +47,60 @@ public String identifier() {

@Override
protected SchemaTransform from(Configuration configuration) {
return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
return PCollectionRowTuple.of(
"output",
input.get("input").apply(ParDo.of(new ExtractWordsFn())).setRowSchema(OUTPUT_SCHEMA));
}
};
return new ExtractWordsTransform(configuration);
}

static class ExtractWordsFn extends DoFn<Row, Row> {
@ProcessElement
public void processElement(@Element Row element, OutputReceiver<Row> receiver) {
// Split the line into words.
String line = Preconditions.checkStateNotNull(element.getString("line"));
String[] words = line.split("[^\\p{L}]+", -1);
static class ExtractWordsTransform extends SchemaTransform {
private static final Schema OUTPUT_SCHEMA = Schema.builder().addStringField("word").build();
private final List<String> drop;

for (String word : words) {
if (!word.isEmpty()) {
receiver.output(Row.withSchema(OUTPUT_SCHEMA).withFieldValue("word", word).build());
}
}
ExtractWordsTransform(Configuration configuration) {
this.drop = configuration.getDrop();
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
return PCollectionRowTuple.of(
"output",
input
.getSinglePCollection()
.apply(
ParDo.of(
new DoFn<Row, Row>() {
@ProcessElement
public void process(@Element Row element, OutputReceiver<Row> receiver) {
// Split the line into words.
String line = Preconditions.checkStateNotNull(element.getString("line"));
String[] words = line.split("[^\\p{L}]+", -1);
Arrays.stream(words)
.filter(w -> !drop.contains(w))
.forEach(
word ->
receiver.output(
Row.withSchema(OUTPUT_SCHEMA)
.withFieldValue("word", word)
.build()));
}
}))
.setRowSchema(OUTPUT_SCHEMA));
}
}

@DefaultSchema(AutoValueSchema.class)
@AutoValue
protected abstract static class Configuration {}
public abstract static class Configuration {
public static Builder builder() {
return new AutoValue_ExtractWordsProvider_Configuration.Builder();
}

@SchemaFieldDescription("List of words to drop.")
public abstract List<String> getDrop();

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setDrop(List<String> foo);

public abstract Configuration build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,35 +44,37 @@ public String identifier() {

@Override
protected SchemaTransform from(Configuration configuration) {
return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
Schema outputSchema =
Schema.builder().addStringField("word").addInt64Field("count").build();
return new JavaCountTransform();
}

static class JavaCountTransform extends SchemaTransform {
static final Schema OUTPUT_SCHEMA =
Schema.builder().addStringField("word").addInt64Field("count").build();

PCollection<Row> wordCounts =
input
.get("input")
.apply(Count.perElement())
.apply(
MapElements.into(TypeDescriptors.rows())
.via(
kv ->
Row.withSchema(outputSchema)
.withFieldValue(
"word",
Preconditions.checkStateNotNull(
kv.getKey().getString("word")))
.withFieldValue("count", kv.getValue())
.build()))
.setRowSchema(outputSchema);
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
PCollection<Row> wordCounts =
input
.get("input")
.apply(Count.perElement())
.apply(
MapElements.into(TypeDescriptors.rows())
.via(
kv ->
Row.withSchema(OUTPUT_SCHEMA)
.withFieldValue(
"word",
Preconditions.checkStateNotNull(
kv.getKey().getString("word")))
.withFieldValue("count", kv.getValue())
.build()))
.setRowSchema(OUTPUT_SCHEMA);

return PCollectionRowTuple.of("output", wordCounts);
}
};
return PCollectionRowTuple.of("output", wordCounts);
}
}

@DefaultSchema(AutoValueSchema.class)
@AutoValue
protected abstract static class Configuration {}
public abstract static class Configuration {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,32 @@ public String identifier() {

@Override
protected SchemaTransform from(Configuration configuration) {
return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
input
.get("input")
.apply(
MapElements.into(TypeDescriptors.strings())
.via(row -> Preconditions.checkStateNotNull(row.getString("line"))))
.apply(TextIO.write().to(configuration.getFilePathPrefix()));
return new WriteWordsTransform(configuration);
}

static class WriteWordsTransform extends SchemaTransform {
private final String filePathPrefix;

WriteWordsTransform(Configuration configuration) {
this.filePathPrefix = configuration.getFilePathPrefix();
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
input
.get("input")
.apply(
MapElements.into(TypeDescriptors.strings())
.via(row -> Preconditions.checkStateNotNull(row.getString("line"))))
.apply(TextIO.write().to(filePathPrefix));

return PCollectionRowTuple.empty(input.getPipeline());
}
};
return PCollectionRowTuple.empty(input.getPipeline());
}
}

@DefaultSchema(AutoValueSchema.class)
@AutoValue
protected abstract static class Configuration {
public abstract static class Configuration {
public static Builder builder() {
return new AutoValue_WriteWordsProvider_Configuration.Builder();
}
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/transforms/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ def dict_to_row(schema_proto, py_value):
extra = set(py_value.keys()) - set(row_type._fields)
if extra:
raise ValueError(
f"Unknown fields: {extra}. Valid fields: {row_type._fields}")
f"Transform '{self.identifier()}' was configured with unknown "
f"fields: {extra}. Valid fields: {set(row_type._fields)}")
return row_type(
*[
dict_to_row_recursive(
Expand Down

0 comments on commit 8fee3ca

Please sign in to comment.