Skip to content

Commit

Permalink
Biased sampling primitive bug fix (#4607)
Browse files Browse the repository at this point in the history
Fix bugs in biased sampling with 0 bias values.

Add tests that include 0 bias edges.

Authors:
  - Seunghwa Kang (https://github.com/seunghwak)

Approvers:
  - Chuck Hastings (https://github.com/ChuckHastings)
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - Joseph Nke (https://github.com/jnke2016)

URL: #4607
  • Loading branch information
seunghwak authored Aug 9, 2024
1 parent ce5d3dd commit a5cdea2
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2041,7 +2041,7 @@ biased_sample_and_compute_local_nbr_indices(
zero_bias_frontier_indices.resize(zero_bias_count_inclusive_sums.back(),
handle.get_stream());
zero_bias_frontier_indices.shrink_to_fit(handle.get_stream());
zero_bias_local_nbr_indices.resize(frontier_indices.size(), handle.get_stream());
zero_bias_local_nbr_indices.resize(zero_bias_frontier_indices.size(), handle.get_stream());
zero_bias_local_nbr_indices.shrink_to_fit(handle.get_stream());
std::vector<size_t> zero_bias_counts(zero_bias_count_inclusive_sums.size());
std::adjacent_difference(zero_bias_count_inclusive_sums.begin(),
Expand Down
55 changes: 28 additions & 27 deletions cpp/src/prims/detail/transform_v_frontier_e.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,6 @@ __global__ static void transform_v_frontier_e_mid_degree(
auto const lane_id = tid % raft::warp_size();
size_t idx = static_cast<size_t>(tid / raft::warp_size());

using WarpScan = cub::WarpScan<edge_t, raft::warp_size()>;
__shared__ typename WarpScan::TempStorage temp_storage;

while (idx < static_cast<size_t>(thrust::distance(edge_partition_frontier_key_index_first,
edge_partition_frontier_key_index_last))) {
auto key_idx = *(edge_partition_frontier_key_index_first + idx);
Expand All @@ -224,16 +221,15 @@ __global__ static void transform_v_frontier_e_mid_degree(
thrust::tie(indices, edge_offset, local_degree) = edge_partition.local_edges(major_offset);
auto this_key_value_first = value_first + edge_partition_frontier_local_degree_offsets[key_idx];
if (edge_partition_e_mask) {
// FIXME: it might be faster to update in warp-sync way
edge_t counter{0};
for (edge_t i = lane_id; i < local_degree; i += raft::warp_size()) {
if ((*edge_partition_e_mask).get(edge_offset + i)) { ++counter; }
}
edge_t offset_within_warp{};
WarpScan(temp_storage).ExclusiveSum(counter, offset_within_warp);
counter = 0;
for (edge_t i = lane_id; i < local_degree; i += raft::warp_size()) {
if ((*edge_partition_e_mask).get(edge_offset + i)) {
auto rounded_up_local_degree =
((static_cast<size_t>(local_degree) + (raft::warp_size() - 1)) / raft::warp_size()) *
raft::warp_size();
edge_t base_offset{0};
for (edge_t i = lane_id; i < rounded_up_local_degree; i += raft::warp_size()) {
auto valid = (i < local_degree) && (*edge_partition_e_mask).get(edge_offset + i);
auto ballot = __ballot_sync(raft::warp_full_mask(), valid ? uint32_t{1} : uint32_t{0});
if (valid) {
auto intra_warp_offset = __popc(ballot & ~(raft::warp_full_mask() << lane_id));
transform_v_frontier_e_update_buffer_element<key_t, GraphViewType>(
edge_partition,
key,
Expand All @@ -244,9 +240,9 @@ __global__ static void transform_v_frontier_e_mid_degree(
edge_partition_dst_value_input,
edge_partition_e_value_input,
e_op,
this_key_value_first + offset_within_warp + counter);
++counter;
this_key_value_first + base_offset + intra_warp_offset);
}
base_offset += __popc(ballot);
}
} else {
for (edge_t i = lane_id; i < local_degree; i += raft::warp_size()) {
Expand Down Expand Up @@ -300,6 +296,7 @@ __global__ static void transform_v_frontier_e_high_degree(

using BlockScan = cub::BlockScan<edge_t, transform_v_frontier_e_kernel_block_size>;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ edge_t increment;

while (idx < static_cast<size_t>(thrust::distance(edge_partition_frontier_key_index_first,
edge_partition_frontier_key_index_last))) {
Expand All @@ -313,16 +310,16 @@ __global__ static void transform_v_frontier_e_high_degree(
thrust::tie(indices, edge_offset, local_degree) = edge_partition.local_edges(major_offset);
auto this_key_value_first = value_first + edge_partition_frontier_local_degree_offsets[key_idx];
if (edge_partition_e_mask) {
// FIXME: it might be faster to update in block-sync way
edge_t counter{0};
for (edge_t i = threadIdx.x; i < local_degree; i += blockDim.x) {
if ((*edge_partition_e_mask).get(edge_offset + i)) { ++counter; }
}
edge_t offset_within_block{};
BlockScan(temp_storage).ExclusiveSum(counter, offset_within_block);
counter = 0;
for (edge_t i = threadIdx.x; i < local_degree; i += blockDim.x) {
if ((*edge_partition_e_mask).get(edge_offset + i)) {
auto rounded_up_local_degree =
((static_cast<size_t>(local_degree) + (transform_v_frontier_e_kernel_block_size - 1)) /
transform_v_frontier_e_kernel_block_size) *
transform_v_frontier_e_kernel_block_size;
edge_t base_offset{0};
for (size_t i = threadIdx.x; i < rounded_up_local_degree; i += blockDim.x) {
auto valid = (i < local_degree) && (*edge_partition_e_mask).get(edge_offset + i);
edge_t intra_block_offset{};
BlockScan(temp_storage).ExclusiveSum(valid ? edge_t{1} : edge_t{0}, intra_block_offset);
if (valid) {
transform_v_frontier_e_update_buffer_element<key_t, GraphViewType>(
edge_partition,
key,
Expand All @@ -333,9 +330,13 @@ __global__ static void transform_v_frontier_e_high_degree(
edge_partition_dst_value_input,
edge_partition_e_value_input,
e_op,
this_key_value_first + offset_within_block + counter);
++counter;
this_key_value_first + base_offset + intra_block_offset);
}
if (threadIdx.x == transform_v_frontier_e_kernel_block_size - 1) {
increment = intra_block_offset + (valid ? edge_t{1} : edge_t{0});
}
__syncthreads();
base_offset += increment;
}
} else {
for (edge_t i = threadIdx.x; i < local_degree; i += blockDim.x) {
Expand Down
Loading

0 comments on commit a5cdea2

Please sign in to comment.