Skip to content

Commit

Permalink
Fix stream sync for scratch pad eviction (#2843)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2843

Before this diff, SSD TBE unsafely evicts data from the scratch pad
(the buffer that stores conflict missed data).  The eviction happened
on the SSD stream while the backward TBE happened on the default
stream.  SSD TBE did not properly synchronize the streams to ensure
that the backward TBE completed before evicting data from the scratch
pad.  This diff fixes the problem by adding the syncrhonization
between streams to ensure the correct execution order between the
backward TBE and the scratch pad eviction.

**Before and after this diff**

{F1761028617}

**Before:** the scratch pad eviction happens as soon as the cache eviction
is done, which is incorrect. In the figure, it overlaps with TBE
forward and backward.

**After:** the scartch pad eviction happens after the backward pass of TBE
is done

Reviewed By: q10

Differential Revision: D59716516

fbshipit-source-id: 7c60116b7cea13948d221a2eff2b4be0b99bf17a
  • Loading branch information
sryap authored and facebook-github-bot committed Jul 19, 2024
1 parent 035a02a commit 13d5470
Showing 1 changed file with 44 additions and 21 deletions.
65 changes: 44 additions & 21 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,16 @@ def __init__(
# pyre-fixme[20]: Argument `self` expected.
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
self.ssd_stream = torch.cuda.Stream(priority=low_priority)
self.ssd_set_start = torch.cuda.Event()
self.ssd_set_end = torch.cuda.Event()

# SSD get completion event
self.ssd_event_get = torch.cuda.Event()
# SSD eviction completion event
self.ssd_event_evict = torch.cuda.Event()
# SSD backward completion event
self.ssd_event_backward = torch.cuda.Event()
# SSD scratch pad eviction completion event
self.ssd_event_evict_sp = torch.cuda.Event()

self.timesteps_prefetched: List[int] = []
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = []
# TODO: add type annotation
Expand Down Expand Up @@ -460,43 +468,54 @@ def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor:

def evict(
self,
evicted_rows: Tensor,
evicted_indices: Tensor,
rows: Tensor,
indices: Tensor,
actions_count_cpu: Tensor,
stream: torch.cuda.Stream,
pre_event: torch.cuda.Event,
post_event: torch.cuda.Event,
name: Optional[str] = "",
) -> None:
"""
Evict data from the given input tensors to SSD via RocksDB
"""
with record_function(f"## ssd_evict_{name} ##"):
with torch.cuda.stream(self.ssd_stream):
self.ssd_stream.wait_event(self.ssd_set_start)
evicted_rows_cpu = self.to_pinned_cpu(evicted_rows)
evicted_indices_cpu = self.to_pinned_cpu(evicted_indices)
evicted_rows.record_stream(self.ssd_stream)
evicted_indices.record_stream(self.ssd_stream)
with torch.cuda.stream(stream):
stream.wait_event(pre_event)

rows_cpu = self.to_pinned_cpu(rows)
indices_cpu = self.to_pinned_cpu(indices)

rows.record_stream(stream)
indices.record_stream(stream)

self.record_function_via_dummy_profile(
f"## ssd_set_{name} ##",
self.ssd_db.set_cuda,
evicted_indices_cpu,
evicted_rows_cpu,
indices_cpu,
rows_cpu,
actions_count_cpu,
self.timestep,
)

# TODO: is this needed?
# Need a way to synchronize
# actions_count_cpu.record_stream(self.ssd_stream)
self.ssd_stream.record_event(self.ssd_set_end)
stream.record_event(post_event)

def _evict_from_scratch_pad(self, grad: Tensor) -> None:
assert len(self.ssd_scratch_pads) > 0, "There must be at least one scratch pad"
(inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu) = (
self.ssd_scratch_pads.pop(0)
)
torch.cuda.current_stream().record_event(self.ssd_event_backward)
self.evict(
inserted_rows_gpu,
post_bwd_evicted_indices,
actions_count_cpu,
rows=inserted_rows_gpu,
indices=post_bwd_evicted_indices,
actions_count_cpu=actions_count_cpu,
stream=self.ssd_stream,
pre_event=self.ssd_event_backward,
post_event=self.ssd_event_evict_sp,
name="scratch_pad",
)

Expand Down Expand Up @@ -592,7 +611,8 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
current_stream = torch.cuda.current_stream()

# Ensure the previous iterations l3_db.set(..) has completed.
current_stream.wait_event(self.ssd_set_end)
current_stream.wait_event(self.ssd_event_evict)
current_stream.wait_event(self.ssd_event_evict_sp)

inserted_indices_cpu = self.to_pinned_cpu(inserted_indices)

Expand All @@ -604,7 +624,7 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
actions_count_cpu,
)

current_stream.record_event(self.ssd_set_start)
current_stream.record_event(self.ssd_event_get)
# TODO: T123943415 T123943414 this is a big copy that is (mostly) unnecessary with a decent cache hit rate.
# Should we allocate on HBM?
inserted_rows_gpu = inserted_rows.cuda(non_blocking=True)
Expand All @@ -618,9 +638,12 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:

# Evict rows from cache to SSD
self.evict(
evicted_rows,
evicted_indices,
actions_count_cpu,
rows=evicted_rows,
indices=evicted_indices,
actions_count_cpu=actions_count_cpu,
stream=self.ssd_stream,
pre_event=self.ssd_event_get,
post_event=self.ssd_event_evict,
name="cache",
)

Expand Down

0 comments on commit 13d5470

Please sign in to comment.