Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-14430] Adding a logical type support for Python callables to Row schema #17608

Merged
merged 7 commits into from
May 13, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ option go_package = "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v
option java_package = "org.apache.beam.model.pipeline.v1";
option java_outer_classname = "SchemaApi";

import "org/apache/beam/model/pipeline/v1/beam_runner_api.proto";
ihji marked this conversation as resolved.
Show resolved Hide resolved

message Schema {
// List of fields for this schema. Two fields may not share a name.
repeated Field fields = 1;
Expand Down Expand Up @@ -110,6 +112,27 @@ message LogicalType {
FieldValue argument = 5;
}

// Universally defined Logical types for Row schemas.
// These logical types are supposed to be understood by all SDKs.
message LogicalTypes {
enum Enum {
// A URN for Python Callable logical type
// - Representation type: STRING
// - Language type: In Python SDK, PythonCallableWithSource.
// In any other SDKs, a wrapper object for a string which
// can be evaluated to a Python Callable object.
PYTHON_CALLABLE = 0 [(org.apache.beam.model.pipeline.v1.beam_urn) =
"beam:logical_type:python_callable:v1"];

// A URN for MicrosInstant type
// - Representation type: ROW<seconds: INT64, micros: INT64>
// - A timestamp without a timezone where seconds + micros represents the
// amount of time since the epoch.
ihji marked this conversation as resolved.
Show resolved Hide resolved
MICROS_INSTANT = 1 [(org.apache.beam.model.pipeline.v1.beam_urn) =
"beam:logical_type:micros_instant:v1"];
}
}

message Option {
// REQUIRED. Identifier for the option.
string name = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.beam.sdk.schemas.Schema.LogicalType;
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant;
import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable;
import org.apache.beam.sdk.schemas.logicaltypes.SchemaLogicalType;
import org.apache.beam.sdk.schemas.logicaltypes.UnknownLogicalType;
import org.apache.beam.sdk.util.SerializableUtils;
Expand Down Expand Up @@ -74,6 +75,7 @@ public class SchemaTranslation {
ImmutableMap.<String, Class<? extends LogicalType<?, ?>>>builder()
.put(MicrosInstant.IDENTIFIER, MicrosInstant.class)
.put(SchemaLogicalType.IDENTIFIER, SchemaLogicalType.class)
.put(PythonCallable.IDENTIFIER, PythonCallable.class)
.build();

public static SchemaApi.Schema schemaToProto(Schema schema, boolean serializeLogicalType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.beam.sdk.schemas.logicaltypes;

import java.time.Instant;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.SchemaApi;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.values.Row;

Expand All @@ -36,7 +38,11 @@
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public class MicrosInstant implements Schema.LogicalType<Instant, Row> {
public static final String IDENTIFIER = "beam:logical_type:micros_instant:v1";
public static final String IDENTIFIER =
SchemaApi.LogicalTypes.Enum.MICROS_INSTANT
.getValueDescriptor()
.getOptions()
.getExtension(RunnerApi.beamUrn);
// TODO(BEAM-10878): This should be a constant
private final Schema schema;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.schemas.logicaltypes;

import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.SchemaApi;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.LogicalType;
import org.apache.beam.sdk.util.PythonCallableSource;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;

/** A logical type for PythonCallableSource objects. */
@Experimental(Experimental.Kind.SCHEMAS)
public class PythonCallable implements LogicalType<PythonCallableSource, String> {
public static final String IDENTIFIER =
SchemaApi.LogicalTypes.Enum.PYTHON_CALLABLE
.getValueDescriptor()
.getOptions()
.getExtension(RunnerApi.beamUrn);

@Override
public String getIdentifier() {
return IDENTIFIER;
}

@Override
public Schema.@Nullable FieldType getArgumentType() {
return null;
}

@Override
public Schema.FieldType getBaseType() {
return Schema.FieldType.STRING;
}

@Override
public @NonNull String toBaseType(@NonNull PythonCallableSource input) {
return input.getPythonCallableCode();
}

@Override
public @NonNull PythonCallableSource toInputType(@NonNull String base) {
return PythonCallableSource.of(base);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ public static Schema.FieldType fieldFromType(
return fieldFromType(type, fieldValueTypeSupplier, new HashMap<Class, Schema>());
}

// TODO(BEAM-14458): support type inference for logical types
private static Schema.FieldType fieldFromType(
TypeDescriptor type,
FieldValueTypeSupplier fieldValueTypeSupplier,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.util;

import java.io.Serializable;

/**
* A wrapper object storing a Python code that can be evaluated to Python callables in Python SDK.
*/
public class PythonCallableSource implements Serializable {
private final String pythonCallableCode;

private PythonCallableSource(String pythonCallableCode) {
this.pythonCallableCode = pythonCallableCode;
}

public static PythonCallableSource of(String pythonCallableCode) {
// TODO(BEAM-14457): check syntactic correctness of Python code if possible
return new PythonCallableSource(pythonCallableCode);
}

public String getPythonCallableCode() {
return pythonCallableCode;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.beam.sdk.schemas.logicaltypes.DateTime;
import org.apache.beam.sdk.schemas.logicaltypes.FixedBytes;
import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant;
import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable;
import org.apache.beam.sdk.schemas.logicaltypes.SchemaLogicalType;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
Expand Down Expand Up @@ -132,6 +133,7 @@ public static Iterable<Schema> data() {
Field.of("decimal", FieldType.DECIMAL), Field.of("datetime", FieldType.DATETIME)))
.add(Schema.of(Field.of("fixed_bytes", FieldType.logicalType(FixedBytes.of(24)))))
.add(Schema.of(Field.of("micros_instant", FieldType.logicalType(new MicrosInstant()))))
.add(Schema.of(Field.of("python_callable", FieldType.logicalType(new PythonCallable()))))
.add(
Schema.of(
Field.of("field_with_option_atomic", FieldType.STRING)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.sdk.extensions.python;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
Expand All @@ -33,10 +34,12 @@
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable;
import org.apache.beam.sdk.schemas.utils.StaticSchemaInference;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.PythonCallableSource;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
Expand Down Expand Up @@ -64,6 +67,7 @@ public class PythonExternalTransform<InputT extends PInput, OutputT extends POut
// We preseve the order here since Schema's care about order of fields but the order will not
// matter when applying kwargs at the Python side.
private SortedMap<String, Object> kwargsMap;
private Map<java.lang.Class<?>, Schema.FieldType> typeHints;

private @Nullable Object @NonNull [] argsArray;
private @Nullable Row providedKwargsRow;
Expand All @@ -72,6 +76,11 @@ private PythonExternalTransform(String fullyQualifiedName, String expansionServi
this.fullyQualifiedName = fullyQualifiedName;
this.expansionService = expansionService;
this.kwargsMap = new TreeMap<>();
this.typeHints = new HashMap<>();
// TODO(BEAM-14458): remove a default type hint for PythonCallableSource when BEAM-14458 is
// resolved
this.typeHints.put(
PythonCallableSource.class, Schema.FieldType.logicalType(new PythonCallable()));
argsArray = new Object[] {};
}

Expand Down Expand Up @@ -162,6 +171,26 @@ public PythonExternalTransform<InputT, OutputT> withKwargs(Row kwargs) {
return this;
}

/**
* Specifies the field type of arguments.
*
* <p>Type hints are especially useful for logical types since type inference does not work well
* for logical types.
*
* @param argType A class object for the argument type.
* @param fieldType A schema field type for the argument.
* @return updated wrapper for the cross-language transform.
*/
public PythonExternalTransform<InputT, OutputT> withTypeHint(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be per arg instead of per type ? In other words, can the same class map to different schema types ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per type makes more sense to me. Do you have any specific per arg use-case in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't. Just wasn't sure. We can keep this per type if Robert and Brian are OK.

java.lang.Class<?> argType, Schema.FieldType fieldType) {
if (typeHints.containsKey(argType)) {
throw new IllegalArgumentException(
String.format("typehint for arg type %s already exists", argType));
}
typeHints.put(argType, fieldType);
return this;
}

@VisibleForTesting
Row buildOrGetKwargsRow() {
if (providedKwargsRow != null) {
Expand All @@ -179,16 +208,18 @@ Row buildOrGetKwargsRow() {
// Types that are not one of following are considered custom types.
// * Java primitives
// * Type String
// * Any Type explicitly annotated by withTypeHint()
// * Type Row
private static boolean isCustomType(java.lang.Class<?> type) {
private boolean isCustomType(java.lang.Class<?> type) {
boolean val =
!(ClassUtils.isPrimitiveOrWrapper(type)
|| type == String.class
|| typeHints.containsKey(type)
|| Row.class.isAssignableFrom(type));
return val;
}

// If the custom type has a registered schema, we use that. OTherwise we try to register it using
// If the custom type has a registered schema, we use that. Otherwise, we try to register it using
// 'JavaFieldSchema'.
private Row convertCustomValue(Object value) {
SerializableFunction<Object, Row> toRowFunc;
Expand Down Expand Up @@ -239,6 +270,8 @@ private Schema generateSchemaDirectly(
if (field instanceof Row) {
// Rows are used as is but other types are converted to proper field types.
builder.addRowField(fieldName, ((Row) field).getSchema());
} else if (typeHints.containsKey(field.getClass())) {
builder.addField(fieldName, typeHints.get(field.getClass()));
} else {
builder.addField(
fieldName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@
import static org.junit.Assert.assertTrue;

import java.io.Serializable;
import java.time.Instant;
import java.util.Map;
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.util.PythonCallableSource;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
Expand All @@ -41,7 +44,7 @@
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class ExternalPythonTransformTest implements Serializable {
public class PythonExternalTransformTest implements Serializable {
@Ignore("BEAM-14148")
@Test
public void trivialPythonTransform() {
Expand Down Expand Up @@ -184,6 +187,29 @@ public void generateArgsWithCustomType() {
assertEquals(456, (int) receivedRow.getRow("field1").getInt32("intField"));
}

@Test
public void generateArgsWithPythonCallableSource() {
PythonExternalTransform<?, ?> transform =
PythonExternalTransform
.<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
"DummyTransform")
.withArgs(PythonCallableSource.of("dummy data"));
Row receivedRow = transform.buildOrGetArgsRow();
assertTrue(receivedRow.getValue("field0") instanceof PythonCallableSource);
}

@Test
public void generateArgsWithTypeHint() {
PythonExternalTransform<?, ?> transform =
PythonExternalTransform
.<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
"DummyTransform")
.withArgs(Instant.ofEpochSecond(0))
.withTypeHint(Instant.class, Schema.FieldType.logicalType(new MicrosInstant()));
Row receivedRow = transform.buildOrGetArgsRow();
assertTrue(receivedRow.getValue("field0") instanceof Instant);
}

@Test
public void generateKwargsEmpty() {
PythonExternalTransform<?, ?> transform =
Expand Down Expand Up @@ -274,6 +300,29 @@ public void generateKwargsWithCustomType() {
assertEquals(456, (int) receivedRow.getRow("customField1").getInt32("intField"));
}

@Test
public void generateKwargsWithPythonCallableSource() {
PythonExternalTransform<?, ?> transform =
PythonExternalTransform
.<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
"DummyTransform")
.withKwarg("customField0", PythonCallableSource.of("dummy data"));
Row receivedRow = transform.buildOrGetKwargsRow();
assertTrue(receivedRow.getValue("customField0") instanceof PythonCallableSource);
}

@Test
public void generateKwargsWithTypeHint() {
PythonExternalTransform<?, ?> transform =
PythonExternalTransform
.<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
"DummyTransform")
.withKwarg("customField0", Instant.ofEpochSecond(0))
.withTypeHint(Instant.class, Schema.FieldType.logicalType(new MicrosInstant()));
Row receivedRow = transform.buildOrGetKwargsRow();
assertTrue(receivedRow.getValue("customField0") instanceof Instant);
}

@Test
public void generateKwargsFromMap() {
Map<String, Object> kwargsMap =
Expand Down
Loading