Skip to content
This repository has been archived by the owner on Feb 23, 2023. It is now read-only.

Commit

Permalink
Support lambda class serialization via hints
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhuaqing committed Jul 13, 2022
1 parent bed1a65 commit b644e55
Show file tree
Hide file tree
Showing 14 changed files with 183 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
* overridable methods is not specified. See <i>Effective Java</i> Item 17, "Design and
* Document or inheritance or else prohibit it" for further information.
*/
public class JSONArray {
public class JSONArray implements JSONValue {

private final List<Object> values;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
* overrideable methods is not specified. See <i>Effective Java</i> Item 17, "Design and
* Document or inheritance or else prohibit it" for further information.
*/
public class JSONObject {
public class JSONObject implements JSONValue {

private static final Double NEGATIVE_ZERO = -0d;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@
public class NativeSerializationEntry {

private final Class<?> type;
private final boolean lambdaCapturing;

private NativeSerializationEntry(Class<?> type) {
this(type, false);
}

private NativeSerializationEntry(Class<?> type, boolean lambdaCapturing) {
this.type = type;
this.lambdaCapturing = lambdaCapturing;
}

/**
Expand All @@ -42,6 +48,16 @@ public static NativeSerializationEntry ofType(Class<?> type) {
return new NativeSerializationEntry(type);
}

/**
* Create a new {@link NativeSerializationEntry} for the lambda capturing types.
* @param type the lambda capturing type
* @return the serialization entry
*/
public static NativeSerializationEntry ofLambdaCapturingType(Class<?> type) {
Assert.notNull(type, "type must not be null");
return new NativeSerializationEntry(type, true);
}

/**
* Create a new {@link NativeSerializationEntry} for the specified types.
* @param typeName the related type name
Expand All @@ -52,7 +68,17 @@ public static NativeSerializationEntry ofTypeName(String typeName) {
return new NativeSerializationEntry(ClassUtils.resolveClassName(typeName, null));
}

/**
* Create a new {@link NativeSerializationEntry} for the lambda capturing types.
* @param typeName the lambda capturing type name
* @return the serialization entry
*/
public static NativeSerializationEntry ofLambdaCapturingTypeName(String typeName) {
Assert.notNull(typeName, "typeName must not be null");
return new NativeSerializationEntry(ClassUtils.resolveClassName(typeName, null), true);
}

public void contribute(SerializationDescriptor descriptor) {
descriptor.add(this.type.getName());
descriptor.add(this.type.getName(), lambdaCapturing);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,38 @@
public class SerializationDescriptor {

private final Set<String> serializableTypes;
private final Set<String> serializableLambdaCapturingTypes;

public SerializationDescriptor() {
this.serializableTypes = new HashSet<>();
this.serializableLambdaCapturingTypes = new HashSet<>();
}

public SerializationDescriptor(SerializationDescriptor metadata) {
this.serializableTypes = new HashSet<>(metadata.serializableTypes);
this.serializableLambdaCapturingTypes = new HashSet<>(metadata.serializableLambdaCapturingTypes);
}

public Set<String> getSerializableTypes() {
return this.serializableTypes;
}

public Set<String> getSerializableLambdaCapturingTypes() {
return this.serializableLambdaCapturingTypes;
}

public void add(String className) {
this.serializableTypes.add(className);
}

public void add(String className, boolean lambdaCapturing) {
if (lambdaCapturing) {
this.serializableLambdaCapturingTypes.add(className);
} else {
this.serializableTypes.add(className);
}
}

@Override
public String toString() {
return String.format("SerializationDescriptor #%s: %s", serializableTypes.size(),serializableTypes.toString());
Expand All @@ -64,12 +79,21 @@ public void consume(Consumer<String> consumer) {
serializableTypes.stream().forEach(t -> consumer.accept(t));
}

public void consumeLambdaCapturing(Consumer<String> consumer) {
serializableLambdaCapturingTypes.forEach(consumer);
}

public void merge(SerializationDescriptor otherSerializationDescriptor) {
serializableTypes.addAll(otherSerializationDescriptor.serializableTypes);
serializableLambdaCapturingTypes.addAll(otherSerializationDescriptor.serializableLambdaCapturingTypes);
}

public boolean contains(String className) {
return serializableTypes.contains(className);
}

public boolean containsLambdaCapturing(String className) {
return serializableLambdaCapturingTypes.contains(className);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.springframework.nativex.json.JSONArray;
import org.springframework.nativex.json.JSONObject;
import org.springframework.nativex.json.JSONValue;

/**
* Converter to change {@link SerializationDescriptor} objects into JSON objects
Expand All @@ -26,14 +27,32 @@
*/
class SerializationDescriptorJsonConverter {

public JSONArray toJsonArray(SerializationDescriptor sd) throws Exception {
JSONArray jsonArray = new JSONArray();
public JSONValue toJsonValue(SerializationDescriptor sd) throws Exception {
if (sd.getSerializableLambdaCapturingTypes().isEmpty()) {
JSONArray jsonArray = new JSONArray();
for (String type: sd.getSerializableTypes()) {
JSONObject jo = new JSONObject();
jo.put("name", type);
jsonArray.put(jo);
}
return jsonArray;
}
JSONObject jsonObject = new JSONObject();
JSONArray types = new JSONArray();
for (String type: sd.getSerializableTypes()) {
JSONObject jo = new JSONObject();
jo.put("name", type);
jsonArray.put(jo);
types.put(jo);
}
jsonObject.put("types", types);
JSONArray lambdaCapturingTypes = new JSONArray();
for (String type: sd.getSerializableLambdaCapturingTypes()) {
JSONObject jo = new JSONObject();
jo.put("name", type);
lambdaCapturingTypes.put(jo);
}
return jsonArray;
jsonObject.put("lambdaCapturingTypes", lambdaCapturingTypes);
return jsonObject;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

import org.springframework.nativex.json.JSONArray;
import org.springframework.nativex.json.JSONObject;
import org.springframework.nativex.json.JSONTokener;
import org.springframework.nativex.json.JSONValue;

/**
* Marshaller to write {@link SerializationDescriptor} as JSON.
Expand All @@ -39,8 +41,8 @@ public static void write(SerializationDescriptor descriptor, OutputStream output
throws IOException {
try {
SerializationDescriptorJsonConverter converter = new SerializationDescriptorJsonConverter();
JSONArray jsonArray = converter.toJsonArray(descriptor);
outputStream.write(jsonArray.toString(2).getBytes(StandardCharsets.UTF_8));
JSONValue jsonValue = converter.toJsonValue(descriptor);
outputStream.write(jsonValue.toString(2).getBytes(StandardCharsets.UTF_8));
}
catch (Exception ex) {
if (ex instanceof IOException) {
Expand All @@ -67,18 +69,38 @@ public static SerializationDescriptor read(byte[] input) throws Exception {

public static SerializationDescriptor read(InputStream inputStream) {
try {
SerializationDescriptor descriptor = toSerializationDescriptor(new JSONArray(toString(inputStream)));
SerializationDescriptor descriptor = toSerializationDescriptor(new JSONTokener(toString(inputStream)));
return descriptor;
} catch (Exception e) {
throw new IllegalStateException("Unable to read ProxiesDescriptor from inputstream", e);
}
}

private static SerializationDescriptor toSerializationDescriptor(JSONArray array) throws Exception {
private static SerializationDescriptor toSerializationDescriptor(JSONTokener tokenizer) throws Exception {
SerializationDescriptor descriptor = new SerializationDescriptor();
for (int i=0;i<array.length();i++) {
JSONObject object = (JSONObject) array.get(i);
descriptor.add(object.getString("name"));
Object jsonValue = tokenizer.nextValue();
if (jsonValue instanceof JSONArray) {
JSONArray array = (JSONArray) jsonValue;
for (int i=0;i<array.length();i++) {
JSONObject object = (JSONObject) array.get(i);
descriptor.add(object.getString("name"));
}
} else {
JSONObject object = (JSONObject) jsonValue;
JSONArray typesArray = object.optJSONArray("types");
if (typesArray != null) {
for (int i=0;i<typesArray.length();i++) {
JSONObject typeObject = (JSONObject) typesArray.get(i);
descriptor.add(typeObject.getString("name"));
}
}
JSONArray lambdaCapturingTypesArray = object.optJSONArray("lambdaCapturingTypes");
if (lambdaCapturingTypesArray != null) {
for (int i=0;i<lambdaCapturingTypesArray.length();i++) {
JSONObject lambdaCapturingTypeObject = (JSONObject) lambdaCapturingTypesArray.get(i);
descriptor.add(lambdaCapturingTypeObject.getString("name"), true);
}
}
}
return descriptor;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,17 @@ public boolean addSerializationType(String className, boolean verify) {
serializationDescriptor.add(className);
return true;
}

public boolean addSerializationLambdaCapturingType(String className, boolean verify) {
if (verify) {
Type clazz = ts.resolveDotted(className, true);
if (clazz == null) {
return false;
}
}
serializationDescriptor.add(className, true);
return true;
}

private boolean areMembersSpecified(ClassDescriptor cd) {
List<MethodDescriptor> methods = cd.getMethods();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ private void handleConstantHints(boolean isAgentMode) {
serializationHandler.addType(st);
}
}

Set<String> serializationLambdaCapturingTypes = ch.getSerializationLambdaCapturingTypes();
if (!serializationLambdaCapturingTypes.isEmpty()) {
logger.debug("Registering lambda capturing types as serializable: "+serializationLambdaCapturingTypes);
for (String st: serializationLambdaCapturingTypes) {
serializationHandler.addLambdaCapturingType(st);
}
}
List<org.springframework.nativex.type.ResourcesDescriptor> resourcesDescriptors = ch
.getResourcesDescriptors();
for (org.springframework.nativex.type.ResourcesDescriptor rd : resourcesDescriptors) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,9 @@ public class SerializationHandler extends Handler {
public void addType(String className) {
collector.addSerializationType(className, true);
}

public void addLambdaCapturingType(String className) {
collector.addSerializationLambdaCapturingType(className, true);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ public class HintDeclaration {
private Map<String, AccessDescriptor> jniTypes = new LinkedHashMap<>();

private Set<String> serializationTypes = new HashSet<>();

private Set<String> serializationLambdaCapturingTypes = new HashSet<>();

private List<JdkProxyDescriptor> proxyDescriptor = new ArrayList<>();

Expand Down Expand Up @@ -109,11 +111,19 @@ public Map<String, AccessDescriptor> getJNITypes() {
public void addSerializationType(String className) {
serializationTypes.add(className);
}

public void addSerializationLambdaCapturingType(String className) {
serializationLambdaCapturingTypes.add(className);
}

public Set<String> getSerializationTypes() {
return serializationTypes;
}


public Set<String> getSerializationLambdaCapturingTypes() {
return serializationLambdaCapturingTypes;
}

public void addDependantType(String className, AccessDescriptor accessDescriptor) {
specificTypes.put(className, accessDescriptor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1715,13 +1715,19 @@ private void unpackSerializationHint(AnnotationNode typeInfo, HintDeclaration ch
List<Object> values = typeInfo.values;
List<org.objectweb.asm.Type> types = new ArrayList<>();
List<String> typeNames = new ArrayList<>();
List<org.objectweb.asm.Type> lambdaCapturingTypes = new ArrayList<>();
List<String> lambdaCapturingTypeNames = new ArrayList<>();
for (int i = 0; i < values.size(); i += 2) {
String key = (String) values.get(i);
Object value = values.get(i + 1);
if (key.equals("types")) {
types = (ArrayList<org.objectweb.asm.Type>) value;
} else if (key.equals("typeNames")) {
typeNames = (ArrayList<String>) value;
} else if (key.equals("lambdaCapturingTypes")) {
lambdaCapturingTypes = (List<org.objectweb.asm.Type>) value;
} else if (key.equals("lambdaCapturingTypeNames")) {
lambdaCapturingTypeNames = (List<String>) value;
}
}
for (org.objectweb.asm.Type type : types) {
Expand All @@ -1733,6 +1739,15 @@ private void unpackSerializationHint(AnnotationNode typeInfo, HintDeclaration ch
ch.addSerializationType(typeName);
}
}
for (org.objectweb.asm.Type type : lambdaCapturingTypes) {
ch.addSerializationLambdaCapturingType(type.getClassName());
}
for (String typeName : lambdaCapturingTypeNames) {
Type resolvedType = typeSystem.resolveName(typeName, true);
if (resolvedType != null) {
ch.addSerializationLambdaCapturingType(typeName);
}
}
}

private Integer inferAccessRequired(org.objectweb.asm.Type type, List<MethodDescriptor> mds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,23 @@ void addSerialization() {
assertThat(serializationDescriptor.getSerializableTypes()).singleElement().isEqualTo(String.class.getName()));
}

@Test
void addLambdaCapturingSerialization() {
registry.serialization().add(NativeSerializationEntry.ofLambdaCapturingType(String.class));
assertThat(registry.serialization().toSerializationDescriptor()).satisfies((serializationDescriptor) ->
assertThat(serializationDescriptor.getSerializableLambdaCapturingTypes()).singleElement().isEqualTo(String.class.getName()));
}

@Test
void addSeveralSerializations() {
registry.serialization().add(NativeSerializationEntry.ofType(String.class));
registry.serialization().add(NativeSerializationEntry.ofType(Long.class));
registry.serialization().add(NativeSerializationEntry.ofLambdaCapturingType(String.class));
registry.serialization().add(NativeSerializationEntry.ofLambdaCapturingType(Long.class));
assertThat(registry.serialization().toSerializationDescriptor()).satisfies((serializationDescriptor) ->
assertThat(serializationDescriptor.getSerializableTypes()).containsOnly(String.class.getName(), Long.class.getName()));
assertThat(registry.serialization().toSerializationDescriptor()).satisfies((serializationDescriptor) ->
assertThat(serializationDescriptor.getSerializableLambdaCapturingTypes()).containsOnly(String.class.getName(), Long.class.getName()));
}

@Test
Expand Down
Loading

0 comments on commit b644e55

Please sign in to comment.