From 5960645f37e4ad7ea71c87234af76b16434dc34a Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 19 Dec 2023 11:08:52 -0800 Subject: [PATCH] Optimize the cache fetch for forward split, pt. 3 (#2216) Summary: This adds TBE UVM caching benchmarks to support the work on D51865590 stack Differential Revision: D52177208 --- ...plit_table_batched_embeddings_benchmark.py | 73 ++++++++++++++++++- ...t_table_batched_embeddings_ops_training.py | 13 ++-- 2 files changed, 75 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index cb7d30a817..dfa79f48a4 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -330,6 +330,8 @@ def device( # noqa C901 @click.option("--cache-algorithm", default="lru") @click.option("--cache-load-factor", default=0.2) @click.option("--enforce-hbm", is_flag=True, default=False) +@click.option("--no-conflict-misses", is_flag=True, default=False) +@click.option("--all-conflict-misses", is_flag=True, default=False) def uvm( alpha: bool, bag_size: int, @@ -354,6 +356,8 @@ def uvm( cache_algorithm: str, cache_load_factor: float, enforce_hbm: bool, + no_conflict_misses: bool, + all_conflict_misses: bool, ) -> None: np.random.seed(42) torch.manual_seed(42) @@ -369,6 +373,7 @@ def uvm( ), f"T_uvm specified {T_uvm} <= 0. If not testing UVM, please use device benchmark." T_gpu = T - T_uvm L_uvm = uvm_bag_size + eval_conflict_misses: bool = no_conflict_misses or all_conflict_misses cache_alg = CacheAlgorithm.LRU if cache_algorithm == "lru" else CacheAlgorithm.LFU managed_type = ( @@ -383,6 +388,7 @@ def uvm( D = np.average(Ds) else: Ds = [D] * T + emb_uvm = SplitTableBatchedEmbeddingBagsCodegen( [ ( @@ -481,13 +487,69 @@ def uvm( + param_size_multiplier * B * sum(Ds[:T_uvm]) * L_uvm ) - time_per_iter = benchmark_requests( - requests_uvm, - lambda indices, offsets, per_sample_weights: emb_uvm.forward( + if eval_conflict_misses: + assert ( + use_cache + ), "--use-cache is required for --no-conflict-misses or all-conflict-misses" + assert (no_conflict_misses and not all_conflict_misses) or ( + not no_conflict_misses and all_conflict_misses + ) + logging.info( + "Evaluate {}: Cache shape {}".format( + "no_conflict_misses" if no_conflict_misses else "all_conflict_misses", + emb_uvm.lxu_cache_weights.shape, + ) + ) + num_cache_slots = emb_uvm.lxu_cache_weights.shape[0] + for it, (indices, offsets, _) in enumerate(requests_uvm): + num_uniq = 0 + all_inverse = [] + for t in range(T_uvm): + uniq, inverse = indices[offsets[t * B] : offsets[(t + 1) * B]].unique( + return_inverse=True + ) + all_inverse.append(inverse + num_uniq) + num_uniq += uniq.numel() + assert ( + num_cache_slots >= num_uniq + ), "num_cache_slots < num_uniq: Please increase --cache-load-factor" + + # Intercept prefetch + if no_conflict_misses: + locations = np.random.choice( + np.arange(num_cache_slots), size=num_uniq, replace=False + ) + locations = ( + torch.from_numpy(locations).to(torch.int32).to(indices.device) + ) + locations = locations.index_select( + dim=0, index=torch.concat(all_inverse) + ) + assert locations.numel() == indices.numel() + else: + locations = torch.full_like( + indices, -1, dtype=torch.int32, device=indices.device + ) + emb_uvm.lxu_cache_locations_list.append(locations) + emb_uvm.timesteps_prefetched.append(it) + + # pyre-ignore[53] + def run_bench(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor) -> None: + if eval_conflict_misses: + # Set uvm_cache_stats + assert emb_uvm.local_uvm_cache_stats.numel() == emb_uvm.uvm_cache_stats_size + # Use uvm_cache_stats_index::num_conflict_unique_misses + emb_uvm.local_uvm_cache_stats[4] = 0 if no_conflict_misses else 1 + + emb_uvm.forward( indices.long(), offsets.long(), per_sample_weights, - ), + ) + + time_per_iter = benchmark_requests( + requests_uvm, + run_bench, flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, num_warmups=warmup_runs, ) @@ -497,6 +559,9 @@ def uvm( f"BW: {read_write_bytes_uvm / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us" ) + print( + f"|{uvm_tables}|{embedding_dim}|{read_write_bytes_uvm / time_per_iter / 1.0e9: .2f}|" + ) if T_gpu > 0: requests = [] diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index a0c8232ab7..c566a8baa6 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1118,17 +1118,16 @@ def reset_uvm_cache_stats(self) -> None: self.uvm_cache_stats.zero_() self.local_uvm_cache_stats.zero_() - def get_uvm_cache_stats(self) -> Tensor: + def get_uvm_cache_stats(self, use_local_cache: bool = False) -> Tensor: assert ( self.gather_uvm_cache_stats ), "gather_uvm_cache_stats should be set to true to access uvm cache stats." - return self.uvm_cache_stats + return self.local_uvm_cache_stats if use_local_cache else self.uvm_cache_stats - def print_uvm_cache_stats(self) -> None: - assert ( - self.gather_uvm_cache_stats - ), "gather_uvm_cache_stats should be set to true to access uvm cache stats." - uvm_cache_stats: List[float] = self.uvm_cache_stats.tolist() + def print_uvm_cache_stats(self, use_local_cache: bool = False) -> None: + uvm_cache_stats: List[float] = self.get_uvm_cache_stats( + use_local_cache + ).tolist() logging.info( f"N_called: {uvm_cache_stats[0]}\n" f"N_requested_indices: {uvm_cache_stats[1]}\n"