Skip to content

Commit

Permalink
Merge pull request apache#17608 from ihji/BEAM-14430
Browse files Browse the repository at this point in the history
[BEAM-14430] Adding a logical type support for Python callables to Row schema
  • Loading branch information
ihji authored May 13, 2022
2 parents 9085345 + 2d36feb commit 2d57753
Show file tree
Hide file tree
Showing 14 changed files with 314 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1046,12 +1046,8 @@ message StandardCoders {
// Nullable types in container types (ArrayType, MapType) per the
// encoding described for general Nullable types below.
//
// Well known logical types:
// beam:logical_type:micros_instant:v1
// - Representation type: ROW<seconds: INT64, micros: INT64>
// - A timestamp without a timezone where seconds + micros represents the
// amount of time since the epoch.
//
// Logical types understood by all SDKs should be defined in schema.proto.
// Example of well known logical types:
// beam:logical_type:schema:v1
// - Representation type: BYTES
// - A Beam Schema stored as a serialized proto.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ option go_package = "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v
option java_package = "org.apache.beam.model.pipeline.v1";
option java_outer_classname = "SchemaApi";

import "org/apache/beam/model/pipeline/v1/beam_runner_api.proto";

message Schema {
// List of fields for this schema. Two fields may not share a name.
repeated Field fields = 1;
Expand Down Expand Up @@ -110,6 +112,27 @@ message LogicalType {
FieldValue argument = 5;
}

// Universally defined Logical types for Row schemas.
// These logical types are supposed to be understood by all SDKs.
message LogicalTypes {
enum Enum {
// A URN for Python Callable logical type
// - Representation type: STRING
// - Language type: In Python SDK, PythonCallableWithSource.
// In any other SDKs, a wrapper object for a string which
// can be evaluated to a Python Callable object.
PYTHON_CALLABLE = 0 [(org.apache.beam.model.pipeline.v1.beam_urn) =
"beam:logical_type:python_callable:v1"];

// A URN for MicrosInstant type
// - Representation type: ROW<seconds: INT64, micros: INT64>
// - A timestamp without a timezone where seconds + micros represents the
// amount of time since the epoch.
MICROS_INSTANT = 1 [(org.apache.beam.model.pipeline.v1.beam_urn) =
"beam:logical_type:micros_instant:v1"];
}
}

message Option {
// REQUIRED. Identifier for the option.
string name = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.beam.sdk.schemas.Schema.LogicalType;
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant;
import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable;
import org.apache.beam.sdk.schemas.logicaltypes.SchemaLogicalType;
import org.apache.beam.sdk.schemas.logicaltypes.UnknownLogicalType;
import org.apache.beam.sdk.util.SerializableUtils;
Expand Down Expand Up @@ -74,6 +75,7 @@ public class SchemaTranslation {
ImmutableMap.<String, Class<? extends LogicalType<?, ?>>>builder()
.put(MicrosInstant.IDENTIFIER, MicrosInstant.class)
.put(SchemaLogicalType.IDENTIFIER, SchemaLogicalType.class)
.put(PythonCallable.IDENTIFIER, PythonCallable.class)
.build();

public static SchemaApi.Schema schemaToProto(Schema schema, boolean serializeLogicalType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.beam.sdk.schemas.logicaltypes;

import java.time.Instant;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.SchemaApi;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.values.Row;

Expand All @@ -36,7 +38,11 @@
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public class MicrosInstant implements Schema.LogicalType<Instant, Row> {
public static final String IDENTIFIER = "beam:logical_type:micros_instant:v1";
public static final String IDENTIFIER =
SchemaApi.LogicalTypes.Enum.MICROS_INSTANT
.getValueDescriptor()
.getOptions()
.getExtension(RunnerApi.beamUrn);
// TODO(BEAM-10878): This should be a constant
private final Schema schema;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.schemas.logicaltypes;

import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.SchemaApi;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.LogicalType;
import org.apache.beam.sdk.util.PythonCallableSource;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;

/** A logical type for PythonCallableSource objects. */
@Experimental(Experimental.Kind.SCHEMAS)
public class PythonCallable implements LogicalType<PythonCallableSource, String> {
public static final String IDENTIFIER =
SchemaApi.LogicalTypes.Enum.PYTHON_CALLABLE
.getValueDescriptor()
.getOptions()
.getExtension(RunnerApi.beamUrn);

@Override
public String getIdentifier() {
return IDENTIFIER;
}

@Override
public Schema.@Nullable FieldType getArgumentType() {
return null;
}

@Override
public Schema.FieldType getBaseType() {
return Schema.FieldType.STRING;
}

@Override
public @NonNull String toBaseType(@NonNull PythonCallableSource input) {
return input.getPythonCallableCode();
}

@Override
public @NonNull PythonCallableSource toInputType(@NonNull String base) {
return PythonCallableSource.of(base);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ public static Schema.FieldType fieldFromType(
return fieldFromType(type, fieldValueTypeSupplier, new HashMap<Class, Schema>());
}

// TODO(BEAM-14458): support type inference for logical types
private static Schema.FieldType fieldFromType(
TypeDescriptor type,
FieldValueTypeSupplier fieldValueTypeSupplier,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.util;

import java.io.Serializable;

/**
* A wrapper object storing a Python code that can be evaluated to Python callables in Python SDK.
*/
public class PythonCallableSource implements Serializable {
private final String pythonCallableCode;

private PythonCallableSource(String pythonCallableCode) {
this.pythonCallableCode = pythonCallableCode;
}

public static PythonCallableSource of(String pythonCallableCode) {
// TODO(BEAM-14457): check syntactic correctness of Python code if possible
return new PythonCallableSource(pythonCallableCode);
}

public String getPythonCallableCode() {
return pythonCallableCode;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.beam.sdk.schemas.logicaltypes.DateTime;
import org.apache.beam.sdk.schemas.logicaltypes.FixedBytes;
import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant;
import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable;
import org.apache.beam.sdk.schemas.logicaltypes.SchemaLogicalType;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
Expand Down Expand Up @@ -132,6 +133,7 @@ public static Iterable<Schema> data() {
Field.of("decimal", FieldType.DECIMAL), Field.of("datetime", FieldType.DATETIME)))
.add(Schema.of(Field.of("fixed_bytes", FieldType.logicalType(FixedBytes.of(24)))))
.add(Schema.of(Field.of("micros_instant", FieldType.logicalType(new MicrosInstant()))))
.add(Schema.of(Field.of("python_callable", FieldType.logicalType(new PythonCallable()))))
.add(
Schema.of(
Field.of("field_with_option_atomic", FieldType.STRING)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.sdk.extensions.python;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
Expand All @@ -33,10 +34,12 @@
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.schemas.logicaltypes.PythonCallable;
import org.apache.beam.sdk.schemas.utils.StaticSchemaInference;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.PythonCallableSource;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
Expand Down Expand Up @@ -64,6 +67,7 @@ public class PythonExternalTransform<InputT extends PInput, OutputT extends POut
// We preseve the order here since Schema's care about order of fields but the order will not
// matter when applying kwargs at the Python side.
private SortedMap<String, Object> kwargsMap;
private Map<java.lang.Class<?>, Schema.FieldType> typeHints;

private @Nullable Object @NonNull [] argsArray;
private @Nullable Row providedKwargsRow;
Expand All @@ -72,6 +76,11 @@ private PythonExternalTransform(String fullyQualifiedName, String expansionServi
this.fullyQualifiedName = fullyQualifiedName;
this.expansionService = expansionService;
this.kwargsMap = new TreeMap<>();
this.typeHints = new HashMap<>();
// TODO(BEAM-14458): remove a default type hint for PythonCallableSource when BEAM-14458 is
// resolved
this.typeHints.put(
PythonCallableSource.class, Schema.FieldType.logicalType(new PythonCallable()));
argsArray = new Object[] {};
}

Expand Down Expand Up @@ -162,6 +171,26 @@ public PythonExternalTransform<InputT, OutputT> withKwargs(Row kwargs) {
return this;
}

/**
* Specifies the field type of arguments.
*
* <p>Type hints are especially useful for logical types since type inference does not work well
* for logical types.
*
* @param argType A class object for the argument type.
* @param fieldType A schema field type for the argument.
* @return updated wrapper for the cross-language transform.
*/
public PythonExternalTransform<InputT, OutputT> withTypeHint(
java.lang.Class<?> argType, Schema.FieldType fieldType) {
if (typeHints.containsKey(argType)) {
throw new IllegalArgumentException(
String.format("typehint for arg type %s already exists", argType));
}
typeHints.put(argType, fieldType);
return this;
}

@VisibleForTesting
Row buildOrGetKwargsRow() {
if (providedKwargsRow != null) {
Expand All @@ -179,16 +208,18 @@ Row buildOrGetKwargsRow() {
// Types that are not one of following are considered custom types.
// * Java primitives
// * Type String
// * Any Type explicitly annotated by withTypeHint()
// * Type Row
private static boolean isCustomType(java.lang.Class<?> type) {
private boolean isCustomType(java.lang.Class<?> type) {
boolean val =
!(ClassUtils.isPrimitiveOrWrapper(type)
|| type == String.class
|| typeHints.containsKey(type)
|| Row.class.isAssignableFrom(type));
return val;
}

// If the custom type has a registered schema, we use that. OTherwise we try to register it using
// If the custom type has a registered schema, we use that. Otherwise, we try to register it using
// 'JavaFieldSchema'.
private Row convertCustomValue(Object value) {
SerializableFunction<Object, Row> toRowFunc;
Expand Down Expand Up @@ -239,6 +270,8 @@ private Schema generateSchemaDirectly(
if (field instanceof Row) {
// Rows are used as is but other types are converted to proper field types.
builder.addRowField(fieldName, ((Row) field).getSchema());
} else if (typeHints.containsKey(field.getClass())) {
builder.addField(fieldName, typeHints.get(field.getClass()));
} else {
builder.addField(
fieldName,
Expand Down
Loading

0 comments on commit 2d57753

Please sign in to comment.