Skip to content

Commit

Permalink
Add parameter checks to BFS and SSSP in C API (#2844)
Browse files Browse the repository at this point in the history
MG testing resulted in discovering an assumption in the C API about types that needs to be validated.  Added checks to BFS and SSSP so that the type checks will be done right away before trying to interpret the data incorrectly.

Won't pass until #2847 is addressed

Authors:
  - Chuck Hastings (https://github.com/ChuckHastings)

Approvers:
  - Seunghwa Kang (https://github.com/seunghwak)
  - Joseph Nke (https://github.com/jnke2016)
  - Rick Ratzel (https://github.com/rlratzel)

URL: #2844
  • Loading branch information
ChuckHastings authored Nov 3, 2022
1 parent 9be46fd commit a045684
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 2 deletions.
8 changes: 8 additions & 0 deletions cpp/src/c_api/bfs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ extern "C" cugraph_error_code_t cugraph_bfs(const cugraph_resource_handle_t* han
cugraph_paths_result_t** result,
cugraph_error_t** error)
{
CAPI_EXPECTS(
reinterpret_cast<cugraph::c_api::cugraph_graph_t*>(graph)->vertex_type_ ==
reinterpret_cast<cugraph::c_api::cugraph_type_erased_device_array_view_t const*>(sources)
->type_,
CUGRAPH_INVALID_INPUT,
"vertex type of graph and sources must match",
*error);

cugraph::c_api::bfs_functor functor(handle,
graph,
sources,
Expand Down
53 changes: 53 additions & 0 deletions cpp/tests/c_api/bfs_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,58 @@ int generic_bfs_test(vertex_t* h_src,
return test_ret_value;
}

int test_bfs_exceptions()
{
size_t num_edges = 8;
size_t num_vertices = 6;
size_t depth_limit = 1;
size_t num_seeds = 1;

vertex_t src[] = {0, 1, 1, 2, 2, 2, 3, 4};
vertex_t dst[] = {1, 3, 4, 0, 1, 3, 5, 5};
weight_t wgt[] = {0.1f, 2.1f, 1.1f, 5.1f, 3.1f, 4.1f, 7.2f, 3.2f};
int64_t seeds[] = {0};

int test_ret_value = 0;

cugraph_error_code_t ret_code = CUGRAPH_SUCCESS;
cugraph_error_t* ret_error = NULL;

cugraph_resource_handle_t* p_handle = NULL;
cugraph_graph_t* p_graph = NULL;
cugraph_paths_result_t* p_result = NULL;
cugraph_type_erased_device_array_t* p_sources = NULL;
cugraph_type_erased_device_array_view_t* p_source_view = NULL;

p_handle = cugraph_create_resource_handle(NULL);
TEST_ASSERT(test_ret_value, p_handle != NULL, "resource handle creation failed.");

ret_code = create_test_graph(
p_handle, src, dst, wgt, num_edges, FALSE, FALSE, FALSE, &p_graph, &ret_error);

/*
* FIXME: in create_graph_test.c, variables are defined but then hard-coded to
* the constant INT32. It would be better to pass the types into the functions
* in both cases so that the test cases could be parameterized in the main.
*/
ret_code =
cugraph_type_erased_device_array_create(p_handle, num_seeds, INT64, &p_sources, &ret_error);
TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "p_sources create failed.");

p_source_view = cugraph_type_erased_device_array_view(p_sources);

ret_code = cugraph_type_erased_device_array_view_copy_from_host(
p_handle, p_source_view, (byte_t*)seeds, &ret_error);
TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "src copy_from_host failed.");

ret_code = cugraph_bfs(
p_handle, p_graph, p_source_view, FALSE, depth_limit, TRUE, FALSE, &p_result, &ret_error);

TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_INVALID_INPUT, "cugraph_bfs expected to fail");

return test_ret_value;
}

int test_bfs()
{
size_t num_edges = 8;
Expand Down Expand Up @@ -176,5 +228,6 @@ int main(int argc, char** argv)
int result = 0;
result |= RUN_TEST(test_bfs);
result |= RUN_TEST(test_bfs_with_transpose);
result |= RUN_TEST(test_bfs_exceptions);
return result;
}
8 changes: 6 additions & 2 deletions python/cugraph/cugraph/traversal/bfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def _ensure_args(G, start, i_start, directed):
else:
if not isinstance(start, cudf.DataFrame):
if not isinstance(start, dask_cudf.DataFrame):
start = cudf.DataFrame({"starts": cudf.Series(start)})
vertex_dtype = G.nodes().dtype
start = cudf.DataFrame(
{"starts": cudf.Series(start, dtype=vertex_dtype)}
)

if G.is_renumbered():
validlen = len(
Expand Down Expand Up @@ -224,7 +227,8 @@ def bfs(
if is_dataframe:
start = start[start.columns[0]]
else:
start = cudf.Series(start, name="starts")
vertex_dtype = G.nodes().dtype
start = cudf.Series(start, dtype=vertex_dtype)

distances, predecessors, vertices = pylibcugraph_bfs(
handle=ResourceHandle(),
Expand Down

0 comments on commit a045684

Please sign in to comment.