Skip to content

Commit

Permalink
Fix stream sync for scratch pad eviction
Browse files Browse the repository at this point in the history
Summary:
Before this diff, SSD TBE unsafely evicts data from the scratch pad
(the buffer that stores conflict missed data).  The eviction happened
on the SSD stream while the backward TBE happened on the default
stream.  SSD TBE did not properly synchronize the streams to ensure
that the backward TBE completed before evicting data from the scratch
pad.  This diff fixes the problem by adding the syncrhonization
between streams to ensure the correct execution order between the
backward TBE and the scratch pad eviction.

Differential Revision: D59716516
  • Loading branch information
sryap authored and facebook-github-bot committed Jul 14, 2024
1 parent 88349ff commit d74edee
Showing 1 changed file with 44 additions and 12 deletions.
56 changes: 44 additions & 12 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,13 @@ def __init__(
# pyre-fixme[20]: Argument `self` expected.
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
self.ssd_stream = torch.cuda.Stream(priority=low_priority)
self.ssd_set_start = torch.cuda.Event()
self.ssd_set_end = torch.cuda.Event()

# SSD events
self.ssd_event_get = torch.cuda.Event()
self.ssd_event_evict = torch.cuda.Event()
self.ssd_event_backward = torch.cuda.Event()
self.ssd_event_evict_sp = torch.cuda.Event()

self.timesteps_prefetched: List[int] = []
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = []
# TODO: add type annotation
Expand Down Expand Up @@ -391,31 +396,49 @@ def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor:
return t_cpu

def evict(
self, evicted_rows: Tensor, evicted_indices: Tensor, actions_count_cpu: Tensor
self,
evicted_rows: Tensor,
evicted_indices: Tensor,
actions_count_cpu: Tensor,
eviction_stream: torch.cuda.Stream,
pre_event: torch.cuda.Event,
post_event: torch.cuda.Event,
) -> None:
"""
Evict data from the given input tensors to SSD via RocksDB
"""
with torch.cuda.stream(self.ssd_stream):
self.ssd_stream.wait_event(self.ssd_set_start)
with torch.cuda.stream(eviction_stream):
eviction_stream.wait_event(pre_event)

evicted_rows_cpu = self.to_pinned_cpu(evicted_rows)
evicted_indices_cpu = self.to_pinned_cpu(evicted_indices)
evicted_rows.record_stream(self.ssd_stream)
evicted_indices.record_stream(self.ssd_stream)

evicted_rows.record_stream(eviction_stream)
evicted_indices.record_stream(eviction_stream)

self.ssd_db.set_cuda(
evicted_indices_cpu, evicted_rows_cpu, actions_count_cpu, self.timestep
)

# TODO: is this needed?
# Need a way to synchronize
# actions_count_cpu.record_stream(self.ssd_stream)
self.ssd_stream.record_event(self.ssd_set_end)
eviction_stream.record_event(post_event)

def _evict_from_scratch_pad(self, grad: Tensor) -> None:
assert len(self.ssd_scratch_pads) > 0, "There must be at least one scratch pad"
(inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu) = (
self.ssd_scratch_pads.pop(0)
)
self.evict(inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu)
torch.cuda.current_stream().record_event(self.ssd_event_backward)
self.evict(
inserted_rows_gpu,
post_bwd_evicted_indices,
actions_count_cpu,
self.ssd_stream,
self.ssd_event_backward,
self.ssd_event_evict_sp,
)

def _compute_cache_ptrs(
self,
Expand Down Expand Up @@ -508,11 +531,13 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
current_stream = torch.cuda.current_stream()

# Ensure the previous iterations l3_db.set(..) has completed.
current_stream.wait_event(self.ssd_set_end)
current_stream.wait_event(self.ssd_event_evict)
current_stream.wait_event(self.ssd_event_evict_sp)

self.ssd_db.get_cuda(
self.to_pinned_cpu(inserted_indices), inserted_rows, actions_count_cpu
)
current_stream.record_event(self.ssd_set_start)
current_stream.record_event(self.ssd_event_get)
# TODO: T123943415 T123943414 this is a big copy that is (mostly) unnecessary with a decent cache hit rate.
# Should we allocate on HBM?
inserted_rows_gpu = inserted_rows.cuda(non_blocking=True)
Expand All @@ -525,7 +550,14 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
)

# Evict rows from cache to SSD
self.evict(evicted_rows, evicted_indices, actions_count_cpu)
self.evict(
evicted_rows,
evicted_indices,
actions_count_cpu,
self.ssd_stream,
self.ssd_event_get,
self.ssd_event_evict,
)

# TODO: keep only necessary tensors
self.ssd_prefetch_data.append(
Expand Down

0 comments on commit d74edee

Please sign in to comment.