Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign discussion: Launch Tasks from Tasks #5671

Open
crusaderky opened this issue Jan 19, 2022 · 19 comments
Open

Redesign discussion: Launch Tasks from Tasks #5671

crusaderky opened this issue Jan 19, 2022 · 19 comments
Labels
discussion Discussing a topic with no specific actions yet feature Something is missing

Comments

@crusaderky
Copy link
Collaborator

crusaderky commented Jan 19, 2022

Dask performs best when the whole graph of tasks is defined ahead of time from the client and submitted at once. This is however not always possible.

Use case and current best practice

A typical example is a top-down discovery + bottom-up aggregation of a tree where the discovery of the parent-child relationships is an operation too expensive to be performed on the client.

Use case in pure Python:

def get_children(node):
    """Expensive operation that discovers the direct children of a node
    """
    ...


def aggregate(node, children_outputs):
    """Expensive operation that calculates the value of a node based on its own
    properties plus the output of the aggregate function for each of its children
    """
    ...


def crawl(node):
    children = get_children(node)
    children_outputs = [crawl(child) for child in children]  # top-down recursion
    return aggregate(node, children_outputs)  # bottom-up aggregation


out = crawl(root)

The first way to solve this problem today with Dask is to have the client invoke client.submit(get_children, node) for every node and wait for results. This can be very network and CPU intensive for the client.

The second approach is to use secede/rejoin, as described in https://distributed.dask.org/en/latest/task-launch.html:

import distributed


def crawl(node):
    children = get_children(node)
    client = distributed.get_client()
    children_futures = client.map(crawl, children)
    distributed.secede()
    children_outputs = client.gather(children_futures)
    distributed.rejoin()
    return aggregate(node, children_outputs)


client = distributed.Client(...)
out = client.submit(crawl, root).result()

The above is problematic, because:

  1. there will be an uncontrolled increase in unmanaged memory on the workers, caused by all the local variables of the seceded task plus the accessory data needed to track the seceded tasks
  2. each seceded task adds a thread to the worker, which adds a burden to the Linux kernel. If the graph has enough nodes, the cluster will eventually die as the workers hit their ulimit.
  3. Because of point 2, it works very poorly for very large quantities of individually small nodes. You may redesign the algorithm to have a single crawl function go through a fixed-size cluster of contiguous nodes; such a change however is algorithmically complex.

and last but not least,
4. If any worker dies during the computation, you have to restart from scratch.

A situationally slightly better variant is as follows:

import distributed


def crawl(node):
    children = get_children(node)
    client = distributed.get_client()
    children_futures = client.map(crawl, children)
    out_future = client.submit(aggregate, node, children_futures)
    distributed.secede()
    return out_future.result()


client = distributed.Client(...)
out = client.submit(crawl, root).result()

The difference is subtle - as the subgraph of each child gets resolved, its (potentially large) output does not get stored in the stack of client.aggregate (which is unmanaged memory), but it goes into the managed memory instead with all the benefits of the case. On the flip side, the scheduler is now burdened with two futures per node instead of one. Regardless, all of the problems listed above remain.

Proposed redesign

I would like to suggest deprecating secede()/rejoin().
In its place, I would like to introduce the following rule:

If a task returns a Future, then the scheduler will wait for it and return its result instead. This may be nested (the result of the Future may itself be a Future).

The use case code becomes as follows:

import distributed


def crawl(node):
    children = get_children(node)
    client = distributed.get_client()
    children_futures = client.map(crawl, children)
    return client.submit(aggregate, node, children_futures)

client = distributed.Client(...)
out = client.submit(crawl, root).result()

No extra threads are ever created. Everything is managed by the scheduler - as it should. The network and CPU load on the user's client (e.g. a jupyter notebook) remain trivial.

Nested resolution of futures aside, the above code currently does not work because, after you return a future, as soon as the future is serialised and removed from Worker.data the future destructor kicks in, which in turn releases the refcount on the scheduler, so by the time the future is rebuilt on the opposite side the data it references may have been lost. Same if the future is spilled to disk.

Challenges

Publish/unpublish

It is currently possible to work around the scheduler forgetting the future upon return by publishing it temporarily. This is generally a bad idea because, short of implementing a user-defined garbage collector, you may end up with cluster-wide memory leaks of managed memory (datasets that are published and then forgotten, because the task that was supposed to unpublish them crashed or never started). Nonetheless, automatically resolving returned futures will break this pattern.

Workaround

Users can still use this hack but return the name of the temporary dataset instead of the Future.

Additions and nice-to-haves

Client-side tracking

It would be nice to see distributed.diagnostics.progressbar.progress display the increasing tasks in real time. This is not something that's happening with the current secede/rejoin design either.

Collections

Returned dask collections could be treated specially like Futures. For example, the below would halve the number of worker->scheduler comms and (personal preference) would also look nicer:

import distributed
from dask import delayed


def get_children(node):
    ...

@delayed
def aggregate(node, children_outputs):
     ...

@delayed
def crawl(node):
    children = get_children(node)
    children_delayeds = [crawl(child) for child in children]
    return aggregate(node, children_delayeds)


client = distributed.Client(...)
out = crawl(root).compute()

Under the hood, all it's happening is a two-liner that converts the collection into a future to revert to the base use case:

if is_dask_collection(retvalue):
    retvalue = get_client().compute(retvalue)

The same should be implemented in dask/dask, so that it works on the threading/multiprocessing schedulers too.

@crusaderky crusaderky added the discussion Discussing a topic with no specific actions yet label Jan 19, 2022
@gjoseph92
Copy link
Collaborator

Overall, I like this proposal, and I think tasks-within-tasks is certainly in need of a redesign.

However, one downside of the "futures returned from a task are always awaited" rule is that it eliminates some of the flexibility which might be your reason for using tasks-within-tasks in the first place.

Specifically, there may be other ways in which you want to wait for the futures to be done:

  • as_completed (you can start processing them as soon as one/a few are done, not all)
  • handing off the Futures themselves to some other task/client, to wait for them in whatever pattern it wants

For example, you couldn't do something like:

import distributed


def get_all_pages(x: Collection) -> list[PageId]: ...
def process_page(id: PageId) -> Page: ...
def combine_pages(pages: list[Page]) -> Summary: ...

def summarize(x: Collection, group_size: int = 4) -> list[distributed.Future[Summary]]:
    pages = get_all_pages(x)

    client = distributed.get_client()
    page_futures = client.map(process_page, x)
    distributed.secede()

    done_pages = []
    summary_futures = []
    for f in distributed.as_completed(result_futures):
        done_pages.append(f)
        if len(done) == group_size:
            summary_futures.append(client.submit(combine_pages, done_pages))
            done_pages.clear()

    return summary_futures

client = distributed.Client(...)
summarize_future = client.submit(summarize, x)
summary_futures = summarize_future.result()
# Note that `summarize` returns a list of Futures, not the actual Summaries,
# so that we can stream them to our Real Time Business Intelligence System

for f, summary in distributed.as_completed(summary_futures, with_results=True):
    display_summary_on_dashboard(summary)

I do like the simplicity of your proposal, though. And these extra-complex use cases may be worth giving up.

I do think Ray is worth looking at for prior art here, as a system that handles tasks-within-tasks as a core use-case, instead of an edge case like it is for distributed. The API is extremely simple, but I think belies some careful thinking about the rules needed to make this work well in a distributed context. (For example, you can submit and wait for futures without managing any of the secede/rejoin logic yourself—still trying to find a good reference for how that works.)

@crusaderky
Copy link
Collaborator Author

crusaderky commented Jan 24, 2022

@gjoseph92 your example currently does not work.
When summarize returns, it will (likely) close its client. The summary futures are likely to be forgotten by the scheduler because the only client holding a reference to them has been shut down. It ends in a race condition on what's fastest, the garbage collection of the worker client or the return value going all the way back to the user client. If there was a circular reference in the worker client, then you may as well experience that everything works, because it takes a while before the next gc run, until you increase the load on the worker and suddenly gc runs faster than your return value. Alternatively, if you run in a multithreaded worker the same Client instance may be used by multiple threads; as long as you have 2+ threads holding a reference to the Client at all times, it will work, but as soon as you end up with a single thread the Client will be garbage collected. For an inexperienced user, this is a nightmare to debug.

Also, in your example the dashboard will not display any updates until all pages to be summarized are complete. You could rely on the (not always true) assumption that futures are completed more or less in FIFO order and just immediately schedule a summary every 4 page futures - which is what all of the dask recursive aggregations do.

To me, you're highlighting two different issues:

  • dask lacks facilities for streaming results; as_completed is insufficient on its own. I personally would like to reopen the discussion on distributed.Queue. A new feature like "spawn a combine_pages task whenever there are 4 or more elements in the queue" would nicely solve your use case.
  • my proposal could be expanded as follows:
    If a task returns a list or tuple and the first element of the list/tuple is a Future, or it returns a set and an arbitrary element of the set is a Future, or it returns a dict and the first value is a Future, then ensure that the internal futures are not released when their client is closed; instead, their survival will depend on whatever client holds the future for the return value. When the future of the list is awaited, you get the list of futures, linked to the receiving client.
    This would solve your use case (but you either have to schedule the summaries from the user's client or immediately schedule all summaries without waiting for the pages). However I think it's a major addition to fix a use case that can't implemented today to begin with...

@crusaderky
Copy link
Collaborator Author

crusaderky commented Jan 24, 2022

@gjoseph92 I tried rewriting your use case in a way that works today, and what I got is very complicated and brittle.

import asyncio
import distributed


def get_all_pages(x: Collection) -> list[PageId]: ...
def process_page(id: PageId) -> Page: ...
def combine_pages(pages: list[Page]) -> Summary: ...


def summarize(x: Collection, q: distributed.Queue, group_size: int = 4) -> None:
    pages = get_all_pages(x)

    client = distributed.get_client()
    page_futures = client.map(process_page, pages)
    del pages
    distributed.secede()

    done_pages = []
    for f in distributed.as_completed(page_futures):
        done_pages.append(f)
        if len(done_pages) == group_size:
            q.put(client.submit(combine_pages, done_pages))
            done_pages = []
    q.put(None)


async def summarize_collection(x: Collection):
    client = await distributed.Client(..., asynchronous=True)
    q = await distributed.Queue()
    summarize_future = client.submit(summarize, x, q)
    # Convert distributed.Future to asyncio.Future
    summarize_future = asyncio.create_task(summarize_future.result())

    done = False
    pending = {summarize_future}
    queue_get_future = None
    while True:
        if queue_get_future is None and not done:
            queue_get_future = asyncio.create_task(q.get())
            pending.add(queue_get_future)
        done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
        for f in done:
            if f is queue_get_future:
                queue_output = await f
                if queue_output is None:
                    done = True
                else:
                    assert isinstance(queue_output, distributed.Future)
                    # Convert distributed.Future to asyncio.Future
                    pending.add(asyncio.create_task(queue_output.result()))
                queue_get_future = None
            elif f is summarize_future:
                # If summarize crashed remotely, bomb out instead of getting stuck
                # forever on a queue that will never get closed
                await f
            else:
                summary = await f
                display_summary_on_dashboard(summary)

I cooked a mock-up that requires a new class, distributed.FutureStream:

from functools import partial
import distributed


def get_all_pages(x: Collection) -> list[PageId]: ...
def process_page(id: PageId) -> Page: ...
def combine_pages(pages: list[Page]) -> Summary: ...


def process_pages(
    x: Collection,
    pages_stream: distributed.FutureStream[Page],
) -> None:
    pages = get_all_pages(x)
    client = distributed.get_client()
    page_futures = client.map(process_page, pages)
    for fut in page_futures:
        pages_stream.put(fut)
    pages_stream.close()


def summarize(
    futures: list[distributed.Future],
    pages: list[Page],
    summaries_stream: distributed.FutureStream[Summary],
) -> None:
    # TODO handle exceptions from futures
    if futures:
        client = distributed.get_client()
        summaries_stream.put(client.submit(combine_pages, pages))
    else:
        summaries_stream.close()


def summarize_collection(x: Collection):
    client = distributed.Client(...)
    pages_stream = distributed.FutureStream(order=False)
    summaries_stream = distributed.FutureStream(order=False)
    pages_stream.add_done_callback(
        partial(summarize, summaries_stream=summaries_stream),
        with_result=True,
        remote=True,
        chunk_size=4,
    )
    client.submit(process_pages, x, pages_stream).result()

    try:
        while True:
            summary = summaries_stream.get()
            display_summary_on_dashboard(summary)
    except distributed.StreamClosed:
        pass

This... works and is reasonably clean, however it serves specifically the use case of a multi-stage as_completed stream. I wonder how much real-life user appetite there is for it?

@akhmerov
Copy link

To chip in a bit, another application that would be covered by the original proposal, but not by a future stream is iterative optimization. With the approach of returning futures it would go as:

def optimization_step(current):
    if converged: return current
    next = guess_next(current)
    return client.compute(optimization_step(next))

As far as I understand, this doesn't seem addressed by a FutureStream.

@gjoseph92
Copy link
Collaborator

your example currently does not work

Apologies, I should have been clear: my example code wasn't meant to work, just to illustrate a use-case. I was imagining some semantic hybrid between the current distributed API and the Ray API here.

The point I was trying to make was just that aways awaiting returned futures restricts some use-cases. In particular, this sort of result streaming is either really tricky, or requires a special structure like FutureStream.

My hypothesis was just that if we could make it so that my code worked (correctly handling the reference-counting/ownership of futures-pointing-to-futures), then I think tasks-within-tasks would still be easy to use, but strictly more powerful than in your auto-await proposal, since you could control how the awaiting happened to easily implement things like streaming yourself, without any additional structures.

Whether it's worth the effort to make this reference-counting work correctly, I really can't say.


But since we're talking about ways to do this, here's my original code, converted to Ray, and actually working. You can see it's still quite simple and readable:

@ray.remote
def summarize(x: Collection, group_size: int = 4) -> list[ray.ObjectRef[Summary]]:
    pages = get_all_pages(x)

    page_oids = [process_page.remote(p) for p in pages]

    summary_oids: list[ray.ObjectRef[Summary]] = []
    while page_oids:
        done, page_oids = ray.wait(
            page_oids, num_returns=min(group_size, len(page_oids)), fetch_local=False
        )
        summary_oids.append(combine_pages.remote(*done))
        del done

    return summary_oids


summarize_future = summarize.remote(Collection("abcdefghijklmnopqrstuvwxyz"))
summary_futures = ray.get(summarize_future)
# Note that `summarize` returns a list of ObjectIDs, not the actual Summaries,
# so that we can stream them to our Real Time Business Intelligence System

while summary_futures:
    [summary], summary_futures = ray.wait(summary_futures)
    display_summary_on_dashboard(ray.get(summary))
Fully-runnable example in Ray
from __future__ import annotations

from collections import Counter
import time
from typing import NewType

import ray

ray.init()

Collection = NewType("Collection", str)
PageId = NewType("PageId", str)
Page = NewType("Page", list[int])
Summary = NewType("Summary", dict)


def get_all_pages(c: Collection) -> list[PageId]:
    return [PageId(x) for x in c]


@ray.remote
def process_page(id: PageId) -> Page:
    print(f"process page {id}")
    time.sleep(0.5)
    return Page(list(range(ord(id))))


@ray.remote
def combine_pages(*pages: Page) -> Summary:
    print(f"combine {len(pages)} pages")
    time.sleep(0.2)
    return Summary(Counter(x for p in pages for x in p))


def display_summary_on_dashboard(summary: Summary):
    print(f"NEW AI DATA: {max(summary.values())}")


@ray.remote
def summarize(x: Collection, group_size: int = 4) -> list[ray.ObjectRef[Summary]]:
    pages = get_all_pages(x)

    page_oids = [process_page.remote(p) for p in pages]

    summary_oids: list[ray.ObjectRef[Summary]] = []
    while page_oids:
        done, page_oids = ray.wait(
            page_oids, num_returns=min(group_size, len(page_oids)), fetch_local=False
        )
        summary_oids.append(combine_pages.remote(*done))
        del done

    return summary_oids


summarize_future = summarize.remote(Collection("abcdefghijklmnopqrstuvwxyz"))
summary_futures = ray.get(summarize_future)
# Note that `summarize` returns a list of ObjectIDs, not the actual Summaries,
# so that we can stream them to our Real Time Business Intelligence System

while summary_futures:
    [summary], summary_futures = ray.wait(summary_futures)
    display_summary_on_dashboard(ray.get(summary))

I think the reason this works in Ray but doesn't in distributed is because in

summarize_future = summarize.remote(Collection("abcdefghijklmnopqrstuvwxyz"))
summary_futures = ray.get(summarize_future)

Ray is able to track that summarize_future is a pointer to a pointer / future to a future, and therefore the inner futures are "owned"/kept alive by the outer future. As you mentioned @crusaderky, distributed can't track this—it's just a race condition whether the worker client releases the inner futures before the user client can dereference the outer futures and pick up references to the inner futures.

I don't think we have a way to represent this dependency structure right now in dask? Currently, if a task is complete, all its dependencies must be complete. This would be a new situation, where a task (summarize) is complete, yet in the process of running, it added more "dependencies" (not dependencies—we'd need a different term) to itself, which are not yet complete. So because your client wants summarize_future, and summarize_future "depends on" all of the summary_futures, the summary_futures are kept alive even though no client wants them directly.

Maybe we could add wants_what (and tasks_who_want) to scheduler.TaskState, as a corollary to ClientState.wants_what? So tasks could pin references to other keys, in the same way that Clients can? That might be sufficient to track this nested ownership. Any keys returned by a task (including traversing lists/tuples/dicts) would be added to wants_what. Keys that were wanted—either by a client, or another task—would not be released. Therefore, by holding a reference to the outer future on your client, you'd transitively keep any tree of keys that it points to alive.

Though we might then want checks for ownership reference cycles, which would be interesting.

With some cleverness, you could even implement some sort of chain-fusion/de-aliasing on the scheduler, where you collapse linear chains of futures down to just the first and last element. I'm thinking about this because, in the example of recursion @akhmerov posted, the recursion will potentially produce a huge number of futures, which serve no purpose other than to point to the next future. Both tracking this on the scheduler, and traversing the many Future objects to get the final result, could be expensive. But this is a classic example of where you'd want tail-call optimization. And I think you might be able to implement it on the scheduler with a chain-collapsing rule.

@crusaderky
Copy link
Collaborator Author

here's my original code, converted to Ray, and actually working

You're still not executing any calls to display_summary_on_dashboard until all calls to process_page have been completed though.

I confess my ignorance with ray - do I understand correctly that

done, pending = ray.wait(...)

is the same as

distributed.secede()
done, pending = distributed.wait(...)
distributed.rejoin()

?
Or is there some python interpreter magic going on where what looks like a synchronous function is actually asynchronous?

yet in the process of running, it added more "dependencies" (not dependencies—we'd need a different term) to itself

consequences?

@gjoseph92
Copy link
Collaborator

Yes, ray.wait / ray.get basically handles the secede/rejoin for you automatically. I haven't read their implementation, but the end result is that as a user, you can just use it without worrying about deadlocks.

consequences?

Not sure what you're asking here. The consequences of a task's result (a Future) depending on other tasks are exactly what you talked about:

It ends in a race condition on what's fastest, the garbage collection of the worker client or the return value going all the way back to the user client.

If we codify the fact that a task's result can depend on other tasks (adding wants_what to TaskState as I'm suggesting), then this is all tracked on the scheduler, there's no more race condition, and returning Futures from tasks has first-class support.

@crusaderky
Copy link
Collaborator Author

consequences?

Not sure what you're asking here.

I'm asking if you like the idea to call these spawned tasks "consequences" since we agree they are not dependencies.

@gjoseph92
Copy link
Collaborator

Ah I see. Yes, that makes sense as a user-facing term maybe. I think I'd prefer wants_what on the TaskState though, since it matches with ClientState.wants_what, and is effectively the same thing.

@AmineDiro
Copy link

Hello, 😄

I'm a newer user of Dask and currently working on a specific workflow where I need to process a huge amount of documents and stumbled on this opened issue.

The basic idea is to open a huge amount of documents. Each document is a collection of an arbitrary number of pages.
We then apply a OCR on each page. The time to process each page is arbitrary and we have a lot more pages than documents ( 20M pages for 400 000 documents).

I tried using Dask for this pipeline and tried different ways of writing this pipeline and I settled on the task in tasks design pattern like @gjoseph92 showed in the mock example :

OCR pages in document mock code

from time import sleep
from distributed import LocalCluster, Client
from distributed import Queue

def ocr_image(page):
    timeDelay = random.randrange(1,10)
    sleep(timeDelay)  # simulate actual ocr work
    return "this is ocr"

def load_pages(doc):
    # simulate open file
    sleep(0.5)
    futures=[]
    n = random.randint(1,5)
    with worker_client() as client: 
        for page in range(n):
            future_ocr = client.submit(ocr_image,page,pure=False)
            futures.append(future_ocr)
    return futures

def main(): 
    # Load and submit tasks
    loaders= [ client.submit(load_pages,doc,pure=False) for doc in filenames]
    res_loaders = client.gather(loaders)
    
    res_ocr = client.gather(list(chain.from_iterable(res_loaders)))
    return res_ocr

The issue with this approach is having to schedule a LOT of small tasks, so I thought about batching but the issue here is the arbitrary number of pages in a document. ( 1 pages to 40000 !)

The correct approach in my humble opinion would be to have a distributed producer/consumer architecture with a distributed queue like of pages that we can consume.

I tried distribute.Queue class with a wait for the first but it has some major issues if you don't know the exact number of spawned tasks in tasks :

Queue code

def batch_ocr_image():
    #  You can't have batch size and timeout 
    #  pages = [ q.get(timeout='1s') for _ in range(batch_size)]
    pages = q.get(batch_size)
    for _ in range(batch_size) :
        timeDelay = random.randrange(1,10)
        sleep(timeDelay) # simulate actual ocr work
    return ["this is ocr"]*batch_size

def ocr_image():
    page = q.get(timeout='1s')
    timeDelay = random.randrange(1,10)
    sleep(timeDelay)
    return "this is ocr"

def load_pages(doc):
    # simulate open file
    sleep(0.5)
    futures=[]
    n = random.randint(1,5)
    n = 10
    for page in range(n):
        q.put(page)
    return n

    
def main(): 
    ## Load pages in queue
    loaders= [ client.submit(load_pages,doc,pure=False) for doc in filenames]

    # Sync 1 : Gather loaders
    # approach 1 : wait for all loaders to finish res_loaders = client.gather(loaders)
    # approach 2 : wait for the first and then submit
    loaders = wait(loaders,return_when='FIRST_COMPLETED')

    ## Batching
    # Batching is very hard : q.qsize() will fail here
    consumers = [client.submit(batch_ocr_image,pure=False,retries=4) 
                 for _ in range(q.qsize()//batch_size)]
    # Sync 2 : to consume queue 
    res_consumer = client.gather(consumers)
    
    return loaders, res_consumer

I might miss something about how to correctly implement the producer/consumer using distributed.
The distributed.FutureStream proposed by @crusaderky could solve the lack of easy to implement producer/consumer pattern in dask.distributed.

I have just submitted a feature request with the needed methods for a proper distributed queue in my opinion based on the existing class :

  • The queue is not garbage collected, calling del q doesnt actually free up the distributed memory
  • We might need to implement a q.join() in distributed manner to block until all the queue items have been processed
  • Batching in queue doesnt have a 'last object mode' where we return lesser objects than the batch_size if the queue is empty
  • Spawning a "pool of consumers" in a distributed system can be tricky : basically the machines producing are also the ones waiting for the tasks to be added ? Is there a way to load balance producers and consumers in the dask scheduler ?

I don't know the internal design of distributed but submitting futures might the issue with derefencing when the root function return, the future continues to live until q.get() is called, we can the either block the consumer until the result is done .

@crusaderky
Copy link
Collaborator Author

@AmineDiro if you just had one task per document (400k tasks) and then process each page sequentially inside the same task, would you incur in a substantially suboptimal behaviour? I assume here that the number of workers you use is a tiny fraction of 400k.

Even if you do need to spawn futures for the pages, wouldn't it be covered by my design in the opening post?

(20M pages for 400 000 documents)

I'm afraid that, as of today, if you have 20M tasks you will likely hit 100% CPU load on the scheduler - and consequently experience a wealth of random timeouts. I would advise to perform some clustering.

@AmineDiro
Copy link

Thanks for the response @crusaderky ! Sequentially processing pages of each document will be bottlenecked by the workers with documents of 40000 pages... I am using an HPC with 100 worker of 24 cores each. You are right, I did see a 100% cpu load on the scheduler.
I would like to batch page but I need an object that stores opened pages in a buffer before submitting them, a queue would do the job....

@crusaderky
Copy link
Collaborator Author

crusaderky commented Feb 24, 2022

@AmineDiro your problem can be solved by the design in the initial post; no need for queues:

CHUNK_SIZE = 1000  # pages processed by a single task

def parse_document(path: str) -> list[Image]:
    # Load document from disk and return one raw image per page
    ...

def ocr_page(page: Image) -> OCRExitStatus:
    # Run a single page through OCR, dump output to disk, and return metadata / useful info
    ...

@delayed
def ocr_pages(pages: list[Image]) -> list[OCRExitStatus]:
    return [ocr_page(page) for page in pages]

@delayed
def aggregate_ocr_results(*chunk_results: list[OCRExitStatus]) -> list[OCRExitStatus]:
    return [r for chunk in chunk_results for r in chunk]

@delayed
def ocr_document(doc_path: str):
    raw_pages = parse_document(path)
    client = distributed.get_client()
    chunks = client.scatter(
        [
            raw_pages[i: i + CHUNK_SIZE]
            for i in range(0, len(raw_pages), CHUNK_SIZE)
        ]
    )
    return aggregate_ocr_results(ocr_pages(chunk) for chunk in chunks)


client = distributed.Client()
all_results = aggregate_ocr_results(ocr_document(path) for path in paths)
all_results.compute()  # returns list[OCRExitStatus]

Also note that, WITHOUT this change, you can achieve today what you want with a slightly less efficient two-stage approach:

CHUNK_SIZE = 1000  # pages processed by a single task

def count_pages(path: str) -> int:
    # Open document, peek at the header, and return number of pages contained within
    ...

@delayed
def parse_document(path: str) -> list[Image]:
    # Fully load document from disk and return one raw image per page
    ...

def ocr_page(page: Image) -> OCRExitStatus:
    # Run a single page through OCR, dump output to disk, and return metadata / useful info
    ...

@delayed
def ocr_pages(pages: list[Image]) -> list[OCRExitStatus]:
    return [ocr_page(page) for page in pages]

@delayed
def aggregate_ocr_results(*chunk_results: list[OCRExitStatus]) -> list[OCRExitStatus]:
    return [r for chunk in chunk_results for r in chunk]

client = distributed.Client()
npages = client.gather(client.map(count_pages, paths))
ocr_chunk_delayeds = []
for path, npages_i in zip(paths, npages):
   raw_pages_delayed = parse_document(path)
   ocr_chunk_delayeds += [
       ocr_pages(raw_pages_delayed[i: i + CHUNK_SIZE])
       for i in range(0, npages_i, CHUNK_SIZE)
   ]
all_results = aggregate_ocr_results(*ocr_chunk_delayeds)
all_results.compute()  # returns list[OCRExitStatus]

@AmineDiro
Copy link

AmineDiro commented Feb 25, 2022

Amazing ! Thanks a lot @crusaderky for taking the time to write up this code.

I have had some issues with resolving the delayed objects from the aggregate_ocr_results . The .compute() never computes the actual ocr and return a List[Delayed] ? Is there something I'm missing ?

I didn't think about using this approach because I taught that I need to chunk pages across documents and not within a single doc, but I see how this could work !

The discussion for a producer/consummer is bit more general, and gives a "cleaner" way to solve these kinds of problems

@crusaderky
Copy link
Collaborator Author

I have had some issues with resolving the delayed objects from the aggregate_ocr_results . The .compute() never computes the actual ocr and return a List[Delayed] ? Is there something I'm missing ?

The first block of code in my previous post does not work today; it requires the change described in the op.

@AmineDiro
Copy link

@crusaderky Ok got it ! the first one needs the change you mentionned above.

I think that the second one also needs that change to work.

client = distributed.Client()
npages = client.gather(client.map(count_pages, paths))
ocr_chunk_delayeds = []
for path, npages_i in zip(paths, npages):
   raw_pages_delayed = parse_document(path)
   ocr_chunk_delayeds += [
       ocr_pages(raw_pages_delayed[i: i + CHUNK_SIZE])
       for i in range(0, npages_i, CHUNK_SIZE)
   ] ### Added the delayed call here to ocr_page
all_results = aggregate_ocr_results(*ocr_chunk_delayeds)
all_results.compute()  # returns list[OCRExitStatus] ==> **Returns a LIST[Delayed]**

The issue is that I still need to wait on another compute to get pages.

result = client.gather(client.compute(all_results.compute()))

Thanks for your help

@crusaderky
Copy link
Collaborator Author

crusaderky commented Mar 2, 2022

@AmineDiro there was 1 line wrong

     ocr_chunk_delayeds += [
-        raw_pages_delayed[i: i + CHUNK_SIZE]
+        ocr_pages(raw_pages_delayed[i: i + CHUNK_SIZE])
         for i in range(0, npages_i, CHUNK_SIZE)
     ]

I think that the second one also needs that change to work.

It doesn't. I just tested with mocked functions that it does work today as intended.

result = client.gather(client.compute(all_results.compute()))

This is not correct. The compute() method already returns the final output.
For a single Delayed object such as this, these are all equivalent, as long as you're using a synchronous Client:

all_results.compute()
client.compute(all_results).result()
client.gather(client.compute(all_results))

@jakirkham
Copy link
Member

Haven't read through the full issue, but just want to say am happy to see this discussion 😄

In particular this stuck out to me...

The same should be implemented in dask/dask, so that it works on the threading/multiprocessing schedulers too.

Couldn't agree more. There are actual use cases we could solve with improved nested task submission (especially if all schedulers can support it as it is easier to add that code into Dask itself).

@crusaderky
Copy link
Collaborator Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion Discussing a topic with no specific actions yet feature Something is missing
Projects
None yet
Development

No branches or pull requests

5 participants