Skip to content

Commit

Permalink
Add Protobuf oneof support for Confluent schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjshook authored and hashhar committed Apr 23, 2023
1 parent e118541 commit 02cc332
Show file tree
Hide file tree
Showing 10 changed files with 474 additions and 38 deletions.
17 changes: 17 additions & 0 deletions lib/trino-record-decoder/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@
<artifactId>protobuf-java</artifactId>
</dependency>

<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.wire</groupId>
<artifactId>wire-runtime-jvm</artifactId>
Expand Down Expand Up @@ -122,6 +127,18 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.confluent</groupId>
<artifactId>kafka-protobuf-provider</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.kafka</groupId>
<artifactId>kafka-clients</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import com.google.common.collect.ImmutableSet;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.OneofDescriptor;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.util.JsonFormat;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.decoder.DecoderColumnHandle;
import io.trino.decoder.FieldValueProvider;
import io.trino.spi.TrinoException;
Expand All @@ -33,21 +37,31 @@
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;

import javax.annotation.Nullable;

import java.util.List;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR;
import static io.trino.spi.type.StandardTypes.JSON;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class ProtobufColumnDecoder
{
private static final Slice EMPTY_JSON = Slices.utf8Slice("{}");

private static final Set<Type> SUPPORTED_PRIMITIVE_TYPES = ImmutableSet.of(
BooleanType.BOOLEAN,
TinyintType.TINYINT,
Expand All @@ -61,11 +75,15 @@ public class ProtobufColumnDecoder
private final Type columnType;
private final String columnMapping;
private final String columnName;
private final TypeManager typeManager;
private final Type jsonType;

public ProtobufColumnDecoder(DecoderColumnHandle columnHandle)
public ProtobufColumnDecoder(DecoderColumnHandle columnHandle, TypeManager typeManager)
{
try {
requireNonNull(columnHandle, "columnHandle is null");
this.typeManager = requireNonNull(typeManager, "typeManager is null");
this.jsonType = typeManager.getType(new TypeSignature(JSON));
this.columnType = columnHandle.getType();
this.columnMapping = columnHandle.getMapping();
this.columnName = columnHandle.getName();
Expand All @@ -81,7 +99,7 @@ public ProtobufColumnDecoder(DecoderColumnHandle columnHandle)
}
}

private static boolean isSupportedType(Type type)
private boolean isSupportedType(Type type)
{
if (isSupportedPrimitive(type)) {
return true;
Expand All @@ -106,7 +124,8 @@ private static boolean isSupportedType(Type type)
}
return true;
}
return false;

return type.equals(jsonType);
}

private static boolean isSupportedPrimitive(Type type)
Expand All @@ -118,7 +137,7 @@ private static boolean isSupportedPrimitive(Type type)

public FieldValueProvider decodeField(DynamicMessage dynamicMessage)
{
return new ProtobufValueProvider(locateField(dynamicMessage, columnMapping), columnType, columnName);
return new ProtobufValueProvider(locateField(dynamicMessage, columnMapping), columnType, columnName, typeManager);
}

@Nullable
Expand All @@ -128,8 +147,15 @@ private static Object locateField(DynamicMessage message, String columnMapping)
Optional<Descriptor> valueDescriptor = Optional.of(message.getDescriptorForType());
for (String pathElement : Splitter.on('/').omitEmptyStrings().split(columnMapping)) {
if (valueDescriptor.filter(descriptor -> descriptor.findFieldByName(pathElement) != null).isEmpty()) {
return null;
// Search the message to see if this column is oneof type
Optional<OneofDescriptor> oneofDescriptor = message.getDescriptorForType().getOneofs().stream()
.filter(descriptor -> descriptor.getName().equals(columnMapping))
.findFirst();

return oneofDescriptor.map(descriptor -> createOneofJson(message, descriptor))
.orElse(null);
}

FieldDescriptor fieldDescriptor = valueDescriptor.get().findFieldByName(pathElement);
value = ((DynamicMessage) value).getField(fieldDescriptor);
valueDescriptor = getDescriptor(fieldDescriptor);
Expand All @@ -144,4 +170,40 @@ private static Optional<Descriptor> getDescriptor(FieldDescriptor fieldDescripto
}
return Optional.empty();
}

private static Object createOneofJson(DynamicMessage message, OneofDescriptor descriptor)
{
// Collect all oneof field names from the descriptor
Set<String> oneofColumns = descriptor.getFields().stream()
.map(FieldDescriptor::getName)
.collect(toImmutableSet());

// Find the oneof field in the message; there will be at most one
List<Entry<FieldDescriptor, Object>> oneofFields = message.getAllFields().entrySet().stream()
.filter(entry -> oneofColumns.contains(entry.getKey().getName()))
.collect(toImmutableList());

if (oneofFields.size() > 1) {
throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Expected to find at most one 'oneof' field in message, found fields: %s", oneofFields));
}

// If found, map the field to a JSON string containing a single field:value pair, else return an empty JSON string {}
if (!oneofFields.isEmpty()) {
try {
// Create a new DynamicMessage where the only set field is the oneof field, so we can use the protobuf-java-util to encode the message as JSON
// If we encoded the entire input message, it would include all fields
Entry<FieldDescriptor, Object> oneofField = oneofFields.get(0);
DynamicMessage oneofMessage = DynamicMessage.newBuilder(oneofField.getKey().getContainingType())
.setField(oneofField.getKey(), oneofField.getValue())
.build();
return Slices.utf8Slice(JsonFormat.printer()
.omittingInsignificantWhitespace()
.print(oneofMessage));
}
catch (Exception e) {
throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to convert oneof message to JSON", e);
}
}
return EMPTY_JSON;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.trino.decoder.DecoderColumnHandle;
import io.trino.decoder.FieldValueProvider;
import io.trino.decoder.RowDecoder;
import io.trino.spi.type.TypeManager;

import java.util.Map;
import java.util.Optional;
Expand All @@ -34,13 +35,13 @@ public class ProtobufRowDecoder
private final DynamicMessageProvider dynamicMessageProvider;
private final Map<DecoderColumnHandle, ProtobufColumnDecoder> columnDecoders;

public ProtobufRowDecoder(DynamicMessageProvider dynamicMessageProvider, Set<DecoderColumnHandle> columns)
public ProtobufRowDecoder(DynamicMessageProvider dynamicMessageProvider, Set<DecoderColumnHandle> columns, TypeManager typeManager)
{
this.dynamicMessageProvider = requireNonNull(dynamicMessageProvider, "dynamicMessageSupplier is null");
this.columnDecoders = columns.stream()
.collect(toImmutableMap(
identity(),
ProtobufColumnDecoder::new));
column -> new ProtobufColumnDecoder(column, typeManager)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.trino.decoder.RowDecoder;
import io.trino.decoder.RowDecoderFactory;
import io.trino.decoder.protobuf.DynamicMessageProvider.Factory;
import io.trino.spi.type.TypeManager;

import javax.inject.Inject;

Expand All @@ -32,18 +33,21 @@ public class ProtobufRowDecoderFactory
public static final String DEFAULT_MESSAGE = "schema";

private final Factory dynamicMessageProviderFactory;
private final TypeManager typeManager;

@Inject
public ProtobufRowDecoderFactory(Factory dynamicMessageProviderFactory)
public ProtobufRowDecoderFactory(Factory dynamicMessageProviderFactory, TypeManager typeManager)
{
this.dynamicMessageProviderFactory = requireNonNull(dynamicMessageProviderFactory, "dynamicMessageProviderFactory is null");
this.typeManager = requireNonNull(typeManager, "typeManager is null");
}

@Override
public RowDecoder create(Map<String, String> decoderParams, Set<DecoderColumnHandle> columns)
{
return new ProtobufRowDecoder(
dynamicMessageProviderFactory.create(Optional.ofNullable(decoderParams.get("dataSchema"))),
columns);
columns,
typeManager);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;

Expand All @@ -48,6 +50,7 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.decoder.DecoderErrorCode.DECODER_CONVERSION_NOT_SUPPORTED;
import static io.trino.spi.type.StandardTypes.JSON;
import static io.trino.spi.type.TimestampType.MAX_SHORT_PRECISION;
import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND;
import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND;
Expand All @@ -65,12 +68,14 @@ public class ProtobufValueProvider
private final Object value;
private final Type columnType;
private final String columnName;
private final Type jsonType;

public ProtobufValueProvider(@Nullable Object value, Type columnType, String columnName)
public ProtobufValueProvider(@Nullable Object value, Type columnType, String columnName, TypeManager typeManager)
{
this.value = value;
this.columnType = requireNonNull(columnType, "columnType is null");
this.columnName = requireNonNull(columnName, "columnName is null");
this.jsonType = typeManager.getType(new TypeSignature(JSON));
}

@Override
Expand Down Expand Up @@ -128,7 +133,7 @@ public Block getBlock()
return serializeObject(null, value, columnType, columnName);
}

private static Slice getSlice(Object value, Type type, String columnName)
private Slice getSlice(Object value, Type type, String columnName)
{
requireNonNull(value, "value is null");
if ((type instanceof VarcharType && value instanceof CharSequence) || value instanceof EnumValueDescriptor) {
Expand All @@ -139,11 +144,15 @@ private static Slice getSlice(Object value, Type type, String columnName)
return Slices.wrappedBuffer(((ByteString) value).toByteArray());
}

if (type.equals(jsonType)) {
return (Slice) value;
}

throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("cannot decode object of '%s' as '%s' for column '%s'", value.getClass(), type, columnName));
}

@Nullable
private static Block serializeObject(BlockBuilder builder, Object value, Type type, String columnName)
private Block serializeObject(BlockBuilder builder, Object value, Type type, String columnName)
{
if (type instanceof ArrayType) {
return serializeList(builder, value, type, columnName);
Expand All @@ -155,12 +164,16 @@ private static Block serializeObject(BlockBuilder builder, Object value, Type ty
return serializeRow(builder, value, type, columnName);
}

if (type.equals(jsonType)) {
return serializeJson(builder, value, type);
}

serializePrimitive(builder, value, type, columnName);
return null;
}

@Nullable
private static Block serializeList(BlockBuilder parentBlockBuilder, @Nullable Object value, Type type, String columnName)
private Block serializeList(BlockBuilder parentBlockBuilder, @Nullable Object value, Type type, String columnName)
{
if (value == null) {
checkState(parentBlockBuilder != null, "parentBlockBuilder is null");
Expand All @@ -182,7 +195,7 @@ private static Block serializeList(BlockBuilder parentBlockBuilder, @Nullable Ob
return blockBuilder.build();
}

private static void serializePrimitive(BlockBuilder blockBuilder, @Nullable Object value, Type type, String columnName)
private void serializePrimitive(BlockBuilder blockBuilder, @Nullable Object value, Type type, String columnName)
{
requireNonNull(blockBuilder, "parent blockBuilder is null");

Expand Down Expand Up @@ -226,7 +239,7 @@ private static void serializePrimitive(BlockBuilder blockBuilder, @Nullable Obje
}

@Nullable
private static Block serializeMap(BlockBuilder parentBlockBuilder, @Nullable Object value, Type type, String columnName)
private Block serializeMap(BlockBuilder parentBlockBuilder, @Nullable Object value, Type type, String columnName)
{
if (value == null) {
checkState(parentBlockBuilder != null, "parentBlockBuilder is null");
Expand Down Expand Up @@ -265,7 +278,7 @@ private static Block serializeMap(BlockBuilder parentBlockBuilder, @Nullable Obj
}

@Nullable
private static Block serializeRow(BlockBuilder parentBlockBuilder, @Nullable Object value, Type type, String columnName)
private Block serializeRow(BlockBuilder parentBlockBuilder, @Nullable Object value, Type type, String columnName)
{
if (value == null) {
checkState(parentBlockBuilder != null, "parent block builder is null");
Expand Down Expand Up @@ -300,6 +313,16 @@ private static Block serializeRow(BlockBuilder parentBlockBuilder, @Nullable Obj
return null;
}

@Nullable
private static Block serializeJson(BlockBuilder builder, Object value, Type type)
{
if (builder != null) {
type.writeObject(builder, value);
return null;
}
return (Block) value;
}

private static long parseTimestamp(int precision, DynamicMessage timestamp)
{
long seconds = (Long) timestamp.getField(timestamp.getDescriptorForType().findFieldByName("seconds"));
Expand Down
Loading

0 comments on commit 02cc332

Please sign in to comment.