diff --git a/CHANGELOG.md b/CHANGELOG.md index c5433ae0..d1b7eca2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `CMake` support ([#5](https://github.com/pyg-team/pyg-lib/pull/5)) - Added `pyg.cuda_version()` ([#4](https://github.com/pyg-team/pyg-lib/pull/4)) ### Changed +- Optional return types in `pyg.subgraph()` ([#40](https://github.com/pyg-team/pyg-lib/pull/40)) - Absolute headers ([#30](https://github.com/pyg-team/pyg-lib/pull/30)) - Use `at::equal` rather than `at::all` in tests ([#37](https://github.com/pyg-team/pyg-lib/pull/37)) ### Removed diff --git a/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp b/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp index 6cdd5c09..b8469e8c 100644 --- a/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp @@ -9,17 +9,19 @@ namespace sampler { namespace { -std::tuple subgraph_kernel( +std::tuple> subgraph_kernel( const at::Tensor& rowptr, const at::Tensor& col, - const at::Tensor& nodes) { + const at::Tensor& nodes, + const bool return_edge_id) { TORCH_CHECK(rowptr.is_cpu(), "'rowptr' must be a CPU tensor"); TORCH_CHECK(col.is_cpu(), "'col' must be a CPU tensor"); TORCH_CHECK(nodes.is_cpu(), "'nodes' must be a CPU tensor"); const auto deg = rowptr.new_empty({nodes.size(0)}); const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1}); - at::Tensor out_col, out_edge_id; + at::Tensor out_col; + c10::optional out_edge_id = c10::nullopt; AT_DISPATCH_INTEGRAL_TYPES(nodes.scalar_type(), "subgraph_kernel", [&] { const auto rowptr_data = rowptr.data_ptr(); @@ -53,9 +55,12 @@ std::tuple subgraph_kernel( at::cumsum_out(tmp, deg, /*dim=*/0); out_col = col.new_empty({out_rowptr_data[nodes.size(0)]}); - out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]}); auto out_col_data = out_col.data_ptr(); - auto out_edge_id_data = out_edge_id.data_ptr(); + scalar_t* out_edge_id_data; + if (return_edge_id) { + out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]}); + out_edge_id_data = out_edge_id.value().data_ptr(); + } // Customize `grain_size` based on the work each thread does (it will need // to find `col.size(0) / nodes.size(0)` neighbors on average). @@ -72,7 +77,8 @@ std::tuple subgraph_kernel( const auto search = to_local_node.find(w); if (search != to_local_node.end()) { out_col_data[offset] = search->second; - out_edge_id_data[offset] = j; + if (return_edge_id) + out_edge_id_data[offset] = j; offset++; } } diff --git a/pyg_lib/csrc/sampler/subgraph.cpp b/pyg_lib/csrc/sampler/subgraph.cpp index 0cbb6c02..dbedb5b6 100644 --- a/pyg_lib/csrc/sampler/subgraph.cpp +++ b/pyg_lib/csrc/sampler/subgraph.cpp @@ -6,10 +6,11 @@ namespace pyg { namespace sampler { -std::tuple subgraph( +std::tuple> subgraph( const at::Tensor& rowptr, const at::Tensor& col, - const at::Tensor& nodes) { + const at::Tensor& nodes, + const bool return_edge_id) { at::TensorArg rowptr_t{rowptr, "rowtpr", 1}; at::TensorArg col_t{col, "col", 1}; at::TensorArg nodes_t{nodes, "nodes", 1}; @@ -21,13 +22,13 @@ std::tuple subgraph( static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::subgraph", "") .typed(); - return op.call(rowptr, col, nodes); + return op.call(rowptr, col, nodes, return_edge_id); } TORCH_LIBRARY_FRAGMENT(pyg, m) { - m.def( - TORCH_SELECTIVE_SCHEMA("pyg::subgraph(Tensor rowptr, Tensor col, Tensor " - "nodes) -> (Tensor, Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::subgraph(Tensor rowptr, Tensor col, Tensor " + "nodes, bool return_edge_id) -> (Tensor, Tensor, Tensor?)")); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/subgraph.h b/pyg_lib/csrc/sampler/subgraph.h index 5326f4e4..a2f8de54 100644 --- a/pyg_lib/csrc/sampler/subgraph.h +++ b/pyg_lib/csrc/sampler/subgraph.h @@ -9,10 +9,11 @@ namespace sampler { // Returns the induced subgraph of the graph given by `(rowptr, col)`, // containing only the nodes in `nodes`. // Returns: (rowptr, col, edge_id) -PYG_API std::tuple subgraph( +PYG_API std::tuple> subgraph( const at::Tensor& rowptr, const at::Tensor& col, - const at::Tensor& nodes); + const at::Tensor& nodes, + const bool return_edge_id = true); } // namespace sampler } // namespace pyg diff --git a/test/csrc/sampler/test_subgraph.cpp b/test/csrc/sampler/test_subgraph.cpp index 8528b253..ea923514 100644 --- a/test/csrc/sampler/test_subgraph.cpp +++ b/test/csrc/sampler/test_subgraph.cpp @@ -18,5 +18,5 @@ TEST(SubgraphTest, BasicAssertions) { auto expected_col = at::tensor({1, 0, 2, 1, 3, 2}, options); EXPECT_TRUE(at::equal(std::get<1>(out), expected_col)); auto expected_edge_id = at::tensor({3, 4, 5, 6, 7, 8}, options); - EXPECT_TRUE(at::equal(std::get<2>(out), expected_edge_id)); + EXPECT_TRUE(at::equal(std::get<2>(out).value(), expected_edge_id)); }