-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[aDAG] support buffered input #47272
Changes from all commits
2c7645b
1c67606
6997436
43cf80b
15d5a1e
a8555f8
da826d2
256b0ec
fe0ff80
f560a18
36c36c5
5d7f2ce
bf690ac
286e633
da65fb4
f1426f3
fafbae6
f2a3dfb
99b1415
fca08a3
98e6553
918c339
480d289
2f912b8
3516e61
ab108e9
e72faa5
7a216d8
568de34
231cb69
e2ae0cf
ad323b4
14462c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -115,6 +115,7 @@ def do_exec_tasks( | |
if done: | ||
break | ||
for operation in schedule: | ||
print("SANG-TODO operation: ", operation) | ||
done = tasks[operation.exec_task_idx].exec_operation( | ||
self, operation.type | ||
) | ||
|
@@ -597,6 +598,7 @@ def __init__( | |
enable_asyncio: bool = False, | ||
asyncio_max_queue_size: Optional[int] = None, | ||
max_buffered_results: Optional[int] = None, | ||
max_inflight_executions: Optional[int] = None, | ||
): | ||
""" | ||
Args: | ||
|
@@ -625,6 +627,10 @@ def __init__( | |
executions is beyond the DAG capacity, the new execution would | ||
be blocked in the first place; therefore, this limit is only | ||
enforced when it is smaller than the DAG capacity. | ||
max_inflight_executions: The maximum number of in-flight executions that | ||
are allowed to be sent to this DAG. Before submitting more requests, | ||
the caller is responsible for calling ray.get to get the result, | ||
otherwise, RayAdagCapacityExceeded is raised. | ||
Returns: | ||
Channel: A wrapper around ray.ObjectRef. | ||
|
@@ -633,29 +639,37 @@ def __init__( | |
|
||
ctx = DAGContext.get_current() | ||
|
||
self._enable_asyncio: bool = enable_asyncio | ||
self._fut_queue = asyncio.Queue() | ||
self._asyncio_max_queue_size: Optional[int] = asyncio_max_queue_size | ||
# TODO(rui): consider unify it with asyncio_max_queue_size | ||
self._max_buffered_results: Optional[int] = max_buffered_results | ||
if self._max_buffered_results is None: | ||
self._max_buffered_results = ctx.max_buffered_results | ||
self._max_inflight_executions = max_inflight_executions | ||
if self._max_inflight_executions is None: | ||
self._max_inflight_executions = ctx.max_inflight_executions | ||
self._dag_id = uuid.uuid4().hex | ||
self._execution_timeout: Optional[float] = execution_timeout | ||
if self._execution_timeout is None: | ||
self._execution_timeout = ctx.execution_timeout | ||
self._buffer_size_bytes: Optional[int] = buffer_size_bytes | ||
if self._buffer_size_bytes is None: | ||
self._buffer_size_bytes = ctx.buffer_size_bytes | ||
|
||
self._default_type_hint: ChannelOutputType = SharedMemoryType( | ||
self._buffer_size_bytes | ||
self._buffer_size_bytes, | ||
# We conservatively set num_shm_buffers to _max_inflight_executions. | ||
# It means that the DAG can be underutilized, but it guarantees there's | ||
# no false positive timeouts. | ||
num_shm_buffers=self._max_inflight_executions, | ||
) | ||
if not isinstance(self._buffer_size_bytes, int) or self._buffer_size_bytes <= 0: | ||
raise ValueError( | ||
"`buffer_size_bytes` must be a positive integer, found " | ||
f"{self._buffer_size_bytes}" | ||
) | ||
|
||
self._enable_asyncio: bool = enable_asyncio | ||
self._fut_queue = asyncio.Queue() | ||
self._asyncio_max_queue_size: Optional[int] = asyncio_max_queue_size | ||
# TODO(rui): consider unify it with asyncio_max_queue_size | ||
self._max_buffered_results: Optional[int] = max_buffered_results | ||
if self._max_buffered_results is None: | ||
self._max_buffered_results = ctx.max_buffered_results | ||
# Used to ensure that the future returned to the | ||
# caller corresponds to the correct DAG output. I.e. | ||
# order of futures added to fut_queue should match the | ||
|
@@ -721,7 +735,7 @@ def __init__( | |
self._execution_index: int = 0 | ||
# The maximum index of finished executions. | ||
# All results with higher indexes have not been generated yet. | ||
self._max_execution_index: int = -1 | ||
self._max_finished_execution_index: int = -1 | ||
self._result_buffer: Dict[int, Any] = {} | ||
|
||
def _get_creator_or_proxy_actor() -> "ray.actor.ActorHandle": | ||
|
@@ -765,6 +779,12 @@ def _get_creator_or_proxy_actor() -> "ray.actor.ActorHandle": | |
|
||
self._creator_or_proxy_actor = _get_creator_or_proxy_actor() | ||
|
||
def increment_max_finished_execution_index(self) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really need this method? It just has a single statement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't want to increment it by acessing internal attribute, lmk if you have other alternative |
||
"""Increment the max finished execution index. It is used to | ||
figure out the max number of in-flight requests to the DAG | ||
""" | ||
self._max_finished_execution_index += 1 | ||
|
||
@property | ||
def has_single_output(self): | ||
return self._has_single_output | ||
|
@@ -1857,6 +1877,20 @@ def run(self): | |
monitor.start() | ||
return monitor | ||
|
||
def raise_if_too_many_inflight_requests(self): | ||
num_in_flight_requests = ( | ||
self._execution_index - self._max_finished_execution_index | ||
) | ||
if num_in_flight_requests > self._max_inflight_executions: | ||
raise ray.exceptions.RayAdagCapacityExceeded( | ||
f"There are {num_in_flight_requests} in-flight requests which " | ||
"is more than specified _max_inflight_executions of the dag: " | ||
f"{self._max_inflight_executions}. Retrieve the output using " | ||
"ray.get before submitting more requests or increase " | ||
"`max_inflight_executions`. " | ||
"`adag.experimental_compile(_max_inflight_executions=...)`" | ||
) | ||
|
||
def _execute_until( | ||
self, | ||
execution_index: int, | ||
|
@@ -1885,10 +1919,10 @@ def _execute_until( | |
if timeout is None: | ||
timeout = ctx.retrieval_timeout | ||
|
||
while self._max_execution_index < execution_index: | ||
if self._max_execution_index + 1 == execution_index: | ||
while self._max_finished_execution_index < execution_index: | ||
if self._max_finished_execution_index + 1 == execution_index: | ||
# Directly fetch and return without buffering | ||
self._max_execution_index += 1 | ||
self.increment_max_finished_execution_index() | ||
return self._dag_output_fetcher.read(timeout) | ||
# Otherwise, buffer the result | ||
if len(self._result_buffer) >= self._max_buffered_results: | ||
|
@@ -1897,10 +1931,10 @@ def _execute_until( | |
f"buffered results is {self._max_buffered_results}; call ray.get() " | ||
"on previous CompiledDAGRefs to free them up from buffer." | ||
) | ||
self._max_execution_index += 1 | ||
self.increment_max_finished_execution_index() | ||
start_time = time.monotonic() | ||
self._result_buffer[ | ||
self._max_execution_index | ||
self._max_finished_execution_index | ||
] = self._dag_output_fetcher.read(timeout) | ||
if timeout != -1: | ||
timeout -= time.monotonic() - start_time | ||
|
@@ -1946,6 +1980,7 @@ def execute( | |
else: | ||
inp = RayDAGArgs(args=args, kwargs=kwargs) | ||
|
||
self.raise_if_too_many_inflight_requests() | ||
self._dag_submitter.write(inp, self._execution_timeout) | ||
|
||
ref = CompiledDAGRef(self, self._execution_index) | ||
|
@@ -2004,6 +2039,7 @@ async def execute_async( | |
else: | ||
inp = RayDAGArgs(args=args, kwargs=kwargs) | ||
|
||
self.raise_if_too_many_inflight_requests() | ||
await self._dag_submitter.write(inp) | ||
# Allocate a future that the caller can use to get the result. | ||
fut = asyncio.Future() | ||
|
@@ -2039,13 +2075,15 @@ def build_compiled_dag_from_ray_dag( | |
enable_asyncio: bool = False, | ||
asyncio_max_queue_size: Optional[int] = None, | ||
max_buffered_results: Optional[int] = None, | ||
max_inflight_executions: Optional[int] = None, | ||
) -> "CompiledDAG": | ||
compiled_dag = CompiledDAG( | ||
execution_timeout, | ||
buffer_size_bytes, | ||
enable_asyncio, | ||
asyncio_max_queue_size, | ||
max_buffered_results, | ||
max_inflight_executions, | ||
) | ||
|
||
def _build_compiled_dag(node): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can/should we remove this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm yeah I think technically it is not needed anymore cc @ruisearch42 I will follow up in the next PR