Skip to content

Commit

Permalink
add ssd ods stats
Browse files Browse the repository at this point in the history
Summary:
add ods stats

1. rocksdb mem stats
2. rocskdb  internal metrics enablement
3. rocksdb io performance stats
4. ssd offloading stats reporting
5. prefetch duration calculation

Differential Revision: D59740171
  • Loading branch information
Joe Wang authored and facebook-github-bot committed Jul 26, 2024
1 parent 9901736 commit 3c63748
Show file tree
Hide file tree
Showing 11 changed files with 400 additions and 15 deletions.
4 changes: 3 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/runtime_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def report_duration(
duration_ms: float,
embedding_id: str = "",
tbe_id: str = "",
time_unit: str = "ms",
) -> None:
"""
Report the duration of a timed event.
Expand Down Expand Up @@ -77,9 +78,10 @@ def report_duration(
duration_ms: float,
embedding_id: str = "",
tbe_id: str = "",
time_unit: str = "ms",
) -> None:
logging.info(
f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} took {duration_ms} ms"
f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} took {duration_ms} {time_unit}"
)

def report_data_amount(
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Tensor:
evicted_rows_cpu,
actions_count_cpu,
self.timestep_counter.get(),
False, # is_bwd
)
# TODO: is this needed?
# Need a way to synchronize
Expand Down
275 changes: 274 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict
# pyre-ignore-all-errors[13,56]

import contextlib
import functools
import itertools
import logging
Expand All @@ -19,6 +20,11 @@
import torch # usort:skip

import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
from fbgemm_gpu.runtime_monitor import (
AsyncSeriesTimer,
TBEStatsReporter,
TBEStatsReporterConfig,
)
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
CacheAlgorithm,
Expand All @@ -34,6 +40,7 @@
)

from torch import distributed as dist, nn, Tensor # usort:skip
from pyre_extensions import none_throws
from torch.autograd.profiler import record_function

from .common import ASSOC
Expand Down Expand Up @@ -105,6 +112,8 @@ def __init__(
use_passed_in_path: int = True,
# in local test we need to use the pass in path for rocksdb creation
# in production we need to do it inside SSD mount path which will ignores the passed in path
gather_ssd_cache_stats: Optional[bool] = False,
stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
) -> None:
super(SSDTableBatchedEmbeddingBags, self).__init__()

Expand Down Expand Up @@ -181,6 +190,8 @@ def __init__(
"lru_state", torch.zeros(cache_sets, ASSOC, dtype=torch.int64)
)

self.step = 0

assert ssd_cache_location in (
EmbeddingLocation.MANAGED,
EmbeddingLocation.DEVICE,
Expand Down Expand Up @@ -413,6 +424,87 @@ def __init__(
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
self.optimizer = optimizer

# stats reporter
self.gather_ssd_cache_stats = gather_ssd_cache_stats
self.stats_reporter: Optional[TBEStatsReporter] = (
stats_reporter_config.create_reporter() if stats_reporter_config else None
)
self.ssd_cache_stats_size = 6
# 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
# 4: N_conflict_unique_misses, 5: N_conflict_misses
self.last_reported_ssd_stats: List[float] = []
self.last_reported_step = 0

self.register_buffer(
"ssd_cache_stats",
torch.zeros(
self.ssd_cache_stats_size,
device=self.current_device,
dtype=torch.int64,
),
persistent=False,
)
logging.info(
f"logging stats reporter setup, {self.gather_ssd_cache_stats=}, "
f"stats_reporter:{none_throws(self.stats_reporter) if self.stats_reporter else 'none'}, "
)

# prefetch launch a series of kernels, we use AsyncSeriesTimer to track the kernel time
self.ssd_prefetch_read_timer: Optional[AsyncSeriesTimer] = None
self.ssd_prefetch_evict_timer: Optional[AsyncSeriesTimer] = None
self.prefetch_parallel_stream_cnt: int = 2
# tuple of iteration, prefetch parallel stream cnt, reported duration
# since there are 2 stream in parallel in prefetch, we want to count the longest one
self.prefetch_duration_us: Tuple[int, int, float] = (
-1,
self.prefetch_parallel_stream_cnt,
0,
)
if self.stats_reporter:
self.ssd_prefetch_read_timer = AsyncSeriesTimer(
functools.partial(
SSDTableBatchedEmbeddingBags._report_duration,
self,
event_name="tbe.prefetch_duration_us",
time_unit="us",
)
)
self.ssd_prefetch_evict_timer = AsyncSeriesTimer(
functools.partial(
SSDTableBatchedEmbeddingBags._report_duration,
self,
event_name="tbe.prefetch_duration_us",
time_unit="us",
)
)

@torch.jit.ignore
def _report_duration(
self,
it_step: int,
dur_ms: float,
event_name: str,
time_unit: str,
) -> None:
recorded_itr, stream_cnt, report_val = self.prefetch_duration_us
duration = dur_ms
if time_unit == "us":
duration = dur_ms * 1000
if it_step == recorded_itr:
report_val = max(report_val, duration)
stream_cnt -= 1
else:
recorded_itr = it_step
report_val = duration
stream_cnt = self.prefetch_parallel_stream_cnt
self.prefetch_duration_us = (recorded_itr, stream_cnt, report_val)

if stream_cnt == 1:
# this is the last stream, handling ods report
none_throws(self.stats_reporter).report_duration(
it_step, event_name, report_val, time_unit=time_unit
)

# pyre-ignore[3]
def record_function_via_dummy_profile_factory(
self,
Expand Down Expand Up @@ -526,6 +618,7 @@ def evict(
post_event: torch.cuda.Event,
is_rows_uvm: bool,
name: Optional[str] = "",
is_bwd: bool = True,
) -> None:
"""
Evict data from the given input tensors to SSD via RocksDB
Expand Down Expand Up @@ -564,6 +657,7 @@ def evict(
rows_cpu,
actions_count_cpu,
self.timestep,
is_bwd,
)

# TODO: is this needed?
Expand Down Expand Up @@ -732,6 +826,10 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
current_stream.wait_event(self.ssd_event_evict_sp)
current_stream.wait_event(self.ssd_event_get_inputs_cpy)

if self.gather_ssd_cache_stats:
# call to collect past SSD IO dur right before next rocksdb IO
self._report_ssd_io_stats()

if linear_cache_indices.numel() > 0:
self.record_function_via_dummy_profile(
"## ssd_get ##",
Expand Down Expand Up @@ -760,6 +858,7 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
post_event=self.ssd_event_evict,
is_rows_uvm=False,
name="cache",
is_bwd=False,
)

# TODO: keep only necessary tensors
Expand All @@ -777,6 +876,12 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
)
)

if self.gather_ssd_cache_stats:
self.ssd_cache_stats[3] += actions_count_gpu[0] # num unique misses
self.ssd_cache_stats[2] += unique_indices_length[0] # num unique indices
self._report_ssd_mem_usage()
self._report_ssd_stats()

def forward(
self,
indices: Tensor,
Expand All @@ -789,7 +894,16 @@ def forward(
if per_sample_weights is not None:
per_sample_weights = per_sample_weights.float()
if len(self.timesteps_prefetched) == 0:
self.prefetch(indices, offsets)
with self._recording_to_timer(
self.ssd_prefetch_read_timer,
context=self.step,
stream=torch.cuda.current_stream(),
), self._recording_to_timer(
self.ssd_prefetch_evict_timer,
context=self.step,
stream=self.ssd_eviction_stream,
):
self.prefetch(indices, offsets)
assert len(self.ssd_prefetch_data) > 0

prefetch_data = self.ssd_prefetch_data.pop(0)
Expand Down Expand Up @@ -844,6 +958,7 @@ def forward(
)

self.timesteps_prefetched.pop(0)
self.step += 1

if self.optimizer == OptimType.EXACT_SGD:
raise AssertionError(
Expand Down Expand Up @@ -966,4 +1081,162 @@ def flush(self) -> None:
active_weights,
torch.tensor([active_ids.numel()]),
self.timestep,
False,
)

@torch.jit.ignore
def _report_ssd_io_stats(self) -> None:
"""
EmbeddingRocksDB will hold stats for total read/write duration in fwd/bwd
this function fetch the stats from EmbeddingRocksDB and report it with stats_reporter
"""
if self.stats_reporter is None:
return

stats_reporter: TBEStatsReporter = self.stats_reporter
if not stats_reporter.should_report(self.step):
return

ssd_io_duration = self.ssd_db.get_io_duration(
self.step, stats_reporter.report_interval # pyre-ignore
)

if len(ssd_io_duration) != 3:
logging.error("ssd io duration should have 3 elements")
return

ssd_read_dur_us = ssd_io_duration[0]
fwd_ssd_write_dur_us = ssd_io_duration[1]
bwd_ssd_write_dur_us = ssd_io_duration[2]

stats_reporter.report_duration(
iteration_step=self.step,
event_name="ssd.io_duration.read_us",
duration_ms=ssd_read_dur_us,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="ssd.io_duration.fwd_write_us",
duration_ms=fwd_ssd_write_dur_us,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="ssd.io_duration.bwd_write_us",
duration_ms=bwd_ssd_write_dur_us,
time_unit="us",
)

@torch.jit.ignore
def _report_ssd_stats(self) -> None:
"""
Each iteration we will record cache stats about L1 SSD cache in ssd_cache_stats tensor
this function extract those stats and report it with stats_reporter
"""
if self.stats_reporter is None:
return

stats_reporter: TBEStatsReporter = self.stats_reporter
passed_steps = self.step - self.last_reported_step
if passed_steps == 0:
return
if not stats_reporter.should_report(self.step):
return

# ssd hbm cache stats

ssd_cache_stats = self.ssd_cache_stats.tolist()
if len(self.last_reported_ssd_stats) == 0:
self.last_reported_ssd_stats = [0.0] * len(ssd_cache_stats)
ssd_cache_stats_delta: List[float] = [0.0] * len(ssd_cache_stats)
for i in range(len(ssd_cache_stats)):
ssd_cache_stats_delta[i] = (
ssd_cache_stats[i] - self.last_reported_ssd_stats[i]
)
self.last_reported_step = self.step
self.last_reported_ssd_stats = ssd_cache_stats
element_size = self.lxu_cache_weights.element_size()

stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.hbm_cache_stats.num_unique_indices_bytes",
data_bytes=int(
ssd_cache_stats_delta[2] * element_size * self.max_D / passed_steps
),
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.hbm_cache_stats.num_unique_misses_bytes",
data_bytes=int(
ssd_cache_stats_delta[3] * element_size * self.max_D / passed_steps
),
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.hbm_cache_stats.num_unique_indices",
data_bytes=int(ssd_cache_stats_delta[2] / passed_steps),
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.hbm_cache_stats.num_unique_misses",
data_bytes=int(ssd_cache_stats_delta[3] / passed_steps),
)

@torch.jit.ignore
def _report_ssd_mem_usage(
self,
) -> None:
"""
rocskdb has internal stats for dram mem usage, here we call EmbeddingRocksDB to
extract those stats out and report it with stats_reporter
"""
if self.stats_reporter is None:
return

stats_reporter: TBEStatsReporter = self.stats_reporter
if not stats_reporter.should_report(self.step):
return

mem_usage_list = self.ssd_db.get_mem_usage()
block_cache_usage = mem_usage_list[0]
estimate_table_reader_usage = mem_usage_list[1]
memtable_usage = mem_usage_list[2]
block_cache_pinned_usage = mem_usage_list[3]
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.mem_usage.block_cache",
data_bytes=block_cache_usage,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.mem_usage.estimate_table_reader",
data_bytes=estimate_table_reader_usage,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.mem_usage.memtable",
data_bytes=memtable_usage,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.mem_usage.block_cache_pinned",
data_bytes=block_cache_pinned_usage,
)

# pyre-ignore
def _recording_to_timer(
self, timer: Optional[AsyncSeriesTimer], **kwargs: Any
) -> Any:
"""
helper function to call AsyncSeriesTimer, wrap it inside the kernels we want to record
"""
if self.stats_reporter is not None and self.stats_reporter.should_report(
self.step
):
assert (
timer
), "We shouldn't be here, async timer must have been initiated if reporter is present."
return timer.recording(**kwargs)
# No-Op context manager
return contextlib.nullcontext()
Loading

0 comments on commit 3c63748

Please sign in to comment.