From 0645fd790f8abac18fd35c2de7595654d3b711f5 Mon Sep 17 00:00:00 2001 From: scwhittle Date: Tue, 24 Sep 2024 20:24:44 +0200 Subject: [PATCH] Set a UUID when building a Schema object. (#32399) * Set a UUID when building a Schema object. Schemas are immutable so this meets the guarantee that a consistent UUID is used for matching schemas. Cleanup some cases setting a random UUID after creating Schema. Fix case where same UUID was assigned to different Schema after sorting. Use Immutable data structures to enforce immutability. Update OneOfType which using serialized proto equality which was incorrect if there was uuid, or encoding positions. Instead we can use a null Row using the generated oneof schema. --- .../org/apache/beam/sdk/schemas/Schema.java | 75 ++++++++++--------- .../beam/sdk/schemas/SchemaTranslation.java | 67 +++++++++++------ .../sdk/schemas/logicaltypes/OneOfType.java | 3 +- .../sdk/schemas/SchemaTranslationTest.java | 46 ++++++++++-- .../logicaltypes/LogicalTypesTest.java | 10 +++ .../python/PythonExternalTransform.java | 3 - ...ManagedSchemaTransformTranslationTest.java | 49 +++++++----- 7 files changed, 168 insertions(+), 85 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java index 255d411028f9..5af59356b174 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java @@ -41,6 +41,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.BiMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBiMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableBiMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; @@ -90,7 +91,12 @@ public String toString() { } } // A mapping between field names an indices. - private final BiMap fieldIndices = HashBiMap.create(); + private final BiMap fieldIndices; + + // Encoding positions can be used to maintain encoded byte compatibility between schemas with + // different field ordering or with added/removed fields. Such positions affect the encoding + // and decoding of Rows performed by RowCoderGenerator. They are stored within Schemas to + // facilitate plumbing to coders, display data etc but do not affect schema equality / uuid etc. private Map encodingPositions = Maps.newHashMap(); private boolean encodingPositionsOverridden = false; @@ -312,17 +318,20 @@ public Schema(List fields) { } public Schema(List fields, Options options) { - this.fields = fields; + this.fields = ImmutableList.copyOf(fields); int index = 0; - for (Field field : fields) { + BiMap fieldIndicesMutable = HashBiMap.create(); + for (Field field : this.fields) { Preconditions.checkArgument( - fieldIndices.get(field.getName()) == null, + fieldIndicesMutable.get(field.getName()) == null, "Duplicate field " + field.getName() + " added to schema"); encodingPositions.put(field.getName(), index); - fieldIndices.put(field.getName(), index++); + fieldIndicesMutable.put(field.getName(), index++); } - this.hashCode = Objects.hash(fieldIndices, fields); + this.fieldIndices = ImmutableBiMap.copyOf(fieldIndicesMutable); this.options = options; + this.hashCode = Objects.hash(this.fieldIndices, this.fields, this.options); + this.uuid = UUID.randomUUID(); } public static Schema of(Field... fields) { @@ -334,29 +343,24 @@ public static Schema of(Field... fields) { * fields. */ public Schema sorted() { - // Create a new schema and copy over the appropriate Schema object attributes: - // {fields, uuid, options} - // Note: encoding positions are not copied over because generally they should align with the - // ordering of field indices. Otherwise, problems may occur when encoding/decoding Rows of - // this schema. - Schema sortedSchema = - this.fields.stream() - .sorted(Comparator.comparing(Field::getName)) - .map( - field -> { - FieldType innerType = field.getType(); - if (innerType.getRowSchema() != null) { - Schema innerSortedSchema = innerType.getRowSchema().sorted(); - innerType = innerType.toBuilder().setRowSchema(innerSortedSchema).build(); - return field.toBuilder().setType(innerType).build(); - } - return field; - }) - .collect(Schema.toSchema()) - .withOptions(getOptions()); - sortedSchema.setUUID(getUUID()); - - return sortedSchema; + // Create a new schema and copy over the appropriate Schema object attributes: {fields, options} + // Note: uuid is not copied as the Schema field ordering is changed. encoding positions are not + // copied over because generally they should align with the ordering of field indices. + // Otherwise, problems may occur when encoding/decoding Rows of this schema. + return this.fields.stream() + .sorted(Comparator.comparing(Field::getName)) + .map( + field -> { + FieldType innerType = field.getType(); + if (innerType.getRowSchema() != null) { + Schema innerSortedSchema = innerType.getRowSchema().sorted(); + innerType = innerType.toBuilder().setRowSchema(innerSortedSchema).build(); + return field.toBuilder().setType(innerType).build(); + } + return field; + }) + .collect(Schema.toSchema()) + .withOptions(getOptions()); } /** Returns a copy of the Schema with the options set. */ @@ -405,11 +409,14 @@ public boolean equals(@Nullable Object o) { return false; } Schema other = (Schema) o; - // If both schemas have a UUID set, we can simply compare the UUIDs. - if (uuid != null && other.uuid != null) { - if (Objects.equals(uuid, other.uuid)) { - return true; - } + // If both schemas have a UUID set, we can short-circuit deep comparison if the + // UUIDs are equal. + if (uuid != null && other.uuid != null && Objects.equals(uuid, other.uuid)) { + return true; + } + // Utilize hash-code pre-calculation for cheap negative comparison. + if (this.hashCode != other.hashCode) { + return false; } return Objects.equals(fieldIndices, other.fieldIndices) && Objects.equals(getFields(), other.getFields()) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java index 1d3f3348f1ed..5253f82d15b9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java @@ -115,7 +115,12 @@ private static String getLogicalTypeUrn(String identifier) { .build(); public static SchemaApi.Schema schemaToProto(Schema schema, boolean serializeLogicalType) { - String uuid = schema.getUUID() != null ? schema.getUUID().toString() : ""; + return schemaToProto(schema, serializeLogicalType, true); + } + + public static SchemaApi.Schema schemaToProto( + Schema schema, boolean serializeLogicalType, boolean serializeUUID) { + String uuid = schema.getUUID() != null && serializeUUID ? schema.getUUID().toString() : ""; SchemaApi.Schema.Builder builder = SchemaApi.Schema.newBuilder().setId(uuid); for (Field field : schema.getFields()) { SchemaApi.Field protoField = @@ -123,7 +128,8 @@ public static SchemaApi.Schema schemaToProto(Schema schema, boolean serializeLog field, schema.indexOf(field.getName()), schema.getEncodingPositions().get(field.getName()), - serializeLogicalType); + serializeLogicalType, + serializeUUID); builder.addFields(protoField); } builder.addAllOptions(optionsToProto(schema.getOptions())); @@ -131,11 +137,11 @@ public static SchemaApi.Schema schemaToProto(Schema schema, boolean serializeLog } private static SchemaApi.Field fieldToProto( - Field field, int fieldId, int position, boolean serializeLogicalType) { + Field field, int fieldId, int position, boolean serializeLogicalType, boolean serializeUUID) { return SchemaApi.Field.newBuilder() .setName(field.getName()) .setDescription(field.getDescription()) - .setType(fieldTypeToProto(field.getType(), serializeLogicalType)) + .setType(fieldTypeToProto(field.getType(), serializeLogicalType, serializeUUID)) .setId(fieldId) .setEncodingPosition(position) .addAllOptions(optionsToProto(field.getOptions())) @@ -143,34 +149,46 @@ private static SchemaApi.Field fieldToProto( } @VisibleForTesting - static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean serializeLogicalType) { + static SchemaApi.FieldType fieldTypeToProto( + FieldType fieldType, boolean serializeLogicalType, boolean serializeUUID) { SchemaApi.FieldType.Builder builder = SchemaApi.FieldType.newBuilder(); switch (fieldType.getTypeName()) { case ROW: builder.setRowType( SchemaApi.RowType.newBuilder() - .setSchema(schemaToProto(fieldType.getRowSchema(), serializeLogicalType))); + .setSchema( + schemaToProto(fieldType.getRowSchema(), serializeLogicalType, serializeUUID))); break; case ARRAY: builder.setArrayType( SchemaApi.ArrayType.newBuilder() .setElementType( - fieldTypeToProto(fieldType.getCollectionElementType(), serializeLogicalType))); + fieldTypeToProto( + fieldType.getCollectionElementType(), + serializeLogicalType, + serializeUUID))); break; case ITERABLE: builder.setIterableType( SchemaApi.IterableType.newBuilder() .setElementType( - fieldTypeToProto(fieldType.getCollectionElementType(), serializeLogicalType))); + fieldTypeToProto( + fieldType.getCollectionElementType(), + serializeLogicalType, + serializeUUID))); break; case MAP: builder.setMapType( SchemaApi.MapType.newBuilder() - .setKeyType(fieldTypeToProto(fieldType.getMapKeyType(), serializeLogicalType)) - .setValueType(fieldTypeToProto(fieldType.getMapValueType(), serializeLogicalType)) + .setKeyType( + fieldTypeToProto( + fieldType.getMapKeyType(), serializeLogicalType, serializeUUID)) + .setValueType( + fieldTypeToProto( + fieldType.getMapValueType(), serializeLogicalType, serializeUUID)) .build()); break; @@ -186,12 +204,14 @@ static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean seriali .setUrn(logicalType.getIdentifier()) .setPayload(ByteString.copyFrom(((UnknownLogicalType) logicalType).getPayload())) .setRepresentation( - fieldTypeToProto(logicalType.getBaseType(), serializeLogicalType)); + fieldTypeToProto( + logicalType.getBaseType(), serializeLogicalType, serializeUUID)); if (logicalType.getArgumentType() != null) { logicalTypeBuilder .setArgumentType( - fieldTypeToProto(logicalType.getArgumentType(), serializeLogicalType)) + fieldTypeToProto( + logicalType.getArgumentType(), serializeLogicalType, serializeUUID)) .setArgument( fieldValueToProto(logicalType.getArgumentType(), logicalType.getArgument())); } @@ -200,13 +220,15 @@ static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean seriali logicalTypeBuilder = SchemaApi.LogicalType.newBuilder() .setRepresentation( - fieldTypeToProto(logicalType.getBaseType(), serializeLogicalType)) + fieldTypeToProto( + logicalType.getBaseType(), serializeLogicalType, serializeUUID)) .setUrn(urn); if (logicalType.getArgumentType() != null) { logicalTypeBuilder = logicalTypeBuilder .setArgumentType( - fieldTypeToProto(logicalType.getArgumentType(), serializeLogicalType)) + fieldTypeToProto( + logicalType.getArgumentType(), serializeLogicalType, serializeUUID)) .setArgument( fieldValueToProto( logicalType.getArgumentType(), logicalType.getArgument())); @@ -226,7 +248,8 @@ static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean seriali builder.setLogicalType( SchemaApi.LogicalType.newBuilder() .setUrn(URN_BEAM_LOGICAL_MILLIS_INSTANT) - .setRepresentation(fieldTypeToProto(FieldType.INT64, serializeLogicalType)) + .setRepresentation( + fieldTypeToProto(FieldType.INT64, serializeLogicalType, serializeUUID)) .build()); break; case DECIMAL: @@ -235,7 +258,8 @@ static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean seriali builder.setLogicalType( SchemaApi.LogicalType.newBuilder() .setUrn(URN_BEAM_LOGICAL_DECIMAL) - .setRepresentation(fieldTypeToProto(FieldType.BYTES, serializeLogicalType)) + .setRepresentation( + fieldTypeToProto(FieldType.BYTES, serializeLogicalType, serializeUUID)) .build()); break; case BYTE: @@ -288,14 +312,14 @@ public static Schema schemaFromProto(SchemaApi.Schema protoSchema) { Schema schema = builder.build(); Preconditions.checkState(encodingLocationMap.size() == schema.getFieldCount()); - long dinstictEncodingPositions = encodingLocationMap.values().stream().distinct().count(); - Preconditions.checkState(dinstictEncodingPositions <= schema.getFieldCount()); - if (dinstictEncodingPositions < schema.getFieldCount() && schema.getFieldCount() > 0) { + long distinctEncodingPositions = encodingLocationMap.values().stream().distinct().count(); + Preconditions.checkState(distinctEncodingPositions <= schema.getFieldCount()); + if (distinctEncodingPositions < schema.getFieldCount() && schema.getFieldCount() > 0) { // This means that encoding positions were not specified in the proto. Generally, we don't // expect this to happen, // but if it does happen, we expect none to be specified - in which case the should all be // zero. - Preconditions.checkState(dinstictEncodingPositions == 1); + Preconditions.checkState(distinctEncodingPositions == 1); } else if (protoSchema.getEncodingPositionsSet()) { schema.setEncodingPositions(encodingLocationMap); } @@ -771,7 +795,8 @@ private static List optionsToProto(Schema.Options options) { protoOptions.add( SchemaApi.Option.newBuilder() .setName(name) - .setType(fieldTypeToProto(Objects.requireNonNull(options.getType(name)), false)) + .setType( + fieldTypeToProto(Objects.requireNonNull(options.getType(name)), false, false)) .setValue( fieldValueToProto( Objects.requireNonNull(options.getType(name)), options.getValue(name))) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java index 4a7573b036e2..31b6c8db2fed 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java @@ -65,7 +65,8 @@ private OneOfType(List fields, @Nullable Map enumMap) { enumerationType = EnumerationType.create(enumValues); } oneOfSchema = Schema.builder().addFields(nullableFields).build(); - schemaProtoRepresentation = SchemaTranslation.schemaToProto(oneOfSchema, false).toByteArray(); + schemaProtoRepresentation = + SchemaTranslation.schemaToProto(oneOfSchema, false, false).toByteArray(); } /** Create an {@link OneOfType} logical type. */ diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTranslationTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTranslationTest.java index 3b22addbf545..b082e2bb68ee 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTranslationTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTranslationTest.java @@ -214,6 +214,7 @@ public void toAndFromProto() throws Exception { public static class FromProtoToProtoTest { @Parameters(name = "{index}: {0}") public static Iterable data() { + ImmutableList.Builder listBuilder = ImmutableList.builder(); SchemaApi.Schema.Builder builder = SchemaApi.Schema.newBuilder(); // A go 'int' builder.addFields( @@ -232,6 +233,9 @@ public static Iterable data() { .setId(0) .setEncodingPosition(0) .build()); + SchemaApi.Schema singleFieldSchema = builder.build(); + listBuilder.add(singleFieldSchema); + // A pickled python object builder.addFields( SchemaApi.Field.newBuilder() @@ -294,21 +298,51 @@ public static Iterable data() { .setId(2) .setEncodingPosition(2) .build()); - SchemaApi.Schema unknownLogicalTypeSchema = builder.build(); + SchemaApi.Schema multipleFieldSchema = builder.build(); + listBuilder.add(multipleFieldSchema); - return ImmutableList.builder().add(unknownLogicalTypeSchema).build(); + builder.clear(); + builder.addFields( + SchemaApi.Field.newBuilder() + .setName("nested") + .setType( + SchemaApi.FieldType.newBuilder() + .setRowType( + SchemaApi.RowType.newBuilder().setSchema(singleFieldSchema).build()) + .build()) + .build()); + SchemaApi.Schema nestedSchema = builder.build(); + listBuilder.add(nestedSchema); + + return listBuilder.build(); } @Parameter(0) public SchemaApi.Schema schemaProto; + private void clearIds(SchemaApi.Schema.Builder builder) { + builder.clearId(); + for (SchemaApi.Field.Builder field : builder.getFieldsBuilderList()) { + if (field.hasType() + && field.getType().hasRowType() + && field.getType().getRowType().hasSchema()) { + clearIds(field.getTypeBuilder().getRowTypeBuilder().getSchemaBuilder()); + } + } + } + @Test public void fromProtoAndToProto() throws Exception { Schema decodedSchema = SchemaTranslation.schemaFromProto(schemaProto); SchemaApi.Schema reencodedSchemaProto = SchemaTranslation.schemaToProto(decodedSchema, true); + SchemaApi.Schema.Builder builder = reencodedSchemaProto.toBuilder(); + clearIds(builder); + assertThat(builder.build(), equalTo(schemaProto)); - assertThat(reencodedSchemaProto, equalTo(schemaProto)); + SchemaApi.Schema reencodedSchemaProtoWithoutUUID = + SchemaTranslation.schemaToProto(decodedSchema, true, false); + assertThat(reencodedSchemaProtoWithoutUUID, equalTo(schemaProto)); } } @@ -432,8 +466,8 @@ public static Iterable data() { public Schema.FieldType fieldType; @Test - public void testLogicalTypeSerializeDeserilizeCorrectly() { - SchemaApi.FieldType proto = SchemaTranslation.fieldTypeToProto(fieldType, true); + public void testLogicalTypeSerializeDeserializeCorrectly() { + SchemaApi.FieldType proto = SchemaTranslation.fieldTypeToProto(fieldType, true, false); Schema.FieldType translated = SchemaTranslation.fieldTypeFromProto(proto); assertThat( @@ -451,7 +485,7 @@ public void testLogicalTypeSerializeDeserilizeCorrectly() { @Test public void testLogicalTypeFromToProtoCorrectly() { - SchemaApi.FieldType proto = SchemaTranslation.fieldTypeToProto(fieldType, false); + SchemaApi.FieldType proto = SchemaTranslation.fieldTypeToProto(fieldType, false, false); Schema.FieldType translated = SchemaTranslation.fieldTypeFromProto(proto); if (STANDARD_LOGICAL_TYPES.containsKey(translated.getLogicalType().getIdentifier())) { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/logicaltypes/LogicalTypesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/logicaltypes/LogicalTypesTest.java index fc264c8104c4..e1590408021a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/logicaltypes/LogicalTypesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/logicaltypes/LogicalTypesTest.java @@ -76,6 +76,16 @@ public void testOneOf() { union = intOneOf.getLogicalTypeValue(0, OneOfType.Value.class); assertEquals("int32", oneOf.getCaseEnumType().toString(union.getCaseType())); assertEquals(42, (int) union.getValue()); + + // Validate schema equality. + OneOfType oneOf2 = + OneOfType.create(Field.of("string", FieldType.STRING), Field.of("int32", FieldType.INT32)); + assertEquals(oneOf.getOneOfSchema(), oneOf2.getOneOfSchema()); + Schema schema2 = Schema.builder().addLogicalTypeField("union", oneOf2).build(); + assertEquals(schema, schema2); + Row stringOneOf2 = + Row.withSchema(schema2).addValue(oneOf.createValue("string", "stringValue")).build(); + assertEquals(stringOneOf, stringOneOf2); } @Test diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java index c23a771f3cc8..d5f1745a9a2c 100644 --- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java +++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java @@ -311,7 +311,6 @@ Row buildOrGetKwargsRow() { Schema schema = generateSchemaFromFieldValues( kwargsMap.values().toArray(), kwargsMap.keySet().toArray(new String[] {})); - schema.setUUID(UUID.randomUUID()); return Row.withSchema(schema) .addValues(convertComplexTypesToRows(kwargsMap.values().toArray())) .build(); @@ -367,7 +366,6 @@ private Object[] convertComplexTypesToRows(@Nullable Object @NonNull [] values) @VisibleForTesting Row buildOrGetArgsRow() { Schema schema = generateSchemaFromFieldValues(argsArray, null); - schema.setUUID(UUID.randomUUID()); Object[] convertedValues = convertComplexTypesToRows(argsArray); return Row.withSchema(schema).addValues(convertedValues).build(); } @@ -421,7 +419,6 @@ ExternalTransforms.ExternalConfigurationPayload generatePayload() { schemaBuilder.addRowField("kwargs", kwargsRow.getSchema()); } Schema payloadSchema = schemaBuilder.build(); - payloadSchema.setUUID(UUID.randomUUID()); Row.Builder payloadRowBuilder = Row.withSchema(payloadSchema); payloadRowBuilder.addValue(fullyQualifiedName); if (argsRow.getValues().size() > 0) { diff --git a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java index 0d122646d899..c0f324c25606 100644 --- a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java +++ b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java @@ -38,6 +38,7 @@ import java.util.Map; import java.util.stream.Collectors; import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.SchemaApi; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.managed.testing.TestSchemaTransformProvider; @@ -54,6 +55,7 @@ import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.junit.Test; public class ManagedSchemaTransformTranslationTest { @@ -169,26 +171,33 @@ public void testProtoTranslation() throws Exception { .withFieldValue("transform_identifier", TestSchemaTransformProvider.IDENTIFIER) .withFieldValue("config", yamlStringConfig) .build(); - Map expectedAnnotations = - ImmutableMap.builder() - .put( - BeamUrns.getConstant(SCHEMATRANSFORM_URN_KEY), - ByteString.copyFromUtf8(MANAGED_TRANSFORM_URN)) - .put( - BeamUrns.getConstant(MANAGED_UNDERLYING_TRANSFORM_URN_KEY), - ByteString.copyFromUtf8(TestSchemaTransformProvider.IDENTIFIER)) - .put( - BeamUrns.getConstant(CONFIG_ROW_KEY), - ByteString.copyFrom( - CoderUtils.encodeToByteArray( - RowCoder.of(PROVIDER.configurationSchema()), managedConfigRow))) - .put( - BeamUrns.getConstant(CONFIG_ROW_SCHEMA_KEY), - ByteString.copyFrom( - SchemaTranslation.schemaToProto(PROVIDER.configurationSchema(), true) - .toByteArray())) - .build(); - assertEquals(expectedAnnotations, convertedTransform.getAnnotationsMap()); + assertEquals( + ImmutableSet.of( + BeamUrns.getConstant(SCHEMATRANSFORM_URN_KEY), + BeamUrns.getConstant(MANAGED_UNDERLYING_TRANSFORM_URN_KEY), + BeamUrns.getConstant(CONFIG_ROW_KEY), + BeamUrns.getConstant(CONFIG_ROW_SCHEMA_KEY)), + convertedTransform.getAnnotationsMap().keySet()); + assertEquals( + ByteString.copyFromUtf8(MANAGED_TRANSFORM_URN), + convertedTransform.getAnnotationsMap().get(BeamUrns.getConstant(SCHEMATRANSFORM_URN_KEY))); + assertEquals( + ByteString.copyFromUtf8(TestSchemaTransformProvider.IDENTIFIER), + convertedTransform + .getAnnotationsMap() + .get(BeamUrns.getConstant(MANAGED_UNDERLYING_TRANSFORM_URN_KEY))); + Schema annotationSchema = + SchemaTranslation.schemaFromProto( + SchemaApi.Schema.parseFrom( + convertedTransform + .getAnnotationsMap() + .get(BeamUrns.getConstant(CONFIG_ROW_SCHEMA_KEY)))); + assertEquals(PROVIDER.configurationSchema(), annotationSchema); + assertEquals( + managedConfigRow, + CoderUtils.decodeFromByteString( + RowCoder.of(annotationSchema), + convertedTransform.getAnnotationsMap().get(BeamUrns.getConstant(CONFIG_ROW_KEY)))); // Check that the spec proto contains correct values RunnerApi.FunctionSpec spec = convertedTransform.getSpec();