Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Add initial support of distributed sampling (#171)
Browse files Browse the repository at this point in the history
This PR introduces support for distributed graph sampling (via NCCL backend). The initial implementation focuses on the uniform neighbor sampler.
We are going to extend it to support other samplers in future.

Highlights:
- Distributed Graph Storage: Now, the graph structure (represented by `row_ptr` and `col_indx` tensors) can be stored as wholememory arrays in a distributed fashion with even distribution across ranks (support both `cpu` and `cuda` storage type).
- Distributed Sampling: The sampling process leverages the existing wholegraph gather function to collect the sampled nodes and edges across all ranks.
- Uniform Neighbor Sampler Support: Currently, only the uniform neighbor sampler is supported.

cc. @linhu-nv @dongxuy04 @BradReesWork  @nvcastet @TristonC

Authors:
  - Chang Liu (https://github.com/chang-l)

Approvers:
  - https://github.com/linhu-nv
  - Brad Rees (https://github.com/BradReesWork)

URL: #171
  • Loading branch information
chang-l authored May 29, 2024
1 parent 1921080 commit ae3748a
Show file tree
Hide file tree
Showing 8 changed files with 566 additions and 13 deletions.
39 changes: 38 additions & 1 deletion cpp/src/wholegraph_ops/sample_comm.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -57,4 +57,41 @@ __global__ void sample_all_kernel(wholememory_gref_t wm_csr_row_ptr,
}
}
}

__device__ __forceinline__ int log2_up_device(int x)
{
if (x <= 2) return x - 1;
return 32 - __clz(x - 1);
}
template <typename IdType>
struct ExpandWithOffsetFunc {
const IdType* indptr;
IdType* indptr_shift;
int length;
__host__ __device__ auto operator()(int64_t tIdx)
{
indptr_shift[tIdx] = indptr[tIdx % length] + tIdx / length;
}
};

template <typename WMIdType, typename DegreeType>
struct ReduceForDegrees {
WMIdType* rowoffsets;
DegreeType* in_degree_ptr;
int length;
__host__ __device__ auto operator()(int64_t tIdx)
{
in_degree_ptr[tIdx] = rowoffsets[tIdx + length] - rowoffsets[tIdx];
}
};

template <typename DegreeType>
struct MinInDegreeFanout {
int max_sample_count;
__host__ __device__ auto operator()(DegreeType degree)
{
return min(static_cast<int>(degree), max_sample_count);
}
};

} // namespace wholegraph_ops
42 changes: 39 additions & 3 deletions cpp/src/wholegraph_ops/unweighted_sample_without_replacement.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,7 +41,8 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement(
}
WHOLEMEMORY_EXPECTS_NOTHROW(!csr_row_ptr_has_handle ||
csr_row_ptr_memory_type == WHOLEMEMORY_MT_CHUNKED ||
csr_row_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS,
csr_row_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS ||
csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED,
"Memory type not supported.");
bool const csr_col_ptr_has_handle = wholememory_tensor_has_handle(wm_csr_col_ptr_tensor);
wholememory_memory_type_t csr_col_ptr_memory_type = WHOLEMEMORY_MT_NONE;
Expand All @@ -51,7 +52,8 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement(
}
WHOLEMEMORY_EXPECTS_NOTHROW(!csr_col_ptr_has_handle ||
csr_col_ptr_memory_type == WHOLEMEMORY_MT_CHUNKED ||
csr_col_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS,
csr_col_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS ||
csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED,
"Memory type not supported.");

auto csr_row_ptr_tensor_description =
Expand Down Expand Up @@ -108,6 +110,40 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement(
void* center_nodes = wholememory_tensor_get_data_pointer(center_nodes_tensor);
void* output_sample_offset = wholememory_tensor_get_data_pointer(output_sample_offset_tensor);

if (csr_col_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED &&
csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED) {
wholememory_distributed_backend_t distributed_backend_row = wholememory_get_distributed_backend(
wholememory_tensor_get_memory_handle(wm_csr_row_ptr_tensor));
wholememory_distributed_backend_t distributed_backend_col = wholememory_get_distributed_backend(
wholememory_tensor_get_memory_handle(wm_csr_col_ptr_tensor));
if (distributed_backend_col == WHOLEMEMORY_DB_NCCL &&
distributed_backend_row == WHOLEMEMORY_DB_NCCL) {
wholememory_handle_t wm_csr_row_ptr_handle =
wholememory_tensor_get_memory_handle(wm_csr_row_ptr_tensor);
wholememory_handle_t wm_csr_col_ptr_handle =
wholememory_tensor_get_memory_handle(wm_csr_col_ptr_tensor);
return wholegraph_ops::wholegraph_csr_unweighted_sample_without_replacement_nccl(
wm_csr_row_ptr_handle,
wm_csr_col_ptr_handle,
csr_row_ptr_tensor_description,
csr_col_ptr_tensor_description,
center_nodes,
center_nodes_desc,
max_sample_count,
output_sample_offset,
output_sample_offset_desc,
output_dest_memory_context,
output_center_localid_memory_context,
output_edge_gid_memory_context,
random_seed,
p_env_fns,
static_cast<cudaStream_t>(stream));
} else {
WHOLEMEMORY_ERROR("Only NCCL communication backend is supported for sampling.");
return WHOLEMEMORY_INVALID_INPUT;
}
}

wholememory_gref_t wm_csr_row_ptr_gref, wm_csr_col_ptr_gref;
WHOLEMEMORY_RETURN_ON_FAIL(
wholememory_tensor_get_global_reference(wm_csr_row_ptr_tensor, &wm_csr_row_ptr_gref));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,6 @@ __global__ void large_sample_kernel(
}
}

__device__ __forceinline__ int log2_up_device(int x)
{
if (x <= 2) return x - 1;
return 32 - __clz(x - 1);
}

template <typename IdType,
typename LocalIdType,
typename WMIdType,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,4 +37,21 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement_ma
unsigned long long random_seed,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream);

wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement_nccl(
wholememory_handle_t csr_row_wholememory_handle,
wholememory_handle_t csr_col_wholememory_handle,
wholememory_tensor_description_t wm_csr_row_ptr_desc,
wholememory_tensor_description_t wm_csr_col_ptr_desc,
void* center_nodes,
wholememory_array_description_t center_nodes_desc,
int max_sample_count,
void* output_sample_offset,
wholememory_array_description_t output_sample_offset_desc,
void* output_dest_memory_context,
void* output_center_localid_memory_context,
void* output_edge_gid_memory_context,
unsigned long long random_seed,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream);
} // namespace wholegraph_ops
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime_api.h>

#include <wholememory/env_func_ptrs.h>
#include <wholememory/wholememory.h>

#include "unweighted_sample_without_replacement_nccl_func.cuh"
#include "wholememory_ops/register.hpp"

namespace wholegraph_ops {

REGISTER_DISPATCH_TWO_TYPES(UnweightedSampleWithoutReplacementCSRNCCL,
wholegraph_csr_unweighted_sample_without_replacement_nccl_func,
SINT3264,
SINT3264)

wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement_nccl(
wholememory_handle_t csr_row_wholememory_handle,
wholememory_handle_t csr_col_wholememory_handle,
wholememory_tensor_description_t wm_csr_row_ptr_desc,
wholememory_tensor_description_t wm_csr_col_ptr_desc,
void* center_nodes,
wholememory_array_description_t center_nodes_desc,
int max_sample_count,
void* output_sample_offset,
wholememory_array_description_t output_sample_offset_desc,
void* output_dest_memory_context,
void* output_center_localid_memory_context,
void* output_edge_gid_memory_context,
unsigned long long random_seed,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream)
{
try {
DISPATCH_TWO_TYPES(center_nodes_desc.dtype,
wm_csr_col_ptr_desc.dtype,
UnweightedSampleWithoutReplacementCSRNCCL,
csr_row_wholememory_handle,
csr_col_wholememory_handle,
wm_csr_row_ptr_desc,
wm_csr_col_ptr_desc,
center_nodes,
center_nodes_desc,
max_sample_count,
output_sample_offset,
output_sample_offset_desc,
output_dest_memory_context,
output_center_localid_memory_context,
output_edge_gid_memory_context,
random_seed,
p_env_fns,
stream);

} catch (const wholememory::cuda_error& rle) {
// WHOLEMEMORY_FAIL_NOTHROW("%s", rle.what());
return WHOLEMEMORY_LOGIC_ERROR;
} catch (const wholememory::logic_error& le) {
return WHOLEMEMORY_LOGIC_ERROR;
} catch (...) {
return WHOLEMEMORY_LOGIC_ERROR;
}
return WHOLEMEMORY_SUCCESS;
}

} // namespace wholegraph_ops
Loading

0 comments on commit ae3748a

Please sign in to comment.