Skip to content

Commit

Permalink
add parameter server attributes
Browse files Browse the repository at this point in the history
Summary:
### THIS DIFF

We added more Parameter Server attributes for  KeyValueParams
- ps_client_thread_num
- ps_max_key_per_request
- ps_max_local_index_length

Reviewed By: q10

Differential Revision: D60793394
  • Loading branch information
Franco Mo authored and facebook-github-bot committed Aug 7, 2024
1 parent 0ebb3ae commit e2914e2
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 8 deletions.
12 changes: 10 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def __init__(
ssd_uniform_init_upper: float = 0.01,
# Parameter Server Configs
ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
ps_max_key_per_request: Optional[int] = None,
ps_client_thread_num: Optional[int] = None,
ps_max_local_index_length: Optional[int] = None,
tbe_unique_id: int = -1, # unique id for this embedding, if not set, will derive based on current rank and tbe index id
) -> None: # noqa C901 # tuple of (rows, dims,)
super(SSDIntNBitTableBatchedEmbeddingBags, self).__init__()
Expand Down Expand Up @@ -291,8 +294,13 @@ def max_ty_D(ty: SparseType) -> int:
[host[0] for host in ps_hosts],
[host[1] for host in ps_hosts],
tbe_unique_id,
54,
32,
(
ps_max_local_index_length
if ps_max_local_index_length is not None
else 54
),
ps_client_thread_num if ps_client_thread_num is not None else 32,
ps_max_key_per_request if ps_max_key_per_request is not None else 500,
)

# pyre-fixme[20]: Argument `self` expected.
Expand Down
12 changes: 10 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def __init__(
pooling_mode: PoolingMode = PoolingMode.SUM,
# Parameter Server Configs
ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
ps_max_key_per_request: Optional[int] = None,
ps_client_thread_num: Optional[int] = None,
ps_max_local_index_length: Optional[int] = None,
tbe_unique_id: int = -1,
# in local test we need to use the pass in path for rocksdb creation
# in production we need to do it inside SSD mount path which will ignores the passed in path
Expand Down Expand Up @@ -314,8 +317,13 @@ def __init__(
[host[0] for host in ps_hosts],
[host[1] for host in ps_hosts],
tbe_unique_id,
54,
32,
(
ps_max_local_index_length
if ps_max_local_index_length is not None
else 54
),
ps_client_thread_num if ps_client_thread_num is not None else 32,
ps_max_key_per_request if ps_max_key_per_request is not None else 500,
)
# pyre-fixme[20]: Argument `self` expected.
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <torch/custom_class.h>
#include "fbgemm_gpu/sparse_ops_utils.h"

#include <folly/logging/xlog.h>

using namespace at;
using namespace ps;

Expand All @@ -22,7 +24,8 @@ class EmbeddingParameterServerWrapper : public torch::jit::CustomClassHolder {
const std::vector<int64_t>& tps_ports,
int64_t tbe_id,
int64_t maxLocalIndexLength = 54,
int64_t num_threads = 32) {
int64_t num_threads = 32,
int64_t maxKeysPerRequest = 500) {
TORCH_CHECK(
tps_ips.size() == tps_ports.size(),
"tps_ips and tps_ports must have the same size");
Expand All @@ -32,7 +35,11 @@ class EmbeddingParameterServerWrapper : public torch::jit::CustomClassHolder {
}

impl_ = std::make_shared<ps::EmbeddingParameterServer>(
std::move(tpsHosts), tbe_id, maxLocalIndexLength, num_threads);
std::move(tpsHosts),
tbe_id,
maxLocalIndexLength,
num_threads,
maxKeysPerRequest);
}

void
Expand Down Expand Up @@ -78,6 +85,7 @@ static auto embedding_parameter_server_wrapper =
const std::vector<int64_t>,
int64_t,
int64_t,
int64_t,
int64_t>())
.def("set_cuda", &EmbeddingParameterServerWrapper::set_cuda)
.def("get_cuda", &EmbeddingParameterServerWrapper::get_cuda)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ class EmbeddingParameterServer : public kv_db::EmbeddingKVDB {
std::vector<std::pair<std::string, int>>&& tps_hosts,
int64_t tbe_id,
int64_t maxLocalIndexLength = 54,
int64_t num_threads = 32)
int64_t num_threads = 32,
int64_t maxKeysPerRequest = 500)
: tps_client_(
std::make_shared<mvai_infra::experimental::ps_training::tps_client::
TrainingParameterServiceClient>(
std::move(tps_hosts),
tbe_id,
maxLocalIndexLength,
num_threads)) {}
num_threads,
maxKeysPerRequest)) {}

void set(
const at::Tensor& indices,
Expand Down

0 comments on commit e2914e2

Please sign in to comment.