From 360b4ec562d030883f1aaf7f67b946950457b199 Mon Sep 17 00:00:00 2001 From: Sungmin Cho Date: Mon, 21 Aug 2023 22:49:13 -0700 Subject: [PATCH] uvm_cache_stats for direct mapped (#1952) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1952 - Implement python frontend for the previous diff (split for backward compatibility) - Revise the existing benchmark to use the uvm_cache_stats for stats instead of cache_miss_counter. - Implement unit test for uvm_cache_stats for direct mapped. Differential Revision: D48439568 fbshipit-source-id: c921d6fc02e2bbe3a8b77bebb55400c394c2930d --- ...plit_table_batched_embeddings_benchmark.py | 33 +++-- ..._table_batched_embeddings_ops_inference.py | 24 +++- .../split_table_batched_embeddings_test.py | 123 ++++++++++++++++++ 3 files changed, 162 insertions(+), 18 deletions(-) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index 39d056c566..ec4d265535 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -1727,6 +1727,7 @@ def nbit_uvm( @click.option("--fp8-exponent-bits", type=int, default=None) @click.option("--fp8-exponent-bias", type=int, default=None) @click.option("--record-cache", is_flag=True, default=False) +@click.option("--uvm-host-mapped", is_flag=True, default=False) @click.option( "--dump-requests", type=int, default=0, help="number of reqs to dump (0=no dump)" ) @@ -1753,6 +1754,7 @@ def nbit_uvm_compare_direct_mapped( fp8_exponent_bits: Optional[int], fp8_exponent_bias: Optional[int], record_cache: bool, + uvm_host_mapped: bool, dump_requests: int, ) -> None: logging.info(json.dumps({k: str(v) for k, v in locals().items()}, indent=2)) @@ -1837,18 +1839,21 @@ def bench_uvm_cls( enforce_hbm=enforce_hbm, fp8_exponent_bits=fp8_exponent_bits, fp8_exponent_bias=fp8_exponent_bias, - record_cache_metrics=RecordCacheMetrics(record_cache, record_cache), + gather_uvm_cache_stats=record_cache, + uvm_host_mapped=uvm_host_mapped, ).cuda() emb.fill_random_weights() - # label nvtx only when cache counter is off - nvtx_range = "" if record_cache else f"UVM-{name.upper()}" - callback_after_warmup = emb.reset_cache_miss_counter if record_cache else None - requests = requests_uvm[:1] if record_cache else requests_uvm + nvtx_range = ( + f"UVM-RECORD-CACHE-{name.upper()}" + if record_cache + else f"UVM-{name.upper()}" + ) + callback_after_warmup = emb.reset_uvm_cache_stats if record_cache else None torch.cuda.cudart().cudaProfilerStart() time_per_iter = benchmark_requests( - requests, + requests_uvm, lambda indices, offsets, per_sample_weights: emb.forward( indices.int(), offsets.int(), @@ -1881,12 +1886,14 @@ def bench_uvm_cls( ) if record_cache: - cmc = emb.cache_miss_counter.detach().cpu().numpy().tolist() + ucs = emb.uvm_cache_stats.detach().cpu().numpy().tolist() cache_stats = { - "miss_forward_count": cmc[0], - "unique_miss": cmc[1], - "unique_req": cmc[2], - "nondedup_req": cmc[3], + "num_calls": ucs[0], + "num_requested_indices": ucs[1], + "num_unique_indices": ucs[2], + "num_unique_misses": ucs[3], + "num_conflict_unique_misses": ucs[4], + "num_conflict_misses": ucs[5], } stats[name]["cache_stats"] = cache_stats logging.info(f"[{name:>8s}] cache stats {cache_stats}") @@ -1932,6 +1939,7 @@ def bench_uvm_cls( @click.option("--batch-size", default=512) @click.option("--cache-algorithm", default="lru") @click.option("--cache-load-factor", default=0.2) +@click.option("--cache-assoc", default=32) @click.option("--embedding-dim", default=128) @click.option("--weights-precision", type=SparseType, default=SparseType.INT4) @click.option("--iters", default=100) @@ -1954,6 +1962,7 @@ def nbit_cache( # noqa C901 batch_size: int, cache_algorithm: str, cache_load_factor: float, + cache_assoc: int, embedding_dim: int, weights_precision: SparseType, iters: int, @@ -2003,6 +2012,7 @@ def nbit_cache( # noqa C901 enforce_hbm=enforce_hbm, fp8_exponent_bits=fp8_exponent_bits, fp8_exponent_bias=fp8_exponent_bias, + cache_assoc=cache_assoc, ).cuda() emb_nc.fill_random_weights() @@ -2027,6 +2037,7 @@ def nbit_cache( # noqa C901 enforce_hbm=enforce_hbm, fp8_exponent_bits=fp8_exponent_bits, fp8_exponent_bias=fp8_exponent_bias, + cache_assoc=cache_assoc, ).cuda() emb.fill_random_weights() diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index cc3b0aab78..4a7c7e4784 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -584,13 +584,7 @@ def prefetch_32way(self, linear_cache_indices: Tensor) -> None: ) ) if self.gather_uvm_cache_stats: - # Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64). - # We may wanna do this accumulation atomically, but as it's only for monitoring, - # slightly inaccurate result may be acceptable. - self.uvm_cache_stats = torch.add( - self.uvm_cache_stats, self.local_uvm_cache_stats - ) - self.local_uvm_cache_stats.zero_() + self._accumulate_uvm_cache_stats() def prefetch_1way(self, linear_cache_indices: Tensor) -> None: if self.cache_algorithm == CacheAlgorithm.LRU: @@ -608,6 +602,9 @@ def prefetch_1way(self, linear_cache_indices: Tensor) -> None: self.timestep_counter.get(), self.lxu_state, self.lxu_cache_miss_timestamp, + 16, # row_alignment; using default value. + self.gather_uvm_cache_stats, + self.local_uvm_cache_stats, ) else: raise ValueError("Direct Mapped for LRU only") @@ -620,8 +617,21 @@ def prefetch_1way(self, linear_cache_indices: Tensor) -> None: linear_cache_indices, self.lxu_cache_state, self.total_cache_hash_size, + self.gather_uvm_cache_stats, + self.local_uvm_cache_stats, ) ) + if self.gather_uvm_cache_stats: + self._accumulate_uvm_cache_stats() + + def _accumulate_uvm_cache_stats(self) -> None: + # Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64). + # We may wanna do this accumulation atomically, but as it's only for monitoring, + # slightly inaccurate result may be acceptable. + self.uvm_cache_stats = torch.add( + self.uvm_cache_stats, self.local_uvm_cache_stats + ) + self.local_uvm_cache_stats.zero_() def _update_cache_miss_counter( self, diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index c05d7f85fe..ddce4c0ea5 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -5575,6 +5575,129 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None: self.assertEqual(num_conflict_miss, e[1]) cc1.reset_uvm_cache_stats() + @unittest.skipIf(*gpu_unavailable) + @given( + N=st.integers(min_value=1, max_value=8), + dtype=st.sampled_from([SparseType.INT8, SparseType.INT4, SparseType.INT2]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_nbit_direct_mapped_uvm_cache_stats( + self, N: int, dtype: SparseType + ) -> None: + # Create an abstract split table + D = 8 + T = 2 + E = 10**3 + Ds = [D] * T + Es = [E] * T + cc = IntNBitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + "", + E, + D, + dtype, + EmbeddingLocation.MANAGED_CACHING, + ) + for (E, D) in zip(Es, Ds) + ], + device=torch.cuda.current_device(), + gather_uvm_cache_stats=True, + cache_assoc=1, # Direct Mapped + ) + cc.fill_random_weights() + + # Create fake input data and the target output + x1 = torch.Tensor([[[1], [1]], [[3], [4]]]).cuda() + x2 = torch.Tensor([[[2], [1]], [[3], [4]]]).cuda() + x3 = torch.Tensor([[[5], [6]], [[7], [8]]]).cuda() + + xs = [x1, x2, x3] + # num_unique_indices, num_unique_misses + # note that these are cumulative over calls; and also "unique" is per batch. + target_counter_list = [[3, 3], [4, 4], [4, 8]] + num_calls_expected = 0 + num_indices_expcted = 0 + num_unique_indices_expected = 0 + for x, t_counter in zip(xs, target_counter_list): + (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False) + for _ in range(N): + num_calls_expected = num_calls_expected + 1 + num_indices_expcted = num_indices_expcted + len(indices) + cc(indices.int(), offsets.int()) + ( + num_calls, + num_indices, + num_unique_indices, + num_unique_misses, + num_conflict_unique_miss, + num_conflict_miss, + ) = cc.get_uvm_cache_stats().cpu() + # Note num_unique_indices is cumulative stats. + num_unique_indices_expected = num_unique_indices_expected + t_counter[0] + self.assertEqual(num_calls, num_calls_expected) + self.assertEqual(num_indices, num_indices_expcted) + self.assertEqual(num_unique_indices, 0) # N/A for Direct Mapped + self.assertEqual(num_unique_misses, 0) # N/A for Direct Mapped + self.assertEqual( + num_conflict_unique_miss, t_counter[1] + ) # number of actually inserted rows for Direct Mapped + self.assertEqual(num_conflict_miss, 0) + + T = 1 # for simplicity + Ds = [D] * T + Es = [E] * T + cc1 = IntNBitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + "", + E, + D, + SparseType.INT8, + EmbeddingLocation.MANAGED_CACHING, + ) + for (E, D) in zip(Es, Ds) + ], + device=torch.cuda.current_device(), + gather_uvm_cache_stats=True, + cache_sets=1, # Only one set. + cache_assoc=1, # Direct Mapped + ) + cc1.fill_random_weights() + + associativty = 1 # Direct-Mapped + repetition = 17 + indices1 = torch.Tensor( + [[list(range(0, associativty))] * repetition] + ).cuda() # no conflict miss + indices2 = torch.Tensor( + [[list(range(0, associativty + 1))] * repetition] + ).cuda() # 1 * 17 conflict miss per request + indices3 = torch.Tensor( + [[list(range(0, associativty + 10))] * repetition] + ).cuda() # 10 * 17 conflict misses per request + + # num_conflict_unique_miss, num_conflict_miss + expected = [[1, 0], [1, 17], [1, 170]] + + accum_num_conflict_miss = 0 + for x, e in zip((indices1, indices2, indices3), expected): + (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False) + for _ in range(N): + cc1(indices.int(), offsets.int()) + ( + _, + _, + _, + _, + num_conflict_unique_miss, + num_conflict_miss, + ) = cc1.get_uvm_cache_stats().cpu() + # for DM this represents number of actually inserted rows + self.assertEqual(num_conflict_unique_miss, e[0]) + accum_num_conflict_miss += e[1] + self.assertEqual(num_conflict_miss, accum_num_conflict_miss) + @given( T=st.integers(min_value=1, max_value=64), B=st.integers(min_value=1, max_value=64),