Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add parameter server attributes #2947

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def __init__(
ssd_uniform_init_upper: float = 0.01,
# Parameter Server Configs
ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
ps_max_key_per_request: Optional[int] = None,
ps_client_thread_num: Optional[int] = None,
ps_max_local_index_length: Optional[int] = None,
tbe_unique_id: int = -1, # unique id for this embedding, if not set, will derive based on current rank and tbe index id
) -> None: # noqa C901 # tuple of (rows, dims,)
super(SSDIntNBitTableBatchedEmbeddingBags, self).__init__()
Expand Down Expand Up @@ -291,8 +294,13 @@ def max_ty_D(ty: SparseType) -> int:
[host[0] for host in ps_hosts],
[host[1] for host in ps_hosts],
tbe_unique_id,
54,
32,
(
ps_max_local_index_length
if ps_max_local_index_length is not None
else 54
),
ps_client_thread_num if ps_client_thread_num is not None else 32,
ps_max_key_per_request if ps_max_key_per_request is not None else 500,
)

# pyre-fixme[20]: Argument `self` expected.
Expand Down
12 changes: 10 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def __init__(
pooling_mode: PoolingMode = PoolingMode.SUM,
# Parameter Server Configs
ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
ps_max_key_per_request: Optional[int] = None,
ps_client_thread_num: Optional[int] = None,
ps_max_local_index_length: Optional[int] = None,
tbe_unique_id: int = -1,
# in local test we need to use the pass in path for rocksdb creation
# in production we need to do it inside SSD mount path which will ignores the passed in path
Expand Down Expand Up @@ -314,8 +317,13 @@ def __init__(
[host[0] for host in ps_hosts],
[host[1] for host in ps_hosts],
tbe_unique_id,
54,
32,
(
ps_max_local_index_length
if ps_max_local_index_length is not None
else 54
),
ps_client_thread_num if ps_client_thread_num is not None else 32,
ps_max_key_per_request if ps_max_key_per_request is not None else 500,
)
# pyre-fixme[20]: Argument `self` expected.
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <torch/custom_class.h>
#include "fbgemm_gpu/sparse_ops_utils.h"

#include <folly/logging/xlog.h>

using namespace at;
using namespace ps;

Expand All @@ -22,7 +24,8 @@ class EmbeddingParameterServerWrapper : public torch::jit::CustomClassHolder {
const std::vector<int64_t>& tps_ports,
int64_t tbe_id,
int64_t maxLocalIndexLength = 54,
int64_t num_threads = 32) {
int64_t num_threads = 32,
int64_t maxKeysPerRequest = 500) {
TORCH_CHECK(
tps_ips.size() == tps_ports.size(),
"tps_ips and tps_ports must have the same size");
Expand All @@ -32,7 +35,11 @@ class EmbeddingParameterServerWrapper : public torch::jit::CustomClassHolder {
}

impl_ = std::make_shared<ps::EmbeddingParameterServer>(
std::move(tpsHosts), tbe_id, maxLocalIndexLength, num_threads);
std::move(tpsHosts),
tbe_id,
maxLocalIndexLength,
num_threads,
maxKeysPerRequest);
}

void
Expand Down Expand Up @@ -78,6 +85,7 @@ static auto embedding_parameter_server_wrapper =
const std::vector<int64_t>,
int64_t,
int64_t,
int64_t,
int64_t>())
.def("set_cuda", &EmbeddingParameterServerWrapper::set_cuda)
.def("get_cuda", &EmbeddingParameterServerWrapper::get_cuda)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ class EmbeddingParameterServer : public kv_db::EmbeddingKVDB {
std::vector<std::pair<std::string, int>>&& tps_hosts,
int64_t tbe_id,
int64_t maxLocalIndexLength = 54,
int64_t num_threads = 32)
int64_t num_threads = 32,
int64_t maxKeysPerRequest = 500)
: tps_client_(
std::make_shared<mvai_infra::experimental::ps_training::tps_client::
TrainingParameterServiceClient>(
std::move(tps_hosts),
tbe_id,
maxLocalIndexLength,
num_threads)) {}
num_threads,
maxKeysPerRequest)) {}

void set(
const at::Tensor& indices,
Expand Down