Skip to content

Commit

Permalink
Support Kafka Managed IO (#31172)
Browse files Browse the repository at this point in the history
* managed kafka read

* managed kafka write
  • Loading branch information
ahmedabu98 authored May 9, 2024
1 parent e0bc8e7 commit 365c2d9
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.apache.beam.sdk.values.Row.toRow;

import java.math.BigDecimal;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
Expand Down Expand Up @@ -181,4 +182,11 @@ public static String yamlStringFromMap(@Nullable Map<String, Object> map) {
}
return new Yaml().dumpAsMap(map);
}

public static Map<String, Object> yamlStringToMap(@Nullable String yaml) {
if (yaml == null || yaml.isEmpty()) {
return Collections.emptyMap();
}
return new Yaml().load(yaml);
}
}
1 change: 1 addition & 0 deletions sdks/java/io/kafka/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ dependencies {
provided library.java.everit_json_schema
testImplementation project(path: ":sdks:java:core", configuration: "shadowTest")
testImplementation project(":sdks:java:io:synthetic")
testImplementation project(":sdks:java:managed")
testImplementation project(path: ":sdks:java:extensions:avro", configuration: "testRuntimeMigration")
testImplementation project(path: ":sdks:java:extensions:protobuf", configuration: "testRuntimeMigration")
testImplementation project(path: ":sdks:java:io:common", configuration: "testRuntimeMigration")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,10 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
}
};
}

if (format.equals("RAW")) {
if ("RAW".equals(format)) {
beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build();
valueMapper = getRawBytesToRowFunction(beamSchema);
} else if (format.equals("PROTO")) {
} else if ("PROTO".equals(format)) {
String fileDescriptorPath = configuration.getFileDescriptorPath();
String messageName = configuration.getMessageName();
if (fileDescriptorPath != null) {
Expand All @@ -165,7 +164,7 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
beamSchema = ProtoByteUtils.getBeamSchemaFromProtoSchema(inputSchema, messageName);
valueMapper = ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(inputSchema, messageName);
}
} else if (format.equals("JSON")) {
} else if ("JSON".equals(format)) {
beamSchema = JsonUtils.beamSchemaFromJsonSchema(inputSchema);
valueMapper = JsonUtils.getJsonBytesToRowFunction(beamSchema);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,25 @@
package org.apache.beam.sdk.io.kafka;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.ServiceLoader;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.managed.Managed;
import org.apache.beam.sdk.managed.ManagedTransformConstants;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.utils.YamlUtils;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
Expand All @@ -51,7 +60,7 @@ public class KafkaReadSchemaTransformProviderTest {
+ " string name = 2;\n"
+ " bool active = 3;\n"
+ "\n"
+ " // Nested field\n"
+ " // Nested field\n\n"
+ " message Address {\n"
+ " string street = 1;\n"
+ " string city = 2;\n"
Expand Down Expand Up @@ -284,4 +293,46 @@ public void testBuildTransformWithoutProtoSchemaFormat() {
.setMessageName("MyMessage")
.build()));
}

@Test
public void testBuildTransformWithManaged() {
List<String> configs =
Arrays.asList(
"topic: topic_1\n" + "bootstrap_servers: some bootstrap\n" + "data_format: RAW",
"topic: topic_2\n"
+ "bootstrap_servers: some bootstrap\n"
+ "schema: '{\"type\":\"record\",\"name\":\"my_record\",\"fields\":[{\"name\":\"bool\",\"type\":\"boolean\"}]}'",
"topic: topic_3\n"
+ "bootstrap_servers: some bootstrap\n"
+ "schema_registry_url: some-url\n"
+ "schema_registry_subject: some-subject\n"
+ "data_format: RAW",
"topic: topic_4\n"
+ "bootstrap_servers: some bootstrap\n"
+ "data_format: PROTO\n"
+ "schema: '"
+ PROTO_SCHEMA
+ "'\n"
+ "message_name: MyMessage");

for (String config : configs) {
// Kafka Read SchemaTransform gets built in ManagedSchemaTransformProvider's expand
Managed.read(Managed.KAFKA)
.withConfig(YamlUtils.yamlStringToMap(config))
.expand(PCollectionRowTuple.empty(Pipeline.create()));
}
}

@Test
public void testManagedMappings() {
KafkaReadSchemaTransformProvider provider = new KafkaReadSchemaTransformProvider();
Map<String, String> mapping = ManagedTransformConstants.MAPPINGS.get(provider.identifier());

assertNotNull(mapping);

List<String> configSchemaFieldNames = provider.configurationSchema().getFieldNames();
for (String paramName : mapping.values()) {
assertTrue(configSchemaFieldNames.contains(paramName));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,32 @@
package org.apache.beam.sdk.io.kafka;

import static org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.getRowToRawBytesFunction;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
import org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform.ErrorCounterFn;
import org.apache.beam.sdk.managed.Managed;
import org.apache.beam.sdk.managed.ManagedTransformConstants;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
import org.apache.beam.sdk.schemas.utils.JsonUtils;
import org.apache.beam.sdk.schemas.utils.YamlUtils;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
Expand Down Expand Up @@ -185,4 +193,55 @@ public void testKafkaErrorFnProtoSuccess() {
output.get(ERROR_TAG).setRowSchema(errorSchema);
p.run().waitUntilFinish();
}

private static final String PROTO_SCHEMA =
"syntax = \"proto3\";\n"
+ "\n"
+ "message MyMessage {\n"
+ " int32 id = 1;\n"
+ " string name = 2;\n"
+ " bool active = 3;\n"
+ "}";

@Test
public void testBuildTransformWithManaged() {
List<String> configs =
Arrays.asList(
"topic: topic_1\n" + "bootstrap_servers: some bootstrap\n" + "data_format: RAW",
"topic: topic_2\n"
+ "bootstrap_servers: some bootstrap\n"
+ "producer_config_updates: {\"foo\": \"bar\"}\n"
+ "data_format: AVRO",
"topic: topic_3\n"
+ "bootstrap_servers: some bootstrap\n"
+ "data_format: PROTO\n"
+ "schema: '"
+ PROTO_SCHEMA
+ "'\n"
+ "message_name: MyMessage");

for (String config : configs) {
// Kafka Write SchemaTransform gets built in ManagedSchemaTransformProvider's expand
Managed.write(Managed.KAFKA)
.withConfig(YamlUtils.yamlStringToMap(config))
.expand(
PCollectionRowTuple.of(
"input",
Pipeline.create()
.apply(Create.empty(Schema.builder().addByteArrayField("bytes").build()))));
}
}

@Test
public void testManagedMappings() {
KafkaWriteSchemaTransformProvider provider = new KafkaWriteSchemaTransformProvider();
Map<String, String> mapping = ManagedTransformConstants.MAPPINGS.get(provider.identifier());

assertNotNull(mapping);

List<String> configSchemaFieldNames = provider.configurationSchema().getFieldNames();
for (String paramName : mapping.values()) {
assertTrue(configSchemaFieldNames.contains(paramName));
}
}
}
1 change: 1 addition & 0 deletions sdks/java/managed/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ ext.summary = """Library that provides managed IOs."""
dependencies {
implementation project(path: ":sdks:java:core", configuration: "shadow")
implementation library.java.vendored_guava_32_1_2_jre
implementation library.java.slf4j_api

testImplementation library.java.junit
testRuntimeOnly "org.yaml:snakeyaml:2.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
*/
package org.apache.beam.sdk.managed;

import static org.apache.beam.sdk.managed.ManagedTransformConstants.ICEBERG_READ;
import static org.apache.beam.sdk.managed.ManagedTransformConstants.ICEBERG_WRITE;

import com.google.auto.value.AutoValue;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -80,12 +77,19 @@ public class Managed {

// TODO: Dynamically generate a list of supported transforms
public static final String ICEBERG = "iceberg";
public static final String KAFKA = "kafka";

// Supported SchemaTransforms
public static final Map<String, String> READ_TRANSFORMS =
ImmutableMap.<String, String>builder().put(ICEBERG, ICEBERG_READ).build();
ImmutableMap.<String, String>builder()
.put(ICEBERG, ManagedTransformConstants.ICEBERG_READ)
.put(KAFKA, ManagedTransformConstants.KAFKA_READ)
.build();
public static final Map<String, String> WRITE_TRANSFORMS =
ImmutableMap.<String, String>builder().put(ICEBERG, ICEBERG_WRITE).build();
ImmutableMap.<String, String>builder()
.put(ICEBERG, ManagedTransformConstants.ICEBERG_WRITE)
.put(KAFKA, ManagedTransformConstants.KAFKA_WRITE)
.build();

/**
* Instantiates a {@link Managed.ManagedTransform} transform for the specified source. The
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.managed;

import static org.apache.beam.sdk.managed.ManagedTransformConstants.MAPPINGS;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;

import com.google.auto.service.AutoService;
Expand Down Expand Up @@ -49,10 +50,13 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Predicates;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@AutoService(SchemaTransformProvider.class)
public class ManagedSchemaTransformProvider
extends TypedSchemaTransformProvider<ManagedSchemaTransformProvider.ManagedConfig> {
private static final Logger LOG = LoggerFactory.getLogger(ManagedSchemaTransformProvider.class);

@Override
public String identifier() {
Expand Down Expand Up @@ -179,6 +183,11 @@ static class ManagedSchemaTransform extends SchemaTransform {

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
LOG.debug(
"Building transform \"{}\" with Row configuration: {}",
underlyingTransformProvider.identifier(),
underlyingTransformConfig);

return input.apply(underlyingTransformProvider.from(underlyingTransformConfig));
}

Expand All @@ -205,7 +214,26 @@ Row getConfigurationRow() {
static Row getRowConfig(ManagedConfig config, Schema transformSchema) {
// May return an empty row (perhaps the underlying transform doesn't have any required
// parameters)
return YamlUtils.toBeamRow(config.resolveUnderlyingConfig(), transformSchema, false);
String yamlConfig = config.resolveUnderlyingConfig();
Map<String, Object> configMap = YamlUtils.yamlStringToMap(yamlConfig);

// The config Row object will be used to build the underlying SchemaTransform.
// If a mapping for the SchemaTransform exists, we use it to update parameter names and align
// with the underlying config schema
Map<String, String> mapping = MAPPINGS.get(config.getTransformIdentifier());
if (mapping != null && configMap != null) {
Map<String, Object> remappedConfig = new HashMap<>();
for (Map.Entry<String, Object> entry : configMap.entrySet()) {
String paramName = entry.getKey();
if (mapping.containsKey(paramName)) {
paramName = mapping.get(paramName);
}
remappedConfig.put(paramName, entry.getValue());
}
configMap = remappedConfig;
}

return YamlUtils.toBeamRow(configMap, transformSchema, false);
}

Map<String, SchemaTransformProvider> getAllProviders() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,59 @@
*/
package org.apache.beam.sdk.managed;

/** This class contains constants for supported managed transform identifiers. */
import java.util.Map;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;

/**
* This class contains constants for supported managed transforms, including:
*
* <ul>
* <li>Identifiers of supported transforms
* <li>Configuration parameter renaming
* </ul>
*
* <p>Configuration parameter names exposed via Managed interface may differ from the parameter
* names in the underlying SchemaTransform implementation.
*
* <p>Any naming differences are laid out in {@link ManagedTransformConstants#MAPPINGS} to update
* the configuration object before it's used to build the underlying transform.
*
* <p>Mappings don't need to include ALL underlying parameter names, as we may not want to expose
* every single parameter through the Managed interface.
*/
public class ManagedTransformConstants {
public static final String ICEBERG_READ = "beam:schematransform:org.apache.beam:iceberg_read:v1";
public static final String ICEBERG_WRITE =
"beam:schematransform:org.apache.beam:iceberg_write:v1";
public static final String KAFKA_READ = "beam:schematransform:org.apache.beam:kafka_read:v1";
public static final String KAFKA_WRITE = "beam:schematransform:org.apache.beam:kafka_write:v1";

private static final Map<String, String> KAFKA_READ_MAPPINGS =
ImmutableMap.<String, String>builder()
.put("topic", "topic")
.put("bootstrap_servers", "bootstrapServers")
.put("consumer_config_updates", "consumerConfigUpdates")
.put("confluent_schema_registry_url", "confluentSchemaRegistryUrl")
.put("confluent_schema_registry_subject", "confluentSchemaRegistrySubject")
.put("data_format", "format")
.put("schema", "schema")
.put("file_descriptor_path", "fileDescriptorPath")
.put("message_name", "messageName")
.build();

private static final Map<String, String> KAFKA_WRITE_MAPPINGS =
ImmutableMap.<String, String>builder()
.put("topic", "topic")
.put("bootstrap_servers", "bootstrapServers")
.put("producer_config_updates", "producerConfigUpdates")
.put("data_format", "format")
.put("file_descriptor_path", "fileDescriptorPath")
.put("message_name", "messageName")
.build();

public static final Map<String, Map<String, String>> MAPPINGS =
ImmutableMap.<String, Map<String, String>>builder()
.put(KAFKA_READ, KAFKA_READ_MAPPINGS)
.put(KAFKA_WRITE, KAFKA_WRITE_MAPPINGS)
.build();
}

0 comments on commit 365c2d9

Please sign in to comment.