Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ssd ods stats #2906

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@ def gen_split_tbe_generator(
cache_sets=cache_set,
ssd_storage_directory=tempdir,
ssd_cache_location=EmbeddingLocation.MANAGED,
ssd_shards=8,
ssd_block_cache_size=block_cache_size_mb * (2**20),
ssd_rocksdb_shards=8,
ssd_block_cache_size_per_tbe=block_cache_size_mb * (2**20),
**common_args,
),
}
Expand Down
130 changes: 120 additions & 10 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import torch # usort:skip

import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
from fbgemm_gpu.runtime_monitor import (
TBEStatsReporter,
TBEStatsReporterConfig,
)
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
CacheAlgorithm,
Expand All @@ -34,6 +38,7 @@
)

from torch import distributed as dist, nn, Tensor # usort:skip
from pyre_extensions import none_throws
from torch.autograd.profiler import record_function

from .common import ASSOC
Expand Down Expand Up @@ -63,19 +68,19 @@ def __init__(
feature_table_map: Optional[List[int]], # [T]
cache_sets: int,
ssd_storage_directory: str,
ssd_shards: int = 1,
ssd_rocksdb_shards: int = 1,
ssd_memtable_flush_period: int = -1,
ssd_memtable_flush_offset: int = -1,
ssd_l0_files_per_compact: int = 4,
ssd_rate_limit_mbps: int = 0,
ssd_size_ratio: int = 10,
ssd_compaction_trigger: int = 8,
ssd_write_buffer_size: int = 2 * 1024 * 1024 * 1024,
ssd_max_write_buffer_num: int = 16,
ssd_rocksdb_write_buffer_size: int = 2 * 1024 * 1024 * 1024,
ssd_max_write_buffer_num: int = 4,
ssd_cache_location: EmbeddingLocation = EmbeddingLocation.MANAGED,
ssd_uniform_init_lower: float = -0.01,
ssd_uniform_init_upper: float = 0.01,
ssd_block_cache_size: int = 0,
ssd_block_cache_size_per_tbe: int = 0,
weights_precision: SparseType = SparseType.FP32,
output_dtype: SparseType = SparseType.FP32,
optimizer: OptimType = OptimType.EXACT_ROWWISE_ADAGRAD,
Expand All @@ -102,6 +107,11 @@ def __init__(
# Parameter Server Configs
ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
tbe_unique_id: int = -1,
# in local test we need to use the pass in path for rocksdb creation
# in production we need to do it inside SSD mount path which will ignores the passed in path
use_passed_in_path: int = True,
gather_ssd_cache_stats: Optional[bool] = False,
stats_reporter_config: Optional[TBEStatsReporterConfig] = None,
) -> None:
super(SSDTableBatchedEmbeddingBags, self).__init__()

Expand Down Expand Up @@ -161,7 +171,7 @@ def __init__(
cache_size = cache_sets * ASSOC * element_size * self.max_D
logging.info(
f"Using cache for SSD with admission algorithm "
f"{CacheAlgorithm.LRU}, {cache_sets} sets, stored on {'DEVICE' if ssd_cache_location is EmbeddingLocation.DEVICE else 'MANAGED'} with {ssd_shards} shards, "
f"{CacheAlgorithm.LRU}, {cache_sets} sets, stored on {'DEVICE' if ssd_cache_location is EmbeddingLocation.DEVICE else 'MANAGED'} with {ssd_rocksdb_shards} shards, "
f"SSD storage directory: {ssd_storage_directory}, "
f"Memtable Flush Period: {ssd_memtable_flush_period}, "
f"Memtable Flush Offset: {ssd_memtable_flush_offset}, "
Expand All @@ -178,6 +188,8 @@ def __init__(
"lru_state", torch.zeros(cache_sets, ASSOC, dtype=torch.int64)
)

self.step = 0

assert ssd_cache_location in (
EmbeddingLocation.MANAGED,
EmbeddingLocation.DEVICE,
Expand Down Expand Up @@ -249,28 +261,40 @@ def __init__(
)
# logging.info("DEBUG: weights_precision {}".format(weights_precision))
if not ps_hosts:
logging.info(
f"Logging SSD offloading setup "
f"passed_in_path={ssd_directory}, num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
f"memtable_flush_period={ssd_memtable_flush_period},memtable_flush_offset={ssd_memtable_flush_offset},"
f"l0_files_per_compact={ssd_l0_files_per_compact},max_D={self.max_D},rate_limit_mbps={ssd_rate_limit_mbps},"
f"size_ratio={ssd_size_ratio},compaction_trigger={ssd_compaction_trigger},"
f"write_buffer_size_per_tbe={ssd_rocksdb_write_buffer_size},max_write_buffer_num_per_db_shard={ssd_max_write_buffer_num},"
f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper},"
f"row_storage_bitwidth={weights_precision.bit_rate()},block_cache_size_per_tbe={ssd_block_cache_size_per_tbe},"
f"use_passed_in_path:{use_passed_in_path}, real_path will be printed in EmbeddingRocksDB"
)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
ssd_directory,
ssd_shards,
ssd_shards,
ssd_rocksdb_shards,
ssd_rocksdb_shards,
ssd_memtable_flush_period,
ssd_memtable_flush_offset,
ssd_l0_files_per_compact,
self.max_D,
ssd_rate_limit_mbps,
ssd_size_ratio,
ssd_compaction_trigger,
ssd_write_buffer_size,
ssd_rocksdb_write_buffer_size,
ssd_max_write_buffer_num,
ssd_uniform_init_lower,
ssd_uniform_init_upper,
weights_precision.bit_rate(), # row_storage_bitwidth
ssd_block_cache_size,
ssd_block_cache_size_per_tbe,
use_passed_in_path,
)
else:
# create tbe unique id using rank index | pooling mode
# create tbe unique id using rank index | local tbe idx
if tbe_unique_id == -1:
SSDTableBatchedEmbeddingBags._local_instance_index += 1
assert (
Expand Down Expand Up @@ -398,6 +422,31 @@ def __init__(
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
self.optimizer = optimizer

# stats reporter
self.gather_ssd_cache_stats = gather_ssd_cache_stats
self.stats_reporter: Optional[TBEStatsReporter] = (
stats_reporter_config.create_reporter() if stats_reporter_config else None
)
self.ssd_cache_stats_size = 6
# 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses,
# 4: N_conflict_unique_misses, 5: N_conflict_misses
self.last_reported_ssd_stats: List[float] = []
self.last_reported_step = 0

self.register_buffer(
"ssd_cache_stats",
torch.zeros(
self.ssd_cache_stats_size,
device=self.current_device,
dtype=torch.int64,
),
persistent=False,
)
logging.info(
f"logging stats reporter setup, {self.gather_ssd_cache_stats=}, "
f"stats_reporter:{none_throws(self.stats_reporter) if self.stats_reporter else 'none'}, "
)

# pyre-ignore[3]
def record_function_via_dummy_profile_factory(
self,
Expand Down Expand Up @@ -762,6 +811,11 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
)
)

if self.gather_ssd_cache_stats:
self.ssd_cache_stats[3] += actions_count_gpu[0] # num unique misses
self.ssd_cache_stats[2] += unique_indices_length[0] # num unique indices
self._report_ssd_stats()

def forward(
self,
indices: Tensor,
Expand Down Expand Up @@ -829,6 +883,7 @@ def forward(
)

self.timesteps_prefetched.pop(0)
self.step += 1

if self.optimizer == OptimType.EXACT_SGD:
raise AssertionError(
Expand Down Expand Up @@ -952,3 +1007,58 @@ def flush(self) -> None:
torch.tensor([active_ids.numel()]),
self.timestep,
)

@torch.jit.ignore
def _report_ssd_stats(self) -> None:
"""
Each iteration we will record cache stats about L1 SSD cache in ssd_cache_stats tensor
this function extract those stats and report it with stats_reporter
"""
if self.stats_reporter is None:
return

stats_reporter: TBEStatsReporter = self.stats_reporter
passed_steps = self.step - self.last_reported_step
if passed_steps == 0:
return
if not stats_reporter.should_report(self.step):
return

# ssd hbm cache stats

ssd_cache_stats = self.ssd_cache_stats.tolist()
if len(self.last_reported_ssd_stats) == 0:
self.last_reported_ssd_stats = [0.0] * len(ssd_cache_stats)
ssd_cache_stats_delta: List[float] = [0.0] * len(ssd_cache_stats)
for i in range(len(ssd_cache_stats)):
ssd_cache_stats_delta[i] = (
ssd_cache_stats[i] - self.last_reported_ssd_stats[i]
)
self.last_reported_step = self.step
self.last_reported_ssd_stats = ssd_cache_stats
element_size = self.lxu_cache_weights.element_size()

stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.hbm_cache_stats.num_unique_indices_bytes",
data_bytes=int(
ssd_cache_stats_delta[2] * element_size * self.max_D / passed_steps
),
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.hbm_cache_stats.num_unique_misses_bytes",
data_bytes=int(
ssd_cache_stats_delta[3] * element_size * self.max_D / passed_steps
),
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.hbm_cache_stats.num_unique_indices",
data_bytes=int(ssd_cache_stats_delta[2] / passed_steps),
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="ssd.hbm_cache_stats.num_unique_misses",
data_bytes=int(ssd_cache_stats_delta[3] / passed_steps),
)
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
double uniform_init_lower,
double uniform_init_upper,
int64_t row_storage_bitwidth = 32,
int64_t cache_size = 0)
int64_t cache_size = 0,
bool use_passed_in_path = false)
: impl_(std::make_shared<ssd::EmbeddingRocksDB>(
path,
num_shards,
Expand All @@ -130,7 +131,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
uniform_init_lower,
uniform_init_upper,
row_storage_bitwidth,
cache_size)) {}
cache_size,
use_passed_in_path)) {}

void
set_cuda(Tensor indices, Tensor weights, Tensor count, int64_t timestep) {
Expand Down Expand Up @@ -164,23 +166,45 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {

static auto embedding_rocks_db_wrapper =
torch::class_<EmbeddingRocksDBWrapper>("fbgemm", "EmbeddingRocksDBWrapper")
.def(torch::init<
std::string,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
double,
double,
int64_t,
int64_t>())
.def(
torch::init<
std::string,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
double,
double,
int64_t,
int64_t,
bool>(),
"",
{
torch::arg("path"),
torch::arg("num_shards"),
torch::arg("num_threads"),
torch::arg("memtable_flush_period"),
torch::arg("memtable_flush_offset"),
torch::arg("l0_files_per_compact"),
torch::arg("max_D"),
torch::arg("rate_limit_mbps"),
torch::arg("size_ratio"),
torch::arg("compaction_ratio"),
torch::arg("write_buffer_size"),
torch::arg("max_write_buffer_num"),
torch::arg("uniform_init_lower"),
torch::arg("uniform_init_upper"),
torch::arg("row_storage_bitwidth"),
torch::arg("cache_size"),
torch::arg("use_passed_in_path") = true,
})
.def("set_cuda", &EmbeddingRocksDBWrapper::set_cuda)
.def("get_cuda", &EmbeddingRocksDBWrapper::get_cuda)
.def("compact", &EmbeddingRocksDBWrapper::compact)
Expand Down
Loading