diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java index 15bfe6400bb4..6929fc409da6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java @@ -359,7 +359,7 @@ static Row decodeDelegate(Schema schema, Coder[] coders, InputStream inputStream // all values. Since we assume that decode is always being called on a previously-encoded // Row, the values should already be validated and of the correct type. So, we can save // some processing by simply transferring ownership of the list to the Row. - return Row.withSchema(schema).attachValues(fieldValues).build(); + return Row.withSchema(schema).attachValues(fieldValues); } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldAccessDescriptor.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldAccessDescriptor.java index f61f49b2cb76..319e1061124a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldAccessDescriptor.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldAccessDescriptor.java @@ -610,12 +610,12 @@ private Schema getFieldDescriptorSchema(FieldDescriptor fieldDescriptor, Schema private static Schema getFieldSchema(FieldType type) { if (TypeName.ROW.equals(type.getTypeName())) { return type.getRowSchema(); - } else if (type.getTypeName().isCollectionType() - && TypeName.ROW.equals(type.getCollectionElementType().getTypeName())) { - return type.getCollectionElementType().getRowSchema(); - } else if (TypeName.MAP.equals(type.getTypeName()) - && TypeName.ROW.equals(type.getMapValueType().getTypeName())) { - return type.getMapValueType().getRowSchema(); + } else if (type.getTypeName().isCollectionType()) { + return getFieldSchema(type.getCollectionElementType()); + } else if (TypeName.MAP.equals(type.getTypeName())) { + return getFieldSchema(type.getMapValueType()); + } else if (TypeName.LOGICAL_TYPE.equals(type.getTypeName())) { + return getFieldSchema(type.getLogicalType().getBaseType()); } else { throw new IllegalArgumentException( "FieldType " + type + " must be either a row or a container containing rows"); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java index 6b2ebd543e3d..f69c0097c360 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java @@ -56,7 +56,7 @@ public ToRowWithValueGetters(Schema schema) { @Override public Row apply(T input) { - return Row.withSchema(schema).withFieldValueGetters(getterFactory, input).build(); + return Row.withSchema(schema).withFieldValueGetters(getterFactory, input); } private GetterBasedSchemaProvider getOuter() { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/AddFields.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/AddFields.java index 1dbfdb3965dc..4b13bf8c2855 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/AddFields.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/AddFields.java @@ -373,7 +373,7 @@ private static Row fillNewFields(Row row, AddFieldsInformation addFieldsInformat } } - return Row.withSchema(outputSchema).attachValues(newValues).build(); + return Row.withSchema(outputSchema).attachValues(newValues); } private static Object fillNewFields( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java index 2a461c1a201f..c4366e2318cf 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java @@ -548,7 +548,7 @@ void outputUnexpandedRow(Schema outputSchema, OutputReceiver o) { List fields = Lists.newArrayListWithCapacity(getIterables().size() + 1); fields.add(getKey()); fields.addAll(getIterables()); - o.output(Row.withSchema(outputSchema).attachValues(fields).build()); + o.output(Row.withSchema(outputSchema).attachValues(fields)); } static void verifyExpandedArgs(JoinInformation joinInformation, JoinArguments joinArgs) { @@ -632,9 +632,7 @@ private void crossProductHelper( if (atBottom) { // Bottom of recursive call, so output the row we've accumulated. Row row = - Row.withSchema(getOutputSchema()) - .attachValues(Lists.newArrayList(accumulatedRows)) - .build(); + Row.withSchema(getOutputSchema()).attachValues(Lists.newArrayList(accumulatedRows)); o.output(row); } else { crossProduct(tagIndex + 1, accumulatedRows, iterables, o); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java index bf6187426124..76c1e89b1203 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java @@ -859,8 +859,7 @@ public PCollection expand(PCollection input) { public void process(@Element KV> e, OutputReceiver o) { o.output( Row.withSchema(outputSchema) - .attachValues(Lists.newArrayList(e.getKey(), e.getValue())) - .build()); + .attachValues(Lists.newArrayList(e.getKey(), e.getValue()))); } })) .setRowSchema(outputSchema); @@ -1140,8 +1139,7 @@ public void process(@Element KV element, OutputReceiver o) { o.output( Row.withSchema(outputSchema) .attachValues( - Lists.newArrayList(element.getKey(), element.getValue())) - .build()); + Lists.newArrayList(element.getKey(), element.getValue()))); } })) .setRowSchema(outputSchema); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/RenameFields.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/RenameFields.java index 5e2eb97f36b7..e0405bba8602 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/RenameFields.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/RenameFields.java @@ -180,7 +180,7 @@ public PCollection expand(PCollection input) { new DoFn() { @ProcessElement public void processElement(@Element Row row, OutputReceiver o) { - o.output(Row.withSchema(outputSchema).attachValues(row.getValues()).build()); + o.output(Row.withSchema(outputSchema).attachValues(row.getValues())); } })) .setRowSchema(outputSchema); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/SelectByteBuddyHelpers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/SelectByteBuddyHelpers.java index 3bf3069dc618..c5f29a5611b3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/SelectByteBuddyHelpers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/SelectByteBuddyHelpers.java @@ -368,11 +368,6 @@ public ByteCodeAppender appender(final Target implementationTarget) { ElementMatchers.named("attachValues") .and(ElementMatchers.takesArguments(Object[].class))) .getOnly()), - MethodInvocation.invoke( - new ForLoadedType(Row.Builder.class) - .getDeclaredMethods() - .filter(ElementMatchers.named("build")) - .getOnly()), MethodReturn.REFERENCE); size = size.aggregate(attachToRow.apply(methodVisitor, implementationContext)); return new Size(size.getMaximalSize(), localVariables.getTotalNumVariables()); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index 3d37b7388d38..097c2a035f0f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -17,7 +17,7 @@ */ package org.apache.beam.sdk.values; -import static org.apache.beam.sdk.values.SchemaVerification.verifyRowValues; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; @@ -37,11 +37,18 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.Factory; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.sdk.values.RowUtils.CapturingRowCases; +import org.apache.beam.sdk.values.RowUtils.FieldOverride; +import org.apache.beam.sdk.values.RowUtils.FieldOverrides; +import org.apache.beam.sdk.values.RowUtils.RowFieldMatcher; +import org.apache.beam.sdk.values.RowUtils.RowPosition; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.joda.time.DateTime; @@ -52,8 +59,31 @@ * {@link Row} is an immutable tuple-like schema to represent one element in a {@link PCollection}. * The fields are described with a {@link Schema}. * - *

{@link Schema} contains the names for each field and the coder for the whole record, - * {see @link Schema#getRowCoder()}. + *

{@link Schema} contains the names and types for each field. + * + *

There are several ways to build a new Row object. To build a row from scratch using a schema + * object, {@link Row#withSchema} can be used. Schema fields can be specified by name, and nested + * fields can be specified using the field selection syntax. For example: + * + *

{@code
+ * Row row = Row.withSchema(schema)
+ *              .withFieldValue("userId", "user1)
+ *              .withFieldValue("location.city", "seattle")
+ *              .withFieldValue("location.state", "wa")
+ *              .build();
+ * }
+ * + *

The {@link Row#fromRow} builder can be used to base a row off of another row. The builder can + * be used to specify values for specific fields, and all the remaining values will be taken from + * the original row. For example, the following produces a row identical to the above row except for + * the location.city field. + * + *

{@code
+ * Row modifiedRow =
+ *     Row.fromRow(row)
+ *        .withFieldValue("location.city", "tacoma")
+ *        .build();
+ * }
*/ @Experimental(Kind.SCHEMAS) public abstract class Row implements Serializable { @@ -72,6 +102,7 @@ public abstract class Row implements Serializable { /** Return the size of data fields. */ public abstract int getFieldCount(); + /** Return the list of data values. */ public abstract List getValues(); @@ -585,37 +616,150 @@ public String toString() { } /** - * Creates a record builder with specified {@link #getSchema()}. {@link Builder#build()} will - * throw an {@link IllegalArgumentException} if number of fields in {@link #getSchema()} does not - * match the number of fields specified. + * Creates a row builder with specified {@link #getSchema()}. {@link Builder#build()} will throw + * an {@link IllegalArgumentException} if number of fields in {@link #getSchema()} does not match + * the number of fields specified. If any of the arguments don't match the expected types for the + * schema fields, {@link Builder#build()} will throw a {@link ClassCastException}. */ public static Builder withSchema(Schema schema) { return new Builder(schema); } + /** + * Creates a row builder based on the specified row. Field values in the new row can be explicitly + * set using {@link FieldValueBuilder#withFieldValue}. Any values not so overridden will be the + * same as the values in the original row. + */ + public static FieldValueBuilder fromRow(Row row) { + return new FieldValueBuilder(row.getSchema(), row); + } + + /** Builder for {@link Row} that bases a row on another row. */ + public static class FieldValueBuilder { + private final Schema schema; + private final @Nullable Row sourceRow; + private final FieldOverrides fieldOverrides; + + private FieldValueBuilder(Schema schema, @Nullable Row sourceRow) { + this.schema = schema; + this.sourceRow = sourceRow; + this.fieldOverrides = new FieldOverrides(schema); + } + + public Schema getSchema() { + return schema; + } + + /** + * Set a field value using the field name. Nested values can be set using the field selection + * syntax. + */ + public FieldValueBuilder withFieldValue(String fieldName, Object value) { + return withFieldValue(FieldAccessDescriptor.withFieldNames(fieldName), value); + } + + /** Set a field value using the field id. */ + public FieldValueBuilder withFieldValue(Integer fieldId, Object value) { + return withFieldValue(FieldAccessDescriptor.withFieldIds(fieldId), value); + } + + /** Set a field value using a FieldAccessDescriptor. */ + public FieldValueBuilder withFieldValue( + FieldAccessDescriptor fieldAccessDescriptor, Object value) { + FieldAccessDescriptor fieldAccess = fieldAccessDescriptor.resolve(getSchema()); + checkArgument(fieldAccess.referencesSingleField(), ""); + fieldOverrides.addOverride(fieldAccess, new FieldOverride(value)); + return this; + } + + /** + * Sets field values using the field names. Nested values can be set using the field selection + * syntax. + */ + public FieldValueBuilder withFieldValues(Map values) { + values.entrySet().stream() + .forEach( + e -> + fieldOverrides.addOverride( + FieldAccessDescriptor.withFieldNames(e.getKey()), + new FieldOverride(e.getValue()))); + return this; + } + + /** + * Sets field values using the FieldAccessDescriptors. Nested values can be set using the field + * selection syntax. + */ + public FieldValueBuilder withFieldAccessDescriptors(Map values) { + values.entrySet().stream() + .forEach(e -> fieldOverrides.addOverride(e.getKey(), new FieldOverride(e.getValue()))); + return this; + } + + public Row build() { + Row row = + (Row) + new RowFieldMatcher() + .match( + new CapturingRowCases(getSchema(), this.fieldOverrides), + FieldType.row(getSchema()), + new RowPosition(FieldAccessDescriptor.create()), + sourceRow); + return row; + } + } + /** Builder for {@link Row}. */ public static class Builder { private List values = Lists.newArrayList(); - private boolean attached = false; - @Nullable private Factory> fieldValueGetterFactory; - @Nullable private Object getterTarget; - private Schema schema; + private final Schema schema; + private Row nullRow; Builder(Schema schema) { this.schema = schema; } - public int nextFieldId() { - if (fieldValueGetterFactory != null) { - throw new RuntimeException("Not supported"); - } - return values.size(); - } - + /** Return the schema for the row being built. */ public Schema getSchema() { return schema; } + /** + * Set a field value using the field name. Nested values can be set using the field selection + * syntax. + */ + public FieldValueBuilder withFieldValue(String fieldName, Object value) { + checkState(values.isEmpty()); + return new FieldValueBuilder(schema, null).withFieldValue(fieldName, value); + } + + /** Set a field value using the field id. */ + public FieldValueBuilder withFieldValue(Integer fieldId, Object value) { + checkState(values.isEmpty()); + return new FieldValueBuilder(schema, null).withFieldValue(fieldId, value); + } + + /** Set a field value using a FieldAccessDescriptor. */ + public FieldValueBuilder withFieldValue( + FieldAccessDescriptor fieldAccessDescriptor, Object value) { + checkState(values.isEmpty()); + return new FieldValueBuilder(schema, null).withFieldValue(fieldAccessDescriptor, value); + } + /** + * Sets field values using the field names. Nested values can be set using the field selection + * syntax. + */ + public FieldValueBuilder withFieldValues(Map values) { + checkState(values.isEmpty()); + return new FieldValueBuilder(schema, null).withFieldValues(values); + } + + // The following methods allow appending a list of values to the Builder object. The values must + // be in the same + // order as the fields in the row. These methods cannot be used in conjunction with + // withFieldValue or + // withFieldValues. + public Builder addValue(@Nullable Object values) { this.values.add(values); return this; @@ -645,40 +789,62 @@ public Builder addIterable(Iterable values) { return this; } - // Values are attached. No verification is done, and no conversions are done. LogicalType - // values must be specified as the base type. - public Builder attachValues(List values) { - this.attached = true; - this.values = values; - return this; + // Values are attached. No verification is done, and no conversions are done. LogicalType values + // must be specified as the base type. This method should be used with great care, as no + // validation is done. If + // incorrect values are passed in, it could result in strange errors later in the pipeline. This + // method is largely + // used internal to Beam. + @Internal + public Row attachValues(List attachedValues) { + checkState(this.values.isEmpty()); + return new RowWithStorage(schema, attachedValues); } - public Builder attachValues(Object... values) { + public Row attachValues(Object... values) { return attachValues(Arrays.asList(values)); } - public Builder withFieldValueGetters( + public int nextFieldId() { + return values.size(); + } + + @Internal + public Row withFieldValueGetters( Factory> fieldValueGetterFactory, Object getterTarget) { - this.fieldValueGetterFactory = fieldValueGetterFactory; - this.getterTarget = getterTarget; - return this; + checkState(getterTarget != null, "getters require withGetterTarget."); + return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); } public Row build() { checkNotNull(schema); - if (!this.values.isEmpty() && fieldValueGetterFactory != null) { - throw new IllegalArgumentException("Cannot specify both values and getters."); + + if (!values.isEmpty() && values.size() != schema.getFieldCount()) { + throw new IllegalArgumentException( + "Row expected " + + schema.getFieldCount() + + " fields. initialized with " + + values.size() + + " fields."); } - if (!this.values.isEmpty()) { - List storageValues = attached ? this.values : verifyRowValues(schema, this.values); - checkState(getterTarget == null, "withGetterTarget requires getters."); - return new RowWithStorage(schema, storageValues); - } else if (fieldValueGetterFactory != null) { - checkState(getterTarget != null, "getters require withGetterTarget."); - return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); + + FieldOverrides fieldOverrides = new FieldOverrides(schema); + fieldOverrides.setOverrides(this.values); + + Row row; + if (!fieldOverrides.isEmpty()) { + row = + (Row) + new RowFieldMatcher() + .match( + new CapturingRowCases(schema, fieldOverrides), + FieldType.row(schema), + new RowPosition(FieldAccessDescriptor.create()), + null); } else { - return new RowWithStorage(schema, Collections.emptyList()); + row = new RowWithStorage(schema, Collections.emptyList()); } + return row; } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowUtils.java new file mode 100644 index 000000000000..3a7c860cf10c --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowUtils.java @@ -0,0 +1,588 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.values; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import javax.annotation.Nullable; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor.FieldDescriptor; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor.FieldDescriptor.ListQualifier; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor.FieldDescriptor.MapQualifier; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor.FieldDescriptor.Qualifier; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.Schema.LogicalType; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.joda.time.Instant; +import org.joda.time.base.AbstractInstant; + +class RowUtils { + static class RowPosition { + FieldAccessDescriptor descriptor; + List qualifiers; + + RowPosition(FieldAccessDescriptor descriptor) { + this(descriptor, Collections.emptyList()); + } + + RowPosition(FieldAccessDescriptor descriptor, List qualifiers) { + this.descriptor = descriptor; + this.qualifiers = qualifiers; + } + + RowPosition withArrayQualifier() { + List newQualifiers = Lists.newArrayListWithCapacity(qualifiers.size() + 1); + newQualifiers.addAll(qualifiers); + newQualifiers.add(Qualifier.of(ListQualifier.ALL)); + return new RowPosition(descriptor, newQualifiers); + } + + RowPosition withMapQualifier() { + List newQualifiers = Lists.newArrayListWithCapacity(qualifiers.size() + 1); + newQualifiers.addAll(qualifiers); + newQualifiers.add(Qualifier.of(MapQualifier.ALL)); + return new RowPosition(descriptor, newQualifiers); + } + } + + // Subclasses of this interface implement process methods for each schema type. Each process + // method is invoked as + // a RowFieldMatcher walks down the schema tree. The FieldAccessDescriptor passed into each method + // identifies the + // current element of the schema being processed. + interface RowCases { + Row processRow(RowPosition rowPosition, Schema schema, Row value, RowFieldMatcher matcher); + + Collection processArray( + RowPosition rowPosition, + FieldType collectionElementType, + Collection values, + RowFieldMatcher matcher); + + Iterable processIterable( + RowPosition rowPosition, + FieldType collectionElementType, + Iterable values, + RowFieldMatcher matcher); + + Map processMap( + RowPosition rowPosition, + FieldType keyType, + FieldType valueType, + Map valueMap, + RowFieldMatcher matcher); + + Object processLogicalType( + RowPosition rowPosition, LogicalType logicalType, Object baseType, RowFieldMatcher matcher); + + Instant processDateTime( + RowPosition rowPosition, AbstractInstant instant, RowFieldMatcher matcher); + + Byte processByte(RowPosition rowPosition, Byte value, RowFieldMatcher matcher); + + Short processInt16(RowPosition rowPosition, Short value, RowFieldMatcher matcher); + + Integer processInt32(RowPosition rowPosition, Integer value, RowFieldMatcher matcher); + + Long processInt64(RowPosition rowPosition, Long value, RowFieldMatcher matcher); + + BigDecimal processDecimal(RowPosition rowPosition, BigDecimal value, RowFieldMatcher matcher); + + Float processFloat(RowPosition rowPosition, Float value, RowFieldMatcher matcher); + + Double processDouble(RowPosition rowPosition, Double value, RowFieldMatcher matcher); + + String processString(RowPosition rowPosition, String value, RowFieldMatcher matcher); + + Boolean processBoolean(RowPosition rowPosition, Boolean value, RowFieldMatcher matcher); + + byte[] processBytes(RowPosition rowPosition, byte[] value, RowFieldMatcher matcher); + } + + // Given a Row field, delegates processing to the correct process method on the RowCases + // parameter. + static class RowFieldMatcher { + public Object match( + RowCases cases, FieldType fieldType, RowPosition rowPosition, Object value) { + Object processedValue = null; + switch (fieldType.getTypeName()) { + case ARRAY: + processedValue = + cases.processArray( + rowPosition, + fieldType.getCollectionElementType(), + (Collection) value, + this); + break; + case ITERABLE: + processedValue = + cases.processIterable( + rowPosition, + fieldType.getCollectionElementType(), + (Iterable) value, + this); + break; + case MAP: + processedValue = + cases.processMap( + rowPosition, + fieldType.getMapKeyType(), + fieldType.getMapValueType(), + (Map) value, + this); + break; + case ROW: + processedValue = + cases.processRow(rowPosition, fieldType.getRowSchema(), (Row) value, this); + break; + case LOGICAL_TYPE: + LogicalType logicalType = fieldType.getLogicalType(); + processedValue = cases.processLogicalType(rowPosition, logicalType, value, this); + break; + case DATETIME: + processedValue = cases.processDateTime(rowPosition, (AbstractInstant) value, this); + break; + case BYTE: + processedValue = cases.processByte(rowPosition, (Byte) value, this); + break; + case BYTES: + processedValue = cases.processBytes(rowPosition, (byte[]) value, this); + break; + case INT16: + processedValue = cases.processInt16(rowPosition, (Short) value, this); + break; + case INT32: + processedValue = cases.processInt32(rowPosition, (Integer) value, this); + break; + case INT64: + processedValue = cases.processInt64(rowPosition, (Long) value, this); + break; + case DECIMAL: + processedValue = cases.processDecimal(rowPosition, (BigDecimal) value, this); + break; + case FLOAT: + processedValue = cases.processFloat(rowPosition, (Float) value, this); + break; + case DOUBLE: + processedValue = cases.processDouble(rowPosition, (Double) value, this); + break; + case STRING: + processedValue = cases.processString(rowPosition, (String) value, this); + break; + case BOOLEAN: + processedValue = cases.processBoolean(rowPosition, (Boolean) value, this); + break; + default: + // Shouldn't actually get here, but we need this case to satisfy linters. + throw new IllegalArgumentException( + String.format( + "Not a primitive type for field name %s: %s", rowPosition.descriptor, fieldType)); + } + if (processedValue == null) { + if (!fieldType.getNullable()) { + throw new IllegalArgumentException( + String.format("%s is not nullable in field %s", fieldType, rowPosition.descriptor)); + } + } + return processedValue; + } + } + + static class FieldOverride { + FieldOverride(Object overrideValue) { + this.overrideValue = overrideValue; + } + + Object getOverrideValue() { + return overrideValue; + } + + final Object overrideValue; + } + + static class FieldOverrides { + private FieldAccessNode topNode; + private Schema rootSchema; + + FieldOverrides(Schema rootSchema) { + this.topNode = new FieldAccessNode(rootSchema); + this.rootSchema = rootSchema; + } + + boolean isEmpty() { + return topNode.isEmpty(); + } + + void addOverride(FieldAccessDescriptor fieldAccessDescriptor, FieldOverride fieldOverride) { + topNode.addOverride(fieldAccessDescriptor, fieldOverride, rootSchema); + } + + void setOverrides(List values) { + List overrides = Lists.newArrayListWithExpectedSize(values.size()); + for (Object value : values) { + overrides.add(new FieldOverride(value)); + } + topNode.setOverrides(overrides); + } + + @Nullable + FieldOverride getOverride(FieldAccessDescriptor fieldAccessDescriptor) { + return topNode.getOverride(fieldAccessDescriptor); + } + + boolean hasOverrideBelow(FieldAccessDescriptor fieldAccessDescriptor) { + return topNode.hasOverrideBelow(fieldAccessDescriptor); + } + + private static class FieldAccessNode { + List fieldOverrides; + List nestedAccess; + + FieldAccessNode(Schema schema) { + fieldOverrides = Lists.newArrayListWithExpectedSize(schema.getFieldCount()); + nestedAccess = Lists.newArrayList(); + } + + boolean isEmpty() { + return fieldOverrides.isEmpty() && nestedAccess.isEmpty(); + } + + void addOverride( + FieldAccessDescriptor fieldAccessDescriptor, + FieldOverride fieldOverride, + Schema currentSchema) { + if (!fieldAccessDescriptor.getFieldsAccessed().isEmpty()) { + FieldDescriptor fieldDescriptor = + Iterables.getOnlyElement(fieldAccessDescriptor.getFieldsAccessed()); + int aheadPosition = fieldDescriptor.getFieldId() - fieldOverrides.size() + 1; + if (aheadPosition > 0) { + fieldOverrides.addAll(Collections.nCopies(aheadPosition, null)); + } + fieldOverrides.set(fieldDescriptor.getFieldId(), fieldOverride); + } else if (!fieldAccessDescriptor.getNestedFieldsAccessed().isEmpty()) { + Map.Entry entry = + Iterables.getOnlyElement(fieldAccessDescriptor.getNestedFieldsAccessed().entrySet()); + int aheadPosition = entry.getKey().getFieldId() - nestedAccess.size() + 1; + if (aheadPosition > 0) { + nestedAccess.addAll(Collections.nCopies(aheadPosition, null)); + } + + Schema nestedSchema = + currentSchema.getField(entry.getKey().getFieldId()).getType().getRowSchema(); + FieldAccessNode node = nestedAccess.get(entry.getKey().getFieldId()); + if (node == null) { + node = new FieldAccessNode(nestedSchema); + nestedAccess.set(entry.getKey().getFieldId(), node); + } + node.addOverride(entry.getValue(), fieldOverride, nestedSchema); + } + } + + void setOverrides(List overrides) { + this.fieldOverrides = overrides; + } + + @Nullable + FieldOverride getOverride(FieldAccessDescriptor fieldAccessDescriptor) { + FieldOverride override = null; + if (!fieldAccessDescriptor.getFieldsAccessed().isEmpty()) { + FieldDescriptor fieldDescriptor = + Iterables.getOnlyElement(fieldAccessDescriptor.getFieldsAccessed()); + if (fieldDescriptor.getFieldId() < fieldOverrides.size()) { + override = fieldOverrides.get(fieldDescriptor.getFieldId()); + } + } else if (!fieldAccessDescriptor.getNestedFieldsAccessed().isEmpty()) { + Map.Entry entry = + Iterables.getOnlyElement(fieldAccessDescriptor.getNestedFieldsAccessed().entrySet()); + if (entry.getKey().getFieldId() < nestedAccess.size()) { + FieldAccessNode node = nestedAccess.get(entry.getKey().getFieldId()); + if (node != null) { + override = node.getOverride(entry.getValue()); + } + } + } + return override; + } + + boolean hasOverrideBelow(FieldAccessDescriptor fieldAccessDescriptor) { + if (!fieldAccessDescriptor.getFieldsAccessed().isEmpty()) { + FieldDescriptor fieldDescriptor = + Iterables.getOnlyElement(fieldAccessDescriptor.getFieldsAccessed()); + return (((fieldDescriptor.getFieldId() < nestedAccess.size())) + && nestedAccess.get(fieldDescriptor.getFieldId()) != null); + } else if (!fieldAccessDescriptor.getNestedFieldsAccessed().isEmpty()) { + Map.Entry entry = + Iterables.getOnlyElement(fieldAccessDescriptor.getNestedFieldsAccessed().entrySet()); + if (entry.getKey().getFieldId() < nestedAccess.size()) { + FieldAccessNode node = nestedAccess.get(entry.getKey().getFieldId()); + if (node != null) { + return node.hasOverrideBelow(entry.getValue()); + } + } + } else { + return true; + } + return false; + } + } + } + + // This implementation of RowCases captures a Row into a new Row. It also has the effect of + // validating all the + // field parameters. + // A Map of field values can also be passed in, and those field values will be used to override + // the values in the + // passed-in row. + static class CapturingRowCases implements RowCases { + private final Schema topSchema; + private final FieldOverrides fieldOverrides; + + CapturingRowCases(Schema topSchema, FieldOverrides fieldOverrides) { + this.topSchema = topSchema; + this.fieldOverrides = fieldOverrides; + } + + private @Nullable FieldOverride override(RowPosition rowPosition) { + if (!rowPosition.qualifiers.isEmpty()) { + // Currently we only support overriding named schema fields. Individual array/map elements + // or nested collections + // cannot be overriden without overriding the entire schema fields. + return null; + } else { + return fieldOverrides.getOverride(rowPosition.descriptor); + } + } + + private T overrideOrReturn(RowPosition rowPosition, T value) { + FieldOverride fieldOverride = override(rowPosition); + // null return means the item isn't in the map. + return (fieldOverride != null) ? (T) fieldOverride.getOverrideValue() : value; + } + + @Override + public Row processRow( + RowPosition rowPosition, Schema schema, Row value, RowFieldMatcher matcher) { + FieldOverride override = override(rowPosition); + Row retValue = value; + if (override != null) { + retValue = (Row) override.getOverrideValue(); + } else if (fieldOverrides.hasOverrideBelow(rowPosition.descriptor)) { + List values = Lists.newArrayListWithCapacity(schema.getFieldCount()); + for (int i = 0; i < schema.getFieldCount(); ++i) { + FieldAccessDescriptor nestedDescriptor = + FieldAccessDescriptor.withFieldIds(rowPosition.descriptor, i).resolve(topSchema); + Object fieldValue = (value != null) ? value.getValue(i) : null; + values.add( + matcher.match( + this, + schema.getField(i).getType(), + new RowPosition(nestedDescriptor), + fieldValue)); + } + retValue = new RowWithStorage(schema, values); + } + return retValue; + } + + @Override + public Collection processArray( + RowPosition rowPosition, + FieldType collectionElementType, + Collection values, + RowFieldMatcher matcher) { + Collection retValue = null; + FieldOverride override = override(rowPosition); + if (override != null) { + retValue = + captureIterable( + rowPosition, + collectionElementType, + (Collection) override.getOverrideValue(), + matcher); + } else if (values != null) { + retValue = captureIterable(rowPosition, collectionElementType, values, matcher); + } + return retValue; + } + + @Override + public Iterable processIterable( + RowPosition rowPosition, + FieldType collectionElementType, + Iterable values, + RowFieldMatcher matcher) { + Iterable retValue = null; + FieldOverride override = override(rowPosition); + if (override != null) { + retValue = + captureIterable( + rowPosition, + collectionElementType, + (Iterable) override.getOverrideValue(), + matcher); + } else if (values != null) { + retValue = captureIterable(rowPosition, collectionElementType, values, matcher); + } + return retValue; + } + + private Collection captureIterable( + RowPosition rowPosition, + FieldType collectionElementType, + Iterable values, + RowFieldMatcher matcher) { + if (values == null) { + return null; + } + + List captured = Lists.newArrayListWithCapacity(Iterables.size(values)); + RowPosition elementPosition = rowPosition.withArrayQualifier(); + for (Object listValue : values) { + if (listValue == null) { + if (!collectionElementType.getNullable()) { + throw new IllegalArgumentException( + String.format( + "%s is not nullable in Array field %s", + collectionElementType, rowPosition.descriptor)); + } + captured.add(null); + } else { + Object capturedElement = + matcher.match(this, collectionElementType, elementPosition, listValue); + captured.add(capturedElement); + } + } + return captured; + } + + @Override + public Map processMap( + RowPosition rowPosition, + FieldType keyType, + FieldType valueType, + Map valueMap, + RowFieldMatcher matcher) { + Map retValue = null; + FieldOverride override = override(rowPosition); + if (override != null) { + valueMap = (Map) override.getOverrideValue(); + } + + if (valueMap != null) { + RowPosition elementPosition = rowPosition.withMapQualifier(); + + retValue = Maps.newHashMapWithExpectedSize(valueMap.size()); + for (Entry kv : valueMap.entrySet()) { + if (kv.getValue() == null) { + if (!valueType.getNullable()) { + throw new IllegalArgumentException( + String.format( + "%s is not nullable in Map field %s", valueType, rowPosition.descriptor)); + } + retValue.put(matcher.match(this, keyType, elementPosition, kv.getKey()), null); + } else { + retValue.put( + matcher.match(this, keyType, elementPosition, kv.getKey()), + matcher.match(this, valueType, elementPosition, kv.getValue())); + } + } + } + return retValue; + } + + @Override + public Object processLogicalType( + RowPosition rowPosition, LogicalType logicalType, Object value, RowFieldMatcher matcher) { + Object retValue = null; + FieldOverride override = override(rowPosition); + if (override != null) { + retValue = override.getOverrideValue(); + } else if (value != null) { + retValue = logicalType.toInputType(logicalType.toBaseType(value)); + } + return retValue; + } + + @Override + public Instant processDateTime( + RowPosition rowPosition, AbstractInstant value, RowFieldMatcher matcher) { + AbstractInstant instantValue = overrideOrReturn(rowPosition, value); + return (instantValue != null) ? instantValue.toInstant() : null; + } + + @Override + public Byte processByte(RowPosition rowPosition, Byte value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public Short processInt16(RowPosition rowPosition, Short value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public Integer processInt32(RowPosition rowPosition, Integer value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public Long processInt64(RowPosition rowPosition, Long value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public BigDecimal processDecimal( + RowPosition rowPosition, BigDecimal value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public Float processFloat(RowPosition rowPosition, Float value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public Double processDouble(RowPosition rowPosition, Double value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public String processString(RowPosition rowPosition, String value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public Boolean processBoolean(RowPosition rowPosition, Boolean value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public byte[] processBytes(RowPosition rowPosition, byte[] value, RowFieldMatcher matcher) { + Object retValue = overrideOrReturn(rowPosition, value); + return (retValue instanceof ByteBuffer) ? ((ByteBuffer) retValue).array() : (byte[]) retValue; + } + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java index b38051075020..3d1d0baa5677 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java @@ -18,239 +18,24 @@ package org.apache.beam.sdk.values; import java.io.Serializable; -import java.math.BigDecimal; -import java.nio.ByteBuffer; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.Schema.LogicalType; -import org.apache.beam.sdk.schemas.Schema.TypeName; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; -import org.joda.time.Instant; -import org.joda.time.base.AbstractInstant; +import org.apache.beam.sdk.values.RowUtils.CapturingRowCases; +import org.apache.beam.sdk.values.RowUtils.FieldOverrides; +import org.apache.beam.sdk.values.RowUtils.RowFieldMatcher; +import org.apache.beam.sdk.values.RowUtils.RowPosition; @Experimental public abstract class SchemaVerification implements Serializable { - - static List verifyRowValues(Schema schema, List values) { - List verifiedValues = Lists.newArrayListWithCapacity(values.size()); - if (schema.getFieldCount() != values.size()) { - throw new IllegalArgumentException( - String.format( - "Field count in Schema (%s) (%d) and values (%s) (%d) must match", - schema.getFieldNames(), schema.getFieldCount(), values, values.size())); - } - for (int i = 0; i < values.size(); ++i) { - Object value = values.get(i); - Schema.Field field = schema.getField(i); - if (value == null) { - if (!field.getType().getNullable()) { - throw new IllegalArgumentException( - String.format("Field %s is not nullable", field.getName())); - } - verifiedValues.add(null); - } else { - verifiedValues.add(verifyFieldValue(value, field.getType(), field.getName())); - } - } - return verifiedValues; - } - public static Object verifyFieldValue(Object value, FieldType type, String fieldName) { - if (TypeName.ARRAY.equals(type.getTypeName())) { - return verifyArray(value, type.getCollectionElementType(), fieldName); - } else if (TypeName.ITERABLE.equals(type.getTypeName())) { - return verifyIterable(value, type.getCollectionElementType(), fieldName); - } - if (TypeName.MAP.equals(type.getTypeName())) { - return verifyMap(value, type.getMapKeyType(), type.getMapValueType(), fieldName); - } else if (TypeName.ROW.equals(type.getTypeName())) { - return verifyRow(value, fieldName); - } else if (TypeName.LOGICAL_TYPE.equals(type.getTypeName())) { - return verifyLogicalType(value, type.getLogicalType(), fieldName); - } else { - return verifyPrimitiveType(value, type.getTypeName(), fieldName); - } - } - - private static Object verifyLogicalType(Object value, LogicalType logicalType, String fieldName) { - // TODO: this isn't guaranteed to clone the object. - return logicalType.toInputType(logicalType.toBaseType(value)); - } - - private static List verifyArray( - Object value, FieldType collectionElementType, String fieldName) { - boolean collectionElementTypeNullable = collectionElementType.getNullable(); - if (!(value instanceof List)) { - throw new IllegalArgumentException( - String.format( - "For field name %s and array type expected List class. Instead " - + "class type was %s.", - fieldName, value.getClass())); - } - List valueList = (List) value; - List verifiedList = Lists.newArrayListWithCapacity(valueList.size()); - for (Object listValue : valueList) { - if (listValue == null) { - if (!collectionElementTypeNullable) { - throw new IllegalArgumentException( - String.format( - "%s is not nullable in Array field %s", collectionElementType, fieldName)); - } - verifiedList.add(null); - } else { - verifiedList.add(verifyFieldValue(listValue, collectionElementType, fieldName)); - } - } - return verifiedList; - } - - private static Iterable verifyIterable( - Object value, FieldType collectionElementType, String fieldName) { - boolean collectionElementTypeNullable = collectionElementType.getNullable(); - if (!(value instanceof Iterable)) { - throw new IllegalArgumentException( - String.format( - "For field name %s and iterable type expected class extending Iterable. Instead " - + "class type was %s.", - fieldName, value.getClass())); - } - Iterable valueList = (Iterable) value; - for (Object listValue : valueList) { - if (listValue == null) { - if (!collectionElementTypeNullable) { - throw new IllegalArgumentException( - String.format( - "%s is not nullable in Array field %s", collectionElementType, fieldName)); - } - } else { - verifyFieldValue(listValue, collectionElementType, fieldName); - } - } - return valueList; - } - - private static Map verifyMap( - Object value, FieldType keyType, FieldType valueType, String fieldName) { - boolean valueTypeNullable = valueType.getNullable(); - if (!(value instanceof Map)) { - throw new IllegalArgumentException( - String.format( - "For field name %s and map type expected Map class. Instead " + "class type was %s.", - fieldName, value.getClass())); - } - Map valueMap = (Map) value; - Map verifiedMap = Maps.newHashMapWithExpectedSize(valueMap.size()); - for (Entry kv : valueMap.entrySet()) { - if (kv.getValue() == null) { - if (!valueTypeNullable) { - throw new IllegalArgumentException( - String.format("%s is not nullable in Map field %s", valueType, fieldName)); - } - verifiedMap.put(verifyFieldValue(kv.getKey(), keyType, fieldName), null); - } else { - verifiedMap.put( - verifyFieldValue(kv.getKey(), keyType, fieldName), - verifyFieldValue(kv.getValue(), valueType, fieldName)); - } - } - return verifiedMap; - } - - private static Row verifyRow(Object value, String fieldName) { - if (!(value instanceof Row)) { - throw new IllegalArgumentException( - String.format( - "For field name %s expected Row type. " + "Instead class type was %s.", - fieldName, value.getClass())); - } - // No need to recursively validate the nested Row, since there's no way to build the - // Row object without it validating. - return (Row) value; - } - - private static Object verifyPrimitiveType(Object value, TypeName type, String fieldName) { - if (type.isDateType()) { - return verifyDateTime(value, fieldName); - } else { - switch (type) { - case BYTE: - if (value instanceof Byte) { - return value; - } - break; - case BYTES: - if (value instanceof ByteBuffer) { - return ((ByteBuffer) value).array(); - } else if (value instanceof byte[]) { - return (byte[]) value; - } - break; - case INT16: - if (value instanceof Short) { - return value; - } - break; - case INT32: - if (value instanceof Integer) { - return value; - } - break; - case INT64: - if (value instanceof Long) { - return value; - } - break; - case DECIMAL: - if (value instanceof BigDecimal) { - return value; - } - break; - case FLOAT: - if (value instanceof Float) { - return value; - } - break; - case DOUBLE: - if (value instanceof Double) { - return value; - } - break; - case STRING: - if (value instanceof String) { - return value; - } - break; - case BOOLEAN: - if (value instanceof Boolean) { - return value; - } - break; - default: - // Shouldn't actually get here, but we need this case to satisfy linters. - throw new IllegalArgumentException( - String.format("Not a primitive type for field name %s: %s", fieldName, type)); - } - throw new IllegalArgumentException( - String.format( - "For field name %s and type %s found incorrect class type %s", - fieldName, type, value.getClass())); - } - } - - private static Instant verifyDateTime(Object value, String fieldName) { - // We support the following classes for datetimes. - if (value instanceof AbstractInstant) { - return ((AbstractInstant) value).toInstant(); - } else { - throw new IllegalArgumentException( - String.format( - "For field name %s and DATETIME type got unexpected class %s ", - fieldName, value.getClass())); - } + Schema schema = Schema.builder().addField(fieldName, type).build(); + return new RowFieldMatcher() + .match( + new CapturingRowCases(schema, new FieldOverrides(schema)), + type, + new RowPosition(FieldAccessDescriptor.withFieldIds(0)), + value); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java index 5ec69e5090b0..81c5ad92f491 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java @@ -461,7 +461,8 @@ public void testFromRowIterable() throws NoSuchSchemaException { SchemaTestUtils.assertSchemaEquivalent(ITERABLE_BEAM_SCHEMA, schema); List list = Lists.newArrayList("one", "two"); - Row iterableRow = Row.withSchema(ITERABLE_BEAM_SCHEMA).addIterable(list).build(); + Row iterableRow = + Row.withSchema(ITERABLE_BEAM_SCHEMA).attachValues(ImmutableList.of((Object) list)); IterableBean converted = registry.getFromRowFunction(IterableBean.class).apply(iterableRow); assertEquals(list, Lists.newArrayList(converted.getStrings())); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java index af3aae015b75..b642109a21c7 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java @@ -578,7 +578,7 @@ public void testIterableFieldFromRow() throws NoSuchSchemaException { SchemaTestUtils.assertSchemaEquivalent(POJO_WITH_ITERABLE, schema); List list = Lists.newArrayList("one", "two"); - Row iterableRow = Row.withSchema(POJO_WITH_ITERABLE).addIterable(list).build(); + Row iterableRow = Row.withSchema(POJO_WITH_ITERABLE).attachValues((Object) list); PojoWithIterable converted = registry.getFromRowFunction(PojoWithIterable.class).apply(iterableRow); assertEquals(list, Lists.newArrayList(converted.strings)); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java index 02adaa0cf6ce..a9fbc7955e78 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java @@ -33,6 +33,7 @@ import java.util.stream.Stream; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; @@ -477,6 +478,213 @@ public void testCreateMapWithRowValue() { assertEquals(data, row.getMap("map")); } + @Test + public void testLogicalTypeWithRowValue() { + EnumerationType enumerationType = EnumerationType.create("zero", "one", "two"); + Schema type = + Stream.of(Schema.Field.of("f1_enum", FieldType.logicalType(enumerationType))) + .collect(toSchema()); + Row row = Row.withSchema(type).addValue(enumerationType.valueOf("zero")).build(); + assertEquals(enumerationType.valueOf(0), row.getValue(0)); + assertEquals( + enumerationType.valueOf("zero"), row.getLogicalTypeValue(0, EnumerationType.Value.class)); + } + + @Test + public void testLogicalTypeWithRowValueName() { + EnumerationType enumerationType = EnumerationType.create("zero", "one", "two"); + Schema type = + Stream.of(Schema.Field.of("f1_enum", FieldType.logicalType(enumerationType))) + .collect(toSchema()); + Row row = + Row.withSchema(type).withFieldValue("f1_enum", enumerationType.valueOf("zero")).build(); + assertEquals(enumerationType.valueOf(0), row.getValue(0)); + assertEquals( + enumerationType.valueOf("zero"), row.getLogicalTypeValue(0, EnumerationType.Value.class)); + } + + @Test + public void testLogicalTypeWithRowValueOverride() { + EnumerationType enumerationType = EnumerationType.create("zero", "one", "two"); + Schema type = + Stream.of(Schema.Field.of("f1_enum", FieldType.logicalType(enumerationType))) + .collect(toSchema()); + Row row = + Row.withSchema(type).withFieldValue("f1_enum", enumerationType.valueOf("zero")).build(); + Row overriddenRow = + Row.fromRow(row).withFieldValue("f1_enum", enumerationType.valueOf("one")).build(); + assertEquals(enumerationType.valueOf(1), overriddenRow.getValue(0)); + assertEquals( + enumerationType.valueOf("one"), + overriddenRow.getLogicalTypeValue(0, EnumerationType.Value.class)); + } + + @Test + public void testCreateWithNames() { + Schema type = + Stream.of( + Schema.Field.of("f_str", FieldType.STRING), + Schema.Field.of("f_byte", FieldType.BYTE), + Schema.Field.of("f_short", FieldType.INT16), + Schema.Field.of("f_int", FieldType.INT32), + Schema.Field.of("f_long", FieldType.INT64), + Schema.Field.of("f_float", FieldType.FLOAT), + Schema.Field.of("f_double", FieldType.DOUBLE), + Schema.Field.of("f_decimal", FieldType.DECIMAL), + Schema.Field.of("f_boolean", FieldType.BOOLEAN), + Schema.Field.of("f_datetime", FieldType.DATETIME), + Schema.Field.of("f_bytes", FieldType.BYTES), + Schema.Field.of("f_array", FieldType.array(FieldType.STRING)), + Schema.Field.of("f_iterable", FieldType.iterable(FieldType.STRING)), + Schema.Field.of("f_map", FieldType.map(FieldType.STRING, FieldType.STRING))) + .collect(toSchema()); + + DateTime dateTime = + new DateTime().withDate(1979, 03, 14).withTime(1, 2, 3, 4).withZone(DateTimeZone.UTC); + byte[] bytes = new byte[] {1, 2, 3, 4}; + + Row row = + Row.withSchema(type) + .withFieldValue("f_str", "str1") + .withFieldValue("f_byte", (byte) 42) + .withFieldValue("f_short", (short) 43) + .withFieldValue("f_int", (int) 44) + .withFieldValue("f_long", (long) 45) + .withFieldValue("f_float", (float) 3.14) + .withFieldValue("f_double", (double) 3.141) + .withFieldValue("f_decimal", new BigDecimal(3.1415)) + .withFieldValue("f_boolean", true) + .withFieldValue("f_datetime", dateTime) + .withFieldValue("f_bytes", bytes) + .withFieldValue("f_array", Lists.newArrayList("one", "two")) + .withFieldValue("f_iterable", Lists.newArrayList("one", "two", "three")) + .withFieldValue("f_map", ImmutableMap.of("hello", "goodbye", "here", "there")) + .build(); + + Row expectedRow = + Row.withSchema(type) + .addValues( + "str1", + (byte) 42, + (short) 43, + (int) 44, + (long) 45, + (float) 3.14, + (double) 3.141, + new BigDecimal(3.1415), + true, + dateTime, + bytes, + Lists.newArrayList("one", "two"), + Lists.newArrayList("one", "two", "three"), + ImmutableMap.of("hello", "goodbye", "here", "there")) + .build(); + assertEquals(expectedRow, row); + } + + @Test + public void testCreateWithNestedNames() { + Schema nestedType = + Stream.of( + Schema.Field.of("f_str", FieldType.STRING), + Schema.Field.of("f_int", FieldType.INT32)) + .collect(toSchema()); + Schema topType = + Stream.of( + Schema.Field.of("top_int", FieldType.INT32), + Schema.Field.of("f_nested", FieldType.row(nestedType))) + .collect(toSchema()); + Row row = + Row.withSchema(topType) + .withFieldValue("top_int", 42) + .withFieldValue("f_nested.f_str", "string") + .withFieldValue("f_nested.f_int", 43) + .build(); + + Row expectedRow = + Row.withSchema(topType) + .addValues(42, Row.withSchema(nestedType).addValues("string", 43).build()) + .build(); + assertEquals(expectedRow, row); + } + + @Test + public void testCreateWithCollectionNames() { + Schema type = + Stream.of( + Schema.Field.of("f_array", FieldType.array(FieldType.INT32)), + Schema.Field.of("f_iterable", FieldType.iterable(FieldType.INT32)), + Schema.Field.of("f_map", FieldType.map(FieldType.STRING, FieldType.STRING))) + .collect(toSchema()); + + Row row = + Row.withSchema(type) + .withFieldValue("f_array", ImmutableList.of(1, 2, 3)) + .withFieldValue("f_iterable", ImmutableList.of(1, 2, 3)) + .withFieldValue("f_map", ImmutableMap.of("one", "two")) + .build(); + + Row expectedRow = + Row.withSchema(type) + .addValues( + ImmutableList.of(1, 2, 3), ImmutableList.of(1, 2, 3), ImmutableMap.of("one", "two")) + .build(); + assertEquals(expectedRow, row); + } + + @Test + public void testOverrideRow() { + Schema type = + Stream.of( + Schema.Field.of("f_str", FieldType.STRING), + Schema.Field.of("f_int", FieldType.INT32)) + .collect(toSchema()); + Row sourceRow = + Row.withSchema(type).withFieldValue("f_str", "string").withFieldValue("f_int", 42).build(); + + Row modifiedRow = Row.fromRow(sourceRow).withFieldValue("f_str", "modifiedString").build(); + + Row expectedRow = + Row.withSchema(type) + .withFieldValue("f_str", "modifiedString") + .withFieldValue("f_int", 42) + .build(); + assertEquals(expectedRow, modifiedRow); + } + + @Test + public void testOverrideNestedRow() { + Schema nestedType = + Stream.of( + Schema.Field.of("f_str", FieldType.STRING), + Schema.Field.of("f_int", FieldType.INT32)) + .collect(toSchema()); + Schema topType = + Stream.of( + Schema.Field.of("top_int", FieldType.INT32), + Schema.Field.of("f_nested", FieldType.row(nestedType))) + .collect(toSchema()); + Row sourceRow = + Row.withSchema(topType) + .withFieldValue("top_int", 42) + .withFieldValue("f_nested.f_str", "string") + .withFieldValue("f_nested.f_int", 43) + .build(); + Row modifiedRow = + Row.fromRow(sourceRow) + .withFieldValue("f_nested.f_str", "modifiedString") + .withFieldValue("f_nested.f_int", 143) + .build(); + + Row expectedRow = + Row.withSchema(topType) + .withFieldValue("top_int", 42) + .withFieldValue("f_nested.f_str", "modifiedString") + .withFieldValue("f_nested.f_int", 143) + .build(); + assertEquals(expectedRow, modifiedRow); + } + @Test public void testCollector() { Schema type =