From 724191c29fa9741aa1d10b40fbe3661d6614be73 Mon Sep 17 00:00:00 2001 From: Christopher Chianelli Date: Tue, 23 Jul 2024 16:06:43 -0400 Subject: [PATCH 1/2] feat: Add CascadingUpdateShadowVariable to Python - Rename the static final field that holds the implementation of each method to be different from the method's name, to avoid confusion during member lookup --- .../jpyinterpreter/PythonClassTranslator.java | 27 +++++--- .../DelegatingInterfaceImplementor.java | 4 +- .../types/PythonSuperObject.java | 3 +- .../PythonClassTranslatorTest.java | 2 +- .../src/main/python/domain/_annotations.py | 64 ++++++++++++++--- .../python-core/tests/test_vehicle_routing.py | 68 ++++++++----------- 6 files changed, 106 insertions(+), 62 deletions(-) diff --git a/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java b/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java index 3b988c1682..1102ceb97b 100644 --- a/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java +++ b/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java @@ -67,6 +67,7 @@ public class PythonClassTranslator { public static final String TYPE_FIELD_NAME = "$TYPE"; public static final String CPYTHON_TYPE_FIELD_NAME = "$CPYTHON_TYPE"; public static final String JAVA_METHOD_PREFIX = "$method$"; + public static final String JAVA_METHOD_HOLDER_PREFIX = "$methodholder$"; public static final String PYTHON_JAVA_TYPE_MAPPING_PREFIX = "$pythonJavaTypeMapping"; public record PreparedClassInfo(PythonLikeType type, String className, String classInternalName) { @@ -484,6 +485,10 @@ public static String getJavaMethodName(String pythonMethodName) { return JAVA_METHOD_PREFIX + pythonMethodName; } + public static String getJavaMethodHolderName(String pythonMethodName) { + return JAVA_METHOD_HOLDER_PREFIX + pythonMethodName; + } + public static String getPythonMethodName(String javaMethodName) { return javaMethodName.substring(JAVA_METHOD_PREFIX.length()); } @@ -566,7 +571,7 @@ private static Class createBytecodeForMethodAndSetOnClass(String className, P throw new IllegalStateException("Unhandled case: " + pythonMethodKind); } - generatedClass.getField(getJavaMethodName(methodEntry.getKey())) + generatedClass.getField(getJavaMethodHolderName(methodEntry.getKey())) .set(null, functionInstance); pythonLikeType.$setAttribute(methodEntry.getKey(), translatedPythonMethodWrapper); return functionClass; @@ -797,7 +802,7 @@ private static PythonLikeFunction createConstructor(String classInternalName, Class generatedClass = (Class) BuiltinTypes.asmClassLoader.loadClass(constructorClassName); if (initFunction != null) { - Object method = typeGeneratedClass.getField(getJavaMethodName("__init__")).get(null); + Object method = typeGeneratedClass.getField(getJavaMethodHolderName("__init__")).get(null); ArgumentSpec spec = (ArgumentSpec) method.getClass().getField(ARGUMENT_SPEC_INSTANCE_FIELD_NAME).get(method); generatedClass.getField(ARGUMENT_SPEC_INSTANCE_FIELD_NAME).set(null, spec); @@ -957,7 +962,7 @@ private static void createInstanceMethod(PythonLikeType pythonLikeType, ClassWri String interfaceDescriptor = interfaceDeclaration.descriptor(); String javaMethodName = getJavaMethodName(methodName); - classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, javaMethodName, interfaceDescriptor, + classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, getJavaMethodHolderName(methodName), interfaceDescriptor, null, null); instanceMethodNameToMethodDescriptor.put(methodName, interfaceDeclaration); Type returnType = getVirtualFunctionReturnType(function); @@ -973,7 +978,7 @@ private static void createInstanceMethod(PythonLikeType pythonLikeType, ClassWri MethodVisitor methodVisitor = classWriter.visitMethod(Modifier.PUBLIC, javaMethodName, javaMethodDescriptor, signature, null); - createInstanceOrStaticMethodBody(internalClassName, javaMethodName, javaParameterTypes, + createInstanceOrStaticMethodBody(internalClassName, methodName, javaParameterTypes, interfaceDeclaration.methodDescriptor, function, interfaceDeclaration.interfaceName, interfaceDescriptor, methodVisitor); @@ -993,7 +998,7 @@ private static void createStaticMethod(PythonLikeType pythonLikeType, ClassWrite String javaMethodName = getJavaMethodName(methodName); String signature = getFunctionSignature(function, function.getAsmMethodDescriptorString()); - classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, javaMethodName, interfaceDescriptor, + classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, getJavaMethodHolderName(methodName), interfaceDescriptor, null, null); MethodVisitor methodVisitor = classWriter.visitMethod(Modifier.PUBLIC | Modifier.STATIC, javaMethodName, function.getAsmMethodDescriptorString(), signature, null); @@ -1005,7 +1010,7 @@ private static void createStaticMethod(PythonLikeType pythonLikeType, ClassWrite javaParameterTypes[i] = Type.getType('L' + parameterPythonTypeList.get(i).getJavaTypeInternalName() + ';'); } - createInstanceOrStaticMethodBody(internalClassName, javaMethodName, javaParameterTypes, + createInstanceOrStaticMethodBody(internalClassName, methodName, javaParameterTypes, interfaceDeclaration.methodDescriptor, function, interfaceDeclaration.interfaceName, interfaceDescriptor, methodVisitor); @@ -1022,7 +1027,7 @@ private static void createClassMethod(PythonLikeType pythonLikeType, ClassWriter String interfaceDescriptor = 'L' + interfaceDeclaration.interfaceName + ';'; String javaMethodName = getJavaMethodName(methodName); - classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, javaMethodName, interfaceDescriptor, + classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, getJavaMethodHolderName(methodName), interfaceDescriptor, null, null); String javaMethodDescriptor = interfaceDeclaration.methodDescriptor; @@ -1042,7 +1047,8 @@ private static void createClassMethod(PythonLikeType pythonLikeType, ClassWriter methodVisitor.visitLabel(start); methodVisitor.visitLineNumber(function.getFirstLine(), start); - methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, javaMethodName, interfaceDescriptor); + methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, getJavaMethodHolderName(methodName), + interfaceDescriptor); for (int i = 0; i < function.totalArgCount(); i++) { methodVisitor.visitVarInsn(Opcodes.ALOAD, i); @@ -1067,7 +1073,7 @@ private static void createClassMethod(PythonLikeType pythonLikeType, ClassWriter parameterTypes)); } - private static void createInstanceOrStaticMethodBody(String internalClassName, String javaMethodName, + private static void createInstanceOrStaticMethodBody(String internalClassName, String methodName, Type[] javaParameterTypes, String methodDescriptorString, PythonCompiledFunction function, String interfaceInternalName, String interfaceDescriptor, @@ -1082,7 +1088,8 @@ private static void createInstanceOrStaticMethodBody(String internalClassName, S methodVisitor.visitLabel(start); methodVisitor.visitLineNumber(function.getFirstLine(), start); - methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, javaMethodName, interfaceDescriptor); + methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, getJavaMethodHolderName(methodName), + interfaceDescriptor); for (int i = 0; i < function.totalArgCount(); i++) { methodVisitor.visitVarInsn(Opcodes.ALOAD, i); } diff --git a/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/DelegatingInterfaceImplementor.java b/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/DelegatingInterfaceImplementor.java index 33657508ae..84e7884e4b 100644 --- a/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/DelegatingInterfaceImplementor.java +++ b/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/implementors/DelegatingInterfaceImplementor.java @@ -81,7 +81,7 @@ private void implementMethod(ClassWriter classWriter, PythonCompiledClass compil var functionInterfaceDeclaration = methodNameToFieldDescriptor.get(interfaceMethod.getName()); interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, 0); interfaceMethodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, - PythonClassTranslator.getJavaMethodName(interfaceMethod.getName()), + PythonClassTranslator.getJavaMethodHolderName(interfaceMethod.getName()), functionInterfaceDeclaration.descriptor()); interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Object.class), "getClass", Type.getMethodDescriptor(Type.getType(Class.class)), false); @@ -89,7 +89,7 @@ private void implementMethod(ClassWriter classWriter, PythonCompiledClass compil interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class), "getField", Type.getMethodDescriptor(Type.getType(Field.class), Type.getType(String.class)), false); interfaceMethodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, - PythonClassTranslator.getJavaMethodName(interfaceMethod.getName()), + PythonClassTranslator.getJavaMethodHolderName(interfaceMethod.getName()), functionInterfaceDeclaration.descriptor()); interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class), "get", Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(Object.class)), false); diff --git a/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonSuperObject.java b/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonSuperObject.java index fe787c589f..85a31e0412 100644 --- a/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonSuperObject.java +++ b/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonSuperObject.java @@ -50,7 +50,8 @@ public PythonSuperObject(PythonLikeType previousType, PythonLikeObject instance) if (typeResult instanceof PythonLikeFunction && !(typeResult instanceof PythonLikeType)) { try { Object methodInstance = - candidate.getJavaClass().getField(PythonClassTranslator.getJavaMethodName(name)).get(null); + candidate.getJavaClass().getField(PythonClassTranslator.getJavaMethodHolderName(name)) + .get(null); typeResult = new GeneratedFunctionMethodReference(methodInstance, methodInstance.getClass().getDeclaredMethods()[0], Map.of(), diff --git a/python/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java b/python/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java index 6fd1b237f7..debe07f80c 100644 --- a/python/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java +++ b/python/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java @@ -64,7 +64,7 @@ public void testPythonClassTranslation() throws ClassNotFoundException, NoSuchMe Class generatedClass = BuiltinTypes.asmClassLoader.loadClass( classType.getJavaTypeInternalName().replace('/', '.')); - assertThat(generatedClass).hasPublicFields(PythonClassTranslator.getJavaMethodName("get_age"), + assertThat(generatedClass).hasPublicFields(PythonClassTranslator.getJavaMethodHolderName("get_age"), PythonClassTranslator.getJavaFieldName("age")); assertThat(generatedClass).hasPublicMethods( PythonClassTranslator.getJavaMethodName("__init__"), diff --git a/python/python-core/src/main/python/domain/_annotations.py b/python/python-core/src/main/python/domain/_annotations.py index 19da539f6c..25d9dcb715 100644 --- a/python/python-core/src/main/python/domain/_annotations.py +++ b/python/python-core/src/main/python/domain/_annotations.py @@ -1,13 +1,12 @@ -import jpype - -from ._variable_listener import VariableListener -from .._timefold_java_interop import ensure_init, get_asm_type from _jpyinterpreter import JavaAnnotation, AnnotationValueSupplier from jpype import JImplements, JOverride from typing import Union, List, Callable, Type, TYPE_CHECKING, TypeVar +from ._variable_listener import VariableListener +from .._timefold_java_interop import ensure_init, get_asm_type + if TYPE_CHECKING: - from ai.timefold.solver.core.api.solver.change import ProblemChange as _ProblemChange + pass Solution_ = TypeVar('Solution_') @@ -240,6 +239,56 @@ def __init__(self, *, }) +class CascadingUpdateShadowVariable(JavaAnnotation): + """ + Specifies that field may be updated by the target method when one or more source variables change. + + Automatically cascades change events to `NextElementShadowVariable` of a `PlanningListVariable`. + + Notes + ----- + Important: it must only change the shadow variable(s) for which it's configured. + It can be applied to multiple attributes to modify different shadow variables. + It should never change a genuine variable or a problem fact. + It can change its shadow variable(s) on multiple entity instances + (for example: an arrival_time change affects all trailing entities too). + + Examples + -------- + >>> from timefold.solver.domain import CascadingUpdateShadowVariable, PreviousElementShadowVariable, planning_entity + >>> from typing import Annotated + >>> from domain import ArrivalTimeVariableListener + >>> from datetime import datetime, timedelta + >>> + >>> @planning_entity + >>> class Visit: + ... previous: Annotated['Visit', PreviousElementShadowVariable] + ... arrival_time: Annotated[datetime, + ... CascadingUpdateShadowVariable( + ... target_method_name='update_arrival_time', + ... source_variable_name='previous' + ... ) + ... ] + ... + ... def update_arrival_time(self): + ... self.arrival_time = previous.arrival_time + timedelta(hours=1) + """ + + def __init__(self, *, + source_variable_name: str, + target_method_name: str): + ensure_init() + from ai.timefold.jpyinterpreter import PythonClassTranslator + from ai.timefold.solver.core.api.domain.variable import \ + CascadingUpdateShadowVariable as JavaCascadingUpdateShadowVariable + + super().__init__(JavaCascadingUpdateShadowVariable, + { + 'sourceVariableName': PythonClassTranslator.getJavaFieldName(source_variable_name), + 'targetMethodName': PythonClassTranslator.getJavaMethodName(target_method_name), + }) + + class IndexShadowVariable(JavaAnnotation): """ Specifies that an attribute is an index of this planning value in another entity's `PlanningListVariable`. @@ -662,8 +711,7 @@ def planning_entity(entity_class: Type = None, /, *, pinning_filter: Callable = def planning_entity_wrapper(entity_class_argument): from .._timefold_java_interop import _add_to_compilation_queue - from ai.timefold.solver.core.api.domain.entity import PinningFilter - from _jpyinterpreter import add_class_annotation, translate_python_bytecode_to_java_bytecode + from _jpyinterpreter import add_class_annotation from typing import get_origin, Annotated planning_pin_field = None @@ -773,7 +821,7 @@ def constraint_configuration(constraint_configuration_class: Type[Solution_]) -> __all__ = ['PlanningId', 'PlanningScore', 'PlanningPin', 'PlanningVariable', 'PlanningListVariable', 'ShadowVariable', - 'PiggybackShadowVariable', + 'PiggybackShadowVariable', 'CascadingUpdateShadowVariable', 'IndexShadowVariable', 'PreviousElementShadowVariable', 'NextElementShadowVariable', 'AnchorShadowVariable', 'InverseRelationShadowVariable', 'ProblemFactProperty', 'ProblemFactCollectionProperty', diff --git a/python/python-core/tests/test_vehicle_routing.py b/python/python-core/tests/test_vehicle_routing.py index eb3673813c..8518ccb582 100644 --- a/python/python-core/tests/test_vehicle_routing.py +++ b/python/python-core/tests/test_vehicle_routing.py @@ -1,12 +1,10 @@ +from dataclasses import dataclass, field from datetime import datetime, timedelta - from timefold.solver import * -from timefold.solver.domain import * from timefold.solver.config import * +from timefold.solver.domain import * from timefold.solver.score import * - -from typing import Annotated, List, Optional -from dataclasses import dataclass, field +from typing import Annotated, Optional @dataclass @@ -19,34 +17,6 @@ def driving_time_to(self, other: 'Location') -> int: return self.driving_time_seconds[id(other)] -class ArrivalTimeUpdatingVariableListener(VariableListener): - def after_variable_changed(self, score_director: ScoreDirector, visit: 'Visit') -> None: - if visit.vehicle is None: - if visit.arrival_time is not None: - score_director.before_variable_changed(visit, 'arrival_time') - visit.arrival_time = None - score_director.after_variable_changed(visit, 'arrival_time') - return - previous_visit = visit.previous_visit - departure_time = visit.vehicle.departure_time if previous_visit is None else previous_visit.departure_time() - next_visit = visit - arrival_time = ArrivalTimeUpdatingVariableListener.calculate_arrival_time(next_visit, departure_time) - while next_visit is not None and next_visit.arrival_time != arrival_time: - score_director.before_variable_changed(next_visit, 'arrival_time') - next_visit.arrival_time = arrival_time - score_director.after_variable_changed(next_visit, 'arrival_time') - departure_time = next_visit.departure_time() - next_visit = next_visit.next_visit - arrival_time = ArrivalTimeUpdatingVariableListener.calculate_arrival_time(next_visit, departure_time) - - @staticmethod - def calculate_arrival_time(visit: Optional['Visit'], previous_departure_time: Optional[datetime]) \ - -> datetime | None: - if visit is None or previous_departure_time is None: - return None - return previous_departure_time + timedelta(seconds=visit.driving_time_seconds_from_previous_standstill()) - - @planning_entity @dataclass class Visit: @@ -63,11 +33,22 @@ class Visit: field(default=None)) next_visit: Annotated[Optional['Visit'], NextElementShadowVariable(source_variable_name='visits')] = field(default=None) - arrival_time: Annotated[Optional[datetime], - ShadowVariable(variable_listener_class=ArrivalTimeUpdatingVariableListener, - source_variable_name='vehicle'), - ShadowVariable(variable_listener_class=ArrivalTimeUpdatingVariableListener, - source_variable_name='previous_visit')] = field(default=None) + arrival_time: Annotated[ + Optional[datetime], + CascadingUpdateShadowVariable(source_variable_name='previous_visit', + target_method_name='update_arrival_time'), + CascadingUpdateShadowVariable(source_variable_name='vehicle', + target_method_name='update_arrival_time')] = field(default=None) + + def update_arrival_time(self): + if self.vehicle is None or (self.previous_visit is not None and self.previous_visit.arrival_time is None): + self.arrival_time = None + elif self.previous_visit is None: + self.arrival_time = (self.vehicle.departure_time + + timedelta(seconds=self.vehicle.home_location.driving_time_to(self.location))) + else: + self.arrival_time = (self.previous_visit.departure_time() + + timedelta(seconds=self.previous_visit.location.driving_time_to(self.location))) def departure_time(self) -> Optional[datetime]: if self.arrival_time is None: @@ -203,7 +184,7 @@ def test_vrp(): constraint_provider_function=vehicle_routing_constraints ), termination_config=TerminationConfig( - best_score_limit='0hard/-300soft' + best_score_limit='0hard/-300soft', ) ) @@ -300,6 +281,13 @@ def test_vrp(): ] ) solution = solver.solve(problem) - + assert [visit.arrival_time for visit in solution.visits] == [ + # Visit 1: 1-minute travel time from Vehicle A start + datetime(2020, 1, 1, hour=0, minute=1), + # Visit 2: 1-minute travel time from visit 1 + 1-hour service + datetime(2020, 1, 1, hour=1, minute=2), + # Visit 3: 1-minute travel time from Vehicle B start + datetime(2020, 1, 1, hour=0, minute=1) + ] assert [visit.id for visit in solution.vehicles[0].visits] == ['1', '2'] assert [visit.id for visit in solution.vehicles[1].visits] == ['3'] From 624e98e1dc03a78fe13d5aed4357ff239d137c0e Mon Sep 17 00:00:00 2001 From: Christopher Chianelli Date: Tue, 23 Jul 2024 16:40:11 -0400 Subject: [PATCH 2/2] chore: remove empty TYPE_CHECKING block --- python/python-core/src/main/python/domain/_annotations.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/python-core/src/main/python/domain/_annotations.py b/python/python-core/src/main/python/domain/_annotations.py index 25d9dcb715..afd35e6f05 100644 --- a/python/python-core/src/main/python/domain/_annotations.py +++ b/python/python-core/src/main/python/domain/_annotations.py @@ -1,14 +1,10 @@ from _jpyinterpreter import JavaAnnotation, AnnotationValueSupplier from jpype import JImplements, JOverride -from typing import Union, List, Callable, Type, TYPE_CHECKING, TypeVar +from typing import Union, List, Callable, Type, TypeVar from ._variable_listener import VariableListener from .._timefold_java_interop import ensure_init, get_asm_type -if TYPE_CHECKING: - pass - - Solution_ = TypeVar('Solution_')