Skip to content

Commit

Permalink
Make evicted_rows a UVA buffer (#3079)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3079

X-link: facebookresearch/FBGEMM#173

Prior to this diff, SSD-TBE used a combination of a pinned CPU buffer
and the GPU buffer for `evicted_rows` (the buffer for staging rows
that are evicted from L1 cache).  It explicitly performed asynchronous
memory copy (via `cudaMemcpyAsync`) to transfer `evicted_rows` from
device to host.  Since the number of evicted rows is known only on the
device, SSD-TBE overallocated the `evicted_rows` CPU and GPU buffers.
Therefore, it transferred extra data during the device-host memory
copy.  Such the extra data could be large and could make the memory
copy a bottleneck of an execution.

This diff mitigates the problem mentioned above by using a unified
address buffer for `evicted_rows` and using a kernel (namely
`masked_index_select` to load/store data instead of using a CUDA
memory copy operation.  This mechanism can avoid the extra memory
copy.  However, the memory copy can be less efficient (might not be
able to fully saturate the available memory bandwidth) since it does
not use the copy engine.  Moreover, since it uses SMs for memory copy,
when overlapping the operator with other computes, it can potentially
compete for the SM resources with others.

Reviewed By: q10

Differential Revision: D62114877

fbshipit-source-id: d91ad0be2820ac21033270f14a784a0bc3193d78
  • Loading branch information
sryap authored and facebook-github-bot committed Sep 5, 2024
1 parent e48b9f7 commit 53d84ad
Showing 1 changed file with 58 additions and 15 deletions.
73 changes: 58 additions & 15 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,22 @@ def __init__(
* self.lxu_cache_weights.element_size()
), "The precomputed cache_size does not match the actual cache size"

# For storing weights to evict
# The max number of rows to be evicted is limited by the number of
# slots in the cache. Thus, we allocate `lxu_cache_evicted_weights` to
# be the same shape as the L1 cache (lxu_cache_weights)
self.register_buffer(
"lxu_cache_evicted_weights",
torch.ops.fbgemm.new_unified_tensor(
torch.zeros(
1,
device=self.current_device,
dtype=cache_dtype,
),
self.lxu_cache_weights.shape,
is_host_mapped=self.uvm_host_mapped,
),
)
self.timestep = 0

# Dummy profile configuration for measuring the SSD get/set time
Expand Down Expand Up @@ -419,8 +435,10 @@ def __init__(

# SSD get completion event
self.ssd_event_get = torch.cuda.Event()
# SSD eviction completion event
self.ssd_event_evict = torch.cuda.Event()
# SSD scratch pad eviction completion event
self.ssd_event_sp_evict = torch.cuda.Event()
# SSD cache eviction completion event
self.ssd_event_cache_evict = torch.cuda.Event()
# SSD backward completion event
self.ssd_event_backward = torch.cuda.Event()
# SSD get's input copy completion event
Expand Down Expand Up @@ -876,7 +894,7 @@ def _evict_from_scratch_pad(self, grad: Tensor) -> None:
actions_count_cpu=actions_count_cpu,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_backward,
post_event=self.ssd_event_evict,
post_event=self.ssd_event_sp_evict,
is_rows_uvm=True,
name="scratch_pad",
)
Expand Down Expand Up @@ -1066,13 +1084,35 @@ def prefetch( # noqa C901
self.local_ssd_cache_stats,
)

# Allocate output tensors for compact_indices
compact_evicted_indices = torch.empty_like(evicted_indices)
compact_assigned_cache_slots = torch.empty_like(assigned_cache_slots)
compact_actions_count_gpu = torch.empty_like(actions_count_gpu)

# Defrag indices based on evicted_indices (removing -1 and making
# the non -1 elements contiguous). We need to do this because the
# number of rows in `lxu_cache_evicted_weights` might be smaller
# than the number of elements in `evicted_indices`. Without this
# step, we can run into the index out of bound issue
current_stream.wait_event(self.ssd_event_cache_evict)
torch.ops.fbgemm.compact_indices(
compact_indices=[compact_evicted_indices, compact_assigned_cache_slots],
compact_count=compact_actions_count_gpu,
indices=[evicted_indices, assigned_cache_slots],
masks=torch.where(evicted_indices != -1, 1, 0),
count=actions_count_gpu,
)

evicted_indices = compact_evicted_indices

with record_function("## ssd_d2h_inserted_indices ##"):
# Transfer actions_count and insert_indices right away to
# incrase an overlap opportunity
actions_count_cpu, inserted_indices_cpu = (
actions_count_cpu, compact_actions_count_cpu, inserted_indices_cpu = (
self.to_pinned_cpu_on_stream_wait_on_another_stream(
tensors=[
actions_count_gpu,
compact_actions_count_gpu,
inserted_indices,
],
stream=self.ssd_memcpy_stream,
Expand All @@ -1096,26 +1136,29 @@ def prefetch( # noqa C901
# Copy rows to be evicted into a separate buffer (will be evicted
# later in the prefetch step)
with record_function("## ssd_compute_evicted_rows ##"):
assigned_cache_slots = assigned_cache_slots.long()
evicted_rows = self.lxu_cache_weights[
assigned_cache_slots.clamp(min=0).long(), :
]
torch.ops.fbgemm.masked_index_select(
self.lxu_cache_evicted_weights,
compact_assigned_cache_slots,
self.lxu_cache_weights,
compact_actions_count_gpu,
)

# Allocation a scratch pad for the current iteration. The scratch
# pad is a UVA tensor
inserted_rows_shape = (assigned_cache_slots.numel(), self.max_D)
if linear_cache_indices.numel() > 0:
inserted_rows = torch.ops.fbgemm.new_unified_tensor(
torch.zeros(
1,
device=self.current_device,
dtype=self.lxu_cache_weights.dtype,
),
evicted_rows.shape,
inserted_rows_shape,
is_host_mapped=self.uvm_host_mapped,
)
else:
inserted_rows = torch.empty(
evicted_rows.shape,
inserted_rows_shape,
dtype=self.lxu_cache_weights.dtype,
device=self.current_device,
)
Expand Down Expand Up @@ -1213,7 +1256,7 @@ def prefetch( # noqa C901
)

# Ensure the previous iterations eviction is complete
current_stream.wait_event(self.ssd_event_evict)
current_stream.wait_event(self.ssd_event_sp_evict)
# Ensure that D2H is done
current_stream.wait_event(self.ssd_event_get_inputs_cpy)

Expand Down Expand Up @@ -1250,15 +1293,15 @@ def prefetch( # noqa C901
if linear_cache_indices.numel() > 0:
# Evict rows from cache to SSD
self.evict(
rows=evicted_rows,
rows=self.lxu_cache_evicted_weights,
indices_cpu=evicted_indices_cpu,
actions_count_cpu=actions_count_cpu,
actions_count_cpu=compact_actions_count_cpu,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_get,
# Record completion event after scratch pad eviction
# instead since that happens after L1 eviction
post_event=None,
is_rows_uvm=False,
post_event=self.ssd_event_cache_evict,
is_rows_uvm=True,
name="cache",
is_bwd=False,
)
Expand Down

0 comments on commit 53d84ad

Please sign in to comment.