Skip to content

Commit

Permalink
Optional return types in pyg::sampler::subgraph (#40)
Browse files Browse the repository at this point in the history
* update

* update

* update
  • Loading branch information
rusty1s authored May 2, 2022
1 parent 1a00013 commit 0b26fcb
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `CMake` support ([#5](https://github.com/pyg-team/pyg-lib/pull/5))
- Added `pyg.cuda_version()` ([#4](https://github.com/pyg-team/pyg-lib/pull/4))
### Changed
- Optional return types in `pyg.subgraph()` ([#40](https://github.com/pyg-team/pyg-lib/pull/40))
- Absolute headers ([#30](https://github.com/pyg-team/pyg-lib/pull/30))
- Use `at::equal` rather than `at::all` in tests ([#37](https://github.com/pyg-team/pyg-lib/pull/37))
### Removed
18 changes: 12 additions & 6 deletions pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@ namespace sampler {

namespace {

std::tuple<at::Tensor, at::Tensor, at::Tensor> subgraph_kernel(
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph_kernel(
const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& nodes) {
const at::Tensor& nodes,
const bool return_edge_id) {
TORCH_CHECK(rowptr.is_cpu(), "'rowptr' must be a CPU tensor");
TORCH_CHECK(col.is_cpu(), "'col' must be a CPU tensor");
TORCH_CHECK(nodes.is_cpu(), "'nodes' must be a CPU tensor");

const auto deg = rowptr.new_empty({nodes.size(0)});
const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1});
at::Tensor out_col, out_edge_id;
at::Tensor out_col;
c10::optional<at::Tensor> out_edge_id = c10::nullopt;

AT_DISPATCH_INTEGRAL_TYPES(nodes.scalar_type(), "subgraph_kernel", [&] {
const auto rowptr_data = rowptr.data_ptr<scalar_t>();
Expand Down Expand Up @@ -53,9 +55,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> subgraph_kernel(
at::cumsum_out(tmp, deg, /*dim=*/0);

out_col = col.new_empty({out_rowptr_data[nodes.size(0)]});
out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]});
auto out_col_data = out_col.data_ptr<scalar_t>();
auto out_edge_id_data = out_edge_id.data_ptr<scalar_t>();
scalar_t* out_edge_id_data;
if (return_edge_id) {
out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]});
out_edge_id_data = out_edge_id.value().data_ptr<scalar_t>();
}

// Customize `grain_size` based on the work each thread does (it will need
// to find `col.size(0) / nodes.size(0)` neighbors on average).
Expand All @@ -72,7 +77,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> subgraph_kernel(
const auto search = to_local_node.find(w);
if (search != to_local_node.end()) {
out_col_data[offset] = search->second;
out_edge_id_data[offset] = j;
if (return_edge_id)
out_edge_id_data[offset] = j;
offset++;
}
}
Expand Down
13 changes: 7 additions & 6 deletions pyg_lib/csrc/sampler/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
namespace pyg {
namespace sampler {

std::tuple<at::Tensor, at::Tensor, at::Tensor> subgraph(
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph(
const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& nodes) {
const at::Tensor& nodes,
const bool return_edge_id) {
at::TensorArg rowptr_t{rowptr, "rowtpr", 1};
at::TensorArg col_t{col, "col", 1};
at::TensorArg nodes_t{nodes, "nodes", 1};
Expand All @@ -21,13 +22,13 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> subgraph(
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::subgraph", "")
.typed<decltype(subgraph)>();
return op.call(rowptr, col, nodes);
return op.call(rowptr, col, nodes, return_edge_id);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(
TORCH_SELECTIVE_SCHEMA("pyg::subgraph(Tensor rowptr, Tensor col, Tensor "
"nodes) -> (Tensor, Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::subgraph(Tensor rowptr, Tensor col, Tensor "
"nodes, bool return_edge_id) -> (Tensor, Tensor, Tensor?)"));
}

} // namespace sampler
Expand Down
5 changes: 3 additions & 2 deletions pyg_lib/csrc/sampler/subgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ namespace sampler {
// Returns the induced subgraph of the graph given by `(rowptr, col)`,
// containing only the nodes in `nodes`.
// Returns: (rowptr, col, edge_id)
PYG_API std::tuple<at::Tensor, at::Tensor, at::Tensor> subgraph(
PYG_API std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph(
const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& nodes);
const at::Tensor& nodes,
const bool return_edge_id = true);

} // namespace sampler
} // namespace pyg
2 changes: 1 addition & 1 deletion test/csrc/sampler/test_subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ TEST(SubgraphTest, BasicAssertions) {
auto expected_col = at::tensor({1, 0, 2, 1, 3, 2}, options);
EXPECT_TRUE(at::equal(std::get<1>(out), expected_col));
auto expected_edge_id = at::tensor({3, 4, 5, 6, 7, 8}, options);
EXPECT_TRUE(at::equal(std::get<2>(out), expected_edge_id));
EXPECT_TRUE(at::equal(std::get<2>(out).value(), expected_edge_id));
}

0 comments on commit 0b26fcb

Please sign in to comment.