diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 800b1e1893..b47366363e 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -19,6 +19,10 @@ import torch # usort:skip import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers +from fbgemm_gpu.runtime_monitor import ( + TBEStatsReporter, + TBEStatsReporterConfig, +) from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( CacheAlgorithm, @@ -34,6 +38,7 @@ ) from torch import distributed as dist, nn, Tensor # usort:skip +from pyre_extensions import none_throws from torch.autograd.profiler import record_function from .common import ASSOC @@ -105,6 +110,8 @@ def __init__( # in local test we need to use the pass in path for rocksdb creation # in production we need to do it inside SSD mount path which will ignores the passed in path use_passed_in_path: int = True, + gather_ssd_cache_stats: Optional[bool] = False, + stats_reporter_config: Optional[TBEStatsReporterConfig] = None, ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() @@ -181,6 +188,8 @@ def __init__( "lru_state", torch.zeros(cache_sets, ASSOC, dtype=torch.int64) ) + self.step = 0 + assert ssd_cache_location in ( EmbeddingLocation.MANAGED, EmbeddingLocation.DEVICE, @@ -413,6 +422,31 @@ def __init__( ), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags" self.optimizer = optimizer + # stats reporter + self.gather_ssd_cache_stats = gather_ssd_cache_stats + self.stats_reporter: Optional[TBEStatsReporter] = ( + stats_reporter_config.create_reporter() if stats_reporter_config else None + ) + self.ssd_cache_stats_size = 6 + # 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses, + # 4: N_conflict_unique_misses, 5: N_conflict_misses + self.last_reported_ssd_stats: List[float] = [] + self.last_reported_step = 0 + + self.register_buffer( + "ssd_cache_stats", + torch.zeros( + self.ssd_cache_stats_size, + device=self.current_device, + dtype=torch.int64, + ), + persistent=False, + ) + logging.info( + f"logging stats reporter setup, {self.gather_ssd_cache_stats=}, " + f"stats_reporter:{none_throws(self.stats_reporter) if self.stats_reporter else 'none'}, " + ) + # pyre-ignore[3] def record_function_via_dummy_profile_factory( self, @@ -777,6 +811,11 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: ) ) + if self.gather_ssd_cache_stats: + self.ssd_cache_stats[3] += actions_count_gpu[0] # num unique misses + self.ssd_cache_stats[2] += unique_indices_length[0] # num unique indices + self._report_ssd_stats() + def forward( self, indices: Tensor, @@ -844,6 +883,7 @@ def forward( ) self.timesteps_prefetched.pop(0) + self.step += 1 if self.optimizer == OptimType.EXACT_SGD: raise AssertionError( @@ -967,3 +1007,58 @@ def flush(self) -> None: torch.tensor([active_ids.numel()]), self.timestep, ) + + @torch.jit.ignore + def _report_ssd_stats(self) -> None: + """ + Each iteration we will record cache stats about L1 SSD cache in ssd_cache_stats tensor + this function extract those stats and report it with stats_reporter + """ + if self.stats_reporter is None: + return + + stats_reporter: TBEStatsReporter = self.stats_reporter + passed_steps = self.step - self.last_reported_step + if passed_steps == 0: + return + if not stats_reporter.should_report(self.step): + return + + # ssd hbm cache stats + + ssd_cache_stats = self.ssd_cache_stats.tolist() + if len(self.last_reported_ssd_stats) == 0: + self.last_reported_ssd_stats = [0.0] * len(ssd_cache_stats) + ssd_cache_stats_delta: List[float] = [0.0] * len(ssd_cache_stats) + for i in range(len(ssd_cache_stats)): + ssd_cache_stats_delta[i] = ( + ssd_cache_stats[i] - self.last_reported_ssd_stats[i] + ) + self.last_reported_step = self.step + self.last_reported_ssd_stats = ssd_cache_stats + element_size = self.lxu_cache_weights.element_size() + + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="ssd.hbm_cache_stats.num_unique_indices_bytes", + data_bytes=int( + ssd_cache_stats_delta[2] * element_size * self.max_D / passed_steps + ), + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="ssd.hbm_cache_stats.num_unique_misses_bytes", + data_bytes=int( + ssd_cache_stats_delta[3] * element_size * self.max_D / passed_steps + ), + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="ssd.hbm_cache_stats.num_unique_indices", + data_bytes=int(ssd_cache_stats_delta[2] / passed_steps), + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="ssd.hbm_cache_stats.num_unique_misses", + data_bytes=int(ssd_cache_stats_delta[3] / passed_steps), + )