Skip to content

Commit

Permalink
uvm_cache_stats for direct mapped (pytorch#1952)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1952

- Implement python frontend for the previous diff (split for backward compatibility)
- Revise the existing benchmark to use the uvm_cache_stats for stats instead of cache_miss_counter.
- Implement unit test for uvm_cache_stats for direct mapped.

Differential Revision: D48439568

fbshipit-source-id: 3909b9925ce192a75069b1eb8e9a29e8901119c1
  • Loading branch information
SungMinCho authored and facebook-github-bot committed Sep 4, 2023
1 parent 2131e9d commit b30fa60
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 18 deletions.
33 changes: 22 additions & 11 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,6 +1727,7 @@ def nbit_uvm(
@click.option("--fp8-exponent-bits", type=int, default=None)
@click.option("--fp8-exponent-bias", type=int, default=None)
@click.option("--record-cache", is_flag=True, default=False)
@click.option("--uvm-host-mapped", is_flag=True, default=False)
@click.option(
"--dump-requests", type=int, default=0, help="number of reqs to dump (0=no dump)"
)
Expand All @@ -1753,6 +1754,7 @@ def nbit_uvm_compare_direct_mapped(
fp8_exponent_bits: Optional[int],
fp8_exponent_bias: Optional[int],
record_cache: bool,
uvm_host_mapped: bool,
dump_requests: int,
) -> None:
logging.info(json.dumps({k: str(v) for k, v in locals().items()}, indent=2))
Expand Down Expand Up @@ -1837,18 +1839,21 @@ def bench_uvm_cls(
enforce_hbm=enforce_hbm,
fp8_exponent_bits=fp8_exponent_bits,
fp8_exponent_bias=fp8_exponent_bias,
record_cache_metrics=RecordCacheMetrics(record_cache, record_cache),
gather_uvm_cache_stats=record_cache,
uvm_host_mapped=uvm_host_mapped,
).cuda()
emb.fill_random_weights()

# label nvtx only when cache counter is off
nvtx_range = "" if record_cache else f"UVM-{name.upper()}"
callback_after_warmup = emb.reset_cache_miss_counter if record_cache else None
requests = requests_uvm[:1] if record_cache else requests_uvm
nvtx_range = (
f"UVM-RECORD-CACHE-{name.upper()}"
if record_cache
else f"UVM-{name.upper()}"
)
callback_after_warmup = emb.reset_uvm_cache_stats if record_cache else None

torch.cuda.cudart().cudaProfilerStart()
time_per_iter = benchmark_requests(
requests,
requests_uvm,
lambda indices, offsets, per_sample_weights: emb.forward(
indices.int(),
offsets.int(),
Expand Down Expand Up @@ -1881,12 +1886,14 @@ def bench_uvm_cls(
)

if record_cache:
cmc = emb.cache_miss_counter.detach().cpu().numpy().tolist()
ucs = emb.uvm_cache_stats.detach().cpu().numpy().tolist()
cache_stats = {
"miss_forward_count": cmc[0],
"unique_miss": cmc[1],
"unique_req": cmc[2],
"nondedup_req": cmc[3],
"num_calls": ucs[0],
"num_requested_indices": ucs[1],
"num_unique_indices": ucs[2],
"num_unique_misses": ucs[3],
"num_conflict_unique_misses": ucs[4],
"num_conflict_misses": ucs[5],
}
stats[name]["cache_stats"] = cache_stats
logging.info(f"[{name:>8s}] cache stats {cache_stats}")
Expand Down Expand Up @@ -1932,6 +1939,7 @@ def bench_uvm_cls(
@click.option("--batch-size", default=512)
@click.option("--cache-algorithm", default="lru")
@click.option("--cache-load-factor", default=0.2)
@click.option("--cache-assoc", default=32)
@click.option("--embedding-dim", default=128)
@click.option("--weights-precision", type=SparseType, default=SparseType.INT4)
@click.option("--iters", default=100)
Expand All @@ -1954,6 +1962,7 @@ def nbit_cache( # noqa C901
batch_size: int,
cache_algorithm: str,
cache_load_factor: float,
cache_assoc: int,
embedding_dim: int,
weights_precision: SparseType,
iters: int,
Expand Down Expand Up @@ -2003,6 +2012,7 @@ def nbit_cache( # noqa C901
enforce_hbm=enforce_hbm,
fp8_exponent_bits=fp8_exponent_bits,
fp8_exponent_bias=fp8_exponent_bias,
cache_assoc=cache_assoc,
).cuda()
emb_nc.fill_random_weights()

Expand All @@ -2027,6 +2037,7 @@ def nbit_cache( # noqa C901
enforce_hbm=enforce_hbm,
fp8_exponent_bits=fp8_exponent_bits,
fp8_exponent_bias=fp8_exponent_bias,
cache_assoc=cache_assoc,
).cuda()
emb.fill_random_weights()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -594,13 +594,7 @@ def prefetch_32way(self, linear_cache_indices: Tensor) -> None:
)
)
if self.gather_uvm_cache_stats:
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
# We may wanna do this accumulation atomically, but as it's only for monitoring,
# slightly inaccurate result may be acceptable.
self.uvm_cache_stats = torch.add(
self.uvm_cache_stats, self.local_uvm_cache_stats
)
self.local_uvm_cache_stats.zero_()
self._accumulate_uvm_cache_stats()

def prefetch_1way(self, linear_cache_indices: Tensor) -> None:
if self.cache_algorithm == CacheAlgorithm.LRU:
Expand All @@ -618,6 +612,9 @@ def prefetch_1way(self, linear_cache_indices: Tensor) -> None:
self.timestep_counter.get(),
self.lxu_state,
self.lxu_cache_miss_timestamp,
16, # row_alignment; using default value.
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
)
else:
raise ValueError("Direct Mapped for LRU only")
Expand All @@ -630,8 +627,21 @@ def prefetch_1way(self, linear_cache_indices: Tensor) -> None:
linear_cache_indices,
self.lxu_cache_state,
self.total_cache_hash_size,
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
)
)
if self.gather_uvm_cache_stats:
self._accumulate_uvm_cache_stats()

def _accumulate_uvm_cache_stats(self) -> None:
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
# We may wanna do this accumulation atomically, but as it's only for monitoring,
# slightly inaccurate result may be acceptable.
self.uvm_cache_stats = torch.add(
self.uvm_cache_stats, self.local_uvm_cache_stats
)
self.local_uvm_cache_stats.zero_()

def _update_cache_miss_counter(
self,
Expand Down
123 changes: 123 additions & 0 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5575,6 +5575,129 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None:
self.assertEqual(num_conflict_miss, e[1])
cc1.reset_uvm_cache_stats()

@unittest.skipIf(*gpu_unavailable)
@given(
N=st.integers(min_value=1, max_value=8),
dtype=st.sampled_from([SparseType.INT8, SparseType.INT4, SparseType.INT2]),
)
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
def test_nbit_direct_mapped_uvm_cache_stats(
self, N: int, dtype: SparseType
) -> None:
# Create an abstract split table
D = 8
T = 2
E = 10**3
Ds = [D] * T
Es = [E] * T
cc = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(
"",
E,
D,
dtype,
EmbeddingLocation.MANAGED_CACHING,
)
for (E, D) in zip(Es, Ds)
],
device=torch.cuda.current_device(),
gather_uvm_cache_stats=True,
cache_assoc=1, # Direct Mapped
)
cc.fill_random_weights()

# Create fake input data and the target output
x1 = torch.Tensor([[[1], [1]], [[3], [4]]]).cuda()
x2 = torch.Tensor([[[2], [1]], [[3], [4]]]).cuda()
x3 = torch.Tensor([[[5], [6]], [[7], [8]]]).cuda()

xs = [x1, x2, x3]
# num_unique_indices, num_unique_misses
# note that these are cumulative over calls; and also "unique" is per batch.
target_counter_list = [[3, 3], [4, 4], [4, 8]]
num_calls_expected = 0
num_indices_expcted = 0
num_unique_indices_expected = 0
for x, t_counter in zip(xs, target_counter_list):
(indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False)
for _ in range(N):
num_calls_expected = num_calls_expected + 1
num_indices_expcted = num_indices_expcted + len(indices)
cc(indices.int(), offsets.int())
(
num_calls,
num_indices,
num_unique_indices,
num_unique_misses,
num_conflict_unique_miss,
num_conflict_miss,
) = cc.get_uvm_cache_stats().cpu()
# Note num_unique_indices is cumulative stats.
num_unique_indices_expected = num_unique_indices_expected + t_counter[0]
self.assertEqual(num_calls, num_calls_expected)
self.assertEqual(num_indices, num_indices_expcted)
self.assertEqual(num_unique_indices, 0) # N/A for Direct Mapped
self.assertEqual(num_unique_misses, 0) # N/A for Direct Mapped
self.assertEqual(
num_conflict_unique_miss, t_counter[1]
) # number of actually inserted rows for Direct Mapped
self.assertEqual(num_conflict_miss, 0)

T = 1 # for simplicity
Ds = [D] * T
Es = [E] * T
cc1 = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(
"",
E,
D,
SparseType.INT8,
EmbeddingLocation.MANAGED_CACHING,
)
for (E, D) in zip(Es, Ds)
],
device=torch.cuda.current_device(),
gather_uvm_cache_stats=True,
cache_sets=1, # Only one set.
cache_assoc=1, # Direct Mapped
)
cc1.fill_random_weights()

associativty = 1 # Direct-Mapped
repetition = 17
indices1 = torch.Tensor(
[[list(range(0, associativty))] * repetition]
).cuda() # no conflict miss
indices2 = torch.Tensor(
[[list(range(0, associativty + 1))] * repetition]
).cuda() # 1 * 17 conflict miss per request
indices3 = torch.Tensor(
[[list(range(0, associativty + 10))] * repetition]
).cuda() # 10 * 17 conflict misses per request

# num_conflict_unique_miss, num_conflict_miss
expected = [[1, 0], [1, 17], [1, 170]]

accum_num_conflict_miss = 0
for x, e in zip((indices1, indices2, indices3), expected):
(indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False)
for _ in range(N):
cc1(indices.int(), offsets.int())
(
_,
_,
_,
_,
num_conflict_unique_miss,
num_conflict_miss,
) = cc1.get_uvm_cache_stats().cpu()
# for DM this represents number of actually inserted rows
self.assertEqual(num_conflict_unique_miss, e[0])
accum_num_conflict_miss += e[1]
self.assertEqual(num_conflict_miss, accum_num_conflict_miss)

@given(
T=st.integers(min_value=1, max_value=64),
B=st.integers(min_value=1, max_value=64),
Expand Down

0 comments on commit b30fa60

Please sign in to comment.