Skip to content

Commit

Permalink
Merge branch 'main' into bf16-emb-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuhaozhe authored Feb 9, 2023
2 parents 7b9eed6 + d88187d commit 026c33e
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 27 deletions.
168 changes: 141 additions & 27 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def construct_cache_state(
return s


# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
"""
Multiple sparse features can share one embedding table.
Expand All @@ -193,6 +195,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
lxu_cache_locations_empty: Tensor
timesteps_prefetched: List[int]
record_cache_metrics: RecordCacheMetrics
uvm_cache_stats: torch.Tensor
local_uvm_cache_stats: torch.Tensor

def __init__( # noqa C901
self,
Expand All @@ -210,6 +214,7 @@ def __init__( # noqa C901
enforce_hbm: bool = False, # place all weights/momentums in HBM when using cache
optimizer: OptimType = OptimType.EXACT_SGD,
record_cache_metrics: Optional[RecordCacheMetrics] = None,
gather_uvm_cache_stats: Optional[bool] = False,
# General Optimizer args
stochastic_rounding: bool = True,
gradient_clipping: bool = False,
Expand Down Expand Up @@ -287,6 +292,13 @@ def __init__( # noqa C901
torch.zeros(0, device=self.current_device, dtype=torch.float)
)

self.gather_uvm_cache_stats = gather_uvm_cache_stats
# Define the size of uvm cache stats as class variable
# to make it work with torch jit script.
self.uvm_cache_stats_size = 6
# 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
# 4: N_conflict_unique_misses, 5: N_conflict_misses

self.int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET

self.feature_table_map: List[int] = (
Expand Down Expand Up @@ -751,6 +763,37 @@ def forward(

raise ValueError(f"Invalid OptimType: {self.optimizer}")

def reset_uvm_cache_stats(self) -> None:
assert (
self.gather_uvm_cache_stats
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
self.uvm_cache_stats.zero_()
self.local_uvm_cache_stats.zero_()

def get_uvm_cache_stats(self) -> Tensor:
assert (
self.gather_uvm_cache_stats
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
return self.uvm_cache_stats

def print_uvm_cache_stats(self) -> None:
assert (
self.gather_uvm_cache_stats
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
uvm_cache_stats = self.uvm_cache_stats.tolist()
logging.info(
f"N_called: {uvm_cache_stats[0]}\n"
f"N_requested_indices: {uvm_cache_stats[1]}\n"
f"N_unique_indices: {uvm_cache_stats[2]}\n"
f"N_unique_misses: {uvm_cache_stats[3]}\n"
f"N_conflict_unique_misses: {uvm_cache_stats[4]}\n"
f"N_conflict_misses: {uvm_cache_stats[5]}\n"
)
logging.info(
f"unique indices / requested indices: {uvm_cache_stats[2]/uvm_cache_stats[1]}\n"
f"unique misses / requested indices: {uvm_cache_stats[3]/uvm_cache_stats[1]}\n"
)

def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
self.timestep += 1
self.timesteps_prefetched.append(self.timestep)
Expand All @@ -775,6 +818,8 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
linear_cache_indices,
self.lxu_cache_state,
self.total_cache_hash_size,
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
)
if self.record_cache_metrics.record_cache_miss_counter:
self._update_cache_miss_counter(
Expand All @@ -799,6 +844,8 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
self.timestep,
self.lxu_state,
self.stochastic_rounding,
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
)
elif self.cache_algorithm == CacheAlgorithm.LFU:
torch.ops.fbgemm.lfu_cache_populate(
Expand All @@ -823,8 +870,18 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
linear_cache_indices,
self.lxu_cache_state,
self.total_cache_hash_size,
self.gather_uvm_cache_stats,
self.local_uvm_cache_stats,
)
)
if self.gather_uvm_cache_stats:
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
# We may wanna do this accumulation atomically, but as it's only for monitoring,
# slightly inaccurate result may be acceptable.
self.uvm_cache_stats = torch.add(
self.uvm_cache_stats, self.local_uvm_cache_stats
)
self.local_uvm_cache_stats.zero_()

def _update_cache_miss_counter(
self,
Expand Down Expand Up @@ -1198,6 +1255,8 @@ def _apply_cache_state(
0, device=self.current_device, dtype=torch.int32
).fill_(-1)

self._init_uvm_cache_stats()

# NOTE: no cache for CPU mode!
if cache_state.total_cache_hash_size == 0 or self.use_cpu:
self.register_buffer(
Expand Down Expand Up @@ -1322,12 +1381,54 @@ def _apply_cache_state(
"cache_miss_counter",
torch.tensor([0, 0], device=self.current_device, dtype=torch.int64),
)

if cache_algorithm not in (CacheAlgorithm.LFU, CacheAlgorithm.LRU):
raise ValueError(
f"cache_algorithm must be {CacheAlgorithm.LRU} "
f"or {CacheAlgorithm.LFU}"
)

def _init_uvm_cache_stats(self) -> None:
if not self.gather_uvm_cache_stats:
# If uvm_cache_stats is not enabled, register stub entries via buffer to state_dict for TorchScript to JIT properly.
# Since we're not using these variables, we can choose minimize tensor size to keep state_dict size small.
self.register_buffer(
"uvm_cache_stats",
torch.zeros(
1,
device=self.current_device,
dtype=torch.int64,
),
persistent=False,
)
self.register_buffer(
"local_uvm_cache_stats",
torch.zeros(
1,
device=self.current_device,
dtype=torch.int32,
),
persistent=False,
)
else:
self.register_buffer(
"uvm_cache_stats",
torch.zeros(
size=(self.uvm_cache_stats_size,),
device=self.current_device,
dtype=torch.int64,
),
)
self.register_buffer(
"local_uvm_cache_stats",
torch.zeros(
size=(self.uvm_cache_stats_size,),
device=self.current_device,
dtype=torch.int32,
),
)
self.reset_uvm_cache_stats()

def reset_cache_states(self) -> None:
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(Tensor.numel)[[Named(self, Tensor)],
Expand Down Expand Up @@ -1631,6 +1732,7 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
record_cache_metrics: RecordCacheMetrics
cache_miss_counter: torch.Tensor
uvm_cache_stats: torch.Tensor
local_uvm_cache_stats: torch.Tensor

def __init__(
self,
Expand Down Expand Up @@ -1948,11 +2050,11 @@ def reset_cache_miss_counter(self) -> None:
)

def reset_uvm_cache_stats(self) -> None:
self.uvm_cache_stats = torch.zeros(
size=(self.uvm_cache_stats_size,),
device=self.current_device,
dtype=torch.int64,
)
assert (
self.gather_uvm_cache_stats
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
self.uvm_cache_stats.zero_()
self.local_uvm_cache_stats.zero_()

def print_cache_miss_counter(self) -> None:
assert (
Expand Down Expand Up @@ -1982,17 +2084,18 @@ def print_uvm_cache_stats(self) -> None:
assert (
self.gather_uvm_cache_stats
), "gather_uvm_cache_stats should be set to true to access uvm cache stats."
uvm_cache_stats = self.uvm_cache_stats.tolist()
logging.info(
f"N_called: {self.uvm_cache_stats[0]}\n"
f"N_requested_indices: {self.uvm_cache_stats[1]}\n"
f"N_unique_indices: {self.uvm_cache_stats[2]}\n"
f"N_unique_misses: {self.uvm_cache_stats[3]}\n"
f"N_conflict_unique_misses: {self.uvm_cache_stats[4]}\n"
f"N_conflict_misses: {self.uvm_cache_stats[5]}\n"
f"N_called: {uvm_cache_stats[0]}\n"
f"N_requested_indices: {uvm_cache_stats[1]}\n"
f"N_unique_indices: {uvm_cache_stats[2]}\n"
f"N_unique_misses: {uvm_cache_stats[3]}\n"
f"N_conflict_unique_misses: {uvm_cache_stats[4]}\n"
f"N_conflict_misses: {uvm_cache_stats[5]}\n"
)
logging.info(
f"unique indices / requested indices: {self.uvm_cache_stats[2]/self.uvm_cache_stats[1]}\n"
f"unique misses / requested indices: {self.uvm_cache_stats[3]/self.uvm_cache_stats[1]}\n"
f"unique indices / requested indices: {uvm_cache_stats[2]/uvm_cache_stats[1]}\n"
f"unique misses / requested indices: {uvm_cache_stats[3]/uvm_cache_stats[1]}\n"
)

@torch.jit.export
Expand Down Expand Up @@ -2050,17 +2153,6 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
raise ValueError(f"{self.cache_assoc} not in [1, 32, 64]")

def prefetch_32way(self, linear_cache_indices: Tensor) -> None:
# uvm cache stats for this batch, using int32.
local_uvm_cache_stats = torch.empty(
0, device=self.current_device, dtype=torch.int32
)
if self.gather_uvm_cache_stats:
# Temporary UVM_CACHE_STATS (int32 counters).
local_uvm_cache_stats = torch.zeros(
size=(self.uvm_cache_stats_size,),
device=self.current_device,
dtype=torch.int32,
)
if self.cache_algorithm == CacheAlgorithm.LRU:
torch.ops.fbgemm.lru_cache_populate_byte(
self.weights_uvm,
Expand All @@ -2077,7 +2169,7 @@ def prefetch_32way(self, linear_cache_indices: Tensor) -> None:
self.lxu_state,
16, # row_alignment; using default value.
self.gather_uvm_cache_stats,
local_uvm_cache_stats,
self.local_uvm_cache_stats,
)
elif self.cache_algorithm == CacheAlgorithm.LFU:
torch.ops.fbgemm.lfu_cache_populate_byte(
Expand All @@ -2103,16 +2195,17 @@ def prefetch_32way(self, linear_cache_indices: Tensor) -> None:
self.lxu_cache_state,
self.total_cache_hash_size,
self.gather_uvm_cache_stats,
local_uvm_cache_stats,
self.local_uvm_cache_stats,
)
)
if self.gather_uvm_cache_stats:
# Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64).
# We may wanna do this accumulation atomically, but as it's only for monitoring,
# slightly inaccurate result may be acceptable.
self.uvm_cache_stats = torch.add(
self.uvm_cache_stats, local_uvm_cache_stats
self.uvm_cache_stats, self.local_uvm_cache_stats
)
self.local_uvm_cache_stats.zero_()

def prefetch_1way(self, linear_cache_indices: Tensor) -> None:
if self.cache_algorithm == CacheAlgorithm.LRU:
Expand Down Expand Up @@ -2445,6 +2538,15 @@ def _apply_cache_state(
),
persistent=False,
)
self.register_buffer(
"local_uvm_cache_stats",
torch.zeros(
size=(self.uvm_cache_stats_size,),
device=self.current_device,
dtype=torch.int32,
),
persistent=False,
)
return

assert cache_load_factor > 0
Expand Down Expand Up @@ -2559,12 +2661,24 @@ def _apply_cache_state(
),
persistent=False,
)
self.register_buffer(
"local_uvm_cache_stats",
torch.zeros(
size=(self.uvm_cache_stats_size,),
device=self.current_device,
dtype=torch.int32,
),
persistent=False,
)
if cache_algorithm not in (CacheAlgorithm.LFU, CacheAlgorithm.LRU):
raise ValueError(
f"cache_algorithm must be {CacheAlgorithm.LRU} "
f"or {CacheAlgorithm.LFU}"
)

if self.gather_uvm_cache_stats:
self.reset_uvm_cache_stats()

def reset_cache_states(self) -> None:
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(Tensor.numel)[[Named(self, Tensor)],
Expand Down
48 changes: 48 additions & 0 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4150,6 +4150,54 @@ def test_cache_miss_counter(self, N: int) -> None:
for i in range(len(tablewise_cache_miss)):
self.assertEqual(tablewise_cache_miss[i], t_tablewise_cache_miss[i])

@given(N=st.integers(min_value=1, max_value=2))
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
def test_stb_uvm_cache_stats(self, N: int) -> None:
# Create an abstract split table
D = 8
T = 2
E = 10**3
Ds = [D] * T
Es = [E] * T
emb_op = (
split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen
)
cc = emb_op(
embedding_specs=[
(
E,
D,
split_table_batched_embeddings_ops.EmbeddingLocation.MANAGED_CACHING,
split_table_batched_embeddings_ops.ComputeDevice.CUDA,
)
for (E, D) in zip(Es, Ds)
],
gather_uvm_cache_stats=True,
)

x = torch.Tensor([[[1], [1]], [[3], [4]]])
x = to_device(torch.tensor(x, dtype=torch.int64), use_cpu=False)

for _ in range(N):
indices, offsets = get_table_batched_offsets_from_dense(x, use_cpu=False)
cc.reset_cache_states()
cc.reset_uvm_cache_stats()
cc(indices, offsets)
(
n_calls,
n_requested_indices,
n_unique_indices,
n_unique_misses,
n_conflict_unique_misses,
n_conflict_misses,
) = cc.get_uvm_cache_stats()
self.assertEqual(n_calls, 1)
self.assertEqual(n_requested_indices, len(indices))
self.assertEqual(n_unique_indices, len(set(indices.tolist())))
self.assertEqual(n_unique_misses, len(set(indices.tolist())))
self.assertEqual(n_conflict_unique_misses, 0)
self.assertEqual(n_conflict_misses, 0)

@unittest.skipIf(*gpu_unavailable)
@given(
L=st.integers(min_value=0, max_value=16),
Expand Down

0 comments on commit 026c33e

Please sign in to comment.