Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing @BuiltinMethod.inlineable and InlineableNode #6442

Merged
merged 16 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.enso.interpreter.dsl.test;

import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.dsl.BuiltinMethod;
import static org.junit.Assert.assertNotNull;

@BuiltinMethod(type = "InliningBuiltins", name = "need_not", needsFrame = false)
final class InliningBuiltinsNeedNotNode extends Node {

long execute(VirtualFrame frame, long a, long b) {
assertNotNull("Some frame is still provided", frame);
return a + b;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.enso.interpreter.dsl.test;

import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.dsl.BuiltinMethod;

@BuiltinMethod(type = "InliningBuiltins", name = "needs", needsFrame = true)
final class InliningBuiltinsNeedsNode extends Node {

long execute(long a, long b) {
return a + b;
}

}
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
package org.enso.interpreter.dsl.test;

import org.enso.interpreter.node.InlineableRootNode;
JaroslavTulach marked this conversation as resolved.
Show resolved Hide resolved
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.DirectCallNode;
import com.oracle.truffle.api.nodes.RootNode;
import org.enso.interpreter.runtime.callable.function.Function;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import org.junit.Test;
import org.enso.interpreter.node.InlineableNode;

public class InliningBuiltinsTest {

/** @see InliningBuiltinsInNode#execute(long, long) */
@Test
public void executeWithoutVirtualFrame() {
var fn = InliningBuiltinsInMethodGen.makeFunction(null);
if (fn.getCallTarget().getRootNode() instanceof InlineableRootNode root) {
var call = root.createDirectCallNode();
var clazz = call.getClass().getSuperclass();
assertEquals("InlinedCallNode", clazz.getSimpleName());
assertEquals("BuiltinRootNode", clazz.getEnclosingClass().getSimpleName());
if (fn.getCallTarget().getRootNode() instanceof InlineableNode.Root root) {
var call = root.createInlineableNode();
var clazz = call.getClass();
assertEquals("InlineableNode", clazz.getSuperclass().getSimpleName());
assertEquals("org.enso.interpreter.node.InlineableNode$Root", clazz.getEnclosingClass().getInterfaces()[0].getName());

var res = call.call(Function.ArgumentsHelper.buildArguments(null, null, new Object[] { null, 5L, 7L }));
var res = WithFrame.invoke((frame) -> {
return call.call(frame, Function.ArgumentsHelper.buildArguments(null, null, new Object[] { null, 5L, 7L }));
});
assertEquals(12L, res);
} else {
fail("It is inlineable: " + fn.getCallTarget().getRootNode());
Expand All @@ -29,15 +34,73 @@ public void executeWithoutVirtualFrame() {
@Test
public void executeWithVirtualFrame() {
var fn = InliningBuiltinsOutMethodGen.makeFunction(null);
if (fn.getCallTarget().getRootNode() instanceof InlineableRootNode root) {
var call = root.createDirectCallNode();
if (fn.getCallTarget().getRootNode() instanceof InlineableNode.Root root) {
fail("The node isn't inlineable: " + fn.getCallTarget().getRootNode());
} else {
var call = DirectCallNode.create(fn.getCallTarget());
var clazz = call.getClass().getSuperclass();
assertEquals("com.oracle.truffle.api.nodes.DirectCallNode", clazz.getName());

var res = call.call(Function.ArgumentsHelper.buildArguments(null, null, new Object[] { null, 3L, 9L }));
var res = WithFrame.invoke((frame) -> {
return call.call(Function.ArgumentsHelper.buildArguments(null, null, new Object[] { null, 3L, 9L }));
});
assertEquals(12L, res);
}
}

/** @see InliningBuiltinsNeedsNode#execute(long, long) */
@Test
public void executeWhenNeedsVirtualFrame() {
var fn = InliningBuiltinsNeedsMethodGen.makeFunction(null);
if (fn.getCallTarget().getRootNode() instanceof InlineableNode.Root root) {
fail("The node isn't inlineable: " + fn.getCallTarget().getRootNode());
} else {
var call = DirectCallNode.create(fn.getCallTarget());
var clazz = call.getClass().getSuperclass();
assertEquals("com.oracle.truffle.api.nodes.DirectCallNode", clazz.getName());

var res = WithFrame.invoke((frame) -> {
return call.call(Function.ArgumentsHelper.buildArguments(null, null, new Object[] { null, 3L, 9L }));
});
assertEquals(12L, res);
}
}

/** @see InliningBuiltinsNeedNotNode#execute(com.oracle.truffle.api.frame.VirtualFrame, long, long) */
@Test
public void executeWhenNeedNotVirtualFrame() {
var fn = InliningBuiltinsNeedNotMethodGen.makeFunction(null);
if (fn.getCallTarget().getRootNode() instanceof InlineableNode.Root root) {
var call = root.createInlineableNode();
var clazz = call.getClass();
assertEquals("InlineableNode", clazz.getSuperclass().getSimpleName());
assertEquals("org.enso.interpreter.node.InlineableNode$Root", clazz.getEnclosingClass().getInterfaces()[0].getName());

var res = WithFrame.invoke((frame) -> {
return call.call(frame, Function.ArgumentsHelper.buildArguments(null, null, new Object[] { null, 5L, 7L }));
});
assertEquals(12L, res);
} else {
fail("It is inlineable: " + fn.getCallTarget().getRootNode());
}
}

private static final class WithFrame<T> extends RootNode {
private final java.util.function.Function<VirtualFrame, T> fn;

private WithFrame(java.util.function.Function<VirtualFrame, T> fn) {
super(null);
this.fn = fn;
}

@Override
public Object execute(VirtualFrame frame) {
return fn.apply(frame);
}

@SuppressWarnings("unchecked")
static <T> T invoke(java.util.function.Function<VirtualFrame, T> fn, Object... args) {
return (T) new WithFrame<>(fn).getCallTarget().call(args);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package org.enso.interpreter.node;

import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.DirectCallNode;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.RootNode;
import org.enso.interpreter.node.callable.ExecuteCallNode;

/**
* More effective {@link DirectCallNode} alternative. Supports more aggressive inlining needed by
* {@link ExecuteCallNode}.
*/
public abstract class InlineableNode extends Node {
/**
* Invokes the computation represented by the node.
*
* @param frame current frame of the caller
* @param arguments arguments for the functionality
* @return result of the computation
*/
public abstract Object call(VirtualFrame frame, Object[] arguments);

/**
* Special interface that allows various {@link RootNode} subclasses to provide more effective
* implementation of {@link DirectCallNode} alternative. Used by for example by {@code
* BuiltinRootNode}.
*/
public interface Root {
/**
* Provides access to {@link RootNode}. Usually the object shall inherit from {link RootNode} as
* well as implement the {@link InlineableNode} interface. This method thus usually returns
* {@code this}.
*
* @return {@code this} types as {link RootNode}
*/
public RootNode getRootNode();
JaroslavTulach marked this conversation as resolved.
Show resolved Hide resolved

/**
* Name of the {@link RootNode}.
*
* @return root node name
*/
public String getName();

/**
* Override to provide more effective implementation of {@link DirectCallNode} alternative.
* Suited more for Enso aggressive inlining.
*
* @return a node to call the associated {@link RootNode} - may return {@code null}
*/
public InlineableNode createInlineableNode();
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.GenerateUncached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.DirectCallNode;
import com.oracle.truffle.api.nodes.IndirectCallNode;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.NodeInfo;
import org.enso.interpreter.node.InlineableRootNode;
import org.enso.interpreter.runtime.callable.CallerInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.node.InlineableNode;

/**
* This node is responsible for optimising function calls.
*
* <p>Where possible, it will make the call as a direct call, with potential for inlining.
* This node is responsible for optimising function calls. Where possible, it will handle the call via:
* <ul>
* <li>{@link InlineableNode} to force inlining</li>
* <li>{@link DirectCallNode} with potential for inlining</li>
* </ul>
*/
@NodeInfo(shortName = "ExecCall", description = "Optimises function calls")
@GenerateUncached
Expand All @@ -38,6 +41,7 @@ public static ExecuteCallNode build() {
* <p>This specialisation comes into play where the call target for the provided function is
* already cached. THis means that the call can be made quickly.
JaroslavTulach marked this conversation as resolved.
Show resolved Hide resolved
*
* @param frame current frame
* @param function the function to execute
* @param callerInfo the caller info to pass to the function
* @param state the current state value
Expand All @@ -46,20 +50,45 @@ public static ExecuteCallNode build() {
* @param callNode the cached call node for {@code cachedTarget}
* @return the result of executing {@code function} on {@code arguments}
*/
@Specialization(guards = "function.getCallTarget() == cachedTarget")
@Specialization(guards = {
"function.getCallTarget() == cachedTarget",
"callNode != null"
})
protected Object callDirect(
JaroslavTulach marked this conversation as resolved.
Show resolved Hide resolved
VirtualFrame frame,
Function function,
CallerInfo callerInfo,
Object state,
Object[] arguments,
@Cached("function.getCallTarget()") RootCallTarget cachedTarget,
@Cached("createCallNode(cachedTarget)") DirectCallNode callNode) {
return callNode.call(
Function.ArgumentsHelper.buildArguments(function, callerInfo, state, arguments));
@Cached("createInlineableNode(cachedTarget)") InlineableNode callNode) {
var args = Function.ArgumentsHelper.buildArguments(function, callerInfo, state, arguments);
return callNode.call(frame, args);
}

@Specialization(guards = {
"function.getCallTarget() == cachedTarget",
})
protected Object callDirect(
JaroslavTulach marked this conversation as resolved.
Show resolved Hide resolved
Function function,
CallerInfo callerInfo,
Object state,
Object[] arguments,
@Cached("function.getCallTarget()") RootCallTarget cachedTarget,
@Cached("createDirectCallNode(cachedTarget)") DirectCallNode callNode) {
var args = Function.ArgumentsHelper.buildArguments(function, callerInfo, state, arguments);
return callNode.call(args);
}

static InlineableNode createInlineableNode(RootCallTarget t) {
if (t.getRootNode() instanceof InlineableNode.Root inlineNodeProvider) {
return inlineNodeProvider.createInlineableNode();
}
return null;
}

static DirectCallNode createCallNode(RootCallTarget t) {
return InlineableRootNode.create(t);
static DirectCallNode createDirectCallNode(RootCallTarget t) {
return DirectCallNode.create(t);
}

/**
Expand Down Expand Up @@ -90,12 +119,13 @@ protected Object callIndirect(
/**
* Executes the function call.
*
* @param frame the caller's frame
* @param function the function to execute
* @param callerInfo the caller info to pass to the function
* @param state the state value to pass to the function
* @param arguments the arguments to be passed to {@code function}
* @return the result of executing {@code function} on {@code arguments}
*/
public abstract Object executeCall(
Function function, CallerInfo callerInfo, Object state, Object[] arguments);
VirtualFrame frame, Function function, CallerInfo callerInfo, Object state, Object[] arguments);
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ public Object invokeDynamicSymbol(
if (canApplyThis) {
Object self = arguments[thisArgumentPosition];
if (argumentsExecutionMode.shouldExecute()) {
self = thisExecutor.executeThunk(self, state, BaseNode.TailStatus.NOT_TAIL);
self = thisExecutor.executeThunk(callerFrame, self, state, BaseNode.TailStatus.NOT_TAIL);
arguments[thisArgumentPosition] = self;
}
return invokeMethodNode.execute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ Object doPolyglot(
@Cached HostMethodCallNode hostMethodCallNode) {
Object[] args = new Object[arguments.length - 1];
for (int i = 0; i < arguments.length - 1; i++) {
var r = argExecutor.executeThunk(arguments[i + 1], state, BaseNode.TailStatus.NOT_TAIL);
var r =
argExecutor.executeThunk(frame, arguments[i + 1], state, BaseNode.TailStatus.NOT_TAIL);
if (r instanceof DataflowError) {
return r;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ public Object invokeConversion(
lock.unlock();
}
}
selfArgument = thisExecutor.executeThunk(selfArgument, state, TailStatus.NOT_TAIL);
thatArgument = thatExecutor.executeThunk(thatArgument, state, TailStatus.NOT_TAIL);
selfArgument = thisExecutor.executeThunk(callerFrame, selfArgument, state, TailStatus.NOT_TAIL);
thatArgument = thatExecutor.executeThunk(callerFrame, thatArgument, state, TailStatus.NOT_TAIL);

arguments[thisArgumentPosition] = selfArgument;
arguments[thatArgumentPosition] = thatArgument;
Expand Down Expand Up @@ -248,7 +248,7 @@ public Object invokeDynamicSymbol(
lock.unlock();
}
}
selfArgument = thisExecutor.executeThunk(selfArgument, state, TailStatus.NOT_TAIL);
selfArgument = thisExecutor.executeThunk(callerFrame, selfArgument, state, TailStatus.NOT_TAIL);
arguments[thisArgumentPosition] = selfArgument;
}
return invokeMethodNode.execute(callerFrame, state, symbol, selfArgument, arguments);
Expand Down
Loading