Skip to content

Commit

Permalink
Optimize the cache fetch for forward split, pt. 3 (#2216)
Browse files Browse the repository at this point in the history
Summary:

This adds TBE UVM caching benchmarks to support the work on D51865590 stack

Differential Revision: D52177208
  • Loading branch information
q10 authored and facebook-github-bot committed Dec 27, 2023
1 parent b781232 commit a9eb11d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 11 deletions.
73 changes: 69 additions & 4 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ def device( # noqa C901
@click.option("--cache-algorithm", default="lru")
@click.option("--cache-load-factor", default=0.2)
@click.option("--enforce-hbm", is_flag=True, default=False)
@click.option("--no-conflict-misses", is_flag=True, default=False)
@click.option("--all-conflict-misses", is_flag=True, default=False)
def uvm(
alpha: bool,
bag_size: int,
Expand All @@ -354,6 +356,8 @@ def uvm(
cache_algorithm: str,
cache_load_factor: float,
enforce_hbm: bool,
no_conflict_misses: bool,
all_conflict_misses: bool,
) -> None:
np.random.seed(42)
torch.manual_seed(42)
Expand All @@ -369,6 +373,7 @@ def uvm(
), f"T_uvm specified {T_uvm} <= 0. If not testing UVM, please use device benchmark."
T_gpu = T - T_uvm
L_uvm = uvm_bag_size
eval_conflict_misses: bool = no_conflict_misses or all_conflict_misses

cache_alg = CacheAlgorithm.LRU if cache_algorithm == "lru" else CacheAlgorithm.LFU
managed_type = (
Expand All @@ -383,6 +388,7 @@ def uvm(
D = np.average(Ds)
else:
Ds = [D] * T

emb_uvm = SplitTableBatchedEmbeddingBagsCodegen(
[
(
Expand Down Expand Up @@ -481,13 +487,69 @@ def uvm(
+ param_size_multiplier * B * sum(Ds[:T_uvm]) * L_uvm
)

time_per_iter = benchmark_requests(
requests_uvm,
lambda indices, offsets, per_sample_weights: emb_uvm.forward(
if eval_conflict_misses:
assert (
use_cache
), "--use-cache is required for --no-conflict-misses or all-conflict-misses"
assert (no_conflict_misses and not all_conflict_misses) or (
not no_conflict_misses and all_conflict_misses
)
logging.info(
"Evaluate {}: Cache shape {}".format(
"no_conflict_misses" if no_conflict_misses else "all_conflict_misses",
emb_uvm.lxu_cache_weights.shape,
)
)
num_cache_slots = emb_uvm.lxu_cache_weights.shape[0]
for it, (indices, offsets, _) in enumerate(requests_uvm):
num_uniq = 0
all_inverse = []
for t in range(T_uvm):
uniq, inverse = indices[offsets[t * B] : offsets[(t + 1) * B]].unique(
return_inverse=True
)
all_inverse.append(inverse + num_uniq)
num_uniq += uniq.numel()
assert (
num_cache_slots >= num_uniq
), "num_cache_slots < num_uniq: Please increase --cache-load-factor"

# Intercept prefetch
if no_conflict_misses:
locations = np.random.choice(
np.arange(num_cache_slots), size=num_uniq, replace=False
)
locations = (
torch.from_numpy(locations).to(torch.int32).to(indices.device)
)
locations = locations.index_select(
dim=0, index=torch.concat(all_inverse)
)
assert locations.numel() == indices.numel()
else:
locations = torch.full_like(
indices, -1, dtype=torch.int32, device=indices.device
)
emb_uvm.lxu_cache_locations_list.append(locations)
emb_uvm.timesteps_prefetched.append(it)

# pyre-ignore[53]
def run_bench(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor) -> None:
if eval_conflict_misses:
# Set uvm_cache_stats
assert emb_uvm.local_uvm_cache_stats.numel() == emb_uvm.uvm_cache_stats_size
# Use uvm_cache_stats_index::num_conflict_unique_misses
emb_uvm.local_uvm_cache_stats[4] = 0 if no_conflict_misses else 1

emb_uvm.forward(
indices.long(),
offsets.long(),
per_sample_weights,
),
)

time_per_iter = benchmark_requests(
requests_uvm,
run_bench,
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
num_warmups=warmup_runs,
)
Expand All @@ -497,6 +559,9 @@ def uvm(
f"BW: {read_write_bytes_uvm / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"T: {time_per_iter * 1.0e6:.0f}us"
)
print(
f"|{uvm_tables}|{embedding_dim}|{read_write_bytes_uvm / time_per_iter / 1.0e9: .2f}|"
)

if T_gpu > 0:
requests = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1118,17 +1118,16 @@ def reset_uvm_cache_stats(self) -> None:
self.uvm_cache_stats.zero_()
self.local_uvm_cache_stats.zero_()

def get_uvm_cache_stats(self) -> Tensor:
def get_uvm_cache_stats(self, use_local_cache: bool = False) -> Tensor:
assert (
self.gather_uvm_cache_stats
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
return self.uvm_cache_stats
return self.local_uvm_cache_stats if use_local_cache else self.uvm_cache_stats

def print_uvm_cache_stats(self) -> None:
assert (
self.gather_uvm_cache_stats
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
uvm_cache_stats: List[float] = self.uvm_cache_stats.tolist()
def print_uvm_cache_stats(self, use_local_cache: bool = False) -> None:
uvm_cache_stats: List[float] = self.get_uvm_cache_stats(
use_local_cache
).tolist()
logging.info(
f"N_called: {uvm_cache_stats[0]}\n"
f"N_requested_indices: {uvm_cache_stats[1]}\n"
Expand Down

0 comments on commit a9eb11d

Please sign in to comment.