Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Procedural constraints annotation processor + build logic #1611

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package gov.nasa.ammos.aerie.procedural.constraints

import gov.nasa.jpl.aerie.merlin.protocol.types.SerializedValue
import gov.nasa.jpl.aerie.merlin.protocol.types.ValueSchema

interface ConstraintProcedureMapper<T: Constraint> {
fun valueSchema(): ValueSchema
fun serialize(procedure: T): SerializedValue
fun deserialize(arguments: SerializedValue): T
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package gov.nasa.ammos.aerie.procedural.constraints.annotations

annotation class ConstraintProcedure
46 changes: 46 additions & 0 deletions procedural/examples/foo-procedures/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,52 @@ tasks.create("generateSchedulingProcedureJarTasks") {
}
}

tasks.register('buildAllConstraintProcedureJars') {
group = 'ConstraintProcedureJars'

dependsOn "generateConstraintProcedureJarTasks"
dependsOn {
tasks.findAll { task -> task.name.startsWith('buildConstraintProcedureJar_') }
}
}

tasks.create("generateConstraintProcedureJarTasks") {
group = 'ConstraintProcedureJars'

final proceduresDir = findFirstMatchingBuildDir("generated/procedures")

if (proceduresDir == null) {
println "No procedures folder found"
return
}
println "Generating jar tasks for the following procedures directory: ${proceduresDir}"

final files = file(proceduresDir).listFiles()
if (files.length == 0) {
println "No procedures available within folder ${proceduresDir}"
return
}

files.toList().each { file ->
final nameWithoutExtension = file.name.replace(".java", "")
final taskName = "buildConstraintProcedureJar_${nameWithoutExtension}"

println "Generating ${taskName} task, which will build ${nameWithoutExtension}.jar"

tasks.create(taskName, ShadowJar) {
group = 'ConstraintProcedureJars'
configurations = [project.configurations.runtimeClasspath]
from sourceSets.main.output
archiveBaseName = "" // clear
archiveClassifier.set(nameWithoutExtension) // set output jar name
manifest {
attributes 'Main-Class': getMainClassFromGeneratedFile(file)
}
minimize()
}
}
}

private String findFirstMatchingBuildDir(String pattern) {
String found = null
final generatedDir = file("build/generated/sources")
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
package gov.nasa.ammos.aerie.procedural.examples.fooprocedures.constraints;

import gov.nasa.ammos.aerie.procedural.constraints.GeneratorConstraint;
import gov.nasa.ammos.aerie.procedural.constraints.Constraint;
import gov.nasa.ammos.aerie.procedural.constraints.Violations;
import gov.nasa.ammos.aerie.procedural.constraints.annotations.ConstraintProcedure;
import gov.nasa.ammos.aerie.procedural.timeline.collections.profiles.Real;
import gov.nasa.ammos.aerie.procedural.timeline.plan.Plan;
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults;
import org.jetbrains.annotations.NotNull;

public class ConstFruit extends GeneratorConstraint {
@ConstraintProcedure
public record ConstFruit() implements Constraint {
@Override
public void generate(@NotNull Plan plan, @NotNull SimulationResults simResults) {
public Violations run(@NotNull Plan plan, @NotNull SimulationResults simResults) {
final var fruit = simResults.resource("/fruit", Real.deserializer());

violate(Violations.on(
return Violations.on(
fruit.equalTo(4),
false
));
);
}
}

This file was deleted.

1 change: 1 addition & 0 deletions procedural/processor/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies {
implementation project(':merlin-sdk')
implementation project(':contrib')
implementation project(':procedural:scheduling')
implementation project(':procedural:constraints')
implementation 'org.apache.commons:commons-lang3:3.13.0'
implementation 'com.squareup:javapoet:1.13.0'
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ static JavaFile generateAutoValueMappers(final ClassName typeName, final Iterabl
.addAnnotation(
AnnotationSpec
.builder(javax.annotation.processing.Generated.class)
.addMember("value", "$S", SchedulingProcedureProcessor.class.getCanonicalName())
.addMember("value", "$S", ProcedureProcessor.class.getCanonicalName())
.build())
.addAnnotation(
AnnotationSpec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
import gov.nasa.jpl.aerie.merlin.framework.ValueMapper;
import gov.nasa.jpl.aerie.merlin.protocol.types.SerializedValue;
import gov.nasa.jpl.aerie.merlin.protocol.types.ValueSchema;
import gov.nasa.ammos.aerie.procedural.scheduling.ProcedureMapper;
import gov.nasa.ammos.aerie.procedural.scheduling.SchedulingProcedureMapper;
import gov.nasa.ammos.aerie.procedural.scheduling.annotations.SchedulingProcedure;
import gov.nasa.ammos.aerie.procedural.scheduling.annotations.WithMappers;

import gov.nasa.ammos.aerie.procedural.constraints.ConstraintProcedureMapper;
import gov.nasa.ammos.aerie.procedural.constraints.annotations.ConstraintProcedure;

import javax.annotation.processing.Completion;
import javax.annotation.processing.Filer;
import javax.annotation.processing.Messager;
Expand Down Expand Up @@ -42,8 +45,11 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public final class SchedulingProcedureProcessor implements Processor {
public final class ProcedureProcessor implements Processor {
// Effectively final, late-initialized
private Messager messager = null;
private Filer filer = null;
Expand All @@ -58,7 +64,11 @@ public Set<String> getSupportedOptions() {
/** Elements marked by these annotations will be treated as processing roots. */
@Override
public Set<String> getSupportedAnnotationTypes() {
return Set.of(SchedulingProcedure.class.getCanonicalName(), WithMappers.class.getCanonicalName());
return Set.of(
SchedulingProcedure.class.getCanonicalName(),
ConstraintProcedure.class.getCanonicalName(),
WithMappers.class.getCanonicalName()
);
}

@Override
Expand Down Expand Up @@ -104,34 +114,40 @@ public boolean process(final Set<? extends TypeElement> annotations, final Round
typeRules.addAll(parseValueMappers(factory));
}

final var procedures = roundEnv.getElementsAnnotatedWith(SchedulingProcedure.class);
final var schedulingProcedures = roundEnv.getElementsAnnotatedWith(SchedulingProcedure.class);
final var constraintProcedures = roundEnv.getElementsAnnotatedWith(ConstraintProcedure.class);

final var generatedClassName = ClassName.get(packageElement.getQualifiedName() + ".generated", "AutoValueMappers");
for (final var procedure : procedures) {
for (final var procedure : schedulingProcedures) {
final var procedureElement = (TypeElement) procedure;
typeRules.add(AutoValueMappers.recordTypeRule(procedureElement, generatedClassName));
}

for (final var procedure : constraintProcedures) {
final var procedureElement = (TypeElement) procedure;
typeRules.add(AutoValueMappers.recordTypeRule(procedureElement, generatedClassName));
}

final var generatedFiles = new ArrayList<JavaFile>();

generatedFiles.add(AutoValueMappers.generateAutoValueMappers(generatedClassName, procedures, List.of()));
final var allProcedures = Stream.concat(schedulingProcedures.stream(),constraintProcedures.stream()).collect(Collectors.toSet());

generatedFiles.add(AutoValueMappers.generateAutoValueMappers(generatedClassName, allProcedures, List.of()));

// For each procedure, generate a file that implements Procedure, Supplier<ValueMapper>
for (final var procedure : procedures) {
for (final var procedure : schedulingProcedures) {
final TypeName procedureType = TypeName.get(procedure.asType());
final ParameterizedTypeName valueMapperType = ParameterizedTypeName.get(
ClassName.get(ValueMapper.class),
procedureType);

final var valueMapperCode = new Resolver(typeUtils, elementUtils, typeRules).applyRules(new TypePattern.ClassPattern(ClassName.get(ValueMapper.class), List.of(new TypePattern.ClassPattern((ClassName) procedureType, List.of()))));
final var valueMapperCode = new Resolver(typeUtils, elementUtils, typeRules)
.applyRules(new TypePattern.ClassPattern(ClassName.get(ValueMapper.class), List.of(new TypePattern.ClassPattern((ClassName) procedureType, List.of()))));
if (valueMapperCode.isEmpty()) throw new Error("Could not generate a valuemapper for procedure " + procedure.getSimpleName());


generatedFiles.add(JavaFile
.builder(generatedClassName.packageName() + ".procedures", TypeSpec
.classBuilder(procedure.getSimpleName().toString())
.addModifiers(Modifier.PUBLIC, Modifier.FINAL)
.addSuperinterface(ParameterizedTypeName.get(ClassName.get(ProcedureMapper.class), procedureType))
.addSuperinterface(ParameterizedTypeName.get(ClassName.get(SchedulingProcedureMapper.class), procedureType))
.addMethod(MethodSpec
.methodBuilder("valueSchema")
.addModifiers(Modifier.PUBLIC)
Expand Down Expand Up @@ -160,6 +176,52 @@ public boolean process(final Set<? extends TypeElement> annotations, final Round
.build());
}

// For each procedure, generate a file that implements Procedure, Supplier<ValueMapper>
for (final var procedure : constraintProcedures) {
final TypeName procedureType = TypeName.get(procedure.asType());

this.messager.printMessage(
Diagnostic.Kind.NOTE,
"Looking at: " + procedure.toString());

final var valueMapperCode = new Resolver(typeUtils, elementUtils, typeRules)
.applyRules(new TypePattern.ClassPattern(ClassName.get(ValueMapper.class), List.of(new TypePattern.ClassPattern((ClassName) procedureType, List.of()))));
if (valueMapperCode.isEmpty()) throw new Error("Could not generate a valuemapper for procedure " + procedure.getSimpleName());


generatedFiles.add(JavaFile
.builder(generatedClassName.packageName() + ".procedures", TypeSpec
.classBuilder(procedure.getSimpleName().toString())
.addModifiers(Modifier.PUBLIC, Modifier.FINAL)
.addSuperinterface(ParameterizedTypeName.get(ClassName.get(ConstraintProcedureMapper.class), procedureType))
.addMethod(MethodSpec
.methodBuilder("valueSchema")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.returns(ValueSchema.class)
.addStatement("return $L.getValueSchema()", valueMapperCode.get())
.build())
.addMethod(MethodSpec
.methodBuilder("serialize")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addParameter(procedureType, "procedure")
.returns(SerializedValue.class)
.addStatement("return $L.serializeValue(procedure)", valueMapperCode.get())
.build())
.addMethod(MethodSpec
.methodBuilder("deserialize")
.addModifiers(Modifier.PUBLIC)
.addAnnotation(Override.class)
.addParameter(SerializedValue.class, "value")
.returns(procedureType)
.addStatement("return $L.deserializeValue(value).getSuccessOrThrow(e -> new $T(e))", valueMapperCode.get(), RuntimeException.class)
.build())
.build())
.skipJavaLangImports(true)
.build());
}

for (final var generatedFile : generatedFiles) {
this.messager.printMessage(
Diagnostic.Kind.NOTE,
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
gov.nasa.ammos.aerie.procedural.processor.SchedulingProcedureProcessor
gov.nasa.ammos.aerie.procedural.processor.ProcedureProcessor
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package gov.nasa.ammos.aerie.procedural.scheduling
import gov.nasa.jpl.aerie.merlin.protocol.types.SerializedValue
import gov.nasa.jpl.aerie.merlin.protocol.types.ValueSchema

interface ProcedureMapper<T: Goal> {
interface SchedulingProcedureMapper<T: Goal> {
fun valueSchema(): ValueSchema
fun serialize(procedure: T): SerializedValue
fun deserialize(arguments: SerializedValue): T
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gov.nasa.jpl.aerie.scheduler;

import gov.nasa.jpl.aerie.merlin.protocol.types.ValueSchema;
import gov.nasa.ammos.aerie.procedural.scheduling.ProcedureMapper;
import gov.nasa.ammos.aerie.procedural.scheduling.SchedulingProcedureMapper;

import java.io.IOException;
import java.net.MalformedURLException;
Expand All @@ -12,19 +11,19 @@
import java.util.jar.JarFile;

public final class ProcedureLoader {
public static ProcedureMapper<?> loadProcedure(final Path path)
public static SchedulingProcedureMapper<?> loadProcedure(final Path path)
throws ProcedureLoadException
{
final var className = getImplementingClassName(path);
final var classLoader = new URLClassLoader(new URL[] {pathToUrl(path)});

try {
final var pluginClass$ = classLoader.loadClass(className);
if (!ProcedureMapper.class.isAssignableFrom(pluginClass$)) {
if (!SchedulingProcedureMapper.class.isAssignableFrom(pluginClass$)) {
throw new ProcedureLoadException(path);
}

return (ProcedureMapper<?>) pluginClass$.getConstructor().newInstance();
return (SchedulingProcedureMapper<?>) pluginClass$.getConstructor().newInstance();
} catch (final ReflectiveOperationException ex) {
throw new ProcedureLoadException(path, ex);
}
Expand Down Expand Up @@ -58,7 +57,7 @@ private ProcedureLoadException(final Path path, final Throwable cause) {
super(
String.format(
"No implementation found for `%s` at path `%s`",
ProcedureMapper.class.getSimpleName(),
SchedulingProcedureMapper.class.getSimpleName(),
path),
cause);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import gov.nasa.jpl.aerie.merlin.driver.MissionModel;
import gov.nasa.jpl.aerie.merlin.protocol.types.SerializedValue;
import gov.nasa.ammos.aerie.procedural.scheduling.ProcedureMapper;
import gov.nasa.ammos.aerie.procedural.scheduling.SchedulingProcedureMapper;
import gov.nasa.ammos.aerie.procedural.scheduling.plan.Edit;
import gov.nasa.jpl.aerie.scheduler.DirectiveIdGenerator;
import gov.nasa.jpl.aerie.scheduler.ProcedureLoader;
Expand Down Expand Up @@ -45,7 +45,7 @@ public void run(
final SimulationFacade simulationFacade,
final DirectiveIdGenerator idGenerator
) {
final ProcedureMapper<?> procedureMapper;
final SchedulingProcedureMapper<?> procedureMapper;
try {
procedureMapper = ProcedureLoader.loadProcedure(jarPath);
} catch (ProcedureLoader.ProcedureLoadException e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package gov.nasa.jpl.aerie.scheduler.server.services;

import gov.nasa.ammos.aerie.procedural.scheduling.ProcedureMapper;
import gov.nasa.ammos.aerie.procedural.scheduling.SchedulingProcedureMapper;
import gov.nasa.jpl.aerie.scheduler.ProcedureLoader;
import gov.nasa.jpl.aerie.scheduler.server.exceptions.NoSuchSchedulingGoalException;
import gov.nasa.jpl.aerie.scheduler.server.exceptions.NoSuchSpecificationException;
Expand Down Expand Up @@ -40,7 +40,7 @@ public void refreshSchedulingProcedureParameterTypes(long goalId, long revision)
// Do nothing
}
case GoalType.JAR jar -> {
final ProcedureMapper<?> mapper;
final SchedulingProcedureMapper<?> mapper;
try {
mapper = ProcedureLoader.loadProcedure(Path.of("/usr/src/app/merlin_file_store", jar.path().toString()));
} catch (ProcedureLoader.ProcedureLoadException e) {
Expand Down