Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add CascadingUpdateShadowVariable to Python #989

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public class PythonClassTranslator {
public static final String TYPE_FIELD_NAME = "$TYPE";
public static final String CPYTHON_TYPE_FIELD_NAME = "$CPYTHON_TYPE";
public static final String JAVA_METHOD_PREFIX = "$method$";
public static final String JAVA_METHOD_HOLDER_PREFIX = "$methodholder$";
public static final String PYTHON_JAVA_TYPE_MAPPING_PREFIX = "$pythonJavaTypeMapping";

public record PreparedClassInfo(PythonLikeType type, String className, String classInternalName) {
Expand Down Expand Up @@ -484,6 +485,10 @@ public static String getJavaMethodName(String pythonMethodName) {
return JAVA_METHOD_PREFIX + pythonMethodName;
}

public static String getJavaMethodHolderName(String pythonMethodName) {
return JAVA_METHOD_HOLDER_PREFIX + pythonMethodName;
}

public static String getPythonMethodName(String javaMethodName) {
return javaMethodName.substring(JAVA_METHOD_PREFIX.length());
}
Expand Down Expand Up @@ -566,7 +571,7 @@ private static Class<?> createBytecodeForMethodAndSetOnClass(String className, P
throw new IllegalStateException("Unhandled case: " + pythonMethodKind);
}

generatedClass.getField(getJavaMethodName(methodEntry.getKey()))
generatedClass.getField(getJavaMethodHolderName(methodEntry.getKey()))
.set(null, functionInstance);
pythonLikeType.$setAttribute(methodEntry.getKey(), translatedPythonMethodWrapper);
return functionClass;
Expand Down Expand Up @@ -797,7 +802,7 @@ private static PythonLikeFunction createConstructor(String classInternalName,
Class<? extends PythonLikeFunction> generatedClass =
(Class<? extends PythonLikeFunction>) BuiltinTypes.asmClassLoader.loadClass(constructorClassName);
if (initFunction != null) {
Object method = typeGeneratedClass.getField(getJavaMethodName("__init__")).get(null);
Object method = typeGeneratedClass.getField(getJavaMethodHolderName("__init__")).get(null);
ArgumentSpec spec =
(ArgumentSpec) method.getClass().getField(ARGUMENT_SPEC_INSTANCE_FIELD_NAME).get(method);
generatedClass.getField(ARGUMENT_SPEC_INSTANCE_FIELD_NAME).set(null, spec);
Expand Down Expand Up @@ -957,7 +962,7 @@ private static void createInstanceMethod(PythonLikeType pythonLikeType, ClassWri
String interfaceDescriptor = interfaceDeclaration.descriptor();
String javaMethodName = getJavaMethodName(methodName);

classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, javaMethodName, interfaceDescriptor,
classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, getJavaMethodHolderName(methodName), interfaceDescriptor,
null, null);
instanceMethodNameToMethodDescriptor.put(methodName, interfaceDeclaration);
Type returnType = getVirtualFunctionReturnType(function);
Expand All @@ -973,7 +978,7 @@ private static void createInstanceMethod(PythonLikeType pythonLikeType, ClassWri
MethodVisitor methodVisitor =
classWriter.visitMethod(Modifier.PUBLIC, javaMethodName, javaMethodDescriptor, signature, null);

createInstanceOrStaticMethodBody(internalClassName, javaMethodName, javaParameterTypes,
createInstanceOrStaticMethodBody(internalClassName, methodName, javaParameterTypes,
interfaceDeclaration.methodDescriptor, function,
interfaceDeclaration.interfaceName, interfaceDescriptor, methodVisitor);

Expand All @@ -993,7 +998,7 @@ private static void createStaticMethod(PythonLikeType pythonLikeType, ClassWrite
String javaMethodName = getJavaMethodName(methodName);
String signature = getFunctionSignature(function, function.getAsmMethodDescriptorString());

classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, javaMethodName, interfaceDescriptor,
classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, getJavaMethodHolderName(methodName), interfaceDescriptor,
null, null);
MethodVisitor methodVisitor = classWriter.visitMethod(Modifier.PUBLIC | Modifier.STATIC, javaMethodName,
function.getAsmMethodDescriptorString(), signature, null);
Expand All @@ -1005,7 +1010,7 @@ private static void createStaticMethod(PythonLikeType pythonLikeType, ClassWrite
javaParameterTypes[i] = Type.getType('L' + parameterPythonTypeList.get(i).getJavaTypeInternalName() + ';');
}

createInstanceOrStaticMethodBody(internalClassName, javaMethodName, javaParameterTypes,
createInstanceOrStaticMethodBody(internalClassName, methodName, javaParameterTypes,
interfaceDeclaration.methodDescriptor, function,
interfaceDeclaration.interfaceName, interfaceDescriptor, methodVisitor);

Expand All @@ -1022,7 +1027,7 @@ private static void createClassMethod(PythonLikeType pythonLikeType, ClassWriter
String interfaceDescriptor = 'L' + interfaceDeclaration.interfaceName + ';';
String javaMethodName = getJavaMethodName(methodName);

classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, javaMethodName, interfaceDescriptor,
classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, getJavaMethodHolderName(methodName), interfaceDescriptor,
null, null);

String javaMethodDescriptor = interfaceDeclaration.methodDescriptor;
Expand All @@ -1042,7 +1047,8 @@ private static void createClassMethod(PythonLikeType pythonLikeType, ClassWriter
methodVisitor.visitLabel(start);
methodVisitor.visitLineNumber(function.getFirstLine(), start);

methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, javaMethodName, interfaceDescriptor);
methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, getJavaMethodHolderName(methodName),
interfaceDescriptor);

for (int i = 0; i < function.totalArgCount(); i++) {
methodVisitor.visitVarInsn(Opcodes.ALOAD, i);
Expand All @@ -1067,7 +1073,7 @@ private static void createClassMethod(PythonLikeType pythonLikeType, ClassWriter
parameterTypes));
}

private static void createInstanceOrStaticMethodBody(String internalClassName, String javaMethodName,
private static void createInstanceOrStaticMethodBody(String internalClassName, String methodName,
Type[] javaParameterTypes,
String methodDescriptorString,
PythonCompiledFunction function, String interfaceInternalName, String interfaceDescriptor,
Expand All @@ -1082,7 +1088,8 @@ private static void createInstanceOrStaticMethodBody(String internalClassName, S
methodVisitor.visitLabel(start);
methodVisitor.visitLineNumber(function.getFirstLine(), start);

methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, javaMethodName, interfaceDescriptor);
methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, getJavaMethodHolderName(methodName),
interfaceDescriptor);
for (int i = 0; i < function.totalArgCount(); i++) {
methodVisitor.visitVarInsn(Opcodes.ALOAD, i);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ private void implementMethod(ClassWriter classWriter, PythonCompiledClass compil
var functionInterfaceDeclaration = methodNameToFieldDescriptor.get(interfaceMethod.getName());
interfaceMethodVisitor.visitVarInsn(Opcodes.ALOAD, 0);
interfaceMethodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName,
PythonClassTranslator.getJavaMethodName(interfaceMethod.getName()),
PythonClassTranslator.getJavaMethodHolderName(interfaceMethod.getName()),
functionInterfaceDeclaration.descriptor());
interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Object.class),
"getClass", Type.getMethodDescriptor(Type.getType(Class.class)), false);
interfaceMethodVisitor.visitLdcInsn(PythonBytecodeToJavaBytecodeTranslator.ARGUMENT_SPEC_INSTANCE_FIELD_NAME);
interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class),
"getField", Type.getMethodDescriptor(Type.getType(Field.class), Type.getType(String.class)), false);
interfaceMethodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName,
PythonClassTranslator.getJavaMethodName(interfaceMethod.getName()),
PythonClassTranslator.getJavaMethodHolderName(interfaceMethod.getName()),
functionInterfaceDeclaration.descriptor());
interfaceMethodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class),
"get", Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(Object.class)), false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ public PythonSuperObject(PythonLikeType previousType, PythonLikeObject instance)
if (typeResult instanceof PythonLikeFunction && !(typeResult instanceof PythonLikeType)) {
try {
Object methodInstance =
candidate.getJavaClass().getField(PythonClassTranslator.getJavaMethodName(name)).get(null);
candidate.getJavaClass().getField(PythonClassTranslator.getJavaMethodHolderName(name))
.get(null);
typeResult = new GeneratedFunctionMethodReference(methodInstance,
methodInstance.getClass().getDeclaredMethods()[0],
Map.of(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public void testPythonClassTranslation() throws ClassNotFoundException, NoSuchMe
Class<?> generatedClass = BuiltinTypes.asmClassLoader.loadClass(
classType.getJavaTypeInternalName().replace('/', '.'));

assertThat(generatedClass).hasPublicFields(PythonClassTranslator.getJavaMethodName("get_age"),
assertThat(generatedClass).hasPublicFields(PythonClassTranslator.getJavaMethodHolderName("get_age"),
PythonClassTranslator.getJavaFieldName("age"));
assertThat(generatedClass).hasPublicMethods(
PythonClassTranslator.getJavaMethodName("__init__"),
Expand Down
66 changes: 55 additions & 11 deletions python/python-core/src/main/python/domain/_annotations.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import jpype

from ._variable_listener import VariableListener
from .._timefold_java_interop import ensure_init, get_asm_type
from _jpyinterpreter import JavaAnnotation, AnnotationValueSupplier
from jpype import JImplements, JOverride
from typing import Union, List, Callable, Type, TYPE_CHECKING, TypeVar

if TYPE_CHECKING:
from ai.timefold.solver.core.api.solver.change import ProblemChange as _ProblemChange
from typing import Union, List, Callable, Type, TypeVar

from ._variable_listener import VariableListener
from .._timefold_java_interop import ensure_init, get_asm_type

Solution_ = TypeVar('Solution_')

Expand Down Expand Up @@ -240,6 +235,56 @@ def __init__(self, *,
})


class CascadingUpdateShadowVariable(JavaAnnotation):
"""
Specifies that field may be updated by the target method when one or more source variables change.

Automatically cascades change events to `NextElementShadowVariable` of a `PlanningListVariable`.

Notes
-----
Important: it must only change the shadow variable(s) for which it's configured.
It can be applied to multiple attributes to modify different shadow variables.
It should never change a genuine variable or a problem fact.
It can change its shadow variable(s) on multiple entity instances
(for example: an arrival_time change affects all trailing entities too).

Examples
--------
>>> from timefold.solver.domain import CascadingUpdateShadowVariable, PreviousElementShadowVariable, planning_entity
>>> from typing import Annotated
>>> from domain import ArrivalTimeVariableListener
>>> from datetime import datetime, timedelta
>>>
>>> @planning_entity
>>> class Visit:
... previous: Annotated['Visit', PreviousElementShadowVariable]
... arrival_time: Annotated[datetime,
... CascadingUpdateShadowVariable(
... target_method_name='update_arrival_time',
... source_variable_name='previous'
... )
... ]
...
... def update_arrival_time(self):
... self.arrival_time = previous.arrival_time + timedelta(hours=1)
"""

def __init__(self, *,
source_variable_name: str,
target_method_name: str):
ensure_init()
from ai.timefold.jpyinterpreter import PythonClassTranslator
from ai.timefold.solver.core.api.domain.variable import \
CascadingUpdateShadowVariable as JavaCascadingUpdateShadowVariable

super().__init__(JavaCascadingUpdateShadowVariable,
{
'sourceVariableName': PythonClassTranslator.getJavaFieldName(source_variable_name),
'targetMethodName': PythonClassTranslator.getJavaMethodName(target_method_name),
})


class IndexShadowVariable(JavaAnnotation):
"""
Specifies that an attribute is an index of this planning value in another entity's `PlanningListVariable`.
Expand Down Expand Up @@ -662,8 +707,7 @@ def planning_entity(entity_class: Type = None, /, *, pinning_filter: Callable =

def planning_entity_wrapper(entity_class_argument):
from .._timefold_java_interop import _add_to_compilation_queue
from ai.timefold.solver.core.api.domain.entity import PinningFilter
from _jpyinterpreter import add_class_annotation, translate_python_bytecode_to_java_bytecode
from _jpyinterpreter import add_class_annotation
from typing import get_origin, Annotated

planning_pin_field = None
Expand Down Expand Up @@ -773,7 +817,7 @@ def constraint_configuration(constraint_configuration_class: Type[Solution_]) ->

__all__ = ['PlanningId', 'PlanningScore', 'PlanningPin', 'PlanningVariable',
'PlanningListVariable', 'ShadowVariable',
'PiggybackShadowVariable',
'PiggybackShadowVariable', 'CascadingUpdateShadowVariable',
'IndexShadowVariable', 'PreviousElementShadowVariable', 'NextElementShadowVariable',
'AnchorShadowVariable', 'InverseRelationShadowVariable',
'ProblemFactProperty', 'ProblemFactCollectionProperty',
Expand Down
68 changes: 28 additions & 40 deletions python/python-core/tests/test_vehicle_routing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from dataclasses import dataclass, field
from datetime import datetime, timedelta

from timefold.solver import *
from timefold.solver.domain import *
from timefold.solver.config import *
from timefold.solver.domain import *
from timefold.solver.score import *

from typing import Annotated, List, Optional
from dataclasses import dataclass, field
from typing import Annotated, Optional


@dataclass
Expand All @@ -19,34 +17,6 @@ def driving_time_to(self, other: 'Location') -> int:
return self.driving_time_seconds[id(other)]


class ArrivalTimeUpdatingVariableListener(VariableListener):
def after_variable_changed(self, score_director: ScoreDirector, visit: 'Visit') -> None:
if visit.vehicle is None:
if visit.arrival_time is not None:
score_director.before_variable_changed(visit, 'arrival_time')
visit.arrival_time = None
score_director.after_variable_changed(visit, 'arrival_time')
return
previous_visit = visit.previous_visit
departure_time = visit.vehicle.departure_time if previous_visit is None else previous_visit.departure_time()
next_visit = visit
arrival_time = ArrivalTimeUpdatingVariableListener.calculate_arrival_time(next_visit, departure_time)
while next_visit is not None and next_visit.arrival_time != arrival_time:
score_director.before_variable_changed(next_visit, 'arrival_time')
next_visit.arrival_time = arrival_time
score_director.after_variable_changed(next_visit, 'arrival_time')
departure_time = next_visit.departure_time()
next_visit = next_visit.next_visit
arrival_time = ArrivalTimeUpdatingVariableListener.calculate_arrival_time(next_visit, departure_time)

@staticmethod
def calculate_arrival_time(visit: Optional['Visit'], previous_departure_time: Optional[datetime]) \
-> datetime | None:
if visit is None or previous_departure_time is None:
return None
return previous_departure_time + timedelta(seconds=visit.driving_time_seconds_from_previous_standstill())


@planning_entity
@dataclass
class Visit:
Expand All @@ -63,11 +33,22 @@ class Visit:
field(default=None))
next_visit: Annotated[Optional['Visit'],
NextElementShadowVariable(source_variable_name='visits')] = field(default=None)
arrival_time: Annotated[Optional[datetime],
ShadowVariable(variable_listener_class=ArrivalTimeUpdatingVariableListener,
source_variable_name='vehicle'),
ShadowVariable(variable_listener_class=ArrivalTimeUpdatingVariableListener,
source_variable_name='previous_visit')] = field(default=None)
arrival_time: Annotated[
Optional[datetime],
CascadingUpdateShadowVariable(source_variable_name='previous_visit',
target_method_name='update_arrival_time'),
CascadingUpdateShadowVariable(source_variable_name='vehicle',
target_method_name='update_arrival_time')] = field(default=None)

def update_arrival_time(self):
if self.vehicle is None or (self.previous_visit is not None and self.previous_visit.arrival_time is None):
self.arrival_time = None
elif self.previous_visit is None:
self.arrival_time = (self.vehicle.departure_time +
timedelta(seconds=self.vehicle.home_location.driving_time_to(self.location)))
else:
self.arrival_time = (self.previous_visit.departure_time() +
timedelta(seconds=self.previous_visit.location.driving_time_to(self.location)))

def departure_time(self) -> Optional[datetime]:
if self.arrival_time is None:
Expand Down Expand Up @@ -203,7 +184,7 @@ def test_vrp():
constraint_provider_function=vehicle_routing_constraints
),
termination_config=TerminationConfig(
best_score_limit='0hard/-300soft'
best_score_limit='0hard/-300soft',
)
)

Expand Down Expand Up @@ -300,6 +281,13 @@ def test_vrp():
]
)
solution = solver.solve(problem)

assert [visit.arrival_time for visit in solution.visits] == [
# Visit 1: 1-minute travel time from Vehicle A start
datetime(2020, 1, 1, hour=0, minute=1),
# Visit 2: 1-minute travel time from visit 1 + 1-hour service
datetime(2020, 1, 1, hour=1, minute=2),
# Visit 3: 1-minute travel time from Vehicle B start
datetime(2020, 1, 1, hour=0, minute=1)
]
assert [visit.id for visit in solution.vehicles[0].visits] == ['1', '2']
assert [visit.id for visit in solution.vehicles[1].visits] == ['3']
Loading