Skip to content

Commit

Permalink
Improve TCO in the presence of warnings (#7116)
Browse files Browse the repository at this point in the history
Partially revert #6849, which introduced a regression in TCO in the presence of warnings. Rather than modifying the tail call status, `TailCallException` now propagates the extracted warnings and appends them to the final result.

Closes #7093

# Important Notes
Compared to the previous attempt we don't pay the penalty of adding the warnings or even checking for them because it is being dealt in a separate specialization.
  • Loading branch information
hubertp authored Jun 26, 2023
1 parent c4f19e7 commit ae4666c
Show file tree
Hide file tree
Showing 13 changed files with 154 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.enso.interpreter.runtime.callable.atom.Atom;
import org.enso.interpreter.runtime.callable.atom.AtomConstructor;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.error.DataflowError;
import org.enso.interpreter.runtime.error.PanicException;
import org.enso.interpreter.runtime.error.PanicSentinel;
Expand Down Expand Up @@ -264,6 +265,15 @@ public Object invokeWarnings(
State state,
Object[] arguments,
@CachedLibrary(limit = "3") WarningsLibrary warnings) {

Warning[] extracted;
Object callable;
try {
extracted = warnings.getWarnings(warning, null);
callable = warnings.removeWarnings(warning);
} catch (UnsupportedMessageException e) {
throw CompilerDirectives.shouldNotReachHere(e);
}
try {
if (childDispatch == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
Expand All @@ -277,7 +287,7 @@ public Object invokeWarnings(
invokeFunctionNode.getSchema(),
invokeFunctionNode.getDefaultsExecutionMode(),
invokeFunctionNode.getArgumentsExecutionMode()));
childDispatch.setTailStatus(TailStatus.NOT_TAIL);
childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch);
}
Expand All @@ -287,21 +297,21 @@ public Object invokeWarnings(
}

var result = childDispatch.execute(
warnings.removeWarnings(warning),
callable,
callerFrame,
state,
arguments);

Warning[] extracted = warnings.getWarnings(warning, null);

if (result instanceof DataflowError) {
return result;
} else if (result instanceof WithWarnings withWarnings) {
return withWarnings.append(EnsoContext.get(this), extracted);
} else {
return WithWarnings.wrap(EnsoContext.get(this), result, extracted);
}
} catch (UnsupportedMessageException e) {
throw CompilerDirectives.shouldNotReachHere(e);
} catch (TailCallException e) {
throw new TailCallException(e, extracted);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.enso.interpreter.runtime.callable.UnresolvedConversion;
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.data.ArrayRope;
import org.enso.interpreter.runtime.data.Type;
import org.enso.interpreter.runtime.data.text.Text;
Expand Down Expand Up @@ -162,19 +163,23 @@ Object doWarning(
invokeFunctionNode.getDefaultsExecutionMode(),
invokeFunctionNode.getArgumentsExecutionMode(),
thatArgumentPosition));
childDispatch.setTailStatus(TailStatus.NOT_TAIL);
childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch);
}
} finally {
lock.unlock();
}
}
arguments[thatArgumentPosition] = that.getValue();
Object value = that.getValue();
arguments[thatArgumentPosition] = value;
ArrayRope<Warning> warnings = that.getReassignedWarningsAsRope(this);
Object result =
childDispatch.execute(frame, state, conversion, self, that.getValue(), arguments);
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
try {
Object result = childDispatch.execute(frame, state, conversion, self, value, arguments);
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
} catch (TailCallException e) {
throw new TailCallException(e, warnings.toArray(Warning[]::new));
}
}

@Specialization(guards = "interop.isString(that)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.callable.function.FunctionSchema;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.data.ArrayRope;
import org.enso.interpreter.runtime.data.EnsoDate;
import org.enso.interpreter.runtime.data.EnsoDateTime;
Expand Down Expand Up @@ -407,7 +408,7 @@ Object doWarning(
invokeFunctionNode.getDefaultsExecutionMode(),
invokeFunctionNode.getArgumentsExecutionMode(),
thisArgumentPosition));
childDispatch.setTailStatus(TailStatus.NOT_TAIL);
childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch);
}
Expand All @@ -418,8 +419,12 @@ Object doWarning(

arguments[thisArgumentPosition] = selfWithoutWarnings;

Object result = childDispatch.execute(frame, state, symbol, selfWithoutWarnings, arguments);
return WithWarnings.appendTo(EnsoContext.get(this), result, arrOfWarnings);
try {
Object result = childDispatch.execute(frame, state, symbol, selfWithoutWarnings, arguments);
return WithWarnings.appendTo(EnsoContext.get(this), result, arrOfWarnings);
} catch (TailCallException e) {
throw new TailCallException(e, arrOfWarnings);
}
}

@ExplodeLoop
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.oracle.truffle.api.nodes.NodeInfo;
import org.enso.interpreter.runtime.callable.CallerInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.error.Warning;
import org.enso.interpreter.runtime.state.State;

/**
Expand Down Expand Up @@ -33,12 +34,14 @@ public static CallOptimiserNode build() {
* @param callerInfo the caller info to pass to the function
* @param state the state to pass to the function
* @param arguments the arguments to {@code callable}
* @param warnings warnings associated with the callable, null if empty
* @return the result of executing {@code callable} using {@code arguments}
*/
public abstract Object executeDispatch(
VirtualFrame frame,
Function callable,
CallerInfo callerInfo,
State state,
Object[] arguments);
Object[] arguments,
Warning[] warnings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public Object execute(
return value;
}
} else {
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);

return this.oversaturatedCallableNode.execute(
evaluatedVal, frame, state, oversaturatedArguments);
Expand All @@ -154,7 +154,7 @@ private Object doCall(
return switch (getTailStatus()) {
case TAIL_DIRECT -> directCall.executeCall(frame, function, callerInfo, state, arguments);
case TAIL_LOOP -> throw new TailCallException(function, callerInfo, arguments);
default -> loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
default -> loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ Object doCurry(
return value;
}
} else {
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);

return oversaturatedCallableNode.execute(
evaluatedVal,
Expand Down Expand Up @@ -129,7 +129,7 @@ private Object doCall(
case TAIL_LOOP:
throw new TailCallException(function, callerInfo, arguments);
default:
return loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
return loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
import com.oracle.truffle.api.nodes.RepeatingNode;
import org.enso.interpreter.node.callable.ExecuteCallNode;
import org.enso.interpreter.node.callable.ExecuteCallNodeGen;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.callable.CallerInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.error.Warning;
import org.enso.interpreter.runtime.error.WithWarnings;
import org.enso.interpreter.runtime.state.State;

/**
Expand Down Expand Up @@ -54,31 +57,78 @@ public static LoopingCallOptimiserNode build() {
* @param loopNode a cached instance of the loop node used by this node
* @return the result of executing {@code function} using {@code arguments}
*/
@Specialization
public Object dispatch(
@Specialization(guards = "warnings == null")
public Object cachedDispatch(
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
Warning[] warnings,
@Cached(value = "createLoopNode()") LoopNode loopNode) {
return dispatch(function, callerInfo, state, arguments, loopNode);
}

@Specialization(guards = "warnings != null")
public Object cachedDispatchWarnings(
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
Warning[] warnings,
@Cached(value = "createLoopNode()") LoopNode loopNode) {
Object result = dispatch(function, callerInfo, state, arguments, loopNode);
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
}

private Object dispatch(
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
LoopNode loopNode) {
RepeatedCallNode repeatedCallNode = (RepeatedCallNode) loopNode.getRepeatingNode();
VirtualFrame frame = repeatedCallNode.createFrame();
repeatedCallNode.setNextCall(frame, function, callerInfo, arguments);
repeatedCallNode.setState(frame, state);
loopNode.execute(frame);

return repeatedCallNode.getResult(frame);
}

@Specialization(replaces = "dispatch")
@Specialization(replaces = "cachedDispatch", guards = "warnings == null")
@CompilerDirectives.TruffleBoundary
public Object uncachedDispatch(
MaterializedFrame frame,
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
Warning[] warnings,
@Cached ExecuteCallNode executeCallNode) {
return loopUntilCompletion(frame, function, callerInfo, state, arguments, executeCallNode);
}

@Specialization(replaces = "cachedDispatchWarnings", guards = "warnings != null")
@CompilerDirectives.TruffleBoundary
public Object uncachedDispatchWarnings(
MaterializedFrame frame,
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
Warning[] warnings,
@Cached ExecuteCallNode executeCallNode) {
Object result =
loopUntilCompletion(frame, function, callerInfo, state, arguments, executeCallNode);
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
}

private Object loopUntilCompletion(
MaterializedFrame frame,
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
ExecuteCallNode executeCallNode) {
while (true) {
try {
return executeCallNode.executeCall(frame, function, callerInfo, state, arguments);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.enso.interpreter.runtime.callable.CallerInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.error.Warning;
import org.enso.interpreter.runtime.state.State;

/**
Expand Down Expand Up @@ -40,6 +41,7 @@ public static SimpleCallOptimiserNode build() {
* @param callerInfo the caller info to pass to the function
* @param state the state to pass to the function
* @param arguments the arguments to {@code function}
* @param warnings warnings associated with the callable, null if empty
* @return the result of executing {@code function} using {@code arguments}
*/
@Override
Expand All @@ -48,7 +50,8 @@ public Object executeDispatch(
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments) {
Object[] arguments,
Warning[] warnings) {
try {
return executeCallNode.executeCall(frame, function, callerInfo, state, arguments);
} catch (TailCallException e) {
Expand All @@ -65,7 +68,7 @@ public Object executeDispatch(
}
}
return next.executeDispatch(
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments());
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Object doCached(
return callNode.call(Function.ArgumentsHelper.buildArguments(function, state));
} catch (TailCallException e) {
return loopingCallOptimiserNode.executeDispatch(
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments());
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings());
}
}
}
Expand All @@ -89,7 +89,7 @@ Object doUncached(
function.getCallTarget(), Function.ArgumentsHelper.buildArguments(function, state));
} catch (TailCallException e) {
return loopingCallOptimiserNode.executeDispatch(
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments());
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,8 @@ public int compare(Object x, Object y) {
Object yConverted;
if (hasCustomOnFunc) {
// onFunc cannot have `self` argument, we assume it has just one argument.
xConverted = callNode.executeDispatch(null, onFunc.get(x), null, state, new Object[]{x});
yConverted = callNode.executeDispatch(null, onFunc.get(y), null, state, new Object[]{y});
xConverted = callNode.executeDispatch(null, onFunc.get(x), null, state, new Object[]{x}, null);
yConverted = callNode.executeDispatch(null, onFunc.get(y), null, state, new Object[]{y}, null);
} else {
xConverted = x;
yConverted = y;
Expand All @@ -823,7 +823,7 @@ public int compare(Object x, Object y) {
} else {
args = new Object[] {xConverted, yConverted};
}
Object res = callNode.executeDispatch(null, compareFunc.get(xConverted), null, state, args);
Object res = callNode.executeDispatch(null, compareFunc.get(xConverted), null, state, args, null);
if (res == less) {
return ascending ? -1 : 1;
} else if (res == equal) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,8 @@ private static Object evalExpression(
eval.getFunction(),
callerInfo,
context.emptyState(),
new Object[] {builtins.debug(), Text.create(expr)});
new Object[] {builtins.debug(), Text.create(expr)},
null);
}

private static Object generateDocs(Module module, EnsoContext context) {
Expand Down
Loading

0 comments on commit ae4666c

Please sign in to comment.