Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dependencies from TaskExecutionSpecification #5166

Merged
merged 13 commits into from
Jul 16, 2019
23 changes: 14 additions & 9 deletions java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int t

@Override
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
TaskSpec spec = createTaskSpec(func, null, RayActorImpl.NIL, args, false, options);
TaskSpec spec = createTaskSpec(func, null, RayActorImpl.NIL, args, false, false, options);
rayletClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
Expand All @@ -257,8 +257,7 @@ public RayObject call(RayFunc func, RayActor<?> actor, Object[] args) {
RayActorImpl<?> actorImpl = (RayActorImpl) actor;
TaskSpec spec;
synchronized (actor) {
spec = createTaskSpec(func, null, actorImpl, args, false, null);
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
spec = createTaskSpec(func, null, actorImpl, args, false, true, null);
actorImpl.setTaskCursor(spec.returnIds[1]);
actorImpl.clearNewActorHandles();
}
Expand All @@ -271,7 +270,7 @@ public RayObject call(RayFunc func, RayActor<?> actor, Object[] args) {
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc,
Object[] args, ActorCreationOptions options) {
TaskSpec spec = createTaskSpec(actorFactoryFunc, null, RayActorImpl.NIL,
args, true, options);
args, true, false, options);
RayActorImpl<?> actor = new RayActorImpl(new UniqueId(spec.returnIds[0].getBytes()));
actor.increaseTaskCounter();
actor.setTaskCursor(spec.returnIds[0]);
Expand All @@ -293,7 +292,7 @@ public RayObject callPy(String moduleName, String functionName, Object[] args,
CallOptions options) {
checkPyArguments(args);
PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, "", functionName);
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, false, options);
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, false, false, options);
rayletClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
Expand All @@ -306,8 +305,7 @@ public RayObject callPy(RayPyActor pyActor, String functionName, Object... args)
RayPyActorImpl actorImpl = (RayPyActorImpl) pyActor;
TaskSpec spec;
synchronized (pyActor) {
spec = createTaskSpec(null, desc, actorImpl, args, false, null);
spec.getExecutionDependencies().add(actorImpl.getTaskCursor());
spec = createTaskSpec(null, desc, actorImpl, args, false, true, null);
actorImpl.setTaskCursor(spec.returnIds[1]);
actorImpl.clearNewActorHandles();
}
Expand All @@ -320,7 +318,7 @@ public RayPyActor createPyActor(String moduleName, String className, Object[] ar
ActorCreationOptions options) {
checkPyArguments(args);
PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, className, "__init__");
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, true, options);
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, true, false, options);
RayPyActorImpl actor = new RayPyActorImpl(spec.actorCreationId, moduleName, className);
actor.increaseTaskCounter();
actor.setTaskCursor(spec.returnIds[0]);
Expand All @@ -337,11 +335,12 @@ public RayPyActor createPyActor(String moduleName, String className, Object[] ar
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
* @param args The arguments for the remote function.
* @param isActorCreationTask Whether this task is an actor creation task.
* @param isActorTask Whether this task is an actor task.
* @return A TaskSpec object.
*/
private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDescriptor,
RayActorImpl<?> actor, Object[] args,
boolean isActorCreationTask, BaseTaskOptions taskOptions) {
boolean isActorCreationTask, boolean isActorTask, BaseTaskOptions taskOptions) {
Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null));

TaskId taskId = rayletClient.generateTaskId(workerContext.getCurrentJobId(),
Expand Down Expand Up @@ -382,6 +381,11 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes
functionDescriptor = pyFunctionDescriptor;
}

ObjectId previousActorTaskDummyObjectId = ObjectId.NIL;
if (isActorTask) {
previousActorTaskDummyObjectId = actor.getTaskCursor();
}

return new TaskSpec(
workerContext.getCurrentJobId(),
taskId,
Expand All @@ -392,6 +396,7 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes
actor.getId(),
actor.getHandleId(),
actor.increaseTaskCounter(),
previousActorTaskDummyObjectId,
actor.getNewActorHandles().toArray(new UniqueId[0]),
ArgumentsBuilder.wrap(args, language == TaskLanguage.PYTHON),
numReturns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,9 @@ private Set<ObjectId> getUnreadyObjects(TaskSpec spec) {
}
}
}
// Check whether task dependencies are ready.
for (ObjectId id : spec.getExecutionDependencies()) {
if (!store.isObjectReady(id)) {
unreadyObjects.add(id);
if (spec.isActorTask()) {
if (!store.isObjectReady(spec.previousActorTaskDummyObjectId)) {
unreadyObjects.add(spec.previousActorTaskDummyObjectId);
}
}
return unreadyObjects;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,7 @@ public void submitTask(TaskSpec spec) {
Preconditions.checkState(!spec.jobId.isNil());

byte[] taskSpec = convertTaskSpecToProtobuf(spec);
byte[] cursorId = null;
if (!spec.getExecutionDependencies().isEmpty()) {
//TODO(hchen): handle more than one dependencies.
cursorId = spec.getExecutionDependencies().get(0).getBytes();
}
nativeSubmitTask(client, cursorId, taskSpec);
nativeSubmitTask(client, taskSpec);
}

@Override
Expand Down Expand Up @@ -195,21 +190,25 @@ private static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
// Parse ActorTaskSpec.
UniqueId actorId = UniqueId.NIL;
UniqueId actorHandleId = UniqueId.NIL;
ObjectId previousActorTaskDummyObjectId = ObjectId.NIL;
int actorCounter = 0;
if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) {
Common.ActorTaskSpec actorTaskSpec = taskSpec.getActorTaskSpec();
actorId = UniqueId.fromByteBuffer(actorTaskSpec.getActorId().asReadOnlyByteBuffer());
actorHandleId = UniqueId
.fromByteBuffer(actorTaskSpec.getActorHandleId().asReadOnlyByteBuffer());
actorCounter = (int) actorTaskSpec.getActorCounter();
previousActorTaskDummyObjectId = ObjectId.fromByteBuffer(
actorTaskSpec.getPreviousActorTaskDummyObjectId().asReadOnlyByteBuffer());
newActorHandles = actorTaskSpec.getNewActorHandlesList().stream()
.map(byteString -> UniqueId.fromByteBuffer(byteString.asReadOnlyByteBuffer()))
.toArray(UniqueId[]::new);
}

return new TaskSpec(jobId, taskId, parentTaskId, parentCounter, actorCreationId,
maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles,
args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions);
maxActorReconstructions, actorId, actorHandleId, actorCounter,
previousActorTaskDummyObjectId, newActorHandles, args, numReturns, resources,
TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions);
}

/**
Expand Down Expand Up @@ -275,6 +274,8 @@ private static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
.setActorId(ByteString.copyFrom(task.actorId.getBytes()))
.setActorHandleId(ByteString.copyFrom(task.actorHandleId.getBytes()))
.setActorCreationDummyObjectId(ByteString.copyFrom(task.actorId.getBytes()))
.setPreviousActorTaskDummyObjectId(
ByteString.copyFrom(task.previousActorTaskDummyObjectId.getBytes()))
.setActorCounter(task.actorCounter)
.addAllNewActorHandles(newHandles)
);
Expand Down Expand Up @@ -310,7 +311,7 @@ public void destroy() {
private static native long nativeInit(String localSchedulerSocket, byte[] workerId,
boolean isWorker, byte[] driverTaskId);

private static native void nativeSubmitTask(long client, byte[] cursorId, byte[] taskSpec)
private static native void nativeSubmitTask(long client, byte[] taskSpec)
throws RayException;

private static native byte[] nativeGetTask(long client) throws RayException;
Expand Down
18 changes: 8 additions & 10 deletions java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ public class TaskSpec {
// Number of tasks that have been submitted to this actor so far.
public final int actorCounter;

// Object id returned by the previous task submitted to the same actor.
public final ObjectId previousActorTaskDummyObjectId;

// Task arguments.
public final UniqueId[] newActorHandles;

Expand All @@ -55,7 +58,7 @@ public class TaskSpec {
// number of return objects.
public final int numReturns;

// returns ids.
// Return ids.
public final ObjectId[] returnIds;

// The task's resource demands.
Expand All @@ -71,8 +74,6 @@ public class TaskSpec {
// is Python, the type is PyFunctionDescriptor.
private final FunctionDescriptor functionDescriptor;

private List<ObjectId> executionDependencies;

public boolean isActorTask() {
return !actorId.isNil();
}
Expand All @@ -91,6 +92,7 @@ public TaskSpec(
UniqueId actorId,
UniqueId actorHandleId,
int actorCounter,
ObjectId previousActorTaskDummyObjectId,
UniqueId[] newActorHandles,
FunctionArg[] args,
int numReturns,
Expand All @@ -107,6 +109,7 @@ public TaskSpec(
this.actorId = actorId;
this.actorHandleId = actorHandleId;
this.actorCounter = actorCounter;
this.previousActorTaskDummyObjectId = previousActorTaskDummyObjectId;
this.newActorHandles = newActorHandles;
this.args = args;
this.numReturns = numReturns;
Expand All @@ -128,7 +131,6 @@ public TaskSpec(
Preconditions.checkArgument(false, "Unknown task language: {}.", language);
}
this.functionDescriptor = functionDescriptor;
this.executionDependencies = new ArrayList<>();
}

public JavaFunctionDescriptor getJavaFunctionDescriptor() {
Expand All @@ -141,10 +143,6 @@ public PyFunctionDescriptor getPyFunctionDescriptor() {
return (PyFunctionDescriptor) functionDescriptor;
}

public List<ObjectId> getExecutionDependencies() {
return executionDependencies;
}

@Override
public String toString() {
return "TaskSpec{" +
Expand All @@ -157,14 +155,14 @@ public String toString() {
", actorId=" + actorId +
", actorHandleId=" + actorHandleId +
", actorCounter=" + actorCounter +
", previousActorTaskDummyObjectId=" + previousActorTaskDummyObjectId +
", newActorHandles=" + Arrays.toString(newActorHandles) +
", args=" + Arrays.toString(args) +
", numReturns=" + numReturns +
", resources=" + resources +
", language=" + language +
", functionDescriptor=" + functionDescriptor +
", dynamicWorkerOptions=" + dynamicWorkerOptions +
", executionDependencies=" + executionDependencies +
", dynamicWorkerOptions=" + dynamicWorkerOptions +
'}';
}
}
7 changes: 2 additions & 5 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,11 @@ cdef class RayletClient:
def disconnect(self):
check_status(self.client.get().Disconnect())

def submit_task(self, TaskSpec task_spec, execution_dependencies):
def submit_task(self, TaskSpec task_spec):
cdef:
CObjectID c_id
c_vector[CObjectID] c_dependencies
for dep in execution_dependencies:
c_dependencies.push_back((<ObjectID>dep).native())
check_status(self.client.get().SubmitTask(
c_dependencies, task_spec.task_spec.get()[0]))
task_spec.task_spec.get()[0]))

def get_task(self):
cdef:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def _actor_method_call(self,
actor_counter=self._ray_actor_counter,
actor_creation_dummy_object_id=(
self._ray_actor_creation_dummy_object_id),
execution_dependencies=[self._ray_actor_cursor],
previous_actor_task_dummy_object_id=self._ray_actor_cursor,
new_actor_handles=self._ray_new_actor_handles,
# We add one for the dummy return ID.
num_return_vals=num_return_vals + 1,
Expand Down
4 changes: 1 addition & 3 deletions python/ray/includes/libraylet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
c_bool is_worker, const CJobID &job_id,
const CLanguage &language)
CRayStatus Disconnect()
CRayStatus SubmitTask(
const c_vector[CObjectID] &execution_dependencies,
const CTaskSpec &task_spec)
CRayStatus SubmitTask(const CTaskSpec &task_spec)
CRayStatus GetTask(unique_ptr[CTaskSpec] *task_spec)
CRayStatus TaskDone()
CRayStatus FetchOrReconstruct(c_vector[CObjectID] &object_ids,
Expand Down
7 changes: 5 additions & 2 deletions python/ray/includes/task.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ cdef extern from "ray/common/task/task_spec.h" namespace "ray" nogil:
c_bool IsActorTask() const
CActorID ActorCreationId() const
CObjectID ActorCreationDummyObjectId() const
CObjectID PreviousActorTaskDummyObjectId() const
uint64_t MaxActorReconstructions() const
CActorID ActorId() const
CActorHandleID ActorHandleId() const
Expand All @@ -92,8 +93,10 @@ cdef extern from "ray/common/task/task_util.h" namespace "ray" nogil:

TaskSpecBuilder &SetActorTaskSpec(
const CActorID &actor_id, const CActorHandleID &actor_handle_id,
const CObjectID &actor_creation_dummy_object_id, uint64_t actor_counter,
const c_vector[CActorHandleID] &new_handle_ids)
const CObjectID &actor_creation_dummy_object_id,
const CObjectID &previous_actor_task_dummy_object_id,
uint64_t actor_counter,
const c_vector[CActorHandleID] &new_handle_ids);

RpcTaskSpec GetMessage()

Expand Down
24 changes: 10 additions & 14 deletions python/ray/includes/task.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ cdef class TaskSpec:
int num_returns, TaskID parent_task_id, int parent_counter,
ActorID actor_creation_id,
ObjectID actor_creation_dummy_object_id,
ObjectID previous_actor_task_dummy_object_id,
int32_t max_actor_reconstructions, ActorID actor_id,
ActorHandleID actor_handle_id, int actor_counter,
new_actor_handles, resource_map, placement_resource_map):
Expand Down Expand Up @@ -85,6 +86,7 @@ cdef class TaskSpec:
actor_id.native(),
actor_handle_id.native(),
actor_creation_dummy_object_id.native(),
previous_actor_task_dummy_object_id.native(),
actor_counter,
c_new_actor_handles,
)
Expand Down Expand Up @@ -229,6 +231,13 @@ cdef class TaskSpec:
return ObjectID(
self.task_spec.get().ActorCreationDummyObjectId().Binary())

def previous_actor_task_dummy_object_id(self):
"""Return the object ID of the previously executed actor task."""
if not self.is_actor_task():
return ObjectID.nil()
return ObjectID(
edoakes marked this conversation as resolved.
Show resolved Hide resolved
self.task_spec.get().PreviousActorTaskDummyObjectId().Binary())

def actor_id(self):
"""Return the actor ID for this task."""
if not self.is_actor_task():
Expand All @@ -247,13 +256,10 @@ cdef class TaskExecutionSpec:
cdef:
unique_ptr[CTaskExecutionSpec] c_spec

def __init__(self, execution_dependencies):
def __init__(self):
cdef:
RpcTaskExecutionSpec message;

for dependency in execution_dependencies:
message.add_dependencies(
(<ObjectID?>dependency).binary())
self.c_spec.reset(new CTaskExecutionSpec(message))

@staticmethod
Expand All @@ -264,16 +270,6 @@ cdef class TaskExecutionSpec:
self.c_spec.reset(new CTaskExecutionSpec(string))
return self

def dependencies(self):
cdef:
CObjectID c_id
c_vector[CObjectID] dependencies = (
self.c_spec.get().ExecutionDependencies())
ret = []
for c_id in dependencies:
ret.append(ObjectID(c_id.Binary()))
return ret

def num_forwards(self):
return self.c_spec.get().NumForwards()

Expand Down
3 changes: 2 additions & 1 deletion python/ray/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ def _task_table(self, task_id):
"ActorCreationID": task.actor_creation_id().hex(),
"ActorCreationDummyObjectID": (
task.actor_creation_dummy_object_id().hex()),
"PreviousActorTaskDummyObjectID": (
task.previous_actor_task_dummy_object_id().hex()),
"ActorCounter": task.actor_counter(),
"Args": task.arguments(),
"ReturnObjectIDs": task.returns(),
Expand All @@ -356,7 +358,6 @@ def _task_table(self, task_id):
task_table_data.task.task_execution_spec.SerializeToString())
return {
"ExecutionSpec": {
"Dependencies": execution_spec.dependencies(),
"NumForwards": execution_spec.num_forwards(),
},
"TaskSpec": task_spec_info
Expand Down
Loading