Skip to content

Commit

Permalink
add ssd ods stats counterpart for UVM offloading (#2906)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/torchrec#2248

Pull Request resolved: #2906

add ods stats for unique misses and unique indices

Differential Revision: D59740171
  • Loading branch information
Joe Wang authored and facebook-github-bot committed Jul 29, 2024
1 parent e167847 commit 0528f10
Showing 1 changed file with 95 additions and 0 deletions.
95 changes: 95 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import torch # usort:skip

import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
from fbgemm_gpu.runtime_monitor import (
TBEStatsReporter,
TBEStatsReporterConfig,
)
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
CacheAlgorithm,
Expand All @@ -34,6 +38,7 @@
)

from torch import distributed as dist, nn, Tensor # usort:skip
from pyre_extensions import none_throws
from torch.autograd.profiler import record_function

from .common import ASSOC
Expand Down Expand Up @@ -105,6 +110,8 @@ def __init__(
# in local test we need to use the pass in path for rocksdb creation
# in production we need to do it inside SSD mount path which will ignores the passed in path
use_passed_in_path: int = True,
gather_ssd_cache_stats: Optional[bool] = False,
stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
) -> None:
super(SSDTableBatchedEmbeddingBags, self).__init__()

Expand Down Expand Up @@ -181,6 +188,8 @@ def __init__(
"lru_state", torch.zeros(cache_sets, ASSOC, dtype=torch.int64)
)

self.step = 0

assert ssd_cache_location in (
EmbeddingLocation.MANAGED,
EmbeddingLocation.DEVICE,
Expand Down Expand Up @@ -413,6 +422,31 @@ def __init__(
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
self.optimizer = optimizer

# stats reporter
self.gather_ssd_cache_stats = gather_ssd_cache_stats
self.stats_reporter: Optional[TBEStatsReporter] = (
stats_reporter_config.create_reporter() if stats_reporter_config else None
)
self.ssd_cache_stats_size = 6
# 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
# 4: N_conflict_unique_misses, 5: N_conflict_misses
self.last_reported_ssd_stats: List[float] = []
self.last_reported_step = 0

self.register_buffer(
"ssd_cache_stats",
torch.zeros(
self.ssd_cache_stats_size,
device=self.current_device,
dtype=torch.int64,
),
persistent=False,
)
logging.info(
f"logging stats reporter setup, {self.gather_ssd_cache_stats=}, "
f"stats_reporter:{none_throws(self.stats_reporter) if self.stats_reporter else 'none'}, "
)

# pyre-ignore[3]
def record_function_via_dummy_profile_factory(
self,
Expand Down Expand Up @@ -777,6 +811,11 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
)
)

if self.gather_ssd_cache_stats:
self.ssd_cache_stats[3] += actions_count_gpu[0] # num unique misses
self.ssd_cache_stats[2] += unique_indices_length[0] # num unique indices
self._report_ssd_stats()

def forward(
self,
indices: Tensor,
Expand Down Expand Up @@ -844,6 +883,7 @@ def forward(
)

self.timesteps_prefetched.pop(0)
self.step += 1

if self.optimizer == OptimType.EXACT_SGD:
raise AssertionError(
Expand Down Expand Up @@ -967,3 +1007,58 @@ def flush(self) -> None:
torch.tensor([active_ids.numel()]),
self.timestep,
)

@torch.jit.ignore
def _report_ssd_stats(self) -> None:
"""
Each iteration we will record cache stats about L1 SSD cache in ssd_cache_stats tensor
this function extract those stats and report it with stats_reporter
"""
if self.stats_reporter is None:
return

stats_reporter: TBEStatsReporter = self.stats_reporter
passed_steps = self.step - self.last_reported_step
if passed_steps == 0:
return
if not stats_reporter.should_report(self.step):
return

# ssd hbm cache stats

ssd_cache_stats = self.ssd_cache_stats.tolist()
if len(self.last_reported_ssd_stats) == 0:
self.last_reported_ssd_stats = [0.0] * len(ssd_cache_stats)
ssd_cache_stats_delta: List[float] = [0.0] * len(ssd_cache_stats)
for i in range(len(ssd_cache_stats)):
ssd_cache_stats_delta[i] = (
ssd_cache_stats[i] - self.last_reported_ssd_stats[i]
)
self.last_reported_step = self.step
self.last_reported_ssd_stats = ssd_cache_stats
element_size = self.lxu_cache_weights.element_size()

stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.hbm_cache_stats.num_unique_indices_bytes",
data_bytes=int(
ssd_cache_stats_delta[2] * element_size * self.max_D / passed_steps
),
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.hbm_cache_stats.num_unique_misses_bytes",
data_bytes=int(
ssd_cache_stats_delta[3] * element_size * self.max_D / passed_steps
),
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.hbm_cache_stats.num_unique_indices",
data_bytes=int(ssd_cache_stats_delta[2] / passed_steps),
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.hbm_cache_stats.num_unique_misses",
data_bytes=int(ssd_cache_stats_delta[3] / passed_steps),
)

0 comments on commit 0528f10

Please sign in to comment.