Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve existing Python multi-lang SchemaTransform examples #33361

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions examples/multi-language/python/wordcount_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import logging

import apache_beam as beam
from apache_beam.io import ReadFromText
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.transforms.external_transform_provider import ExternalTransformProvider
from apache_beam.typehints.row_type import RowTypeConstraint
Expand Down Expand Up @@ -60,39 +59,30 @@
--expansion_service_port <PORT>
"""

# Original Java transform is in ExtractWordsProvider.java
EXTRACT_IDENTIFIER = "beam:schematransform:org.apache.beam:extract_words:v1"
# Original Java transform is in JavaCountProvider.java
COUNT_IDENTIFIER = "beam:schematransform:org.apache.beam:count:v1"
# Original Java transform is in WriteWordsProvider.java
WRITE_IDENTIFIER = "beam:schematransform:org.apache.beam:write_words:v1"


def run(input_path, output_path, expansion_service_port, pipeline_args):
pipeline_options = PipelineOptions(pipeline_args)

# Discover and get external transforms from this expansion service
provider = ExternalTransformProvider("localhost:" + expansion_service_port)
# Get transforms with identifiers, then use them as you would a regular
# native PTransform
# Retrieve portable transforms
Extract = provider.get_urn(EXTRACT_IDENTIFIER)
Count = provider.get_urn(COUNT_IDENTIFIER)
Write = provider.get_urn(WRITE_IDENTIFIER)

with beam.Pipeline(options=pipeline_options) as p:
lines = p | 'Read' >> ReadFromText(input_path)

words = (lines
| 'Prepare Rows' >> beam.Map(lambda line: beam.Row(line=line))
| 'Extract Words' >> Extract())
word_counts = words | 'Count Words' >> Count()
formatted_words = (
word_counts
| 'Format Text' >> beam.Map(lambda row: beam.Row(line="%s: %s" % (
row.word, row.count))).with_output_types(
RowTypeConstraint.from_fields([('line', str)])))

formatted_words | 'Write' >> Write(file_path_prefix=output_path)
_ = (p
| 'Read' >> beam.io.ReadFromText(input_path)
| 'Prepare Rows' >> beam.Map(lambda line: beam.Row(line=line))
| 'Extract Words' >> Extract(filter=["king", "palace"])
| 'Count Words' >> Count()
| 'Format Text' >> beam.Map(lambda row: beam.Row(line="%s: %s" % (
row.word, row.count))).with_output_types(
RowTypeConstraint.from_fields([('line', str)]))
| 'Write' >> Write(file_path_prefix=output_path))


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import java.util.Arrays;
import java.util.List;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
Expand All @@ -36,7 +39,6 @@
/** Splits a line into separate words and returns each word. */
@AutoService(SchemaTransformProvider.class)
public class ExtractWordsProvider extends TypedSchemaTransformProvider<Configuration> {
public static final Schema OUTPUT_SCHEMA = Schema.builder().addStringField("word").build();

@Override
public String identifier() {
Expand All @@ -45,32 +47,60 @@ public String identifier() {

@Override
protected SchemaTransform from(Configuration configuration) {
return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
return PCollectionRowTuple.of(
"output",
input.get("input").apply(ParDo.of(new ExtractWordsFn())).setRowSchema(OUTPUT_SCHEMA));
}
};
return new ExtractWordsTransform(configuration);
}

static class ExtractWordsFn extends DoFn<Row, Row> {
@ProcessElement
public void processElement(@Element Row element, OutputReceiver<Row> receiver) {
// Split the line into words.
String line = Preconditions.checkStateNotNull(element.getString("line"));
String[] words = line.split("[^\\p{L}]+", -1);
static class ExtractWordsTransform extends SchemaTransform {
private static final Schema OUTPUT_SCHEMA = Schema.builder().addStringField("word").build();
private final List<String> filter;

for (String word : words) {
if (!word.isEmpty()) {
receiver.output(Row.withSchema(OUTPUT_SCHEMA).withFieldValue("word", word).build());
}
}
ExtractWordsTransform(Configuration configuration) {
this.filter = configuration.getFilter();
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
return PCollectionRowTuple.of(
"output",
input
.getSinglePCollection()
.apply(
ParDo.of(
new DoFn<Row, Row>() {
@ProcessElement
public void process(@Element Row element, OutputReceiver<Row> receiver) {
// Split the line into words.
String line = Preconditions.checkStateNotNull(element.getString("line"));
String[] words = line.split("[^\\p{L}]+", -1);
Arrays.stream(words)
.filter(filter::contains)
.forEach(
word ->
receiver.output(
Row.withSchema(OUTPUT_SCHEMA)
.withFieldValue("word", word)
.build()));
}
}))
.setRowSchema(OUTPUT_SCHEMA));
}
}

@DefaultSchema(AutoValueSchema.class)
@AutoValue
protected abstract static class Configuration {}
public abstract static class Configuration {
public static Builder builder() {
return new AutoValue_ExtractWordsProvider_Configuration.Builder();
}

@SchemaFieldDescription("List of words to filter out.")
public abstract List<String> getFilter();

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setFilter(List<String> foo);

public abstract Configuration build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,35 +44,37 @@ public String identifier() {

@Override
protected SchemaTransform from(Configuration configuration) {
return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
Schema outputSchema =
Schema.builder().addStringField("word").addInt64Field("count").build();
return new JavaCountTransform();
}

static class JavaCountTransform extends SchemaTransform {
static final Schema OUTPUT_SCHEMA =
Schema.builder().addStringField("word").addInt64Field("count").build();

PCollection<Row> wordCounts =
input
.get("input")
.apply(Count.perElement())
.apply(
MapElements.into(TypeDescriptors.rows())
.via(
kv ->
Row.withSchema(outputSchema)
.withFieldValue(
"word",
Preconditions.checkStateNotNull(
kv.getKey().getString("word")))
.withFieldValue("count", kv.getValue())
.build()))
.setRowSchema(outputSchema);
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
PCollection<Row> wordCounts =
input
.get("input")
.apply(Count.perElement())
.apply(
MapElements.into(TypeDescriptors.rows())
.via(
kv ->
Row.withSchema(OUTPUT_SCHEMA)
.withFieldValue(
"word",
Preconditions.checkStateNotNull(
kv.getKey().getString("word")))
.withFieldValue("count", kv.getValue())
.build()))
.setRowSchema(OUTPUT_SCHEMA);

return PCollectionRowTuple.of("output", wordCounts);
}
};
return PCollectionRowTuple.of("output", wordCounts);
}
}

@DefaultSchema(AutoValueSchema.class)
@AutoValue
protected abstract static class Configuration {}
public abstract static class Configuration {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,32 @@ public String identifier() {

@Override
protected SchemaTransform from(Configuration configuration) {
return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
input
.get("input")
.apply(
MapElements.into(TypeDescriptors.strings())
.via(row -> Preconditions.checkStateNotNull(row.getString("line"))))
.apply(TextIO.write().to(configuration.getFilePathPrefix()));
return new WriteWordsTransform(configuration);
}

static class WriteWordsTransform extends SchemaTransform {
private final String filePathPrefix;

WriteWordsTransform(Configuration configuration) {
this.filePathPrefix = configuration.getFilePathPrefix();
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
input
.get("input")
.apply(
MapElements.into(TypeDescriptors.strings())
.via(row -> Preconditions.checkStateNotNull(row.getString("line"))))
.apply(TextIO.write().to(filePathPrefix));

return PCollectionRowTuple.empty(input.getPipeline());
}
};
return PCollectionRowTuple.empty(input.getPipeline());
}
}

@DefaultSchema(AutoValueSchema.class)
@AutoValue
protected abstract static class Configuration {
public abstract static class Configuration {
public static Builder builder() {
return new AutoValue_WriteWordsProvider_Configuration.Builder();
}
Expand Down
Loading