From e1678473bd210d73b027d554a098bec13b7739fc Mon Sep 17 00:00:00 2001 From: Guanqiao Wang Date: Mon, 29 Jul 2024 12:31:47 -0700 Subject: [PATCH 1/2] rocksdb setting adjustment (#2898) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2898 add some optimized rocksdb setting into the trunk Differential Revision: D59740167 Reviewed By: sryap --- .../ssd_table_batched_embeddings_benchmark.py | 4 +- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 35 +++++--- .../ssd_split_table_batched_embeddings.cpp | 62 +++++++++----- .../ssd_table_batched_embeddings.h | 82 ++++++++++++++++++- .../tbe/ssd/ssd_split_tbe_training_test.py | 2 +- 5 files changed, 149 insertions(+), 36 deletions(-) diff --git a/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py index 0dfbfd7d87..b75b8874ab 100644 --- a/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py @@ -354,8 +354,8 @@ def gen_split_tbe_generator( cache_sets=cache_set, ssd_storage_directory=tempdir, ssd_cache_location=EmbeddingLocation.MANAGED, - ssd_shards=8, - ssd_block_cache_size=block_cache_size_mb * (2**20), + ssd_rocksdb_shards=8, + ssd_block_cache_size_per_tbe=block_cache_size_mb * (2**20), **common_args, ), } diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 64a39a444c..800b1e1893 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -63,19 +63,19 @@ def __init__( feature_table_map: Optional[List[int]], # [T] cache_sets: int, ssd_storage_directory: str, - ssd_shards: int = 1, + ssd_rocksdb_shards: int = 1, ssd_memtable_flush_period: int = -1, ssd_memtable_flush_offset: int = -1, ssd_l0_files_per_compact: int = 4, ssd_rate_limit_mbps: int = 0, ssd_size_ratio: int = 10, ssd_compaction_trigger: int = 8, - ssd_write_buffer_size: int = 2 * 1024 * 1024 * 1024, - ssd_max_write_buffer_num: int = 16, + ssd_rocksdb_write_buffer_size: int = 2 * 1024 * 1024 * 1024, + ssd_max_write_buffer_num: int = 4, ssd_cache_location: EmbeddingLocation = EmbeddingLocation.MANAGED, ssd_uniform_init_lower: float = -0.01, ssd_uniform_init_upper: float = 0.01, - ssd_block_cache_size: int = 0, + ssd_block_cache_size_per_tbe: int = 0, weights_precision: SparseType = SparseType.FP32, output_dtype: SparseType = SparseType.FP32, optimizer: OptimType = OptimType.EXACT_ROWWISE_ADAGRAD, @@ -102,6 +102,9 @@ def __init__( # Parameter Server Configs ps_hosts: Optional[Tuple[Tuple[str, int]]] = None, tbe_unique_id: int = -1, + # in local test we need to use the pass in path for rocksdb creation + # in production we need to do it inside SSD mount path which will ignores the passed in path + use_passed_in_path: int = True, ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() @@ -161,7 +164,7 @@ def __init__( cache_size = cache_sets * ASSOC * element_size * self.max_D logging.info( f"Using cache for SSD with admission algorithm " - f"{CacheAlgorithm.LRU}, {cache_sets} sets, stored on {'DEVICE' if ssd_cache_location is EmbeddingLocation.DEVICE else 'MANAGED'} with {ssd_shards} shards, " + f"{CacheAlgorithm.LRU}, {cache_sets} sets, stored on {'DEVICE' if ssd_cache_location is EmbeddingLocation.DEVICE else 'MANAGED'} with {ssd_rocksdb_shards} shards, " f"SSD storage directory: {ssd_storage_directory}, " f"Memtable Flush Period: {ssd_memtable_flush_period}, " f"Memtable Flush Offset: {ssd_memtable_flush_offset}, " @@ -249,12 +252,23 @@ def __init__( ) # logging.info("DEBUG: weights_precision {}".format(weights_precision)) if not ps_hosts: + logging.info( + f"Logging SSD offloading setup " + f"passed_in_path={ssd_directory}, num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards}," + f"memtable_flush_period={ssd_memtable_flush_period},memtable_flush_offset={ssd_memtable_flush_offset}," + f"l0_files_per_compact={ssd_l0_files_per_compact},max_D={self.max_D},rate_limit_mbps={ssd_rate_limit_mbps}," + f"size_ratio={ssd_size_ratio},compaction_trigger={ssd_compaction_trigger}," + f"write_buffer_size_per_tbe={ssd_rocksdb_write_buffer_size},max_write_buffer_num_per_db_shard={ssd_max_write_buffer_num}," + f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper}," + f"row_storage_bitwidth={weights_precision.bit_rate()},block_cache_size_per_tbe={ssd_block_cache_size_per_tbe}," + f"use_passed_in_path:{use_passed_in_path}, real_path will be printed in EmbeddingRocksDB" + ) # pyre-fixme[4]: Attribute must be annotated. # pyre-ignore[16] self.ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper( ssd_directory, - ssd_shards, - ssd_shards, + ssd_rocksdb_shards, + ssd_rocksdb_shards, ssd_memtable_flush_period, ssd_memtable_flush_offset, ssd_l0_files_per_compact, @@ -262,15 +276,16 @@ def __init__( ssd_rate_limit_mbps, ssd_size_ratio, ssd_compaction_trigger, - ssd_write_buffer_size, + ssd_rocksdb_write_buffer_size, ssd_max_write_buffer_num, ssd_uniform_init_lower, ssd_uniform_init_upper, weights_precision.bit_rate(), # row_storage_bitwidth - ssd_block_cache_size, + ssd_block_cache_size_per_tbe, + use_passed_in_path, ) else: - # create tbe unique id using rank index | pooling mode + # create tbe unique id using rank index | local tbe idx if tbe_unique_id == -1: SSDTableBatchedEmbeddingBags._local_instance_index += 1 assert ( diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index ab91d93f4e..afc58c8262 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -113,7 +113,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { double uniform_init_lower, double uniform_init_upper, int64_t row_storage_bitwidth = 32, - int64_t cache_size = 0) + int64_t cache_size = 0, + bool use_passed_in_path = false) : impl_(std::make_shared( path, num_shards, @@ -130,7 +131,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { uniform_init_lower, uniform_init_upper, row_storage_bitwidth, - cache_size)) {} + cache_size, + use_passed_in_path)) {} void set_cuda(Tensor indices, Tensor weights, Tensor count, int64_t timestep) { @@ -164,23 +166,45 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { static auto embedding_rocks_db_wrapper = torch::class_("fbgemm", "EmbeddingRocksDBWrapper") - .def(torch::init< - std::string, - int64_t, - int64_t, - int64_t, - int64_t, - int64_t, - int64_t, - int64_t, - int64_t, - int64_t, - int64_t, - int64_t, - double, - double, - int64_t, - int64_t>()) + .def( + torch::init< + std::string, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t, + double, + double, + int64_t, + int64_t, + bool>(), + "", + { + torch::arg("path"), + torch::arg("num_shards"), + torch::arg("num_threads"), + torch::arg("memtable_flush_period"), + torch::arg("memtable_flush_offset"), + torch::arg("l0_files_per_compact"), + torch::arg("max_D"), + torch::arg("rate_limit_mbps"), + torch::arg("size_ratio"), + torch::arg("compaction_ratio"), + torch::arg("write_buffer_size"), + torch::arg("max_write_buffer_num"), + torch::arg("uniform_init_lower"), + torch::arg("uniform_init_upper"), + torch::arg("row_storage_bitwidth"), + torch::arg("cache_size"), + torch::arg("use_passed_in_path") = true, + }) .def("set_cuda", &EmbeddingRocksDBWrapper::set_cuda) .def("get_cuda", &EmbeddingRocksDBWrapper::get_cuda) .def("compact", &EmbeddingRocksDBWrapper::compact) diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 6c0a51dd99..a7bebf168e 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -10,7 +10,15 @@ #include #include +#ifdef FBGEMM_FBCODE +#include "common/network/PortUtil.h" +#include "common/strings/UUID.h" +#include "fb_rocksdb/DBMonitor/DBMonitor.h" +#include "fb_rocksdb/FbRocksDb.h" +#include "rocks/utils/FB303Stats.h" +#endif #include "kv_db_table_batched_embeddings.h" +#include "torch/csrc/autograd/record_function_ops.h" namespace ssd { @@ -28,6 +36,11 @@ inline size_t db_shard(int64_t id, size_t num_shards) { // We can be a bit sloppy with host memory here. constexpr size_t kRowInitBufferSize = 32 * 1024; +#ifdef FBGEMM_FBCODE +constexpr size_t num_ssd_drives = 8; +const std::string ssd_mount_point = "/data00_nvidia"; +#endif + class Initializer { public: Initializer( @@ -117,7 +130,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { float uniform_init_lower, float uniform_init_upper, int64_t row_storage_bitwidth = 32, - int64_t cache_size = 0) { + int64_t cache_size = 0, + bool use_passed_in_path = false) { // TODO: lots of tunables. NNI or something for this? rocksdb::Options options; options.create_if_missing = true; @@ -126,7 +140,12 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { options.compression = rocksdb::kNoCompression; // Lots of free memory on the TC, use large write buffers. - options.write_buffer_size = write_buffer_size; + // max_write_buffer_num is per rocksdb shard level, write_buffer_size is tbe + // level to calc individual buffer size we need to have total buffer size + // per tbe / # db shards / # buffer per shards + int64_t write_buffer_size_per_buffer = + int64_t(write_buffer_size / num_shards / max_write_buffer_num); + options.write_buffer_size = write_buffer_size_per_buffer; options.max_write_buffer_number = max_write_buffer_num; options.min_write_buffer_number_to_merge = 2; options.target_file_size_base = int64_t(2) * 1024 * 1024 * 1024; @@ -146,9 +165,14 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { // options.allow_concurrent_memtable_write = false; // options.inplace_update_support = true; // Full Pipeline Options - options.allow_concurrent_memtable_write = true; + options.allow_concurrent_memtable_write = false; options.enable_write_thread_adaptive_yield = true; - options.inplace_update_support = false; + // inplace_update_support = false means we will apend kv pair in write + // buffer even we saw duplications, this quickly fills up the buffer and + // causing flush set this to true to make update on the existing key + // allow_concurrent_memtable_write is toggled in pair with + // inplace_update_support + options.inplace_update_support = true; options.avoid_unnecessary_blocking_io = true; options.use_direct_reads = true; @@ -161,9 +185,17 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { options.rate_limiter = rate_limiter_; // TODO: use fb303? +#ifdef FBGEMM_FBCODE + options.statistics = + std::make_shared("tbe_metrics"); +#else options.statistics = rocksdb::CreateDBStatistics(); +#endif options.stats_dump_period_sec = 600; + // no bloom filter on the last level, checkout https://fburl.com/ne99girf + options.optimize_filters_for_hits = true; + rocksdb::BlockBasedTableOptions table_options; if (cache_size > 0) { @@ -191,10 +223,49 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { options.env->SetBackgroundThreads(1, rocksdb::Env::LOW); options.max_open_files = -1; + +#ifdef FBGEMM_FBCODE + auto serviceInfo = std::make_shared(); + serviceInfo->oncall = "pyper_training"; + serviceInfo->service_name = "ssd_offloading_rocksb"; + auto db_monitor_options = facebook::fb_rocksdb::DBMonitorOptions(); + db_monitor_options.fb303Prefix = "tbe_metrics"; + + std::string tbe_uuid = ""; + if (!use_passed_in_path) { + path = ssd_mount_point; + tbe_uuid = facebook::strings::generateUUID(); + } + std::string used_path = ""; +#endif for (auto i = 0; i < num_shards; ++i) { +#ifdef FBGEMM_FBCODE + int ssd_drive_idx = i % num_ssd_drives; + std::string ssd_idx_tbe_id_str = ""; + if (!use_passed_in_path) { + ssd_idx_tbe_id_str = + std::to_string(ssd_drive_idx) + std::string("/") + tbe_uuid; + } + auto shard_path = + path + ssd_idx_tbe_id_str + std::string("_shard") + std::to_string(i); + used_path += shard_path + ", "; +#else auto shard_path = path + std::string("/shard_") + std::to_string(i); +#endif rocksdb::DB* db; + +#ifdef FBGEMM_FBCODE + db_monitor_options.port = facebook::network::getFreePort(); + auto s = facebook::fb_rocksdb::openRocksDB( + options, + shard_path, + &db, + serviceInfo, + facebook::fb_rocksdb::getDefaultProfileOptions(), + db_monitor_options); +#else auto s = rocksdb::DB::Open(options, shard_path, &db); +#endif if (!s.ok() && s.code() == rocksdb::Status::kInvalidArgument && (options.use_direct_reads || options.use_direct_io_for_flush_and_compaction)) { @@ -221,6 +292,9 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { row_storage_bitwidth)); } } +#ifdef FBGEMM_FBCODE + LOG(INFO) << "TBE actual used_path: " << used_path; +#endif executor_ = std::make_unique(num_shards); ro_.verify_checksums = false; ro_.async_io = true; diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index e5aabaad0b..7bfdaea6e1 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -216,7 +216,7 @@ def generate_ssd_tbes( ssd_uniform_init_upper=0.1, learning_rate=lr, eps=eps, - ssd_shards=ssd_shards, + ssd_rocksdb_shards=ssd_shards, optimizer=optimizer, pooling_mode=pooling_mode, weights_precision=weights_precision, From 0528f103d2c9e0a0456396228b1ad82b5be35b35 Mon Sep 17 00:00:00 2001 From: Joe Wang Date: Mon, 29 Jul 2024 14:35:21 -0700 Subject: [PATCH 2/2] add ssd ods stats counterpart for UVM offloading (#2906) Summary: X-link: https://github.com/pytorch/torchrec/pull/2248 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2906 add ods stats for unique misses and unique indices Differential Revision: D59740171 --- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 95 +++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 800b1e1893..b47366363e 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -19,6 +19,10 @@ import torch # usort:skip import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers +from fbgemm_gpu.runtime_monitor import ( + TBEStatsReporter, + TBEStatsReporterConfig, +) from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( CacheAlgorithm, @@ -34,6 +38,7 @@ ) from torch import distributed as dist, nn, Tensor # usort:skip +from pyre_extensions import none_throws from torch.autograd.profiler import record_function from .common import ASSOC @@ -105,6 +110,8 @@ def __init__( # in local test we need to use the pass in path for rocksdb creation # in production we need to do it inside SSD mount path which will ignores the passed in path use_passed_in_path: int = True, + gather_ssd_cache_stats: Optional[bool] = False, + stats_reporter_config: Optional[TBEStatsReporterConfig] = None, ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() @@ -181,6 +188,8 @@ def __init__( "lru_state", torch.zeros(cache_sets, ASSOC, dtype=torch.int64) ) + self.step = 0 + assert ssd_cache_location in ( EmbeddingLocation.MANAGED, EmbeddingLocation.DEVICE, @@ -413,6 +422,31 @@ def __init__( ), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags" self.optimizer = optimizer + # stats reporter + self.gather_ssd_cache_stats = gather_ssd_cache_stats + self.stats_reporter: Optional[TBEStatsReporter] = ( + stats_reporter_config.create_reporter() if stats_reporter_config else None + ) + self.ssd_cache_stats_size = 6 + # 0: N_calls, 1: N_requested_indices, 2: N_unique_indices, 3: N_unique_misses, + # 4: N_conflict_unique_misses, 5: N_conflict_misses + self.last_reported_ssd_stats: List[float] = [] + self.last_reported_step = 0 + + self.register_buffer( + "ssd_cache_stats", + torch.zeros( + self.ssd_cache_stats_size, + device=self.current_device, + dtype=torch.int64, + ), + persistent=False, + ) + logging.info( + f"logging stats reporter setup, {self.gather_ssd_cache_stats=}, " + f"stats_reporter:{none_throws(self.stats_reporter) if self.stats_reporter else 'none'}, " + ) + # pyre-ignore[3] def record_function_via_dummy_profile_factory( self, @@ -777,6 +811,11 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: ) ) + if self.gather_ssd_cache_stats: + self.ssd_cache_stats[3] += actions_count_gpu[0] # num unique misses + self.ssd_cache_stats[2] += unique_indices_length[0] # num unique indices + self._report_ssd_stats() + def forward( self, indices: Tensor, @@ -844,6 +883,7 @@ def forward( ) self.timesteps_prefetched.pop(0) + self.step += 1 if self.optimizer == OptimType.EXACT_SGD: raise AssertionError( @@ -967,3 +1007,58 @@ def flush(self) -> None: torch.tensor([active_ids.numel()]), self.timestep, ) + + @torch.jit.ignore + def _report_ssd_stats(self) -> None: + """ + Each iteration we will record cache stats about L1 SSD cache in ssd_cache_stats tensor + this function extract those stats and report it with stats_reporter + """ + if self.stats_reporter is None: + return + + stats_reporter: TBEStatsReporter = self.stats_reporter + passed_steps = self.step - self.last_reported_step + if passed_steps == 0: + return + if not stats_reporter.should_report(self.step): + return + + # ssd hbm cache stats + + ssd_cache_stats = self.ssd_cache_stats.tolist() + if len(self.last_reported_ssd_stats) == 0: + self.last_reported_ssd_stats = [0.0] * len(ssd_cache_stats) + ssd_cache_stats_delta: List[float] = [0.0] * len(ssd_cache_stats) + for i in range(len(ssd_cache_stats)): + ssd_cache_stats_delta[i] = ( + ssd_cache_stats[i] - self.last_reported_ssd_stats[i] + ) + self.last_reported_step = self.step + self.last_reported_ssd_stats = ssd_cache_stats + element_size = self.lxu_cache_weights.element_size() + + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="ssd.hbm_cache_stats.num_unique_indices_bytes", + data_bytes=int( + ssd_cache_stats_delta[2] * element_size * self.max_D / passed_steps + ), + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="ssd.hbm_cache_stats.num_unique_misses_bytes", + data_bytes=int( + ssd_cache_stats_delta[3] * element_size * self.max_D / passed_steps + ), + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="ssd.hbm_cache_stats.num_unique_indices", + data_bytes=int(ssd_cache_stats_delta[2] / passed_steps), + ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name="ssd.hbm_cache_stats.num_unique_misses", + data_bytes=int(ssd_cache_stats_delta[3] / passed_steps), + )