diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f893e251..238d929ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `CMake` support ([#5](https://github.com/pyg-team/pyg-lib/pull/5)) - Added `pyg.cuda_version()` ([#4](https://github.com/pyg-team/pyg-lib/pull/4)) ### Changed +- Require sorted neighborhoods according to time in temporal sampling ([#108](https://github.com/pyg-team/pyg-lib/pull/108)) - Only sample neighbors with a strictly earlier timestamp than the seed node ([#104](https://github.com/pyg-team/pyg-lib/pull/104)) - Prevent absolute paths in wheel ([#75](https://github.com/pyg-team/pyg-lib/pull/75)) - Improved installation instructions ([#68](https://github.com/pyg-team/pyg-lib/pull/68)) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 32978053c..f634c792f 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -34,47 +34,10 @@ class NeighborSampler { pyg::sampler::Mapper& dst_mapper, pyg::random::RandintEngine& generator, std::vector& out_global_dst_nodes) { - if (count == 0) - return; - const auto row_start = rowptr_[to_scalar_t(global_src_node)]; const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; - const auto population = row_end - row_start; - - if (population == 0) - return; - - // Case 1: Sample the full neighborhood: - if (count < 0 || (!replace && count >= population)) { - for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) { - add(edge_id, global_src_node, local_src_node, dst_mapper, - out_global_dst_nodes); - } - } - - // Case 2: Sample with replacement: - else if (replace) { - for (size_t i = 0; i < count; ++i) { - const auto edge_id = generator(row_start, row_end); - add(edge_id, global_src_node, local_src_node, dst_mapper, - out_global_dst_nodes); - } - } - - // Case 3: Sample without replacement: - else { - auto index_tracker = IndexTracker(population); - for (size_t i = population - count; i < population; ++i) { - auto rnd = generator(0, i + 1); - if (!index_tracker.try_insert(rnd)) { - rnd = i; - index_tracker.insert(i); - } - const auto edge_id = row_start + rnd; - add(edge_id, global_src_node, local_src_node, dst_mapper, - out_global_dst_nodes); - } - } + _sample(global_src_node, local_src_node, row_start, row_end, count, + dst_mapper, generator, out_global_dst_nodes); } void temporal_sample(const node_t global_src_node, @@ -85,11 +48,61 @@ class NeighborSampler { pyg::sampler::Mapper& dst_mapper, pyg::random::RandintEngine& generator, std::vector& out_global_dst_nodes) { + const auto row_start = rowptr_[to_scalar_t(global_src_node)]; + auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; + + // Find new `row_end` such that all neighbors fulfill temporal constraints: + auto it = std::lower_bound( + col_ + row_start, col_ + row_end, seed_time, + [&](const scalar_t& a, const scalar_t& b) { return time[a] < b; }); + row_end = it - col_; + + _sample(global_src_node, local_src_node, row_start, row_end, count, + dst_mapper, generator, out_global_dst_nodes); + } + + std::tuple> + get_sampled_edges(bool csc = false) { + TORCH_CHECK(save_edges, "No edges have been stored") + const auto row = pyg::utils::from_vector(sampled_rows_); + const auto col = pyg::utils::from_vector(sampled_cols_); + c10::optional edge_id = c10::nullopt; + if (save_edge_ids) { + edge_id = pyg::utils::from_vector(sampled_edge_ids_); + } + if (!csc) { + return std::make_tuple(row, col, edge_id); + } else { + return std::make_tuple(col, row, edge_id); + } + } + + private: + inline scalar_t to_scalar_t(const scalar_t& node) { return node; } + inline scalar_t to_scalar_t(const std::pair& node) { + return std::get<1>(node); + } + + inline scalar_t to_node_t(const scalar_t& node, const scalar_t& ref) { + return node; + } + inline std::pair to_node_t( + const scalar_t& node, + const std::pair& ref) { + return {std::get<0>(ref), node}; + } + + void _sample(const node_t global_src_node, + const scalar_t local_src_node, + const scalar_t row_start, + const scalar_t row_end, + const size_t count, + pyg::sampler::Mapper& dst_mapper, + pyg::random::RandintEngine& generator, + std::vector& out_global_dst_nodes) { if (count == 0) return; - const auto row_start = rowptr_[to_scalar_t(global_src_node)]; - const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; const auto population = row_end - row_start; if (population == 0) @@ -98,8 +111,6 @@ class NeighborSampler { // Case 1: Sample the full neighborhood: if (count < 0 || (!replace && count >= population)) { for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) { - if (time[col_[edge_id]] >= seed_time) - continue; add(edge_id, global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); } @@ -109,12 +120,6 @@ class NeighborSampler { else if (replace) { for (size_t i = 0; i < count; ++i) { const auto edge_id = generator(row_start, row_end); - // TODO (matthias) Improve temporal sampling logic. Currently, we sample - // `count` many random neighbors, and filter them based on temporal - // constraints afterwards. Ideally, we only sample exactly `count` - // neighbors which fullfill the time constraint. - if (time[col_[edge_id]] >= seed_time) - continue; add(edge_id, global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); } @@ -122,57 +127,20 @@ class NeighborSampler { // Case 3: Sample without replacement: else { - std::unordered_set rnd_indices; + auto index_tracker = IndexTracker(population); for (size_t i = population - count; i < population; ++i) { auto rnd = generator(0, i + 1); - if (!rnd_indices.insert(rnd).second) { + if (!index_tracker.try_insert(rnd)) { rnd = i; - rnd_indices.insert(i); + index_tracker.insert(i); } const auto edge_id = row_start + rnd; - // TODO (matthias) Improve temporal sampling logic. Currently, we sample - // `count` many random neighbors, and filter them based on temporal - // constraints afterwards. Ideally, we only sample exactly `count` - // neighbors which fullfill the time constraint. - if (time[col_[edge_id]] >= seed_time) - continue; add(edge_id, global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); } } } - std::tuple> - get_sampled_edges(bool csc = false) { - TORCH_CHECK(save_edges, "No edges have been stored") - const auto row = pyg::utils::from_vector(sampled_rows_); - const auto col = pyg::utils::from_vector(sampled_cols_); - c10::optional edge_id = c10::nullopt; - if (save_edge_ids) { - edge_id = pyg::utils::from_vector(sampled_edge_ids_); - } - if (!csc) { - return std::make_tuple(row, col, edge_id); - } else { - return std::make_tuple(col, row, edge_id); - } - } - - private: - inline scalar_t to_scalar_t(const scalar_t& node) { return node; } - inline scalar_t to_scalar_t(const std::pair& node) { - return std::get<1>(node); - } - - inline scalar_t to_node_t(const scalar_t& node, const scalar_t& ref) { - return node; - } - inline std::pair to_node_t( - const scalar_t& node, - const std::pair& ref) { - return {std::get<0>(ref), node}; - } - inline void add(const scalar_t edge_id, const node_t global_src_node, const scalar_t local_src_node, diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index 0a8d6d3d3..a07758b30 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -23,6 +23,12 @@ def neighbor_sample( r"""Recursively samples neighbors from all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`. + .. note:: + + For temporal sampling, the :obj:`col` vector needs to be sorted + according to :obj:`time` within individual neighborhoods since we use + binary search to find neighbors that fulfill temporal constraints. + Args: rowptr (torch.Tensor): Compressed source node indices. col (torch.Tensor): Target node indices. @@ -34,7 +40,9 @@ def neighbor_sample( If set, temporal sampling will be used such that neighbors are guaranteed to fulfill temporal constraints, *i.e.* neighbors have an earlier timestamp than the seed node. - Requires :obj:`disjoint=True`. (default: :obj:`None`) + If used, the :obj:`col` vector needs to be sorted according to time + within individual neighborhoods. Requires :obj:`disjoint=True`. + (default: :obj:`None`) csc (bool, optional): If set to :obj:`True`, assumes that the graph is given in CSC format :obj:`(colptr, row)`. (default: :obj:`False`) replace (bool, optional): If set to :obj:`True`, will sample with diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index 0c455dadd..25b65d8e6 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -44,12 +44,42 @@ TEST(DisjointNeighborTest, BasicAssertions) { EXPECT_TRUE(at::equal(std::get<1>(out), expected_col)); auto expected_nodes = at::tensor( {0, 2, 1, 3, 0, 1, 0, 3, 1, 2, 1, 4, 0, 0, 0, 4, 1, 1, 1, 5}, options); - EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({10, 2}))); + EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({-1, 2}))); auto expected_edges = at::tensor({4, 5, 6, 7, 2, 3, 6, 7, 4, 5, 8, 9}, options); EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); } +TEST(TemporalNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto graph = cycle_graph(/*num_nodes=*/6, options); + auto rowptr = std::get<0>(graph); + auto col = std::get<1>(graph); + auto seed = at::arange(2, 4, options); + std::vector num_neighbors = {2, 2}; + + // Time is equal to node ID ... + auto time = at::arange(6, options); + // ... so we need to sort the column vector by time/node ID: + col = std::get<0>(at::sort(col.view({-1, 2}), /*dim=*/1)).flatten(); + + auto out = pyg::sampler::neighbor_sample( + rowptr, col, seed, num_neighbors, /*time=*/time, + /*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true); + + // Expect only the earlier neighbors to be sampled: + auto expected_row = at::tensor({0, 1, 2, 3}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); + auto expected_col = at::tensor({2, 3, 4, 5}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_col)); + auto expected_nodes = + at::tensor({0, 2, 1, 3, 0, 1, 1, 2, 0, 0, 1, 1}, options); + EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({-1, 2}))); + auto expected_edges = at::tensor({4, 6, 2, 4}, options); + EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); +} + TEST(HeteroNeighborTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong);