Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require sorted neighborhoods according to time in temporal sampling #108

Merged
merged 5 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `CMake` support ([#5](https://github.com/pyg-team/pyg-lib/pull/5))
- Added `pyg.cuda_version()` ([#4](https://github.com/pyg-team/pyg-lib/pull/4))
### Changed
- Require sorted neighborhoods according to time in temporal sampling ([#108](https://github.com/pyg-team/pyg-lib/pull/108))
- Only sample neighbors with a strictly earlier timestamp than the seed node ([#104](https://github.com/pyg-team/pyg-lib/pull/104))
- Prevent absolute paths in wheel ([#75](https://github.com/pyg-team/pyg-lib/pull/75))
- Improved installation instructions ([#68](https://github.com/pyg-team/pyg-lib/pull/68))
Expand Down
147 changes: 58 additions & 89 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <ATen/ATen.h>
#include <torch/library.h>
#include <algorithm>

#include "parallel_hashmap/phmap.h"

Expand Down Expand Up @@ -34,47 +35,10 @@ class NeighborSampler {
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
if (count == 0)
return;

const auto row_start = rowptr_[to_scalar_t(global_src_node)];
const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];
const auto population = row_end - row_start;

if (population == 0)
return;

// Case 1: Sample the full neighborhood:
if (count < 0 || (!replace && count >= population)) {
for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) {
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
}

// Case 2: Sample with replacement:
else if (replace) {
for (size_t i = 0; i < count; ++i) {
const auto edge_id = generator(row_start, row_end);
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
}

// Case 3: Sample without replacement:
else {
auto index_tracker = IndexTracker<scalar_t>(population);
for (size_t i = population - count; i < population; ++i) {
auto rnd = generator(0, i + 1);
if (!index_tracker.try_insert(rnd)) {
rnd = i;
index_tracker.insert(i);
}
const auto edge_id = row_start + rnd;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
}
_sample(global_src_node, local_src_node, row_start, row_end, count,
dst_mapper, generator, out_global_dst_nodes);
}

void temporal_sample(const node_t global_src_node,
Expand All @@ -85,11 +49,61 @@ class NeighborSampler {
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
const auto row_start = rowptr_[to_scalar_t(global_src_node)];
auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];

// Find new `row_end` such that all neighbors fulfill temporal constraints:
auto it = std::lower_bound(
col_ + row_start, col_ + row_end, seed_time,
[&](const scalar_t& a, const scalar_t& b) { return time[a] < b; });
row_end = it - col_;

_sample(global_src_node, local_src_node, row_start, row_end, count,
dst_mapper, generator, out_global_dst_nodes);
}

std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
get_sampled_edges(bool csc = false) {
TORCH_CHECK(save_edges, "No edges have been stored")
const auto row = pyg::utils::from_vector(sampled_rows_);
const auto col = pyg::utils::from_vector(sampled_cols_);
c10::optional<at::Tensor> edge_id = c10::nullopt;
if (save_edge_ids) {
edge_id = pyg::utils::from_vector(sampled_edge_ids_);
}
if (!csc) {
return std::make_tuple(row, col, edge_id);
} else {
return std::make_tuple(col, row, edge_id);
}
}

private:
inline scalar_t to_scalar_t(const scalar_t& node) { return node; }
inline scalar_t to_scalar_t(const std::pair<scalar_t, scalar_t>& node) {
return std::get<1>(node);
}

inline scalar_t to_node_t(const scalar_t& node, const scalar_t& ref) {
return node;
}
inline std::pair<scalar_t, scalar_t> to_node_t(
const scalar_t& node,
const std::pair<scalar_t, scalar_t>& ref) {
return {std::get<0>(ref), node};
}

void _sample(const node_t global_src_node,
const scalar_t local_src_node,
const scalar_t row_start,
const scalar_t row_end,
const size_t count,
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
if (count == 0)
return;

const auto row_start = rowptr_[to_scalar_t(global_src_node)];
const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];
const auto population = row_end - row_start;

if (population == 0)
Expand All @@ -98,8 +112,6 @@ class NeighborSampler {
// Case 1: Sample the full neighborhood:
if (count < 0 || (!replace && count >= population)) {
for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) {
if (time[col_[edge_id]] >= seed_time)
continue;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
Expand All @@ -109,70 +121,27 @@ class NeighborSampler {
else if (replace) {
for (size_t i = 0; i < count; ++i) {
const auto edge_id = generator(row_start, row_end);
// TODO (matthias) Improve temporal sampling logic. Currently, we sample
// `count` many random neighbors, and filter them based on temporal
// constraints afterwards. Ideally, we only sample exactly `count`
// neighbors which fullfill the time constraint.
if (time[col_[edge_id]] >= seed_time)
continue;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
}

// Case 3: Sample without replacement:
else {
std::unordered_set<scalar_t> rnd_indices;
auto index_tracker = IndexTracker<scalar_t>(population);
for (size_t i = population - count; i < population; ++i) {
auto rnd = generator(0, i + 1);
if (!rnd_indices.insert(rnd).second) {
if (!index_tracker.try_insert(rnd)) {
rnd = i;
rnd_indices.insert(i);
index_tracker.insert(i);
}
const auto edge_id = row_start + rnd;
// TODO (matthias) Improve temporal sampling logic. Currently, we sample
// `count` many random neighbors, and filter them based on temporal
// constraints afterwards. Ideally, we only sample exactly `count`
// neighbors which fullfill the time constraint.
if (time[col_[edge_id]] >= seed_time)
continue;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
}
}

std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
get_sampled_edges(bool csc = false) {
TORCH_CHECK(save_edges, "No edges have been stored")
const auto row = pyg::utils::from_vector(sampled_rows_);
const auto col = pyg::utils::from_vector(sampled_cols_);
c10::optional<at::Tensor> edge_id = c10::nullopt;
if (save_edge_ids) {
edge_id = pyg::utils::from_vector(sampled_edge_ids_);
}
if (!csc) {
return std::make_tuple(row, col, edge_id);
} else {
return std::make_tuple(col, row, edge_id);
}
}

private:
inline scalar_t to_scalar_t(const scalar_t& node) { return node; }
inline scalar_t to_scalar_t(const std::pair<scalar_t, scalar_t>& node) {
return std::get<1>(node);
}

inline scalar_t to_node_t(const scalar_t& node, const scalar_t& ref) {
return node;
}
inline std::pair<scalar_t, scalar_t> to_node_t(
const scalar_t& node,
const std::pair<scalar_t, scalar_t>& ref) {
return {std::get<0>(ref), node};
}

inline void add(const scalar_t edge_id,
const node_t global_src_node,
const scalar_t local_src_node,
Expand Down
10 changes: 9 additions & 1 deletion pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def neighbor_sample(
r"""Recursively samples neighbors from all node indices in :obj:`seed`
in the graph given by :obj:`(rowptr, col)`.

.. note::

For temporal sampling, the :obj:`col` vector needs to be sorted
according to :obj:`time` within individual neighborhoods since we use
binary search to find neighbors that fulfill temporal constraints.

Args:
rowptr (torch.Tensor): Compressed source node indices.
col (torch.Tensor): Target node indices.
Expand All @@ -34,7 +40,9 @@ def neighbor_sample(
If set, temporal sampling will be used such that neighbors are
guaranteed to fulfill temporal constraints, *i.e.* neighbors have
an earlier timestamp than the seed node.
Requires :obj:`disjoint=True`. (default: :obj:`None`)
If used, the :obj:`col` vector needs to be sorted according to time
within individual neighborhoods. Requires :obj:`disjoint=True`.
(default: :obj:`None`)
csc (bool, optional): If set to :obj:`True`, assumes that the graph is
given in CSC format :obj:`(colptr, row)`. (default: :obj:`False`)
replace (bool, optional): If set to :obj:`True`, will sample with
Expand Down
32 changes: 31 additions & 1 deletion test/csrc/sampler/test_neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,42 @@ TEST(DisjointNeighborTest, BasicAssertions) {
EXPECT_TRUE(at::equal(std::get<1>(out), expected_col));
auto expected_nodes = at::tensor(
{0, 2, 1, 3, 0, 1, 0, 3, 1, 2, 1, 4, 0, 0, 0, 4, 1, 1, 1, 5}, options);
EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({10, 2})));
EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({-1, 2})));
auto expected_edges =
at::tensor({4, 5, 6, 7, 2, 3, 6, 7, 4, 5, 8, 9}, options);
EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges));
}

TEST(TemporalNeighborTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);

auto graph = cycle_graph(/*num_nodes=*/6, options);
auto rowptr = std::get<0>(graph);
auto col = std::get<1>(graph);
auto seed = at::arange(2, 4, options);
std::vector<int64_t> num_neighbors = {2, 2};

// Time is equal to node ID ...
auto time = at::arange(6, options);
// ... so we need to sort the column vector by time/node ID:
col = std::get<0>(at::sort(col.view({-1, 2}), /*dim=*/1)).flatten();

auto out = pyg::sampler::neighbor_sample(
rowptr, col, seed, num_neighbors, /*time=*/time,
/*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true);

// Expect only the earlier neighbors to be sampled:
auto expected_row = at::tensor({0, 1, 2, 3}, options);
EXPECT_TRUE(at::equal(std::get<0>(out), expected_row));
auto expected_col = at::tensor({2, 3, 4, 5}, options);
EXPECT_TRUE(at::equal(std::get<1>(out), expected_col));
auto expected_nodes =
at::tensor({0, 2, 1, 3, 0, 1, 1, 2, 0, 0, 1, 1}, options);
EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({-1, 2})));
auto expected_edges = at::tensor({4, 6, 2, 4}, options);
EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges));
}

TEST(HeteroNeighborTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);

Expand Down