Skip to content

Commit

Permalink
Merge pull request #32757: Schema inference parameterized types
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenlax authored Oct 15, 2024
1 parent e39e5d7 commit a50f91c
Show file tree
Hide file tree
Showing 27 changed files with 961 additions and 197 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.beam.sdk.schemas.annotations.SchemaIgnore;
import org.apache.beam.sdk.schemas.utils.AutoValueUtils;
Expand Down Expand Up @@ -61,8 +63,9 @@ public List<FieldValueTypeInformation> get(TypeDescriptor<?> typeDescriptor) {
.filter(m -> !m.isAnnotationPresent(SchemaIgnore.class))
.collect(Collectors.toList());
List<FieldValueTypeInformation> types = Lists.newArrayListWithCapacity(methods.size());
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
for (int i = 0; i < methods.size(); ++i) {
types.add(FieldValueTypeInformation.forGetter(methods.get(i), i));
types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes));
}
types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber));
validateFieldNumbers(types);
Expand Down Expand Up @@ -143,7 +146,8 @@ public SchemaUserTypeCreator schemaTypeCreator(

@Override
public <T> @Nullable Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
return JavaBeanUtils.schemaFromJavaBeanClass(
typeDescriptor, AbstractGetterTypeSupplier.INSTANCE);
typeDescriptor, AbstractGetterTypeSupplier.INSTANCE, boundTypes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
import java.lang.reflect.Field;
import java.lang.reflect.Member;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.stream.Stream;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldName;
Expand All @@ -44,6 +46,7 @@
"nullness", // TODO(https://github.com/apache/beam/issues/20497)
"rawtypes"
})
@Internal
public abstract class FieldValueTypeInformation implements Serializable {
/** Optionally returns the field index. */
public abstract @Nullable Integer getNumber();
Expand Down Expand Up @@ -125,18 +128,20 @@ public static FieldValueTypeInformation forOneOf(
.build();
}

public static FieldValueTypeInformation forField(Field field, int index) {
TypeDescriptor<?> type = TypeDescriptor.of(field.getGenericType());
public static FieldValueTypeInformation forField(
Field field, int index, Map<Type, Type> boundTypes) {
TypeDescriptor<?> type =
TypeDescriptor.of(ReflectUtils.resolveType(field.getGenericType(), boundTypes));
return new AutoValue_FieldValueTypeInformation.Builder()
.setName(getNameOverride(field.getName(), field))
.setNumber(getNumberOverride(index, field))
.setNullable(hasNullableAnnotation(field))
.setType(type)
.setRawType(type.getRawType())
.setField(field)
.setElementType(getIterableComponentType(field))
.setMapKeyType(getMapKeyType(field))
.setMapValueType(getMapValueType(field))
.setElementType(getIterableComponentType(field, boundTypes))
.setMapKeyType(getMapKeyType(field, boundTypes))
.setMapValueType(getMapValueType(field, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.setDescription(getFieldDescription(field))
.build();
Expand Down Expand Up @@ -184,7 +189,8 @@ public static <T extends AnnotatedElement & Member> String getNameOverride(
return fieldDescription.value();
}

public static FieldValueTypeInformation forGetter(Method method, int index) {
public static FieldValueTypeInformation forGetter(
Method method, int index, Map<Type, Type> boundTypes) {
String name;
if (method.getName().startsWith("get")) {
name = ReflectUtils.stripPrefix(method.getName(), "get");
Expand All @@ -194,7 +200,8 @@ public static FieldValueTypeInformation forGetter(Method method, int index) {
throw new RuntimeException("Getter has wrong prefix " + method.getName());
}

TypeDescriptor<?> type = TypeDescriptor.of(method.getGenericReturnType());
TypeDescriptor<?> type =
TypeDescriptor.of(ReflectUtils.resolveType(method.getGenericReturnType(), boundTypes));
boolean nullable = hasNullableReturnType(method);
return new AutoValue_FieldValueTypeInformation.Builder()
.setName(getNameOverride(name, method))
Expand All @@ -203,9 +210,9 @@ public static FieldValueTypeInformation forGetter(Method method, int index) {
.setType(type)
.setRawType(type.getRawType())
.setMethod(method)
.setElementType(getIterableComponentType(type))
.setMapKeyType(getMapKeyType(type))
.setMapValueType(getMapValueType(type))
.setElementType(getIterableComponentType(type, boundTypes))
.setMapKeyType(getMapKeyType(type, boundTypes))
.setMapValueType(getMapValueType(type, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.setDescription(getFieldDescription(method))
.build();
Expand Down Expand Up @@ -252,29 +259,33 @@ private static boolean isNullableAnnotation(Annotation annotation) {
return annotation.annotationType().getSimpleName().equals("Nullable");
}

public static FieldValueTypeInformation forSetter(Method method) {
return forSetter(method, "set");
public static FieldValueTypeInformation forSetter(
Method method, Map<Type, Type> boundParameters) {
return forSetter(method, "set", boundParameters);
}

public static FieldValueTypeInformation forSetter(Method method, String setterPrefix) {
public static FieldValueTypeInformation forSetter(
Method method, String setterPrefix, Map<Type, Type> boundTypes) {
String name;
if (method.getName().startsWith(setterPrefix)) {
name = ReflectUtils.stripPrefix(method.getName(), setterPrefix);
} else {
throw new RuntimeException("Setter has wrong prefix " + method.getName());
}

TypeDescriptor<?> type = TypeDescriptor.of(method.getGenericParameterTypes()[0]);
TypeDescriptor<?> type =
TypeDescriptor.of(
ReflectUtils.resolveType(method.getGenericParameterTypes()[0], boundTypes));
boolean nullable = hasSingleNullableParameter(method);
return new AutoValue_FieldValueTypeInformation.Builder()
.setName(name)
.setNullable(nullable)
.setType(type)
.setRawType(type.getRawType())
.setMethod(method)
.setElementType(getIterableComponentType(type))
.setMapKeyType(getMapKeyType(type))
.setMapValueType(getMapValueType(type))
.setElementType(getIterableComponentType(type, boundTypes))
.setMapKeyType(getMapKeyType(type, boundTypes))
.setMapValueType(getMapValueType(type, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.build();
}
Expand All @@ -283,13 +294,15 @@ public FieldValueTypeInformation withName(String name) {
return toBuilder().setName(name).build();
}

private static FieldValueTypeInformation getIterableComponentType(Field field) {
return getIterableComponentType(TypeDescriptor.of(field.getGenericType()));
private static FieldValueTypeInformation getIterableComponentType(
Field field, Map<Type, Type> boundTypes) {
return getIterableComponentType(TypeDescriptor.of(field.getGenericType()), boundTypes);
}

static @Nullable FieldValueTypeInformation getIterableComponentType(TypeDescriptor<?> valueType) {
static @Nullable FieldValueTypeInformation getIterableComponentType(
TypeDescriptor<?> valueType, Map<Type, Type> boundTypes) {
// TODO: Figure out nullable elements.
TypeDescriptor<?> componentType = ReflectUtils.getIterableComponentType(valueType);
TypeDescriptor<?> componentType = ReflectUtils.getIterableComponentType(valueType, boundTypes);
if (componentType == null) {
return null;
}
Expand All @@ -299,41 +312,43 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) {
.setNullable(false)
.setType(componentType)
.setRawType(componentType.getRawType())
.setElementType(getIterableComponentType(componentType))
.setMapKeyType(getMapKeyType(componentType))
.setMapValueType(getMapValueType(componentType))
.setElementType(getIterableComponentType(componentType, boundTypes))
.setMapKeyType(getMapKeyType(componentType, boundTypes))
.setMapValueType(getMapValueType(componentType, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.build();
}

// If the Field is a map type, returns the key type, otherwise returns a null reference.

private static @Nullable FieldValueTypeInformation getMapKeyType(Field field) {
return getMapKeyType(TypeDescriptor.of(field.getGenericType()));
private static @Nullable FieldValueTypeInformation getMapKeyType(
Field field, Map<Type, Type> boundTypes) {
return getMapKeyType(TypeDescriptor.of(field.getGenericType()), boundTypes);
}

private static @Nullable FieldValueTypeInformation getMapKeyType(
TypeDescriptor<?> typeDescriptor) {
return getMapType(typeDescriptor, 0);
TypeDescriptor<?> typeDescriptor, Map<Type, Type> boundTypes) {
return getMapType(typeDescriptor, 0, boundTypes);
}

// If the Field is a map type, returns the value type, otherwise returns a null reference.

private static @Nullable FieldValueTypeInformation getMapValueType(Field field) {
return getMapType(TypeDescriptor.of(field.getGenericType()), 1);
private static @Nullable FieldValueTypeInformation getMapValueType(
Field field, Map<Type, Type> boundTypes) {
return getMapType(TypeDescriptor.of(field.getGenericType()), 1, boundTypes);
}

private static @Nullable FieldValueTypeInformation getMapValueType(
TypeDescriptor<?> typeDescriptor) {
return getMapType(typeDescriptor, 1);
TypeDescriptor<?> typeDescriptor, Map<Type, Type> boundTypes) {
return getMapType(typeDescriptor, 1, boundTypes);
}

// If the Field is a map type, returns the key or value type (0 is key type, 1 is value).
// Otherwise returns a null reference.
@SuppressWarnings("unchecked")
private static @Nullable FieldValueTypeInformation getMapType(
TypeDescriptor<?> valueType, int index) {
TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index);
TypeDescriptor<?> valueType, int index, Map<Type, Type> boundTypes) {
TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index, boundTypes);
if (mapType == null) {
return null;
}
Expand All @@ -342,9 +357,9 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) {
.setNullable(false)
.setType(mapType)
.setRawType(mapType.getRawType())
.setElementType(getIterableComponentType(mapType))
.setMapKeyType(getMapKeyType(mapType))
.setMapValueType(getMapValueType(mapType))
.setElementType(getIterableComponentType(mapType, boundTypes))
.setMapKeyType(getMapKeyType(mapType, boundTypes))
.setMapValueType(getMapValueType(mapType, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldName;
Expand Down Expand Up @@ -67,8 +69,9 @@ public List<FieldValueTypeInformation> get(TypeDescriptor<?> typeDescriptor) {
.filter(m -> !m.isAnnotationPresent(SchemaIgnore.class))
.collect(Collectors.toList());
List<FieldValueTypeInformation> types = Lists.newArrayListWithCapacity(methods.size());
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
for (int i = 0; i < methods.size(); ++i) {
types.add(FieldValueTypeInformation.forGetter(methods.get(i), i));
types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes));
}
types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber));
validateFieldNumbers(types);
Expand Down Expand Up @@ -111,10 +114,11 @@ public static class SetterTypeSupplier implements FieldValueTypeSupplier {

@Override
public List<FieldValueTypeInformation> get(TypeDescriptor<?> typeDescriptor) {
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
return ReflectUtils.getMethods(typeDescriptor.getRawType()).stream()
.filter(ReflectUtils::isSetter)
.filter(m -> !m.isAnnotationPresent(SchemaIgnore.class))
.map(FieldValueTypeInformation::forSetter)
.map(m -> FieldValueTypeInformation.forSetter(m, boundTypes))
.map(
t -> {
if (t.getMethod().getAnnotation(SchemaFieldNumber.class) != null) {
Expand Down Expand Up @@ -156,8 +160,10 @@ public boolean equals(@Nullable Object obj) {

@Override
public <T> Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
Schema schema =
JavaBeanUtils.schemaFromJavaBeanClass(typeDescriptor, GetterTypeSupplier.INSTANCE);
JavaBeanUtils.schemaFromJavaBeanClass(
typeDescriptor, GetterTypeSupplier.INSTANCE, boundTypes);

// If there are no creator methods, then validate that we have setters for every field.
// Otherwise, we will have no way of creating instances of the class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -62,9 +64,11 @@ public List<FieldValueTypeInformation> get(TypeDescriptor<?> typeDescriptor) {
ReflectUtils.getFields(typeDescriptor.getRawType()).stream()
.filter(m -> !m.isAnnotationPresent(SchemaIgnore.class))
.collect(Collectors.toList());

List<FieldValueTypeInformation> types = Lists.newArrayListWithCapacity(fields.size());
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
for (int i = 0; i < fields.size(); ++i) {
types.add(FieldValueTypeInformation.forField(fields.get(i), i));
types.add(FieldValueTypeInformation.forField(fields.get(i), i, boundTypes));
}
types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber));
validateFieldNumbers(types);
Expand Down Expand Up @@ -111,7 +115,9 @@ private static void validateFieldNumbers(List<FieldValueTypeInformation> types)

@Override
public <T> Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
return POJOUtils.schemaFromPojoClass(typeDescriptor, JavaFieldTypeSupplier.INSTANCE);
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
return POJOUtils.schemaFromPojoClass(
typeDescriptor, JavaFieldTypeSupplier.INSTANCE, boundTypes);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ public interface SchemaProvider extends Serializable {
* Given a type, return a function that converts that type to a {@link Row} object If no schema
* exists, returns null.
*/
@Nullable
<T> SerializableFunction<T, Row> toRowFunction(TypeDescriptor<T> typeDescriptor);
<T> @Nullable SerializableFunction<T, Row> toRowFunction(TypeDescriptor<T> typeDescriptor);

/**
* Given a type, returns a function that converts from a {@link Row} object to that type. If no
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid
providers.put(typeDescriptor, schemaProvider);
}

@Override
public <T> @Nullable Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
private <T> @Nullable SchemaProvider schemaProviderFor(TypeDescriptor<T> typeDescriptor) {
TypeDescriptor<?> type = typeDescriptor;
do {
SchemaProvider schemaProvider = providers.get(type);
if (schemaProvider != null) {
return schemaProvider.schemaFor(type);
return schemaProvider;
}
Class<?> superClass = type.getRawType().getSuperclass();
if (superClass == null || superClass.equals(Object.class)) {
Expand All @@ -92,38 +91,24 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid
} while (true);
}

@Override
public <T> @Nullable Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
@Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor);
return schemaProvider != null ? schemaProvider.schemaFor(typeDescriptor) : null;
}

@Override
public <T> @Nullable SerializableFunction<T, Row> toRowFunction(
TypeDescriptor<T> typeDescriptor) {
TypeDescriptor<?> type = typeDescriptor;
do {
SchemaProvider schemaProvider = providers.get(type);
if (schemaProvider != null) {
return (SerializableFunction<T, Row>) schemaProvider.toRowFunction(type);
}
Class<?> superClass = type.getRawType().getSuperclass();
if (superClass == null || superClass.equals(Object.class)) {
return null;
}
type = TypeDescriptor.of(superClass);
} while (true);
@Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor);
return schemaProvider != null ? schemaProvider.toRowFunction(typeDescriptor) : null;
}

@Override
public <T> @Nullable SerializableFunction<Row, T> fromRowFunction(
TypeDescriptor<T> typeDescriptor) {
TypeDescriptor<?> type = typeDescriptor;
do {
SchemaProvider schemaProvider = providers.get(type);
if (schemaProvider != null) {
return (SerializableFunction<Row, T>) schemaProvider.fromRowFunction(type);
}
Class<?> superClass = type.getRawType().getSuperclass();
if (superClass == null || superClass.equals(Object.class)) {
return null;
}
type = TypeDescriptor.of(superClass);
} while (true);
@Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor);
return schemaProvider != null ? schemaProvider.fromRowFunction(typeDescriptor) : null;
}
}

Expand Down
Loading

0 comments on commit a50f91c

Please sign in to comment.