From 2d3e1bc3e5f5144b3b09debc57cb893cbf438b17 Mon Sep 17 00:00:00 2001 From: Reuven Lax Date: Mon, 17 Feb 2020 10:31:15 -0800 Subject: [PATCH 1/4] add better row builders --- .../beam/sdk/coders/RowCoderGenerator.java | 2 +- .../sdk/schemas/FieldAccessDescriptor.java | 64 +- .../schemas/GetterBasedSchemaProvider.java | 2 +- .../sdk/schemas/transforms/AddFields.java | 2 +- .../beam/sdk/schemas/transforms/CoGroup.java | 5 +- .../beam/sdk/schemas/transforms/Group.java | 6 +- .../sdk/schemas/transforms/RenameFields.java | 2 +- .../java/org/apache/beam/sdk/values/Row.java | 701 +++++++++++++++++- .../beam/sdk/schemas/JavaBeanSchemaTest.java | 3 +- .../org/apache/beam/sdk/values/RowTest.java | 208 ++++++ 10 files changed, 939 insertions(+), 56 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java index 15bfe6400bb4..6929fc409da6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java @@ -359,7 +359,7 @@ static Row decodeDelegate(Schema schema, Coder[] coders, InputStream inputStream // all values. Since we assume that decode is always being called on a previously-encoded // Row, the values should already be validated and of the correct type. So, we can save // some processing by simply transferring ownership of the list to the Row. - return Row.withSchema(schema).attachValues(fieldValues).build(); + return Row.withSchema(schema).attachValues(fieldValues); } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldAccessDescriptor.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldAccessDescriptor.java index f61f49b2cb76..eac176016439 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldAccessDescriptor.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldAccessDescriptor.java @@ -201,6 +201,58 @@ public static FieldAccessDescriptor withFieldNamesAs(Map fieldNa return union(fields); } + public static FieldAccessDescriptor withFieldNames( + FieldAccessDescriptor baseDescriptor, String... fieldNames) { + return withFieldNames(baseDescriptor, Arrays.asList(fieldNames)); + } + + public static FieldAccessDescriptor withFieldNames( + FieldAccessDescriptor baseDescriptor, Iterable fieldNames) { + if (baseDescriptor.getFieldsAccessed().isEmpty() + && baseDescriptor.getNestedFieldsAccessed().isEmpty()) { + return withFieldNames(fieldNames); + } + if (!baseDescriptor.getFieldsAccessed().isEmpty()) { + checkArgument(baseDescriptor.getNestedFieldsAccessed().isEmpty()); + FieldDescriptor fieldDescriptor = + Iterables.getOnlyElement(baseDescriptor.getFieldsAccessed()); + return FieldAccessDescriptor.create() + .withNestedField(fieldDescriptor, FieldAccessDescriptor.withFieldNames(fieldNames)); + } else { + checkArgument(baseDescriptor.getFieldsAccessed().isEmpty()); + Map.Entry entry = + Iterables.getOnlyElement(baseDescriptor.getNestedFieldsAccessed().entrySet()); + return FieldAccessDescriptor.create() + .withNestedField(entry.getKey(), withFieldNames(entry.getValue(), fieldNames)); + } + } + + public static FieldAccessDescriptor withFieldIds( + FieldAccessDescriptor baseDescriptor, Integer... fieldIds) { + return withFieldIds(baseDescriptor, Arrays.asList(fieldIds)); + } + + public static FieldAccessDescriptor withFieldIds( + FieldAccessDescriptor baseDescriptor, Iterable fieldIds) { + if (baseDescriptor.getFieldsAccessed().isEmpty() + && baseDescriptor.getNestedFieldsAccessed().isEmpty()) { + return withFieldIds(fieldIds); + } + if (!baseDescriptor.getFieldsAccessed().isEmpty()) { + checkArgument(baseDescriptor.getNestedFieldsAccessed().isEmpty()); + FieldDescriptor fieldDescriptor = + Iterables.getOnlyElement(baseDescriptor.getFieldsAccessed()); + return FieldAccessDescriptor.create() + .withNestedField(fieldDescriptor, FieldAccessDescriptor.withFieldIds(fieldIds)); + } else { + checkArgument(baseDescriptor.getFieldsAccessed().isEmpty()); + Map.Entry entry = + Iterables.getOnlyElement(baseDescriptor.getNestedFieldsAccessed().entrySet()); + return FieldAccessDescriptor.create() + .withNestedField(entry.getKey(), withFieldIds(entry.getValue(), fieldIds)); + } + } + /** * Return a descriptor that accesses the specified field names as nested subfields of the * baseDescriptor. @@ -610,12 +662,12 @@ private Schema getFieldDescriptorSchema(FieldDescriptor fieldDescriptor, Schema private static Schema getFieldSchema(FieldType type) { if (TypeName.ROW.equals(type.getTypeName())) { return type.getRowSchema(); - } else if (type.getTypeName().isCollectionType() - && TypeName.ROW.equals(type.getCollectionElementType().getTypeName())) { - return type.getCollectionElementType().getRowSchema(); - } else if (TypeName.MAP.equals(type.getTypeName()) - && TypeName.ROW.equals(type.getMapValueType().getTypeName())) { - return type.getMapValueType().getRowSchema(); + } else if (type.getTypeName().isCollectionType()) { + return getFieldSchema(type.getCollectionElementType()); + } else if (TypeName.MAP.equals(type.getTypeName())) { + return getFieldSchema(type.getMapValueType()); + } else if (TypeName.LOGICAL_TYPE.equals(type.getTypeName())) { + return getFieldSchema(type.getLogicalType().getBaseType()); } else { throw new IllegalArgumentException( "FieldType " + type + " must be either a row or a container containing rows"); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java index 6b2ebd543e3d..f69c0097c360 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java @@ -56,7 +56,7 @@ public ToRowWithValueGetters(Schema schema) { @Override public Row apply(T input) { - return Row.withSchema(schema).withFieldValueGetters(getterFactory, input).build(); + return Row.withSchema(schema).withFieldValueGetters(getterFactory, input); } private GetterBasedSchemaProvider getOuter() { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/AddFields.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/AddFields.java index 1dbfdb3965dc..4b13bf8c2855 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/AddFields.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/AddFields.java @@ -373,7 +373,7 @@ private static Row fillNewFields(Row row, AddFieldsInformation addFieldsInformat } } - return Row.withSchema(outputSchema).attachValues(newValues).build(); + return Row.withSchema(outputSchema).attachValues(newValues); } private static Object fillNewFields( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java index 2a461c1a201f..39b9de67025d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java @@ -548,7 +548,7 @@ void outputUnexpandedRow(Schema outputSchema, OutputReceiver o) { List fields = Lists.newArrayListWithCapacity(getIterables().size() + 1); fields.add(getKey()); fields.addAll(getIterables()); - o.output(Row.withSchema(outputSchema).attachValues(fields).build()); + o.output(Row.withSchema(outputSchema).attachValues(fields)); } static void verifyExpandedArgs(JoinInformation joinInformation, JoinArguments joinArgs) { @@ -633,8 +633,7 @@ private void crossProductHelper( // Bottom of recursive call, so output the row we've accumulated. Row row = Row.withSchema(getOutputSchema()) - .attachValues(Lists.newArrayList(accumulatedRows)) - .build(); + .attachValues(Lists.newArrayList(accumulatedRows)); o.output(row); } else { crossProduct(tagIndex + 1, accumulatedRows, iterables, o); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java index bf6187426124..76c1e89b1203 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java @@ -859,8 +859,7 @@ public PCollection expand(PCollection input) { public void process(@Element KV> e, OutputReceiver o) { o.output( Row.withSchema(outputSchema) - .attachValues(Lists.newArrayList(e.getKey(), e.getValue())) - .build()); + .attachValues(Lists.newArrayList(e.getKey(), e.getValue()))); } })) .setRowSchema(outputSchema); @@ -1140,8 +1139,7 @@ public void process(@Element KV element, OutputReceiver o) { o.output( Row.withSchema(outputSchema) .attachValues( - Lists.newArrayList(element.getKey(), element.getValue())) - .build()); + Lists.newArrayList(element.getKey(), element.getValue()))); } })) .setRowSchema(outputSchema); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/RenameFields.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/RenameFields.java index 5e2eb97f36b7..e0405bba8602 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/RenameFields.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/RenameFields.java @@ -180,7 +180,7 @@ public PCollection expand(PCollection input) { new DoFn() { @ProcessElement public void processElement(@Element Row row, OutputReceiver o) { - o.output(Row.withSchema(outputSchema).attachValues(row.getValues()).build()); + o.output(Row.withSchema(outputSchema).attachValues(row.getValues())); } })) .setRowSchema(outputSchema); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index 3d37b7388d38..165d56834273 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -18,11 +18,13 @@ package org.apache.beam.sdk.values; import static org.apache.beam.sdk.values.SchemaVerification.verifyRowValues; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; import java.io.Serializable; import java.math.BigDecimal; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -30,30 +32,61 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Objects; +import java.util.Optional; import java.util.stream.Collector; import java.util.stream.Collectors; import java.util.stream.IntStream; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.Factory; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.Schema.LogicalType; import org.apache.beam.sdk.schemas.Schema.TypeName; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.joda.time.DateTime; +import org.joda.time.Instant; import org.joda.time.ReadableDateTime; import org.joda.time.ReadableInstant; +import org.joda.time.base.AbstractInstant; /** * {@link Row} is an immutable tuple-like schema to represent one element in a {@link PCollection}. * The fields are described with a {@link Schema}. * - *

{@link Schema} contains the names for each field and the coder for the whole record, - * {see @link Schema#getRowCoder()}. + *

{@link Schema} contains the names and types for each field. + * + *

There are several ways to build a new Row object. To build a row from scratch using a schema + * object, {@link Row#withSchema} can be used. Schema fields can be specified by name, and nested + * fields can be specified using the field selection syntax. For example: + * + *

{@code
+ * Row row = Row.withSchema(schema)
+ *              .withFieldValue("userId", "user1)
+ *              .withFieldValue("location.city", "seattle")
+ *              .withFieldValue("location.state", "wa")
+ *              .build();
+ * }
+ * + *

The {@link Row#fromRow} builder can be used to base a row off of another row. The builder can + * be used to specify values for specific fields, and all the remaining values will be taken from + * the original row. For example, the following produces a row identical to the above row except for + * the location.city field. + * + *

{@code
+ * Row modifiedRow =
+ *     Row.fromRow(row)
+ *        .withFieldValue("location.city", "tacoma")
+ *        .build();
+ * }
*/ @Experimental(Kind.SCHEMAS) public abstract class Row implements Serializable { @@ -72,6 +105,7 @@ public abstract class Row implements Serializable { /** Return the size of data fields. */ public abstract int getFieldCount(); + /** Return the list of data values. */ public abstract List getValues(); @@ -585,37 +619,151 @@ public String toString() { } /** - * Creates a record builder with specified {@link #getSchema()}. {@link Builder#build()} will - * throw an {@link IllegalArgumentException} if number of fields in {@link #getSchema()} does not - * match the number of fields specified. + * Creates a row builder with specified {@link #getSchema()}. {@link Builder#build()} will throw + * an {@link IllegalArgumentException} if number of fields in {@link #getSchema()} does not match + * the number of fields specified. If any of the arguments don't match the expected types for the + * schema fields, {@link Builder#build()} will throw a {@link ClassCastException}. */ public static Builder withSchema(Schema schema) { return new Builder(schema); } + /** + * Creates a row builder based on the specified row. Field values in the new row can be explicitly + * set using {@link FieldValueBuilder#withFieldValue}. Any values not so overridden will be the + * same as the values in the original row. + */ + public static FieldValueBuilder fromRow(Row row) { + return new FieldValueBuilder(row.getSchema(), row, false); + } + + /** Builder for {@link Row} that bases a row on another row. */ + public static class FieldValueBuilder { + private final Schema schema; + private final @Nullable Row sourceRow; + private final Map fieldValues = Maps.newHashMap(); + private final boolean onlyOverrides; + + private FieldValueBuilder(Schema schema, @Nullable Row sourceRow, boolean onlyOverrides) { + this.schema = schema; + this.sourceRow = sourceRow; + this.onlyOverrides = onlyOverrides; + } + + public Schema getSchema() { + return schema; + } + + /** + * Set a field value using the field name. Nested values can be set using the field selection + * syntax. + */ + public FieldValueBuilder withFieldValue(String fieldName, Object value) { + return withFieldValue(FieldAccessDescriptor.withFieldNames(fieldName), value); + } + + /** Set a field value using the field id. */ + public FieldValueBuilder withFieldValue(Integer fieldId, Object value) { + return withFieldValue(FieldAccessDescriptor.withFieldIds(fieldId), value); + } + + /** Set a field value using a FieldAccessDescriptor. */ + public FieldValueBuilder withFieldValue( + FieldAccessDescriptor fieldAccessDescriptor, Object value) { + FieldAccessDescriptor fieldAccess = fieldAccessDescriptor.resolve(getSchema()); + checkArgument(fieldAccess.referencesSingleField(), ""); + fieldValues.put(fieldAccess, new FieldOverride(value)); + return this; + } + + /** + * Sets field values using the field names. Nested values can be set using the field selection + * syntax. + */ + public FieldValueBuilder withFieldValues(Map values) { + fieldValues.putAll( + values.entrySet().stream() + .collect( + Collectors.toMap( + e -> FieldAccessDescriptor.withFieldNames(e.getKey()), + e -> new FieldOverride(e.getValue())))); + return this; + } + + /** + * Sets field values using the FieldAccessDescriptors. Nested values can be set using the field + * selection syntax. + */ + public FieldValueBuilder withFieldAccessDescriptors(Map values) { + fieldValues.putAll( + values.entrySet().stream() + .collect(Collectors.toMap(e -> e.getKey(), e -> new FieldOverride(e.getValue())))); + return this; + } + + public Row build() { + Row row = + (Row) + new RowFieldMatcher() + .match( + new CapturingRowCases(getSchema(), this.fieldValues, onlyOverrides), + FieldType.row(getSchema()), + FieldAccessDescriptor.create(), + sourceRow); + return row; + } + } + /** Builder for {@link Row}. */ public static class Builder { private List values = Lists.newArrayList(); - private boolean attached = false; - @Nullable private Factory> fieldValueGetterFactory; - @Nullable private Object getterTarget; - private Schema schema; + private final Schema schema; Builder(Schema schema) { this.schema = schema; } - public int nextFieldId() { - if (fieldValueGetterFactory != null) { - throw new RuntimeException("Not supported"); - } - return values.size(); - } - + /** Return the schema for the row being built. */ public Schema getSchema() { return schema; } + /** + * Set a field value using the field name. Nested values can be set using the field selection + * syntax. + */ + public FieldValueBuilder withFieldValue(String fieldName, Object value) { + checkState(values.isEmpty()); + return new FieldValueBuilder(schema, null, true).withFieldValue(fieldName, value); + } + + /** Set a field value using the field id. */ + public FieldValueBuilder withFieldValue(Integer fieldId, Object value) { + checkState(values.isEmpty()); + return new FieldValueBuilder(schema, null, true).withFieldValue(fieldId, value); + } + + /** Set a field value using a FieldAccessDescriptor. */ + public FieldValueBuilder withFieldValue( + FieldAccessDescriptor fieldAccessDescriptor, Object value) { + checkState(values.isEmpty()); + return new FieldValueBuilder(schema, null, true).withFieldValue(fieldAccessDescriptor, value); + } + /** + * Sets field values using the field names. Nested values can be set using the field selection + * syntax. + */ + public FieldValueBuilder withFieldValues(Map values) { + checkState(values.isEmpty()); + return new FieldValueBuilder(schema, null, true).withFieldValues(values); + } + + // The following methods allow appending a list of values to the Builder object. The values must + // be in the same + // order as the fields in the row. These methods cannot be used in conjunction with + // withFieldValue or + // withFieldValues. + public Builder addValue(@Nullable Object values) { this.values.add(values); return this; @@ -645,40 +793,63 @@ public Builder addIterable(Iterable values) { return this; } - // Values are attached. No verification is done, and no conversions are done. LogicalType - // values must be specified as the base type. - public Builder attachValues(List values) { - this.attached = true; - this.values = values; - return this; + // Values are attached. No verification is done, and no conversions are done. LogicalType values + // must be specified as the base type. This method should be used with great care, as no + // validation is done. If + // incorrect values are passed in, it could result in strange errors later in the pipeline. This + // method is largely + // used internal to Beam. + @Internal + public Row attachValues(List attachedValues) { + checkState(this.values.isEmpty()); + return new RowWithStorage(schema, attachedValues); } - public Builder attachValues(Object... values) { - return attachValues(Arrays.asList(values)); + public int nextFieldId() { + return values.size(); } - public Builder withFieldValueGetters( + @Internal + public Row withFieldValueGetters( Factory> fieldValueGetterFactory, Object getterTarget) { - this.fieldValueGetterFactory = fieldValueGetterFactory; - this.getterTarget = getterTarget; - return this; + checkState(getterTarget != null, "getters require withGetterTarget."); + return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); } public Row build() { checkNotNull(schema); - if (!this.values.isEmpty() && fieldValueGetterFactory != null) { - throw new IllegalArgumentException("Cannot specify both values and getters."); + + if (!values.isEmpty() && values.size() != schema.getFieldCount()) { + throw new IllegalArgumentException( + "Row expected " + + schema.getFieldCount() + + " fields. initialized with " + + values.size() + + " fields."); + } + + Map fieldValues = + Maps.newHashMapWithExpectedSize(this.values.size()); + for (int i = 0; i < this.values.size(); ++i) { + FieldAccessDescriptor fieldAccessDescriptor = + FieldAccessDescriptor.withFieldIds(i).resolve(schema); + fieldValues.putIfAbsent(fieldAccessDescriptor, new FieldOverride(this.values.get(i))); } - if (!this.values.isEmpty()) { - List storageValues = attached ? this.values : verifyRowValues(schema, this.values); - checkState(getterTarget == null, "withGetterTarget requires getters."); - return new RowWithStorage(schema, storageValues); - } else if (fieldValueGetterFactory != null) { - checkState(getterTarget != null, "getters require withGetterTarget."); - return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); + + Row row; + if (!fieldValues.isEmpty()) { + row = + (Row) + new RowFieldMatcher() + .match( + new CapturingRowCases(schema, fieldValues, true), + FieldType.row(schema), + FieldAccessDescriptor.create(), + null); } else { - return new RowWithStorage(schema, Collections.emptyList()); + row = new RowWithStorage(schema, Collections.emptyList()); } + return row; } } @@ -700,4 +871,458 @@ public static Row nullRow(Schema schema) { .addValues(Collections.nCopies(schema.getFieldCount(), null)) .build(); } + + // Subclasses of this interface implement process methods for each schema type. Each process + // method is invoked as + // a RowFieldMatcher walks down the schema tree. The FieldAccessDescriptor passed into each method + // identifies the + // current element of the schema being processed. + private interface RowCases { + Row processRow( + FieldAccessDescriptor fieldAccessDescriptor, + Schema schema, + Row value, + RowFieldMatcher matcher); + + Collection processArray( + FieldAccessDescriptor fieldAccessDescriptor, + FieldType collectionElementType, + Collection values, + RowFieldMatcher matcher); + + Iterable processIterable( + FieldAccessDescriptor fieldAccessDescriptor, + FieldType collectionElementType, + Iterable values, + RowFieldMatcher matcher); + + Map processMap( + FieldAccessDescriptor fieldAccessDescriptor, + FieldType keyType, + FieldType valueType, + Map valueMap, + RowFieldMatcher matcher); + + Object processLogicalType( + FieldAccessDescriptor fieldAccessDescriptor, + LogicalType logicalType, + Object baseType, + RowFieldMatcher matcher); + + Instant processDateTime( + FieldAccessDescriptor fieldAccessDescriptor, + AbstractInstant instant, + RowFieldMatcher matcher); + + Byte processByte( + FieldAccessDescriptor fieldAccessDescriptor, Byte value, RowFieldMatcher matcher); + + Short processInt16( + FieldAccessDescriptor fieldAccessDescriptor, Short value, RowFieldMatcher matcher); + + Integer processInt32( + FieldAccessDescriptor fieldAccessDescriptor, Integer value, RowFieldMatcher matcher); + + Long processInt64( + FieldAccessDescriptor fieldAccessDescriptor, Long value, RowFieldMatcher matcher); + + BigDecimal processDecimal( + FieldAccessDescriptor fieldAccessDescriptor, BigDecimal value, RowFieldMatcher matcher); + + Float processFloat( + FieldAccessDescriptor fieldAccessDescriptor, Float value, RowFieldMatcher matcher); + + Double processDouble( + FieldAccessDescriptor fieldAccessDescriptor, Double value, RowFieldMatcher matcher); + + String processString( + FieldAccessDescriptor fieldAccessDescriptor, String value, RowFieldMatcher matcher); + + Boolean processBoolean( + FieldAccessDescriptor fieldAccessDescriptor, Boolean value, RowFieldMatcher matcher); + + byte[] processBytes( + FieldAccessDescriptor fieldAccessDescriptor, byte[] value, RowFieldMatcher matcher); + } + + // Given a Row field, delegates processing to the correct process method on the RowCases + // parameter. + private static class RowFieldMatcher { + public Object match( + RowCases cases, + FieldType fieldType, + FieldAccessDescriptor fieldAccessDescriptor, + Object value) { + Object processedValue = null; + switch (fieldType.getTypeName()) { + case ARRAY: + processedValue = + cases.processArray( + fieldAccessDescriptor, + fieldType.getCollectionElementType(), + (Collection) value, + this); + break; + case ITERABLE: + processedValue = + cases.processIterable( + fieldAccessDescriptor, + fieldType.getCollectionElementType(), + (Iterable) value, + this); + break; + case MAP: + processedValue = + cases.processMap( + fieldAccessDescriptor, + fieldType.getMapKeyType(), + fieldType.getMapValueType(), + (Map) value, + this); + break; + case ROW: + processedValue = + cases.processRow(fieldAccessDescriptor, fieldType.getRowSchema(), (Row) value, this); + break; + case LOGICAL_TYPE: + LogicalType logicalType = fieldType.getLogicalType(); + processedValue = + cases.processLogicalType(fieldAccessDescriptor, logicalType, value, this); + break; + case DATETIME: + processedValue = + cases.processDateTime(fieldAccessDescriptor, (AbstractInstant) value, this); + break; + case BYTE: + processedValue = cases.processByte(fieldAccessDescriptor, (Byte) value, this); + break; + case BYTES: + processedValue = cases.processBytes(fieldAccessDescriptor, (byte[]) value, this); + break; + case INT16: + processedValue = cases.processInt16(fieldAccessDescriptor, (Short) value, this); + break; + case INT32: + processedValue = cases.processInt32(fieldAccessDescriptor, (Integer) value, this); + break; + case INT64: + processedValue = cases.processInt64(fieldAccessDescriptor, (Long) value, this); + break; + case DECIMAL: + processedValue = cases.processDecimal(fieldAccessDescriptor, (BigDecimal) value, this); + break; + case FLOAT: + processedValue = cases.processFloat(fieldAccessDescriptor, (Float) value, this); + break; + case DOUBLE: + processedValue = cases.processDouble(fieldAccessDescriptor, (Double) value, this); + break; + case STRING: + processedValue = cases.processString(fieldAccessDescriptor, (String) value, this); + break; + case BOOLEAN: + processedValue = cases.processBoolean(fieldAccessDescriptor, (Boolean) value, this); + break; + default: + // Shouldn't actually get here, but we need this case to satisfy linters. + throw new IllegalArgumentException( + String.format( + "Not a primitive type for field name %s: %s", fieldAccessDescriptor, fieldType)); + } + if (processedValue == null) { + if (!fieldType.getNullable()) { + throw new IllegalArgumentException( + String.format("%s is not nullable in field %s", fieldType, fieldAccessDescriptor)); + } + } + return processedValue; + } + } + + static class FieldOverride { + FieldOverride(Object overrideValue) { + this.overrideValue = Optional.ofNullable(overrideValue); + alreadyUsed = false; + } + + void setAlreadyUsed() { + this.alreadyUsed = true; + } + + boolean getAlreadyUsed() { + return alreadyUsed; + } + + Optional getOverrideValue() { + return (Optional) overrideValue; + } + + final Optional overrideValue; + boolean alreadyUsed; + } + // This implementation of RowCases captures a Row into a new Row. It also has the effect of + // validating all the + // field parameters. + // A Map of field values can also be passed in, and those field values will be used to override + // the values in the + // passed-in row. + private static class CapturingRowCases implements RowCases { + private final Schema topSchema; + private final Map fieldValueOverrides; + private final boolean onlyOverrides; + + private CapturingRowCases( + Schema topSchema, + Map fieldValueOverrides, + boolean onlyOverrides) { + this.topSchema = topSchema; + this.fieldValueOverrides = fieldValueOverrides; + this.onlyOverrides = onlyOverrides; + } + + private @Nullable FieldOverride override(FieldAccessDescriptor fieldAccessDescriptor) { + return fieldValueOverrides.get(fieldAccessDescriptor); + } + + private Optional overrideOrReturn(FieldAccessDescriptor fieldAccessDescriptor, T value) { + FieldOverride fieldOverride = override(fieldAccessDescriptor); + // null return means the item isn't in the map. + if (fieldOverride == null) { + // return onlyOverrides ? null : Optional.of(value); + return Optional.ofNullable(value); + } else { + return fieldOverride.getOverrideValue(); + } + } + + @Override + public Row processRow( + FieldAccessDescriptor fieldAccessDescriptor, + Schema schema, + Row value, + RowFieldMatcher matcher) { + Optional retValue = Optional.empty(); + FieldOverride override = override(fieldAccessDescriptor); + if (override == null) { + // Not in map. + if (value != null || onlyOverrides) { + List values = Lists.newArrayListWithCapacity(schema.getFieldCount()); + for (int i = 0; i < schema.getFieldCount(); ++i) { + FieldAccessDescriptor nestedDescriptor = + FieldAccessDescriptor.withFieldIds(fieldAccessDescriptor, i).resolve(topSchema); + Object fieldValue = onlyOverrides ? null : value.getValue(i); + values.add( + matcher.match(this, schema.getField(i).getType(), nestedDescriptor, fieldValue)); + } + retValue = Optional.of(new RowWithStorage(schema, values)); + } + } else { + retValue = override.getOverrideValue(); + } + return retValue.orElse(null); + } + + @Override + public Collection processArray( + FieldAccessDescriptor fieldAccessDescriptor, + FieldType collectionElementType, + Collection values, + RowFieldMatcher matcher) { + Optional> retValue = Optional.empty(); + FieldOverride override = override(fieldAccessDescriptor); + if (override == null) { + // Not in map of overrides. + if (onlyOverrides) { + retValue = Optional.of(Collections.emptyList()); + } else if (values != null) { + retValue = Optional.of(captureIterable(fieldAccessDescriptor, collectionElementType, values, true, matcher)); + } + } else { + retValue = override.getOverrideValue() + .map(o -> captureIterable(fieldAccessDescriptor, collectionElementType, (Iterable)o, false, matcher)); + } + return retValue.orElse(null); + } + + private Collection captureIterable(FieldAccessDescriptor fieldAccessDescriptor, + FieldType collectionElementType, Iterable values, + boolean recurseNestedRows, + RowFieldMatcher matcher) { + List captured = Lists.newArrayListWithCapacity(Iterables.size(values)); + for (Object listValue : values) { + boolean recurse = !collectionElementType.getTypeName().isCompositeType() || recurseNestedRows; + Object capturedElement = recurse ? + matcher.match(this, collectionElementType, fieldAccessDescriptor, listValue) + : listValue; + captured.add(capturedElement); + } + return captured; + } + + @Override + public Iterable processIterable( + FieldAccessDescriptor fieldAccessDescriptor, + FieldType collectionElementType, + Iterable values, + RowFieldMatcher matcher) { + Optional> retValue = Optional.empty(); + FieldOverride override = override(fieldAccessDescriptor); + if (override == null) { + if (onlyOverrides) { + retValue = Optional.of(Collections.emptyList()); + } else if (values != null) { + List capturedValues = Lists.newArrayListWithCapacity(Iterables.size(values)); + for (Object listValue : values) { + Object capturedElement = + matcher.match(this, collectionElementType, fieldAccessDescriptor, listValue); + capturedValues.add(capturedElement); + } + retValue = Optional.of(captureIterable(fieldAccessDescriptor, collectionElementType, + values, true, matcher)); + } + } else { + retValue = override.getOverrideValue() + .map(o -> captureIterable(fieldAccessDescriptor, collectionElementType, (Iterable) o, false, matcher)); + + } + return retValue.orElse(null); + } + + @Override + public Map processMap( + FieldAccessDescriptor fieldAccessDescriptor, + FieldType keyType, + FieldType valueType, + Map valueMap, + RowFieldMatcher matcher) { + Optional> retValue = Optional.empty(); + FieldOverride override = override(fieldAccessDescriptor); + if (override == null) { + if (onlyOverrides) { + retValue = Optional.of(Collections.emptyMap()); + } else if (valueMap != null) { + retValue = Optional.of(Maps.newHashMapWithExpectedSize(valueMap.size())); + for (Entry kv : valueMap.entrySet()) { + retValue + .get() + .put( + matcher.match(this, keyType, fieldAccessDescriptor, kv.getKey()), + matcher.match(this, valueType, fieldAccessDescriptor, kv.getValue())); + } + } + } else { + retValue = override.getOverrideValue(); + } + return retValue.orElse(null); + } + + @Override + public Object processLogicalType( + FieldAccessDescriptor fieldAccessDescriptor, + LogicalType logicalType, + Object value, + RowFieldMatcher matcher) { + Optional retValue = Optional.empty(); + FieldOverride override = override(fieldAccessDescriptor); + if (override == null) { + if (onlyOverrides || value != null) { + // If not an override, then this is coming from an already-built row, so no need to + // convert to base type. + retValue = + Optional.of( + matcher.match(this, logicalType.getBaseType(), fieldAccessDescriptor, value)); + } + } else { + // This is the override case. We assume the override is given as the logical type, not the + // base type. + retValue = + override + .getOverrideValue() + .map( + o -> + !logicalType.getBaseType().getTypeName().isCompositeType() + ? matcher.match( + this, + logicalType.getBaseType(), + fieldAccessDescriptor, + logicalType.toBaseType(o)) + : logicalType.toBaseType(o)); + } + return retValue.orElse(null); + } + + @Override + public Instant processDateTime( + FieldAccessDescriptor fieldAccessDescriptor, + AbstractInstant value, + RowFieldMatcher matcher) { + Optional retValue = Optional.empty(); + if (onlyOverrides || value != null) { + Instant instantValue = (value != null) ? value.toInstant() : null; + retValue = overrideOrReturn(fieldAccessDescriptor, instantValue); + } + return retValue.map(AbstractInstant::toInstant).orElse(null); + } + + @Override + public Byte processByte( + FieldAccessDescriptor fieldAccessDescriptor, Byte value, RowFieldMatcher matcher) { + return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + } + + @Override + public Short processInt16( + FieldAccessDescriptor fieldAccessDescriptor, Short value, RowFieldMatcher matcher) { + return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + } + + @Override + public Integer processInt32( + FieldAccessDescriptor fieldAccessDescriptor, Integer value, RowFieldMatcher matcher) { + return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + } + + @Override + public Long processInt64( + FieldAccessDescriptor fieldAccessDescriptor, Long value, RowFieldMatcher matcher) { + return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + } + + @Override + public BigDecimal processDecimal( + FieldAccessDescriptor fieldAccessDescriptor, BigDecimal value, RowFieldMatcher matcher) { + return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + } + + @Override + public Float processFloat( + FieldAccessDescriptor fieldAccessDescriptor, Float value, RowFieldMatcher matcher) { + return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + } + + @Override + public Double processDouble( + FieldAccessDescriptor fieldAccessDescriptor, Double value, RowFieldMatcher matcher) { + return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + } + + @Override + public String processString( + FieldAccessDescriptor fieldAccessDescriptor, String value, RowFieldMatcher matcher) { + return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + } + + @Override + public Boolean processBoolean( + FieldAccessDescriptor fieldAccessDescriptor, Boolean value, RowFieldMatcher matcher) { + return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + } + + @Override + public byte[] processBytes( + FieldAccessDescriptor fieldAccessDescriptor, byte[] value, RowFieldMatcher matcher) { + Object retValue = overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + return (retValue instanceof ByteBuffer) ? ((ByteBuffer) retValue).array() : (byte[]) retValue; + } + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java index 5ec69e5090b0..81c5ad92f491 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java @@ -461,7 +461,8 @@ public void testFromRowIterable() throws NoSuchSchemaException { SchemaTestUtils.assertSchemaEquivalent(ITERABLE_BEAM_SCHEMA, schema); List list = Lists.newArrayList("one", "two"); - Row iterableRow = Row.withSchema(ITERABLE_BEAM_SCHEMA).addIterable(list).build(); + Row iterableRow = + Row.withSchema(ITERABLE_BEAM_SCHEMA).attachValues(ImmutableList.of((Object) list)); IterableBean converted = registry.getFromRowFunction(IterableBean.class).apply(iterableRow); assertEquals(list, Lists.newArrayList(converted.getStrings())); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java index 02adaa0cf6ce..ddd00c282223 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java @@ -33,6 +33,7 @@ import java.util.stream.Stream; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; @@ -477,6 +478,213 @@ public void testCreateMapWithRowValue() { assertEquals(data, row.getMap("map")); } + @Test + public void testLogicalTypeWithRowValue() { + EnumerationType enumerationType = EnumerationType.create("zero", "one", "two"); + Schema type = + Stream.of(Schema.Field.of("f1_enum", FieldType.logicalType(enumerationType))) + .collect(toSchema()); + Row row = Row.withSchema(type).addValue(enumerationType.valueOf("zero")).build(); + assertEquals(0, (int) row.getValue(0)); + assertEquals( + enumerationType.valueOf("zero"), row.getLogicalTypeValue(0, EnumerationType.Value.class)); + } + + @Test + public void testLogicalTypeWithRowValueName() { + EnumerationType enumerationType = EnumerationType.create("zero", "one", "two"); + Schema type = + Stream.of(Schema.Field.of("f1_enum", FieldType.logicalType(enumerationType))) + .collect(toSchema()); + Row row = + Row.withSchema(type).withFieldValue("f1_enum", enumerationType.valueOf("zero")).build(); + assertEquals(0, (int) row.getValue(0)); + assertEquals( + enumerationType.valueOf("zero"), row.getLogicalTypeValue(0, EnumerationType.Value.class)); + } + + @Test + public void testLogicalTypeWithRowValueOverride() { + EnumerationType enumerationType = EnumerationType.create("zero", "one", "two"); + Schema type = + Stream.of(Schema.Field.of("f1_enum", FieldType.logicalType(enumerationType))) + .collect(toSchema()); + Row row = + Row.withSchema(type).withFieldValue("f1_enum", enumerationType.valueOf("zero")).build(); + Row overriddenRow = + Row.fromRow(row).withFieldValue("f1_enum", enumerationType.valueOf("one")).build(); + assertEquals(1, (int) overriddenRow.getValue(0)); + assertEquals( + enumerationType.valueOf("one"), + overriddenRow.getLogicalTypeValue(0, EnumerationType.Value.class)); + } + + @Test + public void testCreateWithNames() { + Schema type = + Stream.of( + Schema.Field.of("f_str", FieldType.STRING), + Schema.Field.of("f_byte", FieldType.BYTE), + Schema.Field.of("f_short", FieldType.INT16), + Schema.Field.of("f_int", FieldType.INT32), + Schema.Field.of("f_long", FieldType.INT64), + Schema.Field.of("f_float", FieldType.FLOAT), + Schema.Field.of("f_double", FieldType.DOUBLE), + Schema.Field.of("f_decimal", FieldType.DECIMAL), + Schema.Field.of("f_boolean", FieldType.BOOLEAN), + Schema.Field.of("f_datetime", FieldType.DATETIME), + Schema.Field.of("f_bytes", FieldType.BYTES), + Schema.Field.of("f_array", FieldType.array(FieldType.STRING)), + Schema.Field.of("f_iterable", FieldType.iterable(FieldType.STRING)), + Schema.Field.of("f_map", FieldType.map(FieldType.STRING, FieldType.STRING))) + .collect(toSchema()); + + DateTime dateTime = + new DateTime().withDate(1979, 03, 14).withTime(1, 2, 3, 4).withZone(DateTimeZone.UTC); + byte[] bytes = new byte[] {1, 2, 3, 4}; + + Row row = + Row.withSchema(type) + .withFieldValue("f_str", "str1") + .withFieldValue("f_byte", (byte) 42) + .withFieldValue("f_short", (short) 43) + .withFieldValue("f_int", (int) 44) + .withFieldValue("f_long", (long) 45) + .withFieldValue("f_float", (float) 3.14) + .withFieldValue("f_double", (double) 3.141) + .withFieldValue("f_decimal", new BigDecimal(3.1415)) + .withFieldValue("f_boolean", true) + .withFieldValue("f_datetime", dateTime) + .withFieldValue("f_bytes", bytes) + .withFieldValue("f_array", Lists.newArrayList("one", "two")) + .withFieldValue("f_iterable", Lists.newArrayList("one", "two", "three")) + .withFieldValue("f_map", ImmutableMap.of("hello", "goodbye", "here", "there")) + .build(); + + Row expectedRow = + Row.withSchema(type) + .addValues( + "str1", + (byte) 42, + (short) 43, + (int) 44, + (long) 45, + (float) 3.14, + (double) 3.141, + new BigDecimal(3.1415), + true, + dateTime, + bytes, + Lists.newArrayList("one", "two"), + Lists.newArrayList("one", "two", "three"), + ImmutableMap.of("hello", "goodbye", "here", "there")) + .build(); + assertEquals(expectedRow, row); + } + + @Test + public void testCreateWithNestedNames() { + Schema nestedType = + Stream.of( + Schema.Field.of("f_str", FieldType.STRING), + Schema.Field.of("f_int", FieldType.INT32)) + .collect(toSchema()); + Schema topType = + Stream.of( + Schema.Field.of("top_int", FieldType.INT32), + Schema.Field.of("f_nested", FieldType.row(nestedType))) + .collect(toSchema()); + Row row = + Row.withSchema(topType) + .withFieldValue("top_int", 42) + .withFieldValue("f_nested.f_str", "string") + .withFieldValue("f_nested.f_int", 43) + .build(); + + Row expectedRow = + Row.withSchema(topType) + .addValues(42, Row.withSchema(nestedType).addValues("string", 43).build()) + .build(); + assertEquals(expectedRow, row); + } + + @Test + public void testCreateWithCollectionNames() { + Schema type = + Stream.of( + Schema.Field.of("f_array", FieldType.array(FieldType.INT32)), + Schema.Field.of("f_iterable", FieldType.iterable(FieldType.INT32)), + Schema.Field.of("f_map", FieldType.map(FieldType.STRING, FieldType.STRING))) + .collect(toSchema()); + + Row row = + Row.withSchema(type) + .withFieldValue("f_array", ImmutableList.of(1, 2, 3)) + .withFieldValue("f_iterable", ImmutableList.of(1, 2, 3)) + .withFieldValue("f_map", ImmutableMap.of("one", "two")) + .build(); + + Row expectedRow = + Row.withSchema(type) + .addValues( + ImmutableList.of(1, 2, 3), ImmutableList.of(1, 2, 3), ImmutableMap.of("one", "two")) + .build(); + assertEquals(expectedRow, row); + } + + @Test + public void testOverrideRow() { + Schema type = + Stream.of( + Schema.Field.of("f_str", FieldType.STRING), + Schema.Field.of("f_int", FieldType.INT32)) + .collect(toSchema()); + Row sourceRow = + Row.withSchema(type).withFieldValue("f_str", "string").withFieldValue("f_int", 42).build(); + + Row modifiedRow = Row.fromRow(sourceRow).withFieldValue("f_str", "modifiedString").build(); + + Row expectedRow = + Row.withSchema(type) + .withFieldValue("f_str", "modifiedString") + .withFieldValue("f_int", 42) + .build(); + assertEquals(expectedRow, modifiedRow); + } + + @Test + public void testOverrideNestedRow() { + Schema nestedType = + Stream.of( + Schema.Field.of("f_str", FieldType.STRING), + Schema.Field.of("f_int", FieldType.INT32)) + .collect(toSchema()); + Schema topType = + Stream.of( + Schema.Field.of("top_int", FieldType.INT32), + Schema.Field.of("f_nested", FieldType.row(nestedType))) + .collect(toSchema()); + Row sourceRow = + Row.withSchema(topType) + .withFieldValue("top_int", 42) + .withFieldValue("f_nested.f_str", "string") + .withFieldValue("f_nested.f_int", 43) + .build(); + Row modifiedRow = + Row.fromRow(sourceRow) + .withFieldValue("f_nested.f_str", "modifiedString") + .withFieldValue("f_nested.f_int", 143) + .build(); + + Row expectedRow = + Row.withSchema(topType) + .withFieldValue("top_int", 42) + .withFieldValue("f_nested.f_str", "modifiedString") + .withFieldValue("f_nested.f_int", 143) + .build(); + assertEquals(expectedRow, modifiedRow); + } + @Test public void testCollector() { Schema type = From f478de782b26afeaac066b269b90442cd213f1fb Mon Sep 17 00:00:00 2001 From: Reuven Lax Date: Wed, 25 Mar 2020 10:27:17 -0700 Subject: [PATCH 2/4] foo --- .../sdk/schemas/FieldAccessDescriptor.java | 52 ---- .../java/org/apache/beam/sdk/values/Row.java | 258 +++++++----------- 2 files changed, 102 insertions(+), 208 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldAccessDescriptor.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldAccessDescriptor.java index eac176016439..319e1061124a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldAccessDescriptor.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldAccessDescriptor.java @@ -201,58 +201,6 @@ public static FieldAccessDescriptor withFieldNamesAs(Map fieldNa return union(fields); } - public static FieldAccessDescriptor withFieldNames( - FieldAccessDescriptor baseDescriptor, String... fieldNames) { - return withFieldNames(baseDescriptor, Arrays.asList(fieldNames)); - } - - public static FieldAccessDescriptor withFieldNames( - FieldAccessDescriptor baseDescriptor, Iterable fieldNames) { - if (baseDescriptor.getFieldsAccessed().isEmpty() - && baseDescriptor.getNestedFieldsAccessed().isEmpty()) { - return withFieldNames(fieldNames); - } - if (!baseDescriptor.getFieldsAccessed().isEmpty()) { - checkArgument(baseDescriptor.getNestedFieldsAccessed().isEmpty()); - FieldDescriptor fieldDescriptor = - Iterables.getOnlyElement(baseDescriptor.getFieldsAccessed()); - return FieldAccessDescriptor.create() - .withNestedField(fieldDescriptor, FieldAccessDescriptor.withFieldNames(fieldNames)); - } else { - checkArgument(baseDescriptor.getFieldsAccessed().isEmpty()); - Map.Entry entry = - Iterables.getOnlyElement(baseDescriptor.getNestedFieldsAccessed().entrySet()); - return FieldAccessDescriptor.create() - .withNestedField(entry.getKey(), withFieldNames(entry.getValue(), fieldNames)); - } - } - - public static FieldAccessDescriptor withFieldIds( - FieldAccessDescriptor baseDescriptor, Integer... fieldIds) { - return withFieldIds(baseDescriptor, Arrays.asList(fieldIds)); - } - - public static FieldAccessDescriptor withFieldIds( - FieldAccessDescriptor baseDescriptor, Iterable fieldIds) { - if (baseDescriptor.getFieldsAccessed().isEmpty() - && baseDescriptor.getNestedFieldsAccessed().isEmpty()) { - return withFieldIds(fieldIds); - } - if (!baseDescriptor.getFieldsAccessed().isEmpty()) { - checkArgument(baseDescriptor.getNestedFieldsAccessed().isEmpty()); - FieldDescriptor fieldDescriptor = - Iterables.getOnlyElement(baseDescriptor.getFieldsAccessed()); - return FieldAccessDescriptor.create() - .withNestedField(fieldDescriptor, FieldAccessDescriptor.withFieldIds(fieldIds)); - } else { - checkArgument(baseDescriptor.getFieldsAccessed().isEmpty()); - Map.Entry entry = - Iterables.getOnlyElement(baseDescriptor.getNestedFieldsAccessed().entrySet()); - return FieldAccessDescriptor.create() - .withNestedField(entry.getKey(), withFieldIds(entry.getValue(), fieldIds)); - } - } - /** * Return a descriptor that accesses the specified field names as nested subfields of the * baseDescriptor. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index 165d56834273..2062e8ea9fb4 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -17,7 +17,6 @@ */ package org.apache.beam.sdk.values; -import static org.apache.beam.sdk.values.SchemaVerification.verifyRowValues; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; @@ -634,7 +633,7 @@ public static Builder withSchema(Schema schema) { * same as the values in the original row. */ public static FieldValueBuilder fromRow(Row row) { - return new FieldValueBuilder(row.getSchema(), row, false); + return new FieldValueBuilder(row.getSchema(), row); } /** Builder for {@link Row} that bases a row on another row. */ @@ -642,12 +641,10 @@ public static class FieldValueBuilder { private final Schema schema; private final @Nullable Row sourceRow; private final Map fieldValues = Maps.newHashMap(); - private final boolean onlyOverrides; - private FieldValueBuilder(Schema schema, @Nullable Row sourceRow, boolean onlyOverrides) { + private FieldValueBuilder(Schema schema, @Nullable Row sourceRow) { this.schema = schema; this.sourceRow = sourceRow; - this.onlyOverrides = onlyOverrides; } public Schema getSchema() { @@ -706,7 +703,7 @@ public Row build() { (Row) new RowFieldMatcher() .match( - new CapturingRowCases(getSchema(), this.fieldValues, onlyOverrides), + new CapturingRowCases(getSchema(), this.fieldValues), FieldType.row(getSchema()), FieldAccessDescriptor.create(), sourceRow); @@ -718,6 +715,7 @@ public Row build() { public static class Builder { private List values = Lists.newArrayList(); private final Schema schema; + private Row nullRow; Builder(Schema schema) { this.schema = schema; @@ -728,26 +726,33 @@ public Schema getSchema() { return schema; } + Row nullRow() { + if (nullRow == null) { + this.nullRow = Row.withSchema(schema).attachValues(Collections.nCopies(schema.getFieldCount(), null)); + } + return nullRow; + } + /** * Set a field value using the field name. Nested values can be set using the field selection * syntax. */ public FieldValueBuilder withFieldValue(String fieldName, Object value) { checkState(values.isEmpty()); - return new FieldValueBuilder(schema, null, true).withFieldValue(fieldName, value); + return new FieldValueBuilder(schema, nullRow()).withFieldValue(fieldName, value); } /** Set a field value using the field id. */ public FieldValueBuilder withFieldValue(Integer fieldId, Object value) { checkState(values.isEmpty()); - return new FieldValueBuilder(schema, null, true).withFieldValue(fieldId, value); + return new FieldValueBuilder(schema, nullRow()).withFieldValue(fieldId, value); } /** Set a field value using a FieldAccessDescriptor. */ public FieldValueBuilder withFieldValue( FieldAccessDescriptor fieldAccessDescriptor, Object value) { checkState(values.isEmpty()); - return new FieldValueBuilder(schema, null, true).withFieldValue(fieldAccessDescriptor, value); + return new FieldValueBuilder(schema, nullRow()).withFieldValue(fieldAccessDescriptor, value); } /** * Sets field values using the field names. Nested values can be set using the field selection @@ -755,7 +760,7 @@ public FieldValueBuilder withFieldValue( */ public FieldValueBuilder withFieldValues(Map values) { checkState(values.isEmpty()); - return new FieldValueBuilder(schema, null, true).withFieldValues(values); + return new FieldValueBuilder(schema, nullRow()).withFieldValues(values); } // The following methods allow appending a list of values to the Builder object. The values must @@ -842,10 +847,10 @@ public Row build() { (Row) new RowFieldMatcher() .match( - new CapturingRowCases(schema, fieldValues, true), + new CapturingRowCases(schema, fieldValues), FieldType.row(schema), FieldAccessDescriptor.create(), - null); + nullRow()); } else { row = new RowWithStorage(schema, Collections.emptyList()); } @@ -1041,25 +1046,16 @@ public Object match( static class FieldOverride { FieldOverride(Object overrideValue) { - this.overrideValue = Optional.ofNullable(overrideValue); - alreadyUsed = false; - } - - void setAlreadyUsed() { - this.alreadyUsed = true; - } - - boolean getAlreadyUsed() { - return alreadyUsed; + this.overrideValue = overrideValue; } - Optional getOverrideValue() { - return (Optional) overrideValue; + Object getOverrideValue() { + return overrideValue; } - final Optional overrideValue; - boolean alreadyUsed; + final Object overrideValue; } + // This implementation of RowCases captures a Row into a new Row. It also has the effect of // validating all the // field parameters. @@ -1069,79 +1065,85 @@ Optional getOverrideValue() { private static class CapturingRowCases implements RowCases { private final Schema topSchema; private final Map fieldValueOverrides; - private final boolean onlyOverrides; + private static class FieldAccessNode { + + } private CapturingRowCases( - Schema topSchema, - Map fieldValueOverrides, - boolean onlyOverrides) { + Schema topSchema, + Map fieldValueOverrides) { this.topSchema = topSchema; this.fieldValueOverrides = fieldValueOverrides; - this.onlyOverrides = onlyOverrides; } - private @Nullable FieldOverride override(FieldAccessDescriptor fieldAccessDescriptor) { + private @Nullable + FieldOverride override(FieldAccessDescriptor fieldAccessDescriptor) { return fieldValueOverrides.get(fieldAccessDescriptor); } - private Optional overrideOrReturn(FieldAccessDescriptor fieldAccessDescriptor, T value) { + private T overrideOrReturn(FieldAccessDescriptor fieldAccessDescriptor, T value) { FieldOverride fieldOverride = override(fieldAccessDescriptor); // null return means the item isn't in the map. - if (fieldOverride == null) { - // return onlyOverrides ? null : Optional.of(value); - return Optional.ofNullable(value); - } else { - return fieldOverride.getOverrideValue(); - } + return (fieldOverride != null) ? (T) fieldOverride.getOverrideValue() : null; } @Override public Row processRow( - FieldAccessDescriptor fieldAccessDescriptor, - Schema schema, - Row value, - RowFieldMatcher matcher) { - Optional retValue = Optional.empty(); + FieldAccessDescriptor fieldAccessDescriptor, + Schema schema, + Row value, + RowFieldMatcher matcher) { FieldOverride override = override(fieldAccessDescriptor); - if (override == null) { - // Not in map. - if (value != null || onlyOverrides) { - List values = Lists.newArrayListWithCapacity(schema.getFieldCount()); - for (int i = 0; i < schema.getFieldCount(); ++i) { - FieldAccessDescriptor nestedDescriptor = - FieldAccessDescriptor.withFieldIds(fieldAccessDescriptor, i).resolve(topSchema); - Object fieldValue = onlyOverrides ? null : value.getValue(i); - values.add( - matcher.match(this, schema.getField(i).getType(), nestedDescriptor, fieldValue)); - } - retValue = Optional.of(new RowWithStorage(schema, values)); + Row retValue = null; + if (override != null) { + retValue = (Row) override.getOverrideValue(); + } else if (value != null) { + // TODO: We should only recurse if there is an overidden subfield. + List values = Lists.newArrayListWithCapacity(schema.getFieldCount()); + for (int i = 0; i < schema.getFieldCount(); ++i) { + FieldAccessDescriptor nestedDescriptor = + FieldAccessDescriptor.withFieldIds(fieldAccessDescriptor, i).resolve(topSchema); + Object fieldValue = value.getValue(i); + values.add( + matcher.match(this, schema.getField(i).getType(), nestedDescriptor, fieldValue)); } - } else { - retValue = override.getOverrideValue(); + retValue = new RowWithStorage(schema, values); } - return retValue.orElse(null); + return retValue; } @Override public Collection processArray( - FieldAccessDescriptor fieldAccessDescriptor, - FieldType collectionElementType, - Collection values, - RowFieldMatcher matcher) { - Optional> retValue = Optional.empty(); + FieldAccessDescriptor fieldAccessDescriptor, + FieldType collectionElementType, + Collection values, + RowFieldMatcher matcher) { + Collection retValue = null; FieldOverride override = override(fieldAccessDescriptor); - if (override == null) { - // Not in map of overrides. - if (onlyOverrides) { - retValue = Optional.of(Collections.emptyList()); - } else if (values != null) { - retValue = Optional.of(captureIterable(fieldAccessDescriptor, collectionElementType, values, true, matcher)); - } - } else { - retValue = override.getOverrideValue() - .map(o -> captureIterable(fieldAccessDescriptor, collectionElementType, (Iterable)o, false, matcher)); + if (override != null) { + retValue = (Collection) override.getOverrideValue(); + } else if (values != null) { + retValue = captureIterable(fieldAccessDescriptor, collectionElementType, values, false, matcher); + } - return retValue.orElse(null); + return retValue; + } + + @Override + public Iterable processIterable( + FieldAccessDescriptor fieldAccessDescriptor, + FieldType collectionElementType, + Iterable values, + RowFieldMatcher matcher) { + Iterable retValue = null; + FieldOverride override = override(fieldAccessDescriptor); + if (override != null) { + retValue = (Iterable) override.getOverrideValue(); + } else if (values != null) { + retValue = captureIterable(fieldAccessDescriptor, collectionElementType, + values, true, matcher); + } + return retValue; } private Collection captureIterable(FieldAccessDescriptor fieldAccessDescriptor, @@ -1151,7 +1153,7 @@ private Collection captureIterable(FieldAccessDescriptor fieldAccessDesc List captured = Lists.newArrayListWithCapacity(Iterables.size(values)); for (Object listValue : values) { boolean recurse = !collectionElementType.getTypeName().isCompositeType() || recurseNestedRows; - Object capturedElement = recurse ? + Object capturedElement = recurse ? matcher.match(this, collectionElementType, fieldAccessDescriptor, listValue) : listValue; captured.add(capturedElement); @@ -1159,34 +1161,6 @@ private Collection captureIterable(FieldAccessDescriptor fieldAccessDesc return captured; } - @Override - public Iterable processIterable( - FieldAccessDescriptor fieldAccessDescriptor, - FieldType collectionElementType, - Iterable values, - RowFieldMatcher matcher) { - Optional> retValue = Optional.empty(); - FieldOverride override = override(fieldAccessDescriptor); - if (override == null) { - if (onlyOverrides) { - retValue = Optional.of(Collections.emptyList()); - } else if (values != null) { - List capturedValues = Lists.newArrayListWithCapacity(Iterables.size(values)); - for (Object listValue : values) { - Object capturedElement = - matcher.match(this, collectionElementType, fieldAccessDescriptor, listValue); - capturedValues.add(capturedElement); - } - retValue = Optional.of(captureIterable(fieldAccessDescriptor, collectionElementType, - values, true, matcher)); - } - } else { - retValue = override.getOverrideValue() - .map(o -> captureIterable(fieldAccessDescriptor, collectionElementType, (Iterable) o, false, matcher)); - - } - return retValue.orElse(null); - } @Override public Map processMap( @@ -1195,25 +1169,20 @@ public Map processMap( FieldType valueType, Map valueMap, RowFieldMatcher matcher) { - Optional> retValue = Optional.empty(); + Map retValue = null; FieldOverride override = override(fieldAccessDescriptor); - if (override == null) { - if (onlyOverrides) { - retValue = Optional.of(Collections.emptyMap()); - } else if (valueMap != null) { - retValue = Optional.of(Maps.newHashMapWithExpectedSize(valueMap.size())); + if (override != null) { + retValue = (Map) override.getOverrideValue(); + } else if (valueMap != null) { + retValue = Maps.newHashMapWithExpectedSize(valueMap.size()); for (Entry kv : valueMap.entrySet()) { retValue - .get() .put( matcher.match(this, keyType, fieldAccessDescriptor, kv.getKey()), matcher.match(this, valueType, fieldAccessDescriptor, kv.getValue())); } - } - } else { - retValue = override.getOverrideValue(); } - return retValue.orElse(null); + return retValue; } @Override @@ -1222,33 +1191,14 @@ public Object processLogicalType( LogicalType logicalType, Object value, RowFieldMatcher matcher) { - Optional retValue = Optional.empty(); + Object retValue = null; FieldOverride override = override(fieldAccessDescriptor); - if (override == null) { - if (onlyOverrides || value != null) { - // If not an override, then this is coming from an already-built row, so no need to - // convert to base type. - retValue = - Optional.of( - matcher.match(this, logicalType.getBaseType(), fieldAccessDescriptor, value)); - } - } else { - // This is the override case. We assume the override is given as the logical type, not the - // base type. - retValue = - override - .getOverrideValue() - .map( - o -> - !logicalType.getBaseType().getTypeName().isCompositeType() - ? matcher.match( - this, - logicalType.getBaseType(), - fieldAccessDescriptor, - logicalType.toBaseType(o)) - : logicalType.toBaseType(o)); + if (override != null) { + retValue = override.getOverrideValue(); + } else if (value != null) { + retValue = logicalType.toInputType(logicalType.toBaseType(value)); } - return retValue.orElse(null); + return retValue; } @Override @@ -1256,72 +1206,68 @@ public Instant processDateTime( FieldAccessDescriptor fieldAccessDescriptor, AbstractInstant value, RowFieldMatcher matcher) { - Optional retValue = Optional.empty(); - if (onlyOverrides || value != null) { - Instant instantValue = (value != null) ? value.toInstant() : null; - retValue = overrideOrReturn(fieldAccessDescriptor, instantValue); - } - return retValue.map(AbstractInstant::toInstant).orElse(null); + AbstractInstant instantValue = overrideOrReturn(fieldAccessDescriptor, value); + return (instantValue != null) ? instantValue.toInstant() : null; } @Override public Byte processByte( FieldAccessDescriptor fieldAccessDescriptor, Byte value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + return overrideOrReturn(fieldAccessDescriptor, value); } @Override public Short processInt16( FieldAccessDescriptor fieldAccessDescriptor, Short value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + return overrideOrReturn(fieldAccessDescriptor, value); } @Override public Integer processInt32( FieldAccessDescriptor fieldAccessDescriptor, Integer value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + return overrideOrReturn(fieldAccessDescriptor, value); } @Override public Long processInt64( FieldAccessDescriptor fieldAccessDescriptor, Long value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + return overrideOrReturn(fieldAccessDescriptor, value); } @Override public BigDecimal processDecimal( FieldAccessDescriptor fieldAccessDescriptor, BigDecimal value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + return overrideOrReturn(fieldAccessDescriptor, value); } @Override public Float processFloat( FieldAccessDescriptor fieldAccessDescriptor, Float value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + return overrideOrReturn(fieldAccessDescriptor, value); } @Override public Double processDouble( FieldAccessDescriptor fieldAccessDescriptor, Double value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + return overrideOrReturn(fieldAccessDescriptor, value); } @Override public String processString( FieldAccessDescriptor fieldAccessDescriptor, String value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + return overrideOrReturn(fieldAccessDescriptor, value); } @Override public Boolean processBoolean( FieldAccessDescriptor fieldAccessDescriptor, Boolean value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + return overrideOrReturn(fieldAccessDescriptor, value); } @Override public byte[] processBytes( FieldAccessDescriptor fieldAccessDescriptor, byte[] value, RowFieldMatcher matcher) { - Object retValue = overrideOrReturn(fieldAccessDescriptor, value).orElse(null); + Object retValue = overrideOrReturn(fieldAccessDescriptor, value); return (retValue instanceof ByteBuffer) ? ((ByteBuffer) retValue).array() : (byte[]) retValue; } } From 52695aa68322b8bbd5fd90a9e99cd6ad72168bee Mon Sep 17 00:00:00 2001 From: Reuven Lax Date: Thu, 26 Mar 2020 15:38:16 -0700 Subject: [PATCH 3/4] fix failing tests --- .../beam/sdk/schemas/transforms/CoGroup.java | 3 +- .../schemas/utils/SelectByteBuddyHelpers.java | 5 - .../java/org/apache/beam/sdk/values/Row.java | 469 +------------- .../org/apache/beam/sdk/values/RowUtils.java | 591 ++++++++++++++++++ .../beam/sdk/values/SchemaVerification.java | 239 +------ .../org/apache/beam/sdk/values/RowTest.java | 6 +- 6 files changed, 639 insertions(+), 674 deletions(-) create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowUtils.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java index 39b9de67025d..c4366e2318cf 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/CoGroup.java @@ -632,8 +632,7 @@ private void crossProductHelper( if (atBottom) { // Bottom of recursive call, so output the row we've accumulated. Row row = - Row.withSchema(getOutputSchema()) - .attachValues(Lists.newArrayList(accumulatedRows)); + Row.withSchema(getOutputSchema()).attachValues(Lists.newArrayList(accumulatedRows)); o.output(row); } else { crossProduct(tagIndex + 1, accumulatedRows, iterables, o); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/SelectByteBuddyHelpers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/SelectByteBuddyHelpers.java index 3bf3069dc618..c5f29a5611b3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/SelectByteBuddyHelpers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/SelectByteBuddyHelpers.java @@ -368,11 +368,6 @@ public ByteCodeAppender appender(final Target implementationTarget) { ElementMatchers.named("attachValues") .and(ElementMatchers.takesArguments(Object[].class))) .getOnly()), - MethodInvocation.invoke( - new ForLoadedType(Row.Builder.class) - .getDeclaredMethods() - .filter(ElementMatchers.named("build")) - .getOnly()), MethodReturn.REFERENCE); size = size.aggregate(attachToRow.apply(methodVisitor, implementationContext)); return new Size(size.getMaximalSize(), localVariables.getTotalNumVariables()); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index 2062e8ea9fb4..097c2a035f0f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -23,7 +23,6 @@ import java.io.Serializable; import java.math.BigDecimal; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -31,9 +30,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Objects; -import java.util.Optional; import java.util.stream.Collector; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -46,16 +43,17 @@ import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.Schema.LogicalType; import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.sdk.values.RowUtils.CapturingRowCases; +import org.apache.beam.sdk.values.RowUtils.FieldOverride; +import org.apache.beam.sdk.values.RowUtils.FieldOverrides; +import org.apache.beam.sdk.values.RowUtils.RowFieldMatcher; +import org.apache.beam.sdk.values.RowUtils.RowPosition; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.joda.time.DateTime; -import org.joda.time.Instant; import org.joda.time.ReadableDateTime; import org.joda.time.ReadableInstant; -import org.joda.time.base.AbstractInstant; /** * {@link Row} is an immutable tuple-like schema to represent one element in a {@link PCollection}. @@ -640,11 +638,12 @@ public static FieldValueBuilder fromRow(Row row) { public static class FieldValueBuilder { private final Schema schema; private final @Nullable Row sourceRow; - private final Map fieldValues = Maps.newHashMap(); + private final FieldOverrides fieldOverrides; private FieldValueBuilder(Schema schema, @Nullable Row sourceRow) { this.schema = schema; this.sourceRow = sourceRow; + this.fieldOverrides = new FieldOverrides(schema); } public Schema getSchema() { @@ -669,7 +668,7 @@ public FieldValueBuilder withFieldValue( FieldAccessDescriptor fieldAccessDescriptor, Object value) { FieldAccessDescriptor fieldAccess = fieldAccessDescriptor.resolve(getSchema()); checkArgument(fieldAccess.referencesSingleField(), ""); - fieldValues.put(fieldAccess, new FieldOverride(value)); + fieldOverrides.addOverride(fieldAccess, new FieldOverride(value)); return this; } @@ -678,12 +677,12 @@ public FieldValueBuilder withFieldValue( * syntax. */ public FieldValueBuilder withFieldValues(Map values) { - fieldValues.putAll( - values.entrySet().stream() - .collect( - Collectors.toMap( - e -> FieldAccessDescriptor.withFieldNames(e.getKey()), - e -> new FieldOverride(e.getValue())))); + values.entrySet().stream() + .forEach( + e -> + fieldOverrides.addOverride( + FieldAccessDescriptor.withFieldNames(e.getKey()), + new FieldOverride(e.getValue()))); return this; } @@ -692,9 +691,8 @@ public FieldValueBuilder withFieldValues(Map values) { * selection syntax. */ public FieldValueBuilder withFieldAccessDescriptors(Map values) { - fieldValues.putAll( - values.entrySet().stream() - .collect(Collectors.toMap(e -> e.getKey(), e -> new FieldOverride(e.getValue())))); + values.entrySet().stream() + .forEach(e -> fieldOverrides.addOverride(e.getKey(), new FieldOverride(e.getValue()))); return this; } @@ -703,9 +701,9 @@ public Row build() { (Row) new RowFieldMatcher() .match( - new CapturingRowCases(getSchema(), this.fieldValues), + new CapturingRowCases(getSchema(), this.fieldOverrides), FieldType.row(getSchema()), - FieldAccessDescriptor.create(), + new RowPosition(FieldAccessDescriptor.create()), sourceRow); return row; } @@ -726,33 +724,26 @@ public Schema getSchema() { return schema; } - Row nullRow() { - if (nullRow == null) { - this.nullRow = Row.withSchema(schema).attachValues(Collections.nCopies(schema.getFieldCount(), null)); - } - return nullRow; - } - /** * Set a field value using the field name. Nested values can be set using the field selection * syntax. */ public FieldValueBuilder withFieldValue(String fieldName, Object value) { checkState(values.isEmpty()); - return new FieldValueBuilder(schema, nullRow()).withFieldValue(fieldName, value); + return new FieldValueBuilder(schema, null).withFieldValue(fieldName, value); } /** Set a field value using the field id. */ public FieldValueBuilder withFieldValue(Integer fieldId, Object value) { checkState(values.isEmpty()); - return new FieldValueBuilder(schema, nullRow()).withFieldValue(fieldId, value); + return new FieldValueBuilder(schema, null).withFieldValue(fieldId, value); } /** Set a field value using a FieldAccessDescriptor. */ public FieldValueBuilder withFieldValue( FieldAccessDescriptor fieldAccessDescriptor, Object value) { checkState(values.isEmpty()); - return new FieldValueBuilder(schema, nullRow()).withFieldValue(fieldAccessDescriptor, value); + return new FieldValueBuilder(schema, null).withFieldValue(fieldAccessDescriptor, value); } /** * Sets field values using the field names. Nested values can be set using the field selection @@ -760,7 +751,7 @@ public FieldValueBuilder withFieldValue( */ public FieldValueBuilder withFieldValues(Map values) { checkState(values.isEmpty()); - return new FieldValueBuilder(schema, nullRow()).withFieldValues(values); + return new FieldValueBuilder(schema, null).withFieldValues(values); } // The following methods allow appending a list of values to the Builder object. The values must @@ -810,6 +801,10 @@ public Row attachValues(List attachedValues) { return new RowWithStorage(schema, attachedValues); } + public Row attachValues(Object... values) { + return attachValues(Arrays.asList(values)); + } + public int nextFieldId() { return values.size(); } @@ -833,24 +828,19 @@ public Row build() { + " fields."); } - Map fieldValues = - Maps.newHashMapWithExpectedSize(this.values.size()); - for (int i = 0; i < this.values.size(); ++i) { - FieldAccessDescriptor fieldAccessDescriptor = - FieldAccessDescriptor.withFieldIds(i).resolve(schema); - fieldValues.putIfAbsent(fieldAccessDescriptor, new FieldOverride(this.values.get(i))); - } + FieldOverrides fieldOverrides = new FieldOverrides(schema); + fieldOverrides.setOverrides(this.values); Row row; - if (!fieldValues.isEmpty()) { + if (!fieldOverrides.isEmpty()) { row = (Row) new RowFieldMatcher() .match( - new CapturingRowCases(schema, fieldValues), + new CapturingRowCases(schema, fieldOverrides), FieldType.row(schema), - FieldAccessDescriptor.create(), - nullRow()); + new RowPosition(FieldAccessDescriptor.create()), + null); } else { row = new RowWithStorage(schema, Collections.emptyList()); } @@ -876,399 +866,4 @@ public static Row nullRow(Schema schema) { .addValues(Collections.nCopies(schema.getFieldCount(), null)) .build(); } - - // Subclasses of this interface implement process methods for each schema type. Each process - // method is invoked as - // a RowFieldMatcher walks down the schema tree. The FieldAccessDescriptor passed into each method - // identifies the - // current element of the schema being processed. - private interface RowCases { - Row processRow( - FieldAccessDescriptor fieldAccessDescriptor, - Schema schema, - Row value, - RowFieldMatcher matcher); - - Collection processArray( - FieldAccessDescriptor fieldAccessDescriptor, - FieldType collectionElementType, - Collection values, - RowFieldMatcher matcher); - - Iterable processIterable( - FieldAccessDescriptor fieldAccessDescriptor, - FieldType collectionElementType, - Iterable values, - RowFieldMatcher matcher); - - Map processMap( - FieldAccessDescriptor fieldAccessDescriptor, - FieldType keyType, - FieldType valueType, - Map valueMap, - RowFieldMatcher matcher); - - Object processLogicalType( - FieldAccessDescriptor fieldAccessDescriptor, - LogicalType logicalType, - Object baseType, - RowFieldMatcher matcher); - - Instant processDateTime( - FieldAccessDescriptor fieldAccessDescriptor, - AbstractInstant instant, - RowFieldMatcher matcher); - - Byte processByte( - FieldAccessDescriptor fieldAccessDescriptor, Byte value, RowFieldMatcher matcher); - - Short processInt16( - FieldAccessDescriptor fieldAccessDescriptor, Short value, RowFieldMatcher matcher); - - Integer processInt32( - FieldAccessDescriptor fieldAccessDescriptor, Integer value, RowFieldMatcher matcher); - - Long processInt64( - FieldAccessDescriptor fieldAccessDescriptor, Long value, RowFieldMatcher matcher); - - BigDecimal processDecimal( - FieldAccessDescriptor fieldAccessDescriptor, BigDecimal value, RowFieldMatcher matcher); - - Float processFloat( - FieldAccessDescriptor fieldAccessDescriptor, Float value, RowFieldMatcher matcher); - - Double processDouble( - FieldAccessDescriptor fieldAccessDescriptor, Double value, RowFieldMatcher matcher); - - String processString( - FieldAccessDescriptor fieldAccessDescriptor, String value, RowFieldMatcher matcher); - - Boolean processBoolean( - FieldAccessDescriptor fieldAccessDescriptor, Boolean value, RowFieldMatcher matcher); - - byte[] processBytes( - FieldAccessDescriptor fieldAccessDescriptor, byte[] value, RowFieldMatcher matcher); - } - - // Given a Row field, delegates processing to the correct process method on the RowCases - // parameter. - private static class RowFieldMatcher { - public Object match( - RowCases cases, - FieldType fieldType, - FieldAccessDescriptor fieldAccessDescriptor, - Object value) { - Object processedValue = null; - switch (fieldType.getTypeName()) { - case ARRAY: - processedValue = - cases.processArray( - fieldAccessDescriptor, - fieldType.getCollectionElementType(), - (Collection) value, - this); - break; - case ITERABLE: - processedValue = - cases.processIterable( - fieldAccessDescriptor, - fieldType.getCollectionElementType(), - (Iterable) value, - this); - break; - case MAP: - processedValue = - cases.processMap( - fieldAccessDescriptor, - fieldType.getMapKeyType(), - fieldType.getMapValueType(), - (Map) value, - this); - break; - case ROW: - processedValue = - cases.processRow(fieldAccessDescriptor, fieldType.getRowSchema(), (Row) value, this); - break; - case LOGICAL_TYPE: - LogicalType logicalType = fieldType.getLogicalType(); - processedValue = - cases.processLogicalType(fieldAccessDescriptor, logicalType, value, this); - break; - case DATETIME: - processedValue = - cases.processDateTime(fieldAccessDescriptor, (AbstractInstant) value, this); - break; - case BYTE: - processedValue = cases.processByte(fieldAccessDescriptor, (Byte) value, this); - break; - case BYTES: - processedValue = cases.processBytes(fieldAccessDescriptor, (byte[]) value, this); - break; - case INT16: - processedValue = cases.processInt16(fieldAccessDescriptor, (Short) value, this); - break; - case INT32: - processedValue = cases.processInt32(fieldAccessDescriptor, (Integer) value, this); - break; - case INT64: - processedValue = cases.processInt64(fieldAccessDescriptor, (Long) value, this); - break; - case DECIMAL: - processedValue = cases.processDecimal(fieldAccessDescriptor, (BigDecimal) value, this); - break; - case FLOAT: - processedValue = cases.processFloat(fieldAccessDescriptor, (Float) value, this); - break; - case DOUBLE: - processedValue = cases.processDouble(fieldAccessDescriptor, (Double) value, this); - break; - case STRING: - processedValue = cases.processString(fieldAccessDescriptor, (String) value, this); - break; - case BOOLEAN: - processedValue = cases.processBoolean(fieldAccessDescriptor, (Boolean) value, this); - break; - default: - // Shouldn't actually get here, but we need this case to satisfy linters. - throw new IllegalArgumentException( - String.format( - "Not a primitive type for field name %s: %s", fieldAccessDescriptor, fieldType)); - } - if (processedValue == null) { - if (!fieldType.getNullable()) { - throw new IllegalArgumentException( - String.format("%s is not nullable in field %s", fieldType, fieldAccessDescriptor)); - } - } - return processedValue; - } - } - - static class FieldOverride { - FieldOverride(Object overrideValue) { - this.overrideValue = overrideValue; - } - - Object getOverrideValue() { - return overrideValue; - } - - final Object overrideValue; - } - - // This implementation of RowCases captures a Row into a new Row. It also has the effect of - // validating all the - // field parameters. - // A Map of field values can also be passed in, and those field values will be used to override - // the values in the - // passed-in row. - private static class CapturingRowCases implements RowCases { - private final Schema topSchema; - private final Map fieldValueOverrides; - - private static class FieldAccessNode { - - } - private CapturingRowCases( - Schema topSchema, - Map fieldValueOverrides) { - this.topSchema = topSchema; - this.fieldValueOverrides = fieldValueOverrides; - } - - private @Nullable - FieldOverride override(FieldAccessDescriptor fieldAccessDescriptor) { - return fieldValueOverrides.get(fieldAccessDescriptor); - } - - private T overrideOrReturn(FieldAccessDescriptor fieldAccessDescriptor, T value) { - FieldOverride fieldOverride = override(fieldAccessDescriptor); - // null return means the item isn't in the map. - return (fieldOverride != null) ? (T) fieldOverride.getOverrideValue() : null; - } - - @Override - public Row processRow( - FieldAccessDescriptor fieldAccessDescriptor, - Schema schema, - Row value, - RowFieldMatcher matcher) { - FieldOverride override = override(fieldAccessDescriptor); - Row retValue = null; - if (override != null) { - retValue = (Row) override.getOverrideValue(); - } else if (value != null) { - // TODO: We should only recurse if there is an overidden subfield. - List values = Lists.newArrayListWithCapacity(schema.getFieldCount()); - for (int i = 0; i < schema.getFieldCount(); ++i) { - FieldAccessDescriptor nestedDescriptor = - FieldAccessDescriptor.withFieldIds(fieldAccessDescriptor, i).resolve(topSchema); - Object fieldValue = value.getValue(i); - values.add( - matcher.match(this, schema.getField(i).getType(), nestedDescriptor, fieldValue)); - } - retValue = new RowWithStorage(schema, values); - } - return retValue; - } - - @Override - public Collection processArray( - FieldAccessDescriptor fieldAccessDescriptor, - FieldType collectionElementType, - Collection values, - RowFieldMatcher matcher) { - Collection retValue = null; - FieldOverride override = override(fieldAccessDescriptor); - if (override != null) { - retValue = (Collection) override.getOverrideValue(); - } else if (values != null) { - retValue = captureIterable(fieldAccessDescriptor, collectionElementType, values, false, matcher); - - } - return retValue; - } - - @Override - public Iterable processIterable( - FieldAccessDescriptor fieldAccessDescriptor, - FieldType collectionElementType, - Iterable values, - RowFieldMatcher matcher) { - Iterable retValue = null; - FieldOverride override = override(fieldAccessDescriptor); - if (override != null) { - retValue = (Iterable) override.getOverrideValue(); - } else if (values != null) { - retValue = captureIterable(fieldAccessDescriptor, collectionElementType, - values, true, matcher); - } - return retValue; - } - - private Collection captureIterable(FieldAccessDescriptor fieldAccessDescriptor, - FieldType collectionElementType, Iterable values, - boolean recurseNestedRows, - RowFieldMatcher matcher) { - List captured = Lists.newArrayListWithCapacity(Iterables.size(values)); - for (Object listValue : values) { - boolean recurse = !collectionElementType.getTypeName().isCompositeType() || recurseNestedRows; - Object capturedElement = recurse ? - matcher.match(this, collectionElementType, fieldAccessDescriptor, listValue) - : listValue; - captured.add(capturedElement); - } - return captured; - } - - - @Override - public Map processMap( - FieldAccessDescriptor fieldAccessDescriptor, - FieldType keyType, - FieldType valueType, - Map valueMap, - RowFieldMatcher matcher) { - Map retValue = null; - FieldOverride override = override(fieldAccessDescriptor); - if (override != null) { - retValue = (Map) override.getOverrideValue(); - } else if (valueMap != null) { - retValue = Maps.newHashMapWithExpectedSize(valueMap.size()); - for (Entry kv : valueMap.entrySet()) { - retValue - .put( - matcher.match(this, keyType, fieldAccessDescriptor, kv.getKey()), - matcher.match(this, valueType, fieldAccessDescriptor, kv.getValue())); - } - } - return retValue; - } - - @Override - public Object processLogicalType( - FieldAccessDescriptor fieldAccessDescriptor, - LogicalType logicalType, - Object value, - RowFieldMatcher matcher) { - Object retValue = null; - FieldOverride override = override(fieldAccessDescriptor); - if (override != null) { - retValue = override.getOverrideValue(); - } else if (value != null) { - retValue = logicalType.toInputType(logicalType.toBaseType(value)); - } - return retValue; - } - - @Override - public Instant processDateTime( - FieldAccessDescriptor fieldAccessDescriptor, - AbstractInstant value, - RowFieldMatcher matcher) { - AbstractInstant instantValue = overrideOrReturn(fieldAccessDescriptor, value); - return (instantValue != null) ? instantValue.toInstant() : null; - } - - @Override - public Byte processByte( - FieldAccessDescriptor fieldAccessDescriptor, Byte value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value); - } - - @Override - public Short processInt16( - FieldAccessDescriptor fieldAccessDescriptor, Short value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value); - } - - @Override - public Integer processInt32( - FieldAccessDescriptor fieldAccessDescriptor, Integer value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value); - } - - @Override - public Long processInt64( - FieldAccessDescriptor fieldAccessDescriptor, Long value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value); - } - - @Override - public BigDecimal processDecimal( - FieldAccessDescriptor fieldAccessDescriptor, BigDecimal value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value); - } - - @Override - public Float processFloat( - FieldAccessDescriptor fieldAccessDescriptor, Float value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value); - } - - @Override - public Double processDouble( - FieldAccessDescriptor fieldAccessDescriptor, Double value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value); - } - - @Override - public String processString( - FieldAccessDescriptor fieldAccessDescriptor, String value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value); - } - - @Override - public Boolean processBoolean( - FieldAccessDescriptor fieldAccessDescriptor, Boolean value, RowFieldMatcher matcher) { - return overrideOrReturn(fieldAccessDescriptor, value); - } - - @Override - public byte[] processBytes( - FieldAccessDescriptor fieldAccessDescriptor, byte[] value, RowFieldMatcher matcher) { - Object retValue = overrideOrReturn(fieldAccessDescriptor, value); - return (retValue instanceof ByteBuffer) ? ((ByteBuffer) retValue).array() : (byte[]) retValue; - } - } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowUtils.java new file mode 100644 index 000000000000..8e955bb0a9ff --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowUtils.java @@ -0,0 +1,591 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.values; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import javax.annotation.Nullable; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor.FieldDescriptor; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor.FieldDescriptor.ListQualifier; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor.FieldDescriptor.MapQualifier; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor.FieldDescriptor.Qualifier; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.Schema.LogicalType; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.joda.time.Instant; +import org.joda.time.base.AbstractInstant; + +class RowUtils { + static class RowPosition { + FieldAccessDescriptor descriptor; + List qualifiers; + + RowPosition(FieldAccessDescriptor descriptor) { + this(descriptor, Collections.emptyList()); + } + + RowPosition(FieldAccessDescriptor descriptor, List qualifiers) { + this.descriptor = descriptor; + this.qualifiers = qualifiers; + } + + RowPosition withArrayQualifier() { + List newQualifiers = Lists.newArrayListWithCapacity(qualifiers.size() + 1); + newQualifiers.addAll(qualifiers); + newQualifiers.add(Qualifier.of(ListQualifier.ALL)); + return new RowPosition(descriptor, newQualifiers); + } + + RowPosition withMapQualifier() { + List newQualifiers = Lists.newArrayListWithCapacity(qualifiers.size() + 1); + newQualifiers.addAll(qualifiers); + newQualifiers.add(Qualifier.of(MapQualifier.ALL)); + return new RowPosition(descriptor, newQualifiers); + } + } + + // Subclasses of this interface implement process methods for each schema type. Each process + // method is invoked as + // a RowFieldMatcher walks down the schema tree. The FieldAccessDescriptor passed into each method + // identifies the + // current element of the schema being processed. + interface RowCases { + Row processRow(RowPosition rowPosition, Schema schema, Row value, RowFieldMatcher matcher); + + Collection processArray( + RowPosition rowPosition, + FieldType collectionElementType, + Collection values, + RowFieldMatcher matcher); + + Iterable processIterable( + RowPosition rowPosition, + FieldType collectionElementType, + Iterable values, + RowFieldMatcher matcher); + + Map processMap( + RowPosition rowPosition, + FieldType keyType, + FieldType valueType, + Map valueMap, + RowFieldMatcher matcher); + + Object processLogicalType( + RowPosition rowPosition, LogicalType logicalType, Object baseType, RowFieldMatcher matcher); + + Instant processDateTime( + RowPosition rowPosition, AbstractInstant instant, RowFieldMatcher matcher); + + Byte processByte(RowPosition rowPosition, Byte value, RowFieldMatcher matcher); + + Short processInt16(RowPosition rowPosition, Short value, RowFieldMatcher matcher); + + Integer processInt32(RowPosition rowPosition, Integer value, RowFieldMatcher matcher); + + Long processInt64(RowPosition rowPosition, Long value, RowFieldMatcher matcher); + + BigDecimal processDecimal(RowPosition rowPosition, BigDecimal value, RowFieldMatcher matcher); + + Float processFloat(RowPosition rowPosition, Float value, RowFieldMatcher matcher); + + Double processDouble(RowPosition rowPosition, Double value, RowFieldMatcher matcher); + + String processString(RowPosition rowPosition, String value, RowFieldMatcher matcher); + + Boolean processBoolean(RowPosition rowPosition, Boolean value, RowFieldMatcher matcher); + + byte[] processBytes(RowPosition rowPosition, byte[] value, RowFieldMatcher matcher); + } + + // Given a Row field, delegates processing to the correct process method on the RowCases + // parameter. + static class RowFieldMatcher { + public Object match( + RowCases cases, FieldType fieldType, RowPosition rowPosition, Object value) { + Object processedValue = null; + switch (fieldType.getTypeName()) { + case ARRAY: + processedValue = + cases.processArray( + rowPosition, + fieldType.getCollectionElementType(), + (Collection) value, + this); + break; + case ITERABLE: + processedValue = + cases.processIterable( + rowPosition, + fieldType.getCollectionElementType(), + (Iterable) value, + this); + break; + case MAP: + processedValue = + cases.processMap( + rowPosition, + fieldType.getMapKeyType(), + fieldType.getMapValueType(), + (Map) value, + this); + break; + case ROW: + processedValue = + cases.processRow(rowPosition, fieldType.getRowSchema(), (Row) value, this); + break; + case LOGICAL_TYPE: + LogicalType logicalType = fieldType.getLogicalType(); + processedValue = cases.processLogicalType(rowPosition, logicalType, value, this); + break; + case DATETIME: + processedValue = cases.processDateTime(rowPosition, (AbstractInstant) value, this); + break; + case BYTE: + processedValue = cases.processByte(rowPosition, (Byte) value, this); + break; + case BYTES: + processedValue = cases.processBytes(rowPosition, (byte[]) value, this); + break; + case INT16: + processedValue = cases.processInt16(rowPosition, (Short) value, this); + break; + case INT32: + processedValue = cases.processInt32(rowPosition, (Integer) value, this); + break; + case INT64: + processedValue = cases.processInt64(rowPosition, (Long) value, this); + break; + case DECIMAL: + processedValue = cases.processDecimal(rowPosition, (BigDecimal) value, this); + break; + case FLOAT: + processedValue = cases.processFloat(rowPosition, (Float) value, this); + break; + case DOUBLE: + processedValue = cases.processDouble(rowPosition, (Double) value, this); + break; + case STRING: + processedValue = cases.processString(rowPosition, (String) value, this); + break; + case BOOLEAN: + processedValue = cases.processBoolean(rowPosition, (Boolean) value, this); + break; + default: + // Shouldn't actually get here, but we need this case to satisfy linters. + throw new IllegalArgumentException( + String.format( + "Not a primitive type for field name %s: %s", rowPosition.descriptor, fieldType)); + } + if (processedValue == null) { + if (!fieldType.getNullable()) { + throw new IllegalArgumentException( + String.format("%s is not nullable in field %s", fieldType, rowPosition.descriptor)); + } + } + return processedValue; + } + } + + static class FieldOverride { + FieldOverride(Object overrideValue) { + this.overrideValue = overrideValue; + } + + Object getOverrideValue() { + return overrideValue; + } + + final Object overrideValue; + } + + static class FieldOverrides { + private FieldAccessNode topNode; + private Schema rootSchema; + + FieldOverrides(Schema rootSchema) { + this.topNode = new FieldAccessNode(rootSchema); + this.rootSchema = rootSchema; + } + + boolean isEmpty() { + return topNode.isEmpty(); + } + + void addOverride(FieldAccessDescriptor fieldAccessDescriptor, FieldOverride fieldOverride) { + topNode.addOverride(fieldAccessDescriptor, fieldOverride, rootSchema); + } + + void setOverrides(List values) { + List overrides = Lists.newArrayListWithExpectedSize(values.size()); + for (Object value : values) { + overrides.add(new FieldOverride(value)); + } + topNode.setOverrides(overrides); + } + + @Nullable + FieldOverride getOverride(FieldAccessDescriptor fieldAccessDescriptor) { + return topNode.getOverride(fieldAccessDescriptor); + } + + boolean hasOverrideBelow(FieldAccessDescriptor fieldAccessDescriptor) { + return topNode.hasOverrideBelow(fieldAccessDescriptor); + } + + private static class FieldAccessNode { + List fieldOverrides; + List nestedAccess; + + FieldAccessNode(Schema schema) { + fieldOverrides = Lists.newArrayListWithExpectedSize(schema.getFieldCount()); + nestedAccess = Lists.newArrayList(); + } + + boolean isEmpty() { + return fieldOverrides.isEmpty() && nestedAccess.isEmpty(); + } + + void addOverride( + FieldAccessDescriptor fieldAccessDescriptor, + FieldOverride fieldOverride, + Schema currentSchema) { + if (!fieldAccessDescriptor.getFieldsAccessed().isEmpty()) { + FieldDescriptor fieldDescriptor = + Iterables.getOnlyElement(fieldAccessDescriptor.getFieldsAccessed()); + int aheadPosition = fieldDescriptor.getFieldId() - fieldOverrides.size() + 1; + if (aheadPosition > 0) { + fieldOverrides.addAll(Collections.nCopies(aheadPosition, null)); + } + fieldOverrides.set(fieldDescriptor.getFieldId(), fieldOverride); + } else if (!fieldAccessDescriptor.getNestedFieldsAccessed().isEmpty()) { + Map.Entry entry = + Iterables.getOnlyElement(fieldAccessDescriptor.getNestedFieldsAccessed().entrySet()); + int aheadPosition = entry.getKey().getFieldId() - nestedAccess.size() + 1; + if (aheadPosition > 0) { + nestedAccess.addAll(Collections.nCopies(aheadPosition, null)); + } + + Schema nestedSchema = + currentSchema.getField(entry.getKey().getFieldId()).getType().getRowSchema(); + FieldAccessNode node = nestedAccess.get(entry.getKey().getFieldId()); + if (node == null) { + node = new FieldAccessNode(nestedSchema); + nestedAccess.set(entry.getKey().getFieldId(), node); + } + node.addOverride(entry.getValue(), fieldOverride, nestedSchema); + } + } + + void setOverrides(List overrides) { + this.fieldOverrides = overrides; + } + + @Nullable + FieldOverride getOverride(FieldAccessDescriptor fieldAccessDescriptor) { + FieldOverride override = null; + if (!fieldAccessDescriptor.getFieldsAccessed().isEmpty()) { + FieldDescriptor fieldDescriptor = + Iterables.getOnlyElement(fieldAccessDescriptor.getFieldsAccessed()); + if (fieldDescriptor.getFieldId() < fieldOverrides.size()) { + override = fieldOverrides.get(fieldDescriptor.getFieldId()); + } + } else if (!fieldAccessDescriptor.getNestedFieldsAccessed().isEmpty()) { + Map.Entry entry = + Iterables.getOnlyElement(fieldAccessDescriptor.getNestedFieldsAccessed().entrySet()); + if (entry.getKey().getFieldId() < nestedAccess.size()) { + FieldAccessNode node = nestedAccess.get(entry.getKey().getFieldId()); + if (node != null) { + override = node.getOverride(entry.getValue()); + } + } + } + return override; + } + + boolean hasOverrideBelow(FieldAccessDescriptor fieldAccessDescriptor) { + if (!fieldAccessDescriptor.getFieldsAccessed().isEmpty()) { + FieldDescriptor fieldDescriptor = + Iterables.getOnlyElement(fieldAccessDescriptor.getFieldsAccessed()); + return (((fieldDescriptor.getFieldId() < fieldOverrides.size())) + && fieldOverrides.get(fieldDescriptor.getFieldId()) != null) + || (((fieldDescriptor.getFieldId() < nestedAccess.size())) + && nestedAccess.get(fieldDescriptor.getFieldId()) != null); + } else if (!fieldAccessDescriptor.getNestedFieldsAccessed().isEmpty()) { + Map.Entry entry = + Iterables.getOnlyElement(fieldAccessDescriptor.getNestedFieldsAccessed().entrySet()); + if (entry.getKey().getFieldId() < nestedAccess.size()) { + FieldAccessNode node = nestedAccess.get(entry.getKey().getFieldId()); + if (node != null) { + return node.hasOverrideBelow(entry.getValue()); + } + } + } else { + return true; + } + return false; + } + } + } + + // This implementation of RowCases captures a Row into a new Row. It also has the effect of + // validating all the + // field parameters. + // A Map of field values can also be passed in, and those field values will be used to override + // the values in the + // passed-in row. + static class CapturingRowCases implements RowCases { + private final Schema topSchema; + private final FieldOverrides fieldOverrides; + + CapturingRowCases(Schema topSchema, FieldOverrides fieldOverrides) { + this.topSchema = topSchema; + this.fieldOverrides = fieldOverrides; + } + + private @Nullable FieldOverride override(RowPosition rowPosition) { + if (!rowPosition.qualifiers.isEmpty()) { + // Currently we only support overriding named schema fields. Individual array/map elements + // or nested collections + // cannot be overriden without overriding the entire schema fields. + return null; + } else { + return fieldOverrides.getOverride(rowPosition.descriptor); + } + } + + private T overrideOrReturn(RowPosition rowPosition, T value) { + FieldOverride fieldOverride = override(rowPosition); + // null return means the item isn't in the map. + return (fieldOverride != null) ? (T) fieldOverride.getOverrideValue() : value; + } + + @Override + public Row processRow( + RowPosition rowPosition, Schema schema, Row value, RowFieldMatcher matcher) { + FieldOverride override = override(rowPosition); + Row retValue = value; + if (override != null) { + retValue = (Row) override.getOverrideValue(); + } else if (fieldOverrides.hasOverrideBelow(rowPosition.descriptor)) { + List values = Lists.newArrayListWithCapacity(schema.getFieldCount()); + for (int i = 0; i < schema.getFieldCount(); ++i) { + FieldAccessDescriptor nestedDescriptor = + FieldAccessDescriptor.withFieldIds(rowPosition.descriptor, i).resolve(topSchema); + Object fieldValue = (value != null) ? value.getValue(i) : null; + values.add( + matcher.match( + this, + schema.getField(i).getType(), + new RowPosition(nestedDescriptor), + fieldValue)); + } + retValue = new RowWithStorage(schema, values); + } + return retValue; + } + + @Override + public Collection processArray( + RowPosition rowPosition, + FieldType collectionElementType, + Collection values, + RowFieldMatcher matcher) { + Collection retValue = null; + FieldOverride override = override(rowPosition); + if (override != null) { + retValue = + captureIterable( + rowPosition, + collectionElementType, + (Collection) override.getOverrideValue(), + matcher); + } else if (values != null) { + retValue = captureIterable(rowPosition, collectionElementType, values, matcher); + } + return retValue; + } + + @Override + public Iterable processIterable( + RowPosition rowPosition, + FieldType collectionElementType, + Iterable values, + RowFieldMatcher matcher) { + Iterable retValue = null; + FieldOverride override = override(rowPosition); + if (override != null) { + retValue = + captureIterable( + rowPosition, + collectionElementType, + (Iterable) override.getOverrideValue(), + matcher); + } else if (values != null) { + retValue = captureIterable(rowPosition, collectionElementType, values, matcher); + } + return retValue; + } + + private Collection captureIterable( + RowPosition rowPosition, + FieldType collectionElementType, + Iterable values, + RowFieldMatcher matcher) { + if (values == null) { + return null; + } + + List captured = Lists.newArrayListWithCapacity(Iterables.size(values)); + + RowPosition elementPosition = rowPosition.withArrayQualifier(); + for (Object listValue : values) { + if (listValue == null) { + if (!collectionElementType.getNullable()) { + throw new IllegalArgumentException( + String.format( + "%s is not nullable in Array field %s", + collectionElementType, rowPosition.descriptor)); + } + captured.add(null); + } else { + Object capturedElement = + matcher.match(this, collectionElementType, elementPosition, listValue); + captured.add(capturedElement); + } + } + return captured; + } + + @Override + public Map processMap( + RowPosition rowPosition, + FieldType keyType, + FieldType valueType, + Map valueMap, + RowFieldMatcher matcher) { + Map retValue = null; + FieldOverride override = override(rowPosition); + if (override != null) { + valueMap = (Map) override.getOverrideValue(); + } + + if (valueMap != null) { + RowPosition elementPosition = rowPosition.withMapQualifier(); + + retValue = Maps.newHashMapWithExpectedSize(valueMap.size()); + for (Entry kv : valueMap.entrySet()) { + if (kv.getValue() == null) { + if (!valueType.getNullable()) { + throw new IllegalArgumentException( + String.format( + "%s is not nullable in Map field %s", valueType, rowPosition.descriptor)); + } + retValue.put(matcher.match(this, keyType, elementPosition, kv.getKey()), null); + } else { + retValue.put( + matcher.match(this, keyType, elementPosition, kv.getKey()), + matcher.match(this, valueType, elementPosition, kv.getValue())); + } + } + } + return retValue; + } + + @Override + public Object processLogicalType( + RowPosition rowPosition, LogicalType logicalType, Object value, RowFieldMatcher matcher) { + Object retValue = null; + FieldOverride override = override(rowPosition); + if (override != null) { + retValue = override.getOverrideValue(); + } else if (value != null) { + retValue = logicalType.toInputType(logicalType.toBaseType(value)); + } + return retValue; + } + + @Override + public Instant processDateTime( + RowPosition rowPosition, AbstractInstant value, RowFieldMatcher matcher) { + AbstractInstant instantValue = overrideOrReturn(rowPosition, value); + return (instantValue != null) ? instantValue.toInstant() : null; + } + + @Override + public Byte processByte(RowPosition rowPosition, Byte value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public Short processInt16(RowPosition rowPosition, Short value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public Integer processInt32(RowPosition rowPosition, Integer value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public Long processInt64(RowPosition rowPosition, Long value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public BigDecimal processDecimal( + RowPosition rowPosition, BigDecimal value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public Float processFloat(RowPosition rowPosition, Float value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public Double processDouble(RowPosition rowPosition, Double value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public String processString(RowPosition rowPosition, String value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public Boolean processBoolean(RowPosition rowPosition, Boolean value, RowFieldMatcher matcher) { + return overrideOrReturn(rowPosition, value); + } + + @Override + public byte[] processBytes(RowPosition rowPosition, byte[] value, RowFieldMatcher matcher) { + Object retValue = overrideOrReturn(rowPosition, value); + return (retValue instanceof ByteBuffer) ? ((ByteBuffer) retValue).array() : (byte[]) retValue; + } + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java index b38051075020..3d1d0baa5677 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java @@ -18,239 +18,24 @@ package org.apache.beam.sdk.values; import java.io.Serializable; -import java.math.BigDecimal; -import java.nio.ByteBuffer; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.Schema.LogicalType; -import org.apache.beam.sdk.schemas.Schema.TypeName; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; -import org.joda.time.Instant; -import org.joda.time.base.AbstractInstant; +import org.apache.beam.sdk.values.RowUtils.CapturingRowCases; +import org.apache.beam.sdk.values.RowUtils.FieldOverrides; +import org.apache.beam.sdk.values.RowUtils.RowFieldMatcher; +import org.apache.beam.sdk.values.RowUtils.RowPosition; @Experimental public abstract class SchemaVerification implements Serializable { - - static List verifyRowValues(Schema schema, List values) { - List verifiedValues = Lists.newArrayListWithCapacity(values.size()); - if (schema.getFieldCount() != values.size()) { - throw new IllegalArgumentException( - String.format( - "Field count in Schema (%s) (%d) and values (%s) (%d) must match", - schema.getFieldNames(), schema.getFieldCount(), values, values.size())); - } - for (int i = 0; i < values.size(); ++i) { - Object value = values.get(i); - Schema.Field field = schema.getField(i); - if (value == null) { - if (!field.getType().getNullable()) { - throw new IllegalArgumentException( - String.format("Field %s is not nullable", field.getName())); - } - verifiedValues.add(null); - } else { - verifiedValues.add(verifyFieldValue(value, field.getType(), field.getName())); - } - } - return verifiedValues; - } - public static Object verifyFieldValue(Object value, FieldType type, String fieldName) { - if (TypeName.ARRAY.equals(type.getTypeName())) { - return verifyArray(value, type.getCollectionElementType(), fieldName); - } else if (TypeName.ITERABLE.equals(type.getTypeName())) { - return verifyIterable(value, type.getCollectionElementType(), fieldName); - } - if (TypeName.MAP.equals(type.getTypeName())) { - return verifyMap(value, type.getMapKeyType(), type.getMapValueType(), fieldName); - } else if (TypeName.ROW.equals(type.getTypeName())) { - return verifyRow(value, fieldName); - } else if (TypeName.LOGICAL_TYPE.equals(type.getTypeName())) { - return verifyLogicalType(value, type.getLogicalType(), fieldName); - } else { - return verifyPrimitiveType(value, type.getTypeName(), fieldName); - } - } - - private static Object verifyLogicalType(Object value, LogicalType logicalType, String fieldName) { - // TODO: this isn't guaranteed to clone the object. - return logicalType.toInputType(logicalType.toBaseType(value)); - } - - private static List verifyArray( - Object value, FieldType collectionElementType, String fieldName) { - boolean collectionElementTypeNullable = collectionElementType.getNullable(); - if (!(value instanceof List)) { - throw new IllegalArgumentException( - String.format( - "For field name %s and array type expected List class. Instead " - + "class type was %s.", - fieldName, value.getClass())); - } - List valueList = (List) value; - List verifiedList = Lists.newArrayListWithCapacity(valueList.size()); - for (Object listValue : valueList) { - if (listValue == null) { - if (!collectionElementTypeNullable) { - throw new IllegalArgumentException( - String.format( - "%s is not nullable in Array field %s", collectionElementType, fieldName)); - } - verifiedList.add(null); - } else { - verifiedList.add(verifyFieldValue(listValue, collectionElementType, fieldName)); - } - } - return verifiedList; - } - - private static Iterable verifyIterable( - Object value, FieldType collectionElementType, String fieldName) { - boolean collectionElementTypeNullable = collectionElementType.getNullable(); - if (!(value instanceof Iterable)) { - throw new IllegalArgumentException( - String.format( - "For field name %s and iterable type expected class extending Iterable. Instead " - + "class type was %s.", - fieldName, value.getClass())); - } - Iterable valueList = (Iterable) value; - for (Object listValue : valueList) { - if (listValue == null) { - if (!collectionElementTypeNullable) { - throw new IllegalArgumentException( - String.format( - "%s is not nullable in Array field %s", collectionElementType, fieldName)); - } - } else { - verifyFieldValue(listValue, collectionElementType, fieldName); - } - } - return valueList; - } - - private static Map verifyMap( - Object value, FieldType keyType, FieldType valueType, String fieldName) { - boolean valueTypeNullable = valueType.getNullable(); - if (!(value instanceof Map)) { - throw new IllegalArgumentException( - String.format( - "For field name %s and map type expected Map class. Instead " + "class type was %s.", - fieldName, value.getClass())); - } - Map valueMap = (Map) value; - Map verifiedMap = Maps.newHashMapWithExpectedSize(valueMap.size()); - for (Entry kv : valueMap.entrySet()) { - if (kv.getValue() == null) { - if (!valueTypeNullable) { - throw new IllegalArgumentException( - String.format("%s is not nullable in Map field %s", valueType, fieldName)); - } - verifiedMap.put(verifyFieldValue(kv.getKey(), keyType, fieldName), null); - } else { - verifiedMap.put( - verifyFieldValue(kv.getKey(), keyType, fieldName), - verifyFieldValue(kv.getValue(), valueType, fieldName)); - } - } - return verifiedMap; - } - - private static Row verifyRow(Object value, String fieldName) { - if (!(value instanceof Row)) { - throw new IllegalArgumentException( - String.format( - "For field name %s expected Row type. " + "Instead class type was %s.", - fieldName, value.getClass())); - } - // No need to recursively validate the nested Row, since there's no way to build the - // Row object without it validating. - return (Row) value; - } - - private static Object verifyPrimitiveType(Object value, TypeName type, String fieldName) { - if (type.isDateType()) { - return verifyDateTime(value, fieldName); - } else { - switch (type) { - case BYTE: - if (value instanceof Byte) { - return value; - } - break; - case BYTES: - if (value instanceof ByteBuffer) { - return ((ByteBuffer) value).array(); - } else if (value instanceof byte[]) { - return (byte[]) value; - } - break; - case INT16: - if (value instanceof Short) { - return value; - } - break; - case INT32: - if (value instanceof Integer) { - return value; - } - break; - case INT64: - if (value instanceof Long) { - return value; - } - break; - case DECIMAL: - if (value instanceof BigDecimal) { - return value; - } - break; - case FLOAT: - if (value instanceof Float) { - return value; - } - break; - case DOUBLE: - if (value instanceof Double) { - return value; - } - break; - case STRING: - if (value instanceof String) { - return value; - } - break; - case BOOLEAN: - if (value instanceof Boolean) { - return value; - } - break; - default: - // Shouldn't actually get here, but we need this case to satisfy linters. - throw new IllegalArgumentException( - String.format("Not a primitive type for field name %s: %s", fieldName, type)); - } - throw new IllegalArgumentException( - String.format( - "For field name %s and type %s found incorrect class type %s", - fieldName, type, value.getClass())); - } - } - - private static Instant verifyDateTime(Object value, String fieldName) { - // We support the following classes for datetimes. - if (value instanceof AbstractInstant) { - return ((AbstractInstant) value).toInstant(); - } else { - throw new IllegalArgumentException( - String.format( - "For field name %s and DATETIME type got unexpected class %s ", - fieldName, value.getClass())); - } + Schema schema = Schema.builder().addField(fieldName, type).build(); + return new RowFieldMatcher() + .match( + new CapturingRowCases(schema, new FieldOverrides(schema)), + type, + new RowPosition(FieldAccessDescriptor.withFieldIds(0)), + value); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java index ddd00c282223..a9fbc7955e78 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/RowTest.java @@ -485,7 +485,7 @@ public void testLogicalTypeWithRowValue() { Stream.of(Schema.Field.of("f1_enum", FieldType.logicalType(enumerationType))) .collect(toSchema()); Row row = Row.withSchema(type).addValue(enumerationType.valueOf("zero")).build(); - assertEquals(0, (int) row.getValue(0)); + assertEquals(enumerationType.valueOf(0), row.getValue(0)); assertEquals( enumerationType.valueOf("zero"), row.getLogicalTypeValue(0, EnumerationType.Value.class)); } @@ -498,7 +498,7 @@ public void testLogicalTypeWithRowValueName() { .collect(toSchema()); Row row = Row.withSchema(type).withFieldValue("f1_enum", enumerationType.valueOf("zero")).build(); - assertEquals(0, (int) row.getValue(0)); + assertEquals(enumerationType.valueOf(0), row.getValue(0)); assertEquals( enumerationType.valueOf("zero"), row.getLogicalTypeValue(0, EnumerationType.Value.class)); } @@ -513,7 +513,7 @@ public void testLogicalTypeWithRowValueOverride() { Row.withSchema(type).withFieldValue("f1_enum", enumerationType.valueOf("zero")).build(); Row overriddenRow = Row.fromRow(row).withFieldValue("f1_enum", enumerationType.valueOf("one")).build(); - assertEquals(1, (int) overriddenRow.getValue(0)); + assertEquals(enumerationType.valueOf(1), overriddenRow.getValue(0)); assertEquals( enumerationType.valueOf("one"), overriddenRow.getLogicalTypeValue(0, EnumerationType.Value.class)); From d50f358619fc67c603f78c87ba04206aeea14540 Mon Sep 17 00:00:00 2001 From: Reuven Lax Date: Fri, 27 Mar 2020 11:47:25 -0700 Subject: [PATCH 4/4] fix tests --- .../src/main/java/org/apache/beam/sdk/values/RowUtils.java | 7 ++----- .../org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowUtils.java index 8e955bb0a9ff..3a7c860cf10c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowUtils.java @@ -331,10 +331,8 @@ boolean hasOverrideBelow(FieldAccessDescriptor fieldAccessDescriptor) { if (!fieldAccessDescriptor.getFieldsAccessed().isEmpty()) { FieldDescriptor fieldDescriptor = Iterables.getOnlyElement(fieldAccessDescriptor.getFieldsAccessed()); - return (((fieldDescriptor.getFieldId() < fieldOverrides.size())) - && fieldOverrides.get(fieldDescriptor.getFieldId()) != null) - || (((fieldDescriptor.getFieldId() < nestedAccess.size())) - && nestedAccess.get(fieldDescriptor.getFieldId()) != null); + return (((fieldDescriptor.getFieldId() < nestedAccess.size())) + && nestedAccess.get(fieldDescriptor.getFieldId()) != null); } else if (!fieldAccessDescriptor.getNestedFieldsAccessed().isEmpty()) { Map.Entry entry = Iterables.getOnlyElement(fieldAccessDescriptor.getNestedFieldsAccessed().entrySet()); @@ -461,7 +459,6 @@ private Collection captureIterable( } List captured = Lists.newArrayListWithCapacity(Iterables.size(values)); - RowPosition elementPosition = rowPosition.withArrayQualifier(); for (Object listValue : values) { if (listValue == null) { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java index af3aae015b75..b642109a21c7 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java @@ -578,7 +578,7 @@ public void testIterableFieldFromRow() throws NoSuchSchemaException { SchemaTestUtils.assertSchemaEquivalent(POJO_WITH_ITERABLE, schema); List list = Lists.newArrayList("one", "two"); - Row iterableRow = Row.withSchema(POJO_WITH_ITERABLE).addIterable(list).build(); + Row iterableRow = Row.withSchema(POJO_WITH_ITERABLE).attachValues((Object) list); PojoWithIterable converted = registry.getFromRowFunction(PojoWithIterable.class).apply(iterableRow); assertEquals(list, Lists.newArrayList(converted.strings));