diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/YamlUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/YamlUtils.java index 122f2d1963b9..e631e166e8be 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/YamlUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/YamlUtils.java @@ -20,6 +20,7 @@ import static org.apache.beam.sdk.values.Row.toRow; import java.math.BigDecimal; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.function.Function; @@ -181,4 +182,11 @@ public static String yamlStringFromMap(@Nullable Map map) { } return new Yaml().dumpAsMap(map); } + + public static Map yamlStringToMap(@Nullable String yaml) { + if (yaml == null || yaml.isEmpty()) { + return Collections.emptyMap(); + } + return new Yaml().load(yaml); + } } diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle index 269ddb3f5eb2..3e095a2bacca 100644 --- a/sdks/java/io/kafka/build.gradle +++ b/sdks/java/io/kafka/build.gradle @@ -90,6 +90,7 @@ dependencies { provided library.java.everit_json_schema testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") testImplementation project(":sdks:java:io:synthetic") + testImplementation project(":sdks:java:managed") testImplementation project(path: ":sdks:java:extensions:avro", configuration: "testRuntimeMigration") testImplementation project(path: ":sdks:java:extensions:protobuf", configuration: "testRuntimeMigration") testImplementation project(path: ":sdks:java:io:common", configuration: "testRuntimeMigration") diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java index 2776c388f7cc..13240ea9dc40 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java @@ -151,11 +151,10 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } }; } - - if (format.equals("RAW")) { + if ("RAW".equals(format)) { beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build(); valueMapper = getRawBytesToRowFunction(beamSchema); - } else if (format.equals("PROTO")) { + } else if ("PROTO".equals(format)) { String fileDescriptorPath = configuration.getFileDescriptorPath(); String messageName = configuration.getMessageName(); if (fileDescriptorPath != null) { @@ -165,7 +164,7 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { beamSchema = ProtoByteUtils.getBeamSchemaFromProtoSchema(inputSchema, messageName); valueMapper = ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(inputSchema, messageName); } - } else if (format.equals("JSON")) { + } else if ("JSON".equals(format)) { beamSchema = JsonUtils.beamSchemaFromJsonSchema(inputSchema); valueMapper = JsonUtils.getJsonBytesToRowFunction(beamSchema); } else { diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java index f6e231c758a5..d5962a737baf 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java @@ -18,16 +18,25 @@ package org.apache.beam.sdk.io.kafka; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.ServiceLoader; import java.util.stream.Collectors; import java.util.stream.StreamSupport; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.managed.ManagedTransformConstants; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.schemas.utils.YamlUtils; +import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; @@ -51,7 +60,7 @@ public class KafkaReadSchemaTransformProviderTest { + " string name = 2;\n" + " bool active = 3;\n" + "\n" - + " // Nested field\n" + + " // Nested field\n\n" + " message Address {\n" + " string street = 1;\n" + " string city = 2;\n" @@ -284,4 +293,46 @@ public void testBuildTransformWithoutProtoSchemaFormat() { .setMessageName("MyMessage") .build())); } + + @Test + public void testBuildTransformWithManaged() { + List configs = + Arrays.asList( + "topic: topic_1\n" + "bootstrap_servers: some bootstrap\n" + "data_format: RAW", + "topic: topic_2\n" + + "bootstrap_servers: some bootstrap\n" + + "schema: '{\"type\":\"record\",\"name\":\"my_record\",\"fields\":[{\"name\":\"bool\",\"type\":\"boolean\"}]}'", + "topic: topic_3\n" + + "bootstrap_servers: some bootstrap\n" + + "schema_registry_url: some-url\n" + + "schema_registry_subject: some-subject\n" + + "data_format: RAW", + "topic: topic_4\n" + + "bootstrap_servers: some bootstrap\n" + + "data_format: PROTO\n" + + "schema: '" + + PROTO_SCHEMA + + "'\n" + + "message_name: MyMessage"); + + for (String config : configs) { + // Kafka Read SchemaTransform gets built in ManagedSchemaTransformProvider's expand + Managed.read(Managed.KAFKA) + .withConfig(YamlUtils.yamlStringToMap(config)) + .expand(PCollectionRowTuple.empty(Pipeline.create())); + } + } + + @Test + public void testManagedMappings() { + KafkaReadSchemaTransformProvider provider = new KafkaReadSchemaTransformProvider(); + Map mapping = ManagedTransformConstants.MAPPINGS.get(provider.identifier()); + + assertNotNull(mapping); + + List configSchemaFieldNames = provider.configurationSchema().getFieldNames(); + for (String paramName : mapping.values()) { + assertTrue(configSchemaFieldNames.contains(paramName)); + } + } } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java index 48d463a8f436..60bff89b3555 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java @@ -18,17 +18,24 @@ package org.apache.beam.sdk.io.kafka; import static org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.getRowToRawBytesFunction; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import java.io.UnsupportedEncodingException; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Objects; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils; import org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform.ErrorCounterFn; +import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.managed.ManagedTransformConstants; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling; import org.apache.beam.sdk.schemas.utils.JsonUtils; +import org.apache.beam.sdk.schemas.utils.YamlUtils; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; @@ -36,6 +43,7 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; @@ -185,4 +193,55 @@ public void testKafkaErrorFnProtoSuccess() { output.get(ERROR_TAG).setRowSchema(errorSchema); p.run().waitUntilFinish(); } + + private static final String PROTO_SCHEMA = + "syntax = \"proto3\";\n" + + "\n" + + "message MyMessage {\n" + + " int32 id = 1;\n" + + " string name = 2;\n" + + " bool active = 3;\n" + + "}"; + + @Test + public void testBuildTransformWithManaged() { + List configs = + Arrays.asList( + "topic: topic_1\n" + "bootstrap_servers: some bootstrap\n" + "data_format: RAW", + "topic: topic_2\n" + + "bootstrap_servers: some bootstrap\n" + + "producer_config_updates: {\"foo\": \"bar\"}\n" + + "data_format: AVRO", + "topic: topic_3\n" + + "bootstrap_servers: some bootstrap\n" + + "data_format: PROTO\n" + + "schema: '" + + PROTO_SCHEMA + + "'\n" + + "message_name: MyMessage"); + + for (String config : configs) { + // Kafka Write SchemaTransform gets built in ManagedSchemaTransformProvider's expand + Managed.write(Managed.KAFKA) + .withConfig(YamlUtils.yamlStringToMap(config)) + .expand( + PCollectionRowTuple.of( + "input", + Pipeline.create() + .apply(Create.empty(Schema.builder().addByteArrayField("bytes").build())))); + } + } + + @Test + public void testManagedMappings() { + KafkaWriteSchemaTransformProvider provider = new KafkaWriteSchemaTransformProvider(); + Map mapping = ManagedTransformConstants.MAPPINGS.get(provider.identifier()); + + assertNotNull(mapping); + + List configSchemaFieldNames = provider.configurationSchema().getFieldNames(); + for (String paramName : mapping.values()) { + assertTrue(configSchemaFieldNames.contains(paramName)); + } + } } diff --git a/sdks/java/managed/build.gradle b/sdks/java/managed/build.gradle index f06df27429b1..add0d7f3cc0d 100644 --- a/sdks/java/managed/build.gradle +++ b/sdks/java/managed/build.gradle @@ -29,6 +29,7 @@ ext.summary = """Library that provides managed IOs.""" dependencies { implementation project(path: ":sdks:java:core", configuration: "shadow") implementation library.java.vendored_guava_32_1_2_jre + implementation library.java.slf4j_api testImplementation library.java.junit testRuntimeOnly "org.yaml:snakeyaml:2.0" diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java index d24a3fd88ddc..da4a0853fb39 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java @@ -17,9 +17,6 @@ */ package org.apache.beam.sdk.managed; -import static org.apache.beam.sdk.managed.ManagedTransformConstants.ICEBERG_READ; -import static org.apache.beam.sdk.managed.ManagedTransformConstants.ICEBERG_WRITE; - import com.google.auto.value.AutoValue; import java.util.ArrayList; import java.util.List; @@ -80,12 +77,19 @@ public class Managed { // TODO: Dynamically generate a list of supported transforms public static final String ICEBERG = "iceberg"; + public static final String KAFKA = "kafka"; // Supported SchemaTransforms public static final Map READ_TRANSFORMS = - ImmutableMap.builder().put(ICEBERG, ICEBERG_READ).build(); + ImmutableMap.builder() + .put(ICEBERG, ManagedTransformConstants.ICEBERG_READ) + .put(KAFKA, ManagedTransformConstants.KAFKA_READ) + .build(); public static final Map WRITE_TRANSFORMS = - ImmutableMap.builder().put(ICEBERG, ICEBERG_WRITE).build(); + ImmutableMap.builder() + .put(ICEBERG, ManagedTransformConstants.ICEBERG_WRITE) + .put(KAFKA, ManagedTransformConstants.KAFKA_WRITE) + .build(); /** * Instantiates a {@link Managed.ManagedTransform} transform for the specified source. The diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java index ff08e79e5eac..54e1404c650c 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.managed; +import static org.apache.beam.sdk.managed.ManagedTransformConstants.MAPPINGS; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import com.google.auto.service.AutoService; @@ -49,10 +50,13 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Predicates; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @AutoService(SchemaTransformProvider.class) public class ManagedSchemaTransformProvider extends TypedSchemaTransformProvider { + private static final Logger LOG = LoggerFactory.getLogger(ManagedSchemaTransformProvider.class); @Override public String identifier() { @@ -179,6 +183,11 @@ static class ManagedSchemaTransform extends SchemaTransform { @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { + LOG.debug( + "Building transform \"{}\" with Row configuration: {}", + underlyingTransformProvider.identifier(), + underlyingTransformConfig); + return input.apply(underlyingTransformProvider.from(underlyingTransformConfig)); } @@ -205,7 +214,26 @@ Row getConfigurationRow() { static Row getRowConfig(ManagedConfig config, Schema transformSchema) { // May return an empty row (perhaps the underlying transform doesn't have any required // parameters) - return YamlUtils.toBeamRow(config.resolveUnderlyingConfig(), transformSchema, false); + String yamlConfig = config.resolveUnderlyingConfig(); + Map configMap = YamlUtils.yamlStringToMap(yamlConfig); + + // The config Row object will be used to build the underlying SchemaTransform. + // If a mapping for the SchemaTransform exists, we use it to update parameter names and align + // with the underlying config schema + Map mapping = MAPPINGS.get(config.getTransformIdentifier()); + if (mapping != null && configMap != null) { + Map remappedConfig = new HashMap<>(); + for (Map.Entry entry : configMap.entrySet()) { + String paramName = entry.getKey(); + if (mapping.containsKey(paramName)) { + paramName = mapping.get(paramName); + } + remappedConfig.put(paramName, entry.getValue()); + } + configMap = remappedConfig; + } + + return YamlUtils.toBeamRow(configMap, transformSchema, false); } Map getAllProviders() { diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java index 48735d8c33a3..8165633cf15e 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java @@ -17,9 +17,59 @@ */ package org.apache.beam.sdk.managed; -/** This class contains constants for supported managed transform identifiers. */ +import java.util.Map; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; + +/** + * This class contains constants for supported managed transforms, including: + * + *
    + *
  • Identifiers of supported transforms + *
  • Configuration parameter renaming + *
+ * + *

Configuration parameter names exposed via Managed interface may differ from the parameter + * names in the underlying SchemaTransform implementation. + * + *

Any naming differences are laid out in {@link ManagedTransformConstants#MAPPINGS} to update + * the configuration object before it's used to build the underlying transform. + * + *

Mappings don't need to include ALL underlying parameter names, as we may not want to expose + * every single parameter through the Managed interface. + */ public class ManagedTransformConstants { public static final String ICEBERG_READ = "beam:schematransform:org.apache.beam:iceberg_read:v1"; public static final String ICEBERG_WRITE = "beam:schematransform:org.apache.beam:iceberg_write:v1"; + public static final String KAFKA_READ = "beam:schematransform:org.apache.beam:kafka_read:v1"; + public static final String KAFKA_WRITE = "beam:schematransform:org.apache.beam:kafka_write:v1"; + + private static final Map KAFKA_READ_MAPPINGS = + ImmutableMap.builder() + .put("topic", "topic") + .put("bootstrap_servers", "bootstrapServers") + .put("consumer_config_updates", "consumerConfigUpdates") + .put("confluent_schema_registry_url", "confluentSchemaRegistryUrl") + .put("confluent_schema_registry_subject", "confluentSchemaRegistrySubject") + .put("data_format", "format") + .put("schema", "schema") + .put("file_descriptor_path", "fileDescriptorPath") + .put("message_name", "messageName") + .build(); + + private static final Map KAFKA_WRITE_MAPPINGS = + ImmutableMap.builder() + .put("topic", "topic") + .put("bootstrap_servers", "bootstrapServers") + .put("producer_config_updates", "producerConfigUpdates") + .put("data_format", "format") + .put("file_descriptor_path", "fileDescriptorPath") + .put("message_name", "messageName") + .build(); + + public static final Map> MAPPINGS = + ImmutableMap.>builder() + .put(KAFKA_READ, KAFKA_READ_MAPPINGS) + .put(KAFKA_WRITE, KAFKA_WRITE_MAPPINGS) + .build(); }