Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

117 optimize task cleanup move worker finalization to c++ #141

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/c/backend/include/runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,8 @@ class InnerScheduler {
void task_cleanup_presync(InnerWorker *worker, InnerTask *task, int state);
void task_cleanup_postsync(InnerWorker *worker, InnerTask *task, int state);

void task_cleanup_and_wait_for_task(InnerWorker *worker, InnerTask *task, int state);

/* Get number of active tasks. A task is active if it is spawned but not
* complete */
int get_num_active_tasks();
Expand Down
11 changes: 11 additions & 0 deletions src/c/backend/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,17 @@ void InnerScheduler::task_cleanup(InnerWorker *worker, InnerTask *task,
task_cleanup_postsync(worker, task, state);
}

void InnerScheduler::task_cleanup_and_wait_for_task(InnerWorker *worker, InnerTask *task,
int state) {
NVTX_RANGE("Scheduler::task_cleanup", NVTX_COLOR_MAGENTA)

task_cleanup(worker, task, state);
// wait for task
if(this->should_run)
worker->wait();

}

int InnerScheduler::get_num_active_tasks() { return this->num_active_tasks; }

void InnerScheduler::increase_num_active_tasks() {
Expand Down
8 changes: 8 additions & 0 deletions src/c/backend/task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,11 +446,19 @@ void InnerTask::finalize_assigned_devices() {
TaskState InnerTask::set_state(TaskState state) {
TaskState new_state = state;
TaskState old_state;
bool success = true;

do {
old_state = this->state.load();
if (old_state >= new_state) {
success = false;
}
} while (!this->state.compare_exchange_weak(old_state, new_state));

if (!success) {
throw std::runtime_error("Task States must always be increasing.");
}

return old_state;
}

Expand Down
9 changes: 4 additions & 5 deletions src/python/parla/common/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from parla.common.globals import default_sync, VCU_BASELINE, SynchronizationType, crosspy, CROSSPY_ENABLED
from crosspy import CrossPyArray
import inspect

from parla.cython import tasks

from typing import Collection, Any, Union, List, Tuple
Expand Down Expand Up @@ -70,9 +69,7 @@ def spawn(task=None,
runahead: SynchronizationType = default_sync
):
nvtx.push_range(message="Spawn::spawn", domain="launch", color="blue")

scheduler = get_scheduler_context().scheduler

if not isinstance(task, tasks.Task):
taskspace = scheduler.default_taskspace

Expand Down Expand Up @@ -146,8 +143,10 @@ def decorator(body):
dataflow=dataflow,
runahead=runahead
)

scheduler.spawn_task(task)
try:
scheduler.spawn_task(task)
except RuntimeError:
raise RuntimeError("Conflicting task state while spawning task. Possible duplicate TaskID: " + str(task))
# scheduler.run_scheduler()
nvtx.pop_range(domain="launch")

Expand Down
3 changes: 2 additions & 1 deletion src/python/parla/cython/core.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,14 @@ cdef extern from "include/runtime.hpp" nogil:

void activate_wrapper()

void spawn_task(InnerTask* task)
void spawn_task(InnerTask* task) except +

void add_worker(InnerWorker* worker)
void enqueue_worker(InnerWorker* worker)
void task_cleanup(InnerWorker* worker, InnerTask* task, int state) except +
void task_cleanup_presync(InnerWorker* worker, InnerTask* task, int state) except +
void task_cleanup_postsync(InnerWorker* worker, InnerTask* task, int state) except +
void task_cleanup_and_wait_for_task(InnerWorker* worker, InnerTask* task, int state) except +

void increase_num_active_tasks()
void decrease_num_active_tasks()
Expand Down
11 changes: 11 additions & 0 deletions src/python/parla/cython/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,17 @@ cdef class PyInnerScheduler:
with nogil:
c_self.task_cleanup_postsync(c_worker, c_task, state)

cpdef task_cleanup_and_wait_for_task(self, PyInnerWorker worker, PyInnerTask task, int state):
cdef InnerScheduler* c_self = self.inner_scheduler
cdef InnerWorker* c_worker = worker.inner_worker
cdef InnerTask* c_task = task.c_task
with nogil:
c_self.task_cleanup_and_wait_for_task(c_worker, c_task, state)

cpdef get_num_active_tasks(self):
cdef InnerScheduler* c_self = self.inner_scheduler
return c_self.get_num_active_tasks()

cpdef increase_num_active_tasks(self):
cdef InnerScheduler* c_self = self.inner_scheduler
c_self.increase_num_active_tasks()
Expand Down
77 changes: 49 additions & 28 deletions src/python/parla/cython/scheduler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,16 @@ class WorkerThread(ControllableThread, SchedulerContext):
self.scheduler.start_monitor.notify_all()

while self._should_run:
self.status = "Waiting"
#print("WAITING", flush=True)

#with self._monitor:
# if not self.task:
# self._monitor.wait()
nvtx.push_range(message="worker::wait", domain="Python Runtime", color="blue")
self.inner_worker.wait_for_task()

self.task = self.inner_worker.get_task()
if(self.task is None):
self.status = "Waiting"
# print("WAITING", flush=True)
#with self._monitor:
# if not self.task:
# self._monitor.wait()
nvtx.push_range(message="worker::wait", domain="Python Runtime", color="blue")
self.inner_worker.wait_for_task() # GIL Release
self.task = self.inner_worker.get_task()
if isinstance(self.task, core.DataMovementTaskAttributes):
self.task_attrs = self.task
self.task = DataMovementTask()
Expand All @@ -219,9 +219,7 @@ class WorkerThread(ControllableThread, SchedulerContext):
_global_data_tasks[id(self.task)] = self.task

nvtx.pop_range(domain="Python Runtime")

#print("THREAD AWAKE", self.index, self.task, self._should_run, flush=True)

# print("THREAD AWAKE", self.index, self.task, self._should_run, flush=True)
self.status = "Running"

if isinstance(self.task, Task):
Expand Down Expand Up @@ -270,8 +268,7 @@ class WorkerThread(ControllableThread, SchedulerContext):
device_context.record_events()

nvtx.pop_range(domain="Python Runtime")
#print("Finished Task", self.index, active_task.taskid.full_name, flush=True)

# print("Finished Task", self.index, active_task.taskid.full_name, flush=True)
nvtx.push_range(message="worker::cleanup", domain="Python Runtime", color="blue")

final_state = active_task.state
Expand All @@ -286,7 +283,7 @@ class WorkerThread(ControllableThread, SchedulerContext):

elif isinstance(final_state, tasks.TaskRunning):
nvtx.push_range(message="worker::continuation", domain="Python Runtime", color="red")
#print("CONTINUATION: ", active_task.taskid.full_name, active_task.state.dependencies, flush=True)
# print("CONTINUATION: ", active_task.taskid.full_name, active_task.state.dependencies, flush=True)
active_task.dependencies = active_task.state.dependencies
active_task.func = active_task.state.func
active_task.args = active_task.state.args
Expand All @@ -298,33 +295,57 @@ class WorkerThread(ControllableThread, SchedulerContext):
elif isinstance(final_state, tasks.TaskRunahead):
core.binlog_2("Worker", "Runahead task: ", active_task.inner_task, " on worker: ", self.inner_worker)

#TODO(wlr): Add better exception handling
#print("Cleaning up Task", active_task, flush=True)

if USE_PYTHON_RUNAHEAD:
#Handle synchronization in Python (for debugging, works!)
self.scheduler.inner_scheduler.task_cleanup_presync(self.inner_worker, active_task.inner_task, active_task.state.value)
if active_task.runahead != SyncType.NONE:
device_context.synchronize(events=True)
self.scheduler.inner_scheduler.task_cleanup_postsync(self.inner_worker, active_task.inner_task, active_task.state.value)
#print("Should run before cleanup_and_wait", self._should_run, active_task.inner_task, flush=True)
if self._should_run:
#print("In if", flush=True)
self.status = "Waiting"
nvtx.push_range(message="worker::wait::2", domain="Python Runtime", color="red")
self.scheduler.inner_scheduler.task_cleanup_and_wait_for_task(self.inner_worker, active_task.inner_task, active_task.state.value)
else:
#print("In else", flush=True)
self.scheduler.inner_scheduler.task_cleanup_presync(self.inner_worker, active_task.inner_task, active_task.state.value)
if active_task.runahead != SyncType.NONE:
device_context.synchronize(events=True)
self.scheduler.inner_scheduler.task_cleanup_postsync(self.inner_worker, active_task.inner_task, active_task.state.value)
else:
#Handle synchronization in C++
self.scheduler.inner_scheduler.task_cleanup(self.inner_worker, active_task.inner_task, active_task.state.value)

#print("Finished Cleaning up Task", active_task, flush=True)

# self.scheduler.inner_scheduler.task_cleanup(self.inner_worker, active_task.inner_task, active_task.state.value)
# Adding wait here to reduce context switch between GIL
print("Should run before cleanup_and_wait", self._should_run, active_task.inner_task, flush=True)
if self._should_run:
self.status = "Waiting"
nvtx.push_range(message="worker::wait::2", domain="Python Runtime", color="red")
self.scheduler.inner_scheduler.task_cleanup_and_wait_for_task(self.inner_worker, active_task.inner_task, active_task.state.value)
#self.task = self.inner_worker.get_task()
else:
self.scheduler.inner_scheduler.task_cleanup(self.inner_worker, active_task.inner_task, active_task.state.value)
# print("Finished Cleaning up Task", active_task, flush=True)
#print("Should run before device_context", self._should_run, task, flush=True)
if active_task.runahead != SyncType.NONE:
device_context.return_streams()

#print("Should run before final_state cleanup", self._should_run, task, flush=True)
if isinstance(final_state, tasks.TaskRunahead):
final_state = tasks.TaskCompleted(final_state.return_value)
active_task.cleanup()
core.binlog_2("Worker", "Completed task: ", active_task.inner_task, " on worker: ", self.inner_worker)

# print("Finished Task", active_task, flush=True)
# print("Should run before reassigning active_task", self._should_run, task, flush=True)
active_task.state = final_state
self.task = None
nvtx.pop_range(domain="Python Runtime")


# Adding wait here to reduce context switch between GIL
# if self._should_run:
# self.status = "Waiting"
# nvtx.push_range(message="worker::wait", domain="Python Runtime", color="blue")
# self.inner_worker.wait_for_task() # GIL Release
# self.task = self.inner_worker.get_task()

elif self._should_run:
raise WorkerThreadException("%r Worker: Woke without a task", self.index)
else:
Expand Down Expand Up @@ -500,7 +521,7 @@ class Scheduler(ControllableThread, SchedulerContext):
def spawn_task(self, task):
#print("Scheduler: Spawning Task", task, flush=True)
self.inner_scheduler.spawn_task(task.inner_task)

def assign_task(self, task, worker):
task.state = tasks.TaskRunning(task.func, task.args, task.dependencies)
worker.assign_task(task)
Expand Down Expand Up @@ -556,7 +577,7 @@ class Scheduler(ControllableThread, SchedulerContext):
device.

:param global_dev_id: global logical device id that
this function interests
this function interests
:param parray_parent_id: parent PArray ID
"""
return self.inner_scheduler.get_mapped_parray_state( \
Expand Down