Skip to content

Commit

Permalink
Optimize the cache fetch for forward split, pt. 2 (pytorch#2217)
Browse files Browse the repository at this point in the history
Summary:

This follows up the work on D51865590 by plumbing the `uvm_cache_stats` argument passing up to the Python API level.  `local_uvm_cache_stats` is now zeroed out before the prefetch step as opposed to after, to allow for the data to be passed into the forward step.

Differential Revision: D51995949
  • Loading branch information
q10 authored and facebook-github-bot committed Dec 27, 2023
1 parent 82bbc93 commit d6cb617
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 14 deletions.
28 changes: 17 additions & 11 deletions fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class {{ autograd_func }} :
const c10::optional<Tensor>& feature_requires_grad,
{%- endif %}
const Tensor& lxu_cache_locations,
c10::optional<Tensor> uvm_cache_stats,
{%- if optimizer != "none" %}
const bool gradient_clipping,
const double max_gradient,
Expand Down Expand Up @@ -196,6 +197,12 @@ class {{ autograd_func }} :
const auto max_B_ = offsets.sym_size(0) / T;
{%- endif %}

// NOTE: The `local_uvm_cache_stats` variable held by the nn.Module has dtype int32_t
// TODO: Hook up with frontend code
const auto uvm_cache_stats_ = uvm_cache_stats
.value_or(at::empty({0}, uvm_weights.options().dtype(at::kInt)));
// const auto uvm_cache_stats_ = at::empty({0}, uvm_weights.options().dtype(at::kInt));

// TODO: don't guard here
auto [info_B_num_bits, info_B_mask] = adjust_info_B_num_bits(max_B_.guard_int(__FILE__, __LINE__), T.guard_int(__FILE__, __LINE__));

Expand Down Expand Up @@ -283,13 +290,6 @@ class {{ autograd_func }} :
const auto& flatten_dev_weights = dev_weights;
{%- endif %}




const auto uvm_cache_stats = at::empty({0}, uvm_weights.options().dtype(at::kInt));



{%- if not nobag %}
{%- for weighted in [False, True] %}
{%- set wdesc = "weighted" if weighted else "unweighted" %}
Expand Down Expand Up @@ -324,7 +324,7 @@ class {{ autograd_func }} :
*indice_weights,
{%- endif %}
lxu_cache_locations,
uvm_cache_stats,
uvm_cache_stats_,
output_dtype,
{%- if vbe %}
vbe_row_output_offsets,
Expand Down Expand Up @@ -355,7 +355,7 @@ class {{ autograd_func }} :
indices,
offsets,
lxu_cache_locations,
uvm_cache_stats,
uvm_cache_stats_,
output_dtype,
/*is_experimental=*/false
)
Expand Down Expand Up @@ -555,6 +555,7 @@ class {{ autograd_func }} :
grad_indice_weights, // indice_weights
Variable(), // feature_requires_grad
Variable(), // lxu_cache_locations
Variable(), // uvm_cache_stats
{%- if optimizer != "none" %}
Variable(), // gradient_clipping
Variable(), // max_gradient
Expand Down Expand Up @@ -628,6 +629,7 @@ class {{ autograd_func }} :
Variable(), // indices
Variable(), // offsets
Variable(), // lxu_cache_locations
Variable(), // uvm_cache_stats
{%- if optimizer != "none" %}
Variable(), // gradient_clipping
Variable(), // max_gradient
Expand Down Expand Up @@ -688,7 +690,8 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
const int64_t vbe_output_size = -1,
const bool is_experimental = false,
const bool use_uniq_cache_locations_bwd = false,
const bool use_homogeneous_placements = false
const bool use_homogeneous_placements = false,
c10::optional<Tensor> uvm_cache_stats = c10::optional<Tensor>()
) {
{%- if has_gpu_support %}
{%- for vbe in ([True, False] if has_vbe_support else [False]) %}
Expand Down Expand Up @@ -738,6 +741,7 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
feature_requires_grad,
{%- endif %}
lxu_cache_locations,
uvm_cache_stats,
{%- if optimizer != "none" %}
gradient_clipping,
max_gradient,
Expand Down Expand Up @@ -802,7 +806,9 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) {
" int vbe_output_size=-1, "
" bool is_experimental=False, "
" bool use_uniq_cache_locations_bwd=False, "
" bool use_homogeneous_placements=False) -> Tensor",
" bool use_homogeneous_placements=False, "
" Tensor? uvm_cache_stats=None"
") -> Tensor",
{PT2_COMPLIANT_TAG});
// We're playing a funny trick here: we're using the autograd
// implementation of the operator at all the dispatch keys. This is OK
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/codegen/lookup_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class CommonArgs(NamedTuple):
indice_weights: Optional[torch.Tensor]
feature_requires_grad: Optional[torch.Tensor]
lxu_cache_locations: torch.Tensor
uvm_cache_stats: Optional[torch.Tensor]
output_dtype: int
vbe_metadata: VBEMetadata
is_experimental: bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def invoke(
indice_weights=common_args.indice_weights,
feature_requires_grad=common_args.feature_requires_grad,
lxu_cache_locations=common_args.lxu_cache_locations,
uvm_cache_stats=common_args.uvm_cache_stats,
# VBE metadata
B_offsets=vbe_metadata.B_offsets,
vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,11 @@ def forward( # noqa: C901
indice_weights=per_sample_weights,
feature_requires_grad=feature_requires_grad,
lxu_cache_locations=self.lxu_cache_locations,
# Pass the local_uvm_cache_stats bc only that information is
# relevant for the current iteration
uvm_cache_stats=self.local_uvm_cache_stats
if self.gather_uvm_cache_stats
else None,
output_dtype=self.output_dtype,
vbe_metadata=vbe_metadata,
is_experimental=self.is_experimental,
Expand Down Expand Up @@ -1160,6 +1165,12 @@ def _prefetch(self, indices: Tensor, offsets: Tensor) -> None:
if not self.lxu_cache_weights.numel():
return

# Clear the local_uvm_cache_stats before the prefetch instead of after
# the prefetch step, since it will be used in the CommonArgs in the
# forward step
if self.gather_uvm_cache_stats:
self.local_uvm_cache_stats.zero_()

linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
self.cache_hash_size_cumsum,
indices,
Expand Down Expand Up @@ -1236,12 +1247,11 @@ def _prefetch(self, indices: Tensor, offsets: Tensor) -> None:

if self.gather_uvm_cache_stats:
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
# We may wanna do this accumulation atomically, but as it's only for monitoring,
# slightly inaccurate result may be acceptable.
# We may want to do this accumulation atomically, but as it's only
# for monitoring, slightly inaccurate result may be acceptable.
self.uvm_cache_stats = torch.add(
self.uvm_cache_stats, self.local_uvm_cache_stats
)
self.local_uvm_cache_stats.zero_()

def _prefetch_tensors_record_stream(
self, forward_stream: torch.cuda.Stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def forward(
indice_weights=per_sample_weights,
feature_requires_grad=feature_requires_grad,
lxu_cache_locations=lxu_cache_locations,
uvm_cache_stats=None,
vbe_metadata=invokers.lookup_args.VBEMetadata(
B_offsets=None,
output_offsets_feature_rank=None,
Expand Down

0 comments on commit d6cb617

Please sign in to comment.