Skip to content

Commit

Permalink
transform_e bug fix in edge masking (#4221)
Browse files Browse the repository at this point in the history
  • Loading branch information
seunghwak authored Mar 8, 2024
1 parent f202fa3 commit ab9e445
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions cpp/src/prims/transform_e.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ __global__ void transform_e_packed_bool(
if (local_edge_idx < num_edges) {
bool compute_predicate = true;
if constexpr (check_edge_mask) {
compute_predicate = (edge_mask & packed_bool_mask(lane_id) != packed_bool_empty_mask());
compute_predicate = ((edge_mask & packed_bool_mask(lane_id)) != packed_bool_empty_mask());
}

if (compute_predicate) {
Expand All @@ -111,10 +111,10 @@ __global__ void transform_e_packed_bool(
uint32_t new_val = __ballot_sync(raft::warp_full_mask(), predicate);
if (lane_id == 0) {
if constexpr (check_edge_mask) {
*(edge_partition_e_value_output.value_first() + idx) = new_val;
} else {
auto old_val = *(edge_partition_e_value_output.value_first() + idx);
*(edge_partition_e_value_output.value_first() + idx) = (old_val & ~edge_mask) | new_val;
} else {
*(edge_partition_e_value_output.value_first() + idx) = new_val;
}
}

Expand Down Expand Up @@ -196,6 +196,9 @@ struct update_e_value_t {

__device__ void operator()(typename GraphViewType::edge_type i) const
{
if constexpr (check_edge_mask) {
if (!edge_partition_e_mask.get(i)) { return; }
}
auto major_idx = edge_partition.major_idx_from_local_edge_idx_nocheck(i);
auto major = edge_partition.major_from_major_idx_nocheck(major_idx);
auto major_offset = edge_partition.major_offset_from_major_nocheck(major);
Expand Down

0 comments on commit ab9e445

Please sign in to comment.