Skip to content

Commit

Permalink
Require sorted neighborhoods according to time in temporal sampling (#…
Browse files Browse the repository at this point in the history
…108)

* add

* changelog

* update

* add note

* remove import

Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
ZenoTan and rusty1s authored Sep 22, 2022
1 parent 34c08c9 commit c7d12f2
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 91 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `CMake` support ([#5](https://github.com/pyg-team/pyg-lib/pull/5))
- Added `pyg.cuda_version()` ([#4](https://github.com/pyg-team/pyg-lib/pull/4))
### Changed
- Require sorted neighborhoods according to time in temporal sampling ([#108](https://github.com/pyg-team/pyg-lib/pull/108))
- Only sample neighbors with a strictly earlier timestamp than the seed node ([#104](https://github.com/pyg-team/pyg-lib/pull/104))
- Prevent absolute paths in wheel ([#75](https://github.com/pyg-team/pyg-lib/pull/75))
- Improved installation instructions ([#68](https://github.com/pyg-team/pyg-lib/pull/68))
Expand Down
146 changes: 57 additions & 89 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,47 +34,10 @@ class NeighborSampler {
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
if (count == 0)
return;

const auto row_start = rowptr_[to_scalar_t(global_src_node)];
const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];
const auto population = row_end - row_start;

if (population == 0)
return;

// Case 1: Sample the full neighborhood:
if (count < 0 || (!replace && count >= population)) {
for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) {
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
}

// Case 2: Sample with replacement:
else if (replace) {
for (size_t i = 0; i < count; ++i) {
const auto edge_id = generator(row_start, row_end);
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
}

// Case 3: Sample without replacement:
else {
auto index_tracker = IndexTracker<scalar_t>(population);
for (size_t i = population - count; i < population; ++i) {
auto rnd = generator(0, i + 1);
if (!index_tracker.try_insert(rnd)) {
rnd = i;
index_tracker.insert(i);
}
const auto edge_id = row_start + rnd;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
}
_sample(global_src_node, local_src_node, row_start, row_end, count,
dst_mapper, generator, out_global_dst_nodes);
}

void temporal_sample(const node_t global_src_node,
Expand All @@ -85,11 +48,61 @@ class NeighborSampler {
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
const auto row_start = rowptr_[to_scalar_t(global_src_node)];
auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];

// Find new `row_end` such that all neighbors fulfill temporal constraints:
auto it = std::lower_bound(
col_ + row_start, col_ + row_end, seed_time,
[&](const scalar_t& a, const scalar_t& b) { return time[a] < b; });
row_end = it - col_;

_sample(global_src_node, local_src_node, row_start, row_end, count,
dst_mapper, generator, out_global_dst_nodes);
}

std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
get_sampled_edges(bool csc = false) {
TORCH_CHECK(save_edges, "No edges have been stored")
const auto row = pyg::utils::from_vector(sampled_rows_);
const auto col = pyg::utils::from_vector(sampled_cols_);
c10::optional<at::Tensor> edge_id = c10::nullopt;
if (save_edge_ids) {
edge_id = pyg::utils::from_vector(sampled_edge_ids_);
}
if (!csc) {
return std::make_tuple(row, col, edge_id);
} else {
return std::make_tuple(col, row, edge_id);
}
}

private:
inline scalar_t to_scalar_t(const scalar_t& node) { return node; }
inline scalar_t to_scalar_t(const std::pair<scalar_t, scalar_t>& node) {
return std::get<1>(node);
}

inline scalar_t to_node_t(const scalar_t& node, const scalar_t& ref) {
return node;
}
inline std::pair<scalar_t, scalar_t> to_node_t(
const scalar_t& node,
const std::pair<scalar_t, scalar_t>& ref) {
return {std::get<0>(ref), node};
}

void _sample(const node_t global_src_node,
const scalar_t local_src_node,
const scalar_t row_start,
const scalar_t row_end,
const size_t count,
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
if (count == 0)
return;

const auto row_start = rowptr_[to_scalar_t(global_src_node)];
const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];
const auto population = row_end - row_start;

if (population == 0)
Expand All @@ -98,8 +111,6 @@ class NeighborSampler {
// Case 1: Sample the full neighborhood:
if (count < 0 || (!replace && count >= population)) {
for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) {
if (time[col_[edge_id]] >= seed_time)
continue;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
Expand All @@ -109,70 +120,27 @@ class NeighborSampler {
else if (replace) {
for (size_t i = 0; i < count; ++i) {
const auto edge_id = generator(row_start, row_end);
// TODO (matthias) Improve temporal sampling logic. Currently, we sample
// `count` many random neighbors, and filter them based on temporal
// constraints afterwards. Ideally, we only sample exactly `count`
// neighbors which fullfill the time constraint.
if (time[col_[edge_id]] >= seed_time)
continue;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
}

// Case 3: Sample without replacement:
else {
std::unordered_set<scalar_t> rnd_indices;
auto index_tracker = IndexTracker<scalar_t>(population);
for (size_t i = population - count; i < population; ++i) {
auto rnd = generator(0, i + 1);
if (!rnd_indices.insert(rnd).second) {
if (!index_tracker.try_insert(rnd)) {
rnd = i;
rnd_indices.insert(i);
index_tracker.insert(i);
}
const auto edge_id = row_start + rnd;
// TODO (matthias) Improve temporal sampling logic. Currently, we sample
// `count` many random neighbors, and filter them based on temporal
// constraints afterwards. Ideally, we only sample exactly `count`
// neighbors which fullfill the time constraint.
if (time[col_[edge_id]] >= seed_time)
continue;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
}
}

std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
get_sampled_edges(bool csc = false) {
TORCH_CHECK(save_edges, "No edges have been stored")
const auto row = pyg::utils::from_vector(sampled_rows_);
const auto col = pyg::utils::from_vector(sampled_cols_);
c10::optional<at::Tensor> edge_id = c10::nullopt;
if (save_edge_ids) {
edge_id = pyg::utils::from_vector(sampled_edge_ids_);
}
if (!csc) {
return std::make_tuple(row, col, edge_id);
} else {
return std::make_tuple(col, row, edge_id);
}
}

private:
inline scalar_t to_scalar_t(const scalar_t& node) { return node; }
inline scalar_t to_scalar_t(const std::pair<scalar_t, scalar_t>& node) {
return std::get<1>(node);
}

inline scalar_t to_node_t(const scalar_t& node, const scalar_t& ref) {
return node;
}
inline std::pair<scalar_t, scalar_t> to_node_t(
const scalar_t& node,
const std::pair<scalar_t, scalar_t>& ref) {
return {std::get<0>(ref), node};
}

inline void add(const scalar_t edge_id,
const node_t global_src_node,
const scalar_t local_src_node,
Expand Down
10 changes: 9 additions & 1 deletion pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def neighbor_sample(
r"""Recursively samples neighbors from all node indices in :obj:`seed`
in the graph given by :obj:`(rowptr, col)`.
.. note::
For temporal sampling, the :obj:`col` vector needs to be sorted
according to :obj:`time` within individual neighborhoods since we use
binary search to find neighbors that fulfill temporal constraints.
Args:
rowptr (torch.Tensor): Compressed source node indices.
col (torch.Tensor): Target node indices.
Expand All @@ -34,7 +40,9 @@ def neighbor_sample(
If set, temporal sampling will be used such that neighbors are
guaranteed to fulfill temporal constraints, *i.e.* neighbors have
an earlier timestamp than the seed node.
Requires :obj:`disjoint=True`. (default: :obj:`None`)
If used, the :obj:`col` vector needs to be sorted according to time
within individual neighborhoods. Requires :obj:`disjoint=True`.
(default: :obj:`None`)
csc (bool, optional): If set to :obj:`True`, assumes that the graph is
given in CSC format :obj:`(colptr, row)`. (default: :obj:`False`)
replace (bool, optional): If set to :obj:`True`, will sample with
Expand Down
32 changes: 31 additions & 1 deletion test/csrc/sampler/test_neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,42 @@ TEST(DisjointNeighborTest, BasicAssertions) {
EXPECT_TRUE(at::equal(std::get<1>(out), expected_col));
auto expected_nodes = at::tensor(
{0, 2, 1, 3, 0, 1, 0, 3, 1, 2, 1, 4, 0, 0, 0, 4, 1, 1, 1, 5}, options);
EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({10, 2})));
EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({-1, 2})));
auto expected_edges =
at::tensor({4, 5, 6, 7, 2, 3, 6, 7, 4, 5, 8, 9}, options);
EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges));
}

TEST(TemporalNeighborTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);

auto graph = cycle_graph(/*num_nodes=*/6, options);
auto rowptr = std::get<0>(graph);
auto col = std::get<1>(graph);
auto seed = at::arange(2, 4, options);
std::vector<int64_t> num_neighbors = {2, 2};

// Time is equal to node ID ...
auto time = at::arange(6, options);
// ... so we need to sort the column vector by time/node ID:
col = std::get<0>(at::sort(col.view({-1, 2}), /*dim=*/1)).flatten();

auto out = pyg::sampler::neighbor_sample(
rowptr, col, seed, num_neighbors, /*time=*/time,
/*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true);

// Expect only the earlier neighbors to be sampled:
auto expected_row = at::tensor({0, 1, 2, 3}, options);
EXPECT_TRUE(at::equal(std::get<0>(out), expected_row));
auto expected_col = at::tensor({2, 3, 4, 5}, options);
EXPECT_TRUE(at::equal(std::get<1>(out), expected_col));
auto expected_nodes =
at::tensor({0, 2, 1, 3, 0, 1, 1, 2, 0, 0, 1, 1}, options);
EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({-1, 2})));
auto expected_edges = at::tensor({4, 6, 2, 4}, options);
EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges));
}

TEST(HeteroNeighborTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);

Expand Down

0 comments on commit c7d12f2

Please sign in to comment.