Skip to content

Commit

Permalink
Enable pipeline prefetching (#2963)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2963

This diff enables pipeline cache prefetching for SSD-TBE.  This allows
prefetch for the next iteration's batch to be carried out while the
computation of the current batch is going on.

We have done the following to guarantee cache consistency when
pipeline prefetching is enabled:

(1) Enable cache line locking (implemented in D46172802, D47638502,
D60812956) to ensure that cache lines are not prematurely evicted by
the prefetch when the previous iteration's computation is not
complete.

(2) Lookup L1 cache, the previous iteration's scratch pad (let's call
it SP(i-1)), and SSD/L2 cache. Move rows from SSD/L2 and/or SP(i-1) to
either L1 or the current iteration's scratch pad (let's call it
SP(i)).  Then we update the row pointers of the previous iteration's
indices based on the new locations, i.e., L1 or SP(i).  The detailed
explaination of the process is shown in the figure below:

{F1802341461}
https://internalfb.com/excalidraw/EX264315

(3) Ensure proper synchronizations between streams and events
- Ensure that prefetch of iteration i is complete before backward TBE
  of iteration i-1
- Ensure that prefetch of iteration i+1 starts after the backward TBE
  of iteration i is complete

The following is how prefetch operators run on GPU streams/CPU:

{F1802798301}

**Usage:**

```
# Initialize the module with prefetch_pipeline=True
emb = SSDTableBatchedEmbeddingBags(
            embedding_specs=...,
            prefetch_pipeline=True,
).cuda()

# When calling prefetch, make sure to pass the forward stream if using
# prefetch_stream so that TBE records tensors on streams properly
with torch.cuda.stream(prefetch_stream):
    emb.prefetch(
        indices,
        offsets,
        forward_stream=forward_stream
    )
```

Differential Revision: D60727327
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 14, 2024
1 parent 3070f88 commit 07dd860
Show file tree
Hide file tree
Showing 5 changed files with 807 additions and 318 deletions.
Loading

0 comments on commit 07dd860

Please sign in to comment.