From ae4666c4d31b827dc4f59cd8c13771ea4ddc2324 Mon Sep 17 00:00:00 2001 From: Hubert Plociniczak Date: Mon, 26 Jun 2023 14:38:36 +0200 Subject: [PATCH] Improve TCO in the presence of warnings (#7116) Partially revert https://github.com/enso-org/enso/pull/6849, which introduced a regression in TCO in the presence of warnings. Rather than modifying the tail call status, `TailCallException` now propagates the extracted warnings and appends them to the final result. Closes #7093 # Important Notes Compared to the previous attempt we don't pay the penalty of adding the warnings or even checking for them because it is being dealt in a separate specialization. --- .../node/callable/InvokeCallableNode.java | 20 +++++-- .../node/callable/InvokeConversionNode.java | 15 +++-- .../node/callable/InvokeMethodNode.java | 11 +++- .../callable/dispatch/CallOptimiserNode.java | 5 +- .../node/callable/dispatch/CurryNode.java | 4 +- .../callable/dispatch/IndirectCurryNode.java | 4 +- .../dispatch/LoopingCallOptimiserNode.java | 58 +++++++++++++++++-- .../dispatch/SimpleCallOptimiserNode.java | 7 ++- .../callable/thunk/ThunkExecutorNode.java | 4 +- .../builtin/ordering/SortVectorNode.java | 6 +- .../org/enso/interpreter/runtime/Module.java | 3 +- .../runtime/control/TailCallException.java | 32 +++++++++- test/Tests/src/Semantic/Warnings_Spec.enso | 16 +++++ 13 files changed, 154 insertions(+), 31 deletions(-) diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeCallableNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeCallableNode.java index defd9e29a8ba..d96954ab7642 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeCallableNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeCallableNode.java @@ -21,6 +21,7 @@ import org.enso.interpreter.runtime.callable.atom.Atom; import org.enso.interpreter.runtime.callable.atom.AtomConstructor; import org.enso.interpreter.runtime.callable.function.Function; +import org.enso.interpreter.runtime.control.TailCallException; import org.enso.interpreter.runtime.error.DataflowError; import org.enso.interpreter.runtime.error.PanicException; import org.enso.interpreter.runtime.error.PanicSentinel; @@ -264,6 +265,15 @@ public Object invokeWarnings( State state, Object[] arguments, @CachedLibrary(limit = "3") WarningsLibrary warnings) { + + Warning[] extracted; + Object callable; + try { + extracted = warnings.getWarnings(warning, null); + callable = warnings.removeWarnings(warning); + } catch (UnsupportedMessageException e) { + throw CompilerDirectives.shouldNotReachHere(e); + } try { if (childDispatch == null) { CompilerDirectives.transferToInterpreterAndInvalidate(); @@ -277,7 +287,7 @@ public Object invokeWarnings( invokeFunctionNode.getSchema(), invokeFunctionNode.getDefaultsExecutionMode(), invokeFunctionNode.getArgumentsExecutionMode())); - childDispatch.setTailStatus(TailStatus.NOT_TAIL); + childDispatch.setTailStatus(getTailStatus()); childDispatch.setId(invokeFunctionNode.getId()); notifyInserted(childDispatch); } @@ -287,12 +297,12 @@ public Object invokeWarnings( } var result = childDispatch.execute( - warnings.removeWarnings(warning), + callable, callerFrame, state, arguments); - Warning[] extracted = warnings.getWarnings(warning, null); + if (result instanceof DataflowError) { return result; } else if (result instanceof WithWarnings withWarnings) { @@ -300,8 +310,8 @@ public Object invokeWarnings( } else { return WithWarnings.wrap(EnsoContext.get(this), result, extracted); } - } catch (UnsupportedMessageException e) { - throw CompilerDirectives.shouldNotReachHere(e); + } catch (TailCallException e) { + throw new TailCallException(e, extracted); } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeConversionNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeConversionNode.java index 396a1dce1aba..252f67ee2da4 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeConversionNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeConversionNode.java @@ -17,6 +17,7 @@ import org.enso.interpreter.runtime.callable.UnresolvedConversion; import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo; import org.enso.interpreter.runtime.callable.function.Function; +import org.enso.interpreter.runtime.control.TailCallException; import org.enso.interpreter.runtime.data.ArrayRope; import org.enso.interpreter.runtime.data.Type; import org.enso.interpreter.runtime.data.text.Text; @@ -162,7 +163,7 @@ Object doWarning( invokeFunctionNode.getDefaultsExecutionMode(), invokeFunctionNode.getArgumentsExecutionMode(), thatArgumentPosition)); - childDispatch.setTailStatus(TailStatus.NOT_TAIL); + childDispatch.setTailStatus(getTailStatus()); childDispatch.setId(invokeFunctionNode.getId()); notifyInserted(childDispatch); } @@ -170,11 +171,15 @@ Object doWarning( lock.unlock(); } } - arguments[thatArgumentPosition] = that.getValue(); + Object value = that.getValue(); + arguments[thatArgumentPosition] = value; ArrayRope warnings = that.getReassignedWarningsAsRope(this); - Object result = - childDispatch.execute(frame, state, conversion, self, that.getValue(), arguments); - return WithWarnings.appendTo(EnsoContext.get(this), result, warnings); + try { + Object result = childDispatch.execute(frame, state, conversion, self, value, arguments); + return WithWarnings.appendTo(EnsoContext.get(this), result, warnings); + } catch (TailCallException e) { + throw new TailCallException(e, warnings.toArray(Warning[]::new)); + } } @Specialization(guards = "interop.isString(that)") diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeMethodNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeMethodNode.java index ba2d99665a4b..913ce252c2fa 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeMethodNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/InvokeMethodNode.java @@ -40,6 +40,7 @@ import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo; import org.enso.interpreter.runtime.callable.function.Function; import org.enso.interpreter.runtime.callable.function.FunctionSchema; +import org.enso.interpreter.runtime.control.TailCallException; import org.enso.interpreter.runtime.data.ArrayRope; import org.enso.interpreter.runtime.data.EnsoDate; import org.enso.interpreter.runtime.data.EnsoDateTime; @@ -407,7 +408,7 @@ Object doWarning( invokeFunctionNode.getDefaultsExecutionMode(), invokeFunctionNode.getArgumentsExecutionMode(), thisArgumentPosition)); - childDispatch.setTailStatus(TailStatus.NOT_TAIL); + childDispatch.setTailStatus(getTailStatus()); childDispatch.setId(invokeFunctionNode.getId()); notifyInserted(childDispatch); } @@ -418,8 +419,12 @@ Object doWarning( arguments[thisArgumentPosition] = selfWithoutWarnings; - Object result = childDispatch.execute(frame, state, symbol, selfWithoutWarnings, arguments); - return WithWarnings.appendTo(EnsoContext.get(this), result, arrOfWarnings); + try { + Object result = childDispatch.execute(frame, state, symbol, selfWithoutWarnings, arguments); + return WithWarnings.appendTo(EnsoContext.get(this), result, arrOfWarnings); + } catch (TailCallException e) { + throw new TailCallException(e, arrOfWarnings); + } } @ExplodeLoop diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/CallOptimiserNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/CallOptimiserNode.java index 3d0422ea134a..f2d259eb1b7f 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/CallOptimiserNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/CallOptimiserNode.java @@ -5,6 +5,7 @@ import com.oracle.truffle.api.nodes.NodeInfo; import org.enso.interpreter.runtime.callable.CallerInfo; import org.enso.interpreter.runtime.callable.function.Function; +import org.enso.interpreter.runtime.error.Warning; import org.enso.interpreter.runtime.state.State; /** @@ -33,6 +34,7 @@ public static CallOptimiserNode build() { * @param callerInfo the caller info to pass to the function * @param state the state to pass to the function * @param arguments the arguments to {@code callable} + * @param warnings warnings associated with the callable, null if empty * @return the result of executing {@code callable} using {@code arguments} */ public abstract Object executeDispatch( @@ -40,5 +42,6 @@ public abstract Object executeDispatch( Function callable, CallerInfo callerInfo, State state, - Object[] arguments); + Object[] arguments, + Warning[] warnings); } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/CurryNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/CurryNode.java index 8803831db1c9..a03a529c5e11 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/CurryNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/CurryNode.java @@ -133,7 +133,7 @@ public Object execute( return value; } } else { - var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments); + var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null); return this.oversaturatedCallableNode.execute( evaluatedVal, frame, state, oversaturatedArguments); @@ -154,7 +154,7 @@ private Object doCall( return switch (getTailStatus()) { case TAIL_DIRECT -> directCall.executeCall(frame, function, callerInfo, state, arguments); case TAIL_LOOP -> throw new TailCallException(function, callerInfo, arguments); - default -> loopingCall.executeDispatch(frame, function, callerInfo, state, arguments); + default -> loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null); }; } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/IndirectCurryNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/IndirectCurryNode.java index 489139633460..085d1c3a2ce8 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/IndirectCurryNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/IndirectCurryNode.java @@ -92,7 +92,7 @@ Object doCurry( return value; } } else { - var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments); + var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null); return oversaturatedCallableNode.execute( evaluatedVal, @@ -129,7 +129,7 @@ private Object doCall( case TAIL_LOOP: throw new TailCallException(function, callerInfo, arguments); default: - return loopingCall.executeDispatch(frame, function, callerInfo, state, arguments); + return loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null); } } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/LoopingCallOptimiserNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/LoopingCallOptimiserNode.java index b78a94eb3a63..ddb34a735aa6 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/LoopingCallOptimiserNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/LoopingCallOptimiserNode.java @@ -15,9 +15,12 @@ import com.oracle.truffle.api.nodes.RepeatingNode; import org.enso.interpreter.node.callable.ExecuteCallNode; import org.enso.interpreter.node.callable.ExecuteCallNodeGen; +import org.enso.interpreter.runtime.EnsoContext; import org.enso.interpreter.runtime.callable.CallerInfo; import org.enso.interpreter.runtime.callable.function.Function; import org.enso.interpreter.runtime.control.TailCallException; +import org.enso.interpreter.runtime.error.Warning; +import org.enso.interpreter.runtime.error.WithWarnings; import org.enso.interpreter.runtime.state.State; /** @@ -54,23 +57,44 @@ public static LoopingCallOptimiserNode build() { * @param loopNode a cached instance of the loop node used by this node * @return the result of executing {@code function} using {@code arguments} */ - @Specialization - public Object dispatch( + @Specialization(guards = "warnings == null") + public Object cachedDispatch( Function function, CallerInfo callerInfo, State state, Object[] arguments, + Warning[] warnings, @Cached(value = "createLoopNode()") LoopNode loopNode) { + return dispatch(function, callerInfo, state, arguments, loopNode); + } + + @Specialization(guards = "warnings != null") + public Object cachedDispatchWarnings( + Function function, + CallerInfo callerInfo, + State state, + Object[] arguments, + Warning[] warnings, + @Cached(value = "createLoopNode()") LoopNode loopNode) { + Object result = dispatch(function, callerInfo, state, arguments, loopNode); + return WithWarnings.appendTo(EnsoContext.get(this), result, warnings); + } + + private Object dispatch( + Function function, + CallerInfo callerInfo, + State state, + Object[] arguments, + LoopNode loopNode) { RepeatedCallNode repeatedCallNode = (RepeatedCallNode) loopNode.getRepeatingNode(); VirtualFrame frame = repeatedCallNode.createFrame(); repeatedCallNode.setNextCall(frame, function, callerInfo, arguments); repeatedCallNode.setState(frame, state); loopNode.execute(frame); - return repeatedCallNode.getResult(frame); } - @Specialization(replaces = "dispatch") + @Specialization(replaces = "cachedDispatch", guards = "warnings == null") @CompilerDirectives.TruffleBoundary public Object uncachedDispatch( MaterializedFrame frame, @@ -78,7 +102,33 @@ public Object uncachedDispatch( CallerInfo callerInfo, State state, Object[] arguments, + Warning[] warnings, + @Cached ExecuteCallNode executeCallNode) { + return loopUntilCompletion(frame, function, callerInfo, state, arguments, executeCallNode); + } + + @Specialization(replaces = "cachedDispatchWarnings", guards = "warnings != null") + @CompilerDirectives.TruffleBoundary + public Object uncachedDispatchWarnings( + MaterializedFrame frame, + Function function, + CallerInfo callerInfo, + State state, + Object[] arguments, + Warning[] warnings, @Cached ExecuteCallNode executeCallNode) { + Object result = + loopUntilCompletion(frame, function, callerInfo, state, arguments, executeCallNode); + return WithWarnings.appendTo(EnsoContext.get(this), result, warnings); + } + + private Object loopUntilCompletion( + MaterializedFrame frame, + Function function, + CallerInfo callerInfo, + State state, + Object[] arguments, + ExecuteCallNode executeCallNode) { while (true) { try { return executeCallNode.executeCall(frame, function, callerInfo, state, arguments); diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/SimpleCallOptimiserNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/SimpleCallOptimiserNode.java index b5b3a1b97038..70496d9c0352 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/SimpleCallOptimiserNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/dispatch/SimpleCallOptimiserNode.java @@ -9,6 +9,7 @@ import org.enso.interpreter.runtime.callable.CallerInfo; import org.enso.interpreter.runtime.callable.function.Function; import org.enso.interpreter.runtime.control.TailCallException; +import org.enso.interpreter.runtime.error.Warning; import org.enso.interpreter.runtime.state.State; /** @@ -40,6 +41,7 @@ public static SimpleCallOptimiserNode build() { * @param callerInfo the caller info to pass to the function * @param state the state to pass to the function * @param arguments the arguments to {@code function} + * @param warnings warnings associated with the callable, null if empty * @return the result of executing {@code function} using {@code arguments} */ @Override @@ -48,7 +50,8 @@ public Object executeDispatch( Function function, CallerInfo callerInfo, State state, - Object[] arguments) { + Object[] arguments, + Warning[] warnings) { try { return executeCallNode.executeCall(frame, function, callerInfo, state, arguments); } catch (TailCallException e) { @@ -65,7 +68,7 @@ public Object executeDispatch( } } return next.executeDispatch( - frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments()); + frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings()); } } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/thunk/ThunkExecutorNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/thunk/ThunkExecutorNode.java index 114e9fda306b..388d2c690c37 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/thunk/ThunkExecutorNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/thunk/ThunkExecutorNode.java @@ -67,7 +67,7 @@ Object doCached( return callNode.call(Function.ArgumentsHelper.buildArguments(function, state)); } catch (TailCallException e) { return loopingCallOptimiserNode.executeDispatch( - frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments()); + frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings()); } } } @@ -89,7 +89,7 @@ Object doUncached( function.getCallTarget(), Function.ArgumentsHelper.buildArguments(function, state)); } catch (TailCallException e) { return loopingCallOptimiserNode.executeDispatch( - frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments()); + frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings()); } } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/SortVectorNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/SortVectorNode.java index cfbca3b41c6b..2e485fa517c2 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/SortVectorNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/SortVectorNode.java @@ -811,8 +811,8 @@ public int compare(Object x, Object y) { Object yConverted; if (hasCustomOnFunc) { // onFunc cannot have `self` argument, we assume it has just one argument. - xConverted = callNode.executeDispatch(null, onFunc.get(x), null, state, new Object[]{x}); - yConverted = callNode.executeDispatch(null, onFunc.get(y), null, state, new Object[]{y}); + xConverted = callNode.executeDispatch(null, onFunc.get(x), null, state, new Object[]{x}, null); + yConverted = callNode.executeDispatch(null, onFunc.get(y), null, state, new Object[]{y}, null); } else { xConverted = x; yConverted = y; @@ -823,7 +823,7 @@ public int compare(Object x, Object y) { } else { args = new Object[] {xConverted, yConverted}; } - Object res = callNode.executeDispatch(null, compareFunc.get(xConverted), null, state, args); + Object res = callNode.executeDispatch(null, compareFunc.get(xConverted), null, state, args, null); if (res == less) { return ascending ? -1 : 1; } else if (res == equal) { diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/Module.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/Module.java index 837afc11d55f..492b103729a3 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/Module.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/Module.java @@ -636,7 +636,8 @@ private static Object evalExpression( eval.getFunction(), callerInfo, context.emptyState(), - new Object[] {builtins.debug(), Text.create(expr)}); + new Object[] {builtins.debug(), Text.create(expr)}, + null); } private static Object generateDocs(Module module, EnsoContext context) { diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/control/TailCallException.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/control/TailCallException.java index ced3e08028b5..7dd33919868e 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/control/TailCallException.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/control/TailCallException.java @@ -3,6 +3,7 @@ import com.oracle.truffle.api.nodes.ControlFlowException; import org.enso.interpreter.runtime.callable.CallerInfo; import org.enso.interpreter.runtime.callable.function.Function; +import org.enso.interpreter.runtime.error.Warning; /** * Used to model the switch of control-flow from standard stack-based execution to looping. @@ -13,18 +14,38 @@ public class TailCallException extends ControlFlowException { private final Function function; private final CallerInfo callerInfo; private final Object[] arguments; + private final Warning[] warnings; /** * Creates a new exception containing the necessary data to continue computation. * * @param function the function to execute in a loop - * @param state the state to pass to the function + * @param callerInfo the caller execution context * @param arguments the arguments to {@code function} */ public TailCallException(Function function, CallerInfo callerInfo, Object[] arguments) { this.function = function; this.callerInfo = callerInfo; this.arguments = arguments; + this.warnings = null; + } + + private TailCallException( + Function function, CallerInfo callerInfo, Object[] arguments, Warning[] warnings) { + this.function = function; + this.callerInfo = callerInfo; + this.arguments = arguments; + this.warnings = warnings; + } + + /** + * Creates a new exception containing the necessary data to continue computation. + * + * @param origin the original tail call exception + * @param warnings warnings to be associated with the tail call exception + */ + public TailCallException(TailCallException origin, Warning[] warnings) { + this(origin.getFunction(), origin.getCallerInfo(), origin.getArguments(), warnings); } /** @@ -53,4 +74,13 @@ public Object[] getArguments() { public CallerInfo getCallerInfo() { return callerInfo; } + + /** + * Gets the warnings that should be appended to the result of calling the function. + * + * @return the warnings to be appended to the result of the call, or null if empty + */ + public Warning[] getWarnings() { + return warnings; + } } diff --git a/test/Tests/src/Semantic/Warnings_Spec.enso b/test/Tests/src/Semantic/Warnings_Spec.enso index fd52d72289da..68a5363769de 100644 --- a/test/Tests/src/Semantic/Warnings_Spec.enso +++ b/test/Tests/src/Semantic/Warnings_Spec.enso @@ -437,4 +437,20 @@ spec = Test.group "Dataflow Warnings" <| result_non_tail . should_equal 6 Warning.get_all result_non_tail . map .value . should_equal ["Foo"] + Test.specify "should not break TCO when warnings are attached to arguments" <| + vec = Vector.new 10000 (i-> i+1) + elem1 = Warning.attach "WARNING1" 998 + vec.contains 998 . should_equal True + res1 = vec.contains elem1 + res1 . should_be_true + Warning.get_all res1 . map .value . should_equal ["WARNING1"] + + elem2 = Warning.attach "WARNING2" 9988 + vec.contains 9988 . should_be_true + vec.contains elem2 . should_be_true + + res2 = vec.contains elem2 + res2 . should_equal True + Warning.get_all res2 . map .value . should_equal ["WARNING2"] + main = Test_Suite.run_main spec