Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Temporal heterogeneous neighbor sampling #97

Merged
merged 5 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added CSC mode to `pyg::sampler::neighbor_sample` and `pyg::sampler::hetero_neighbor_sample` ([#95](https://github.com/pyg-team/pyg-lib/pull/95), [#96](https://github.com/pyg-team/pyg-lib/pull/96))
- Speed up `pyg::sampler::neighbor_sample` via `IndexTracker` implementation ([#84](https://github.com/pyg-team/pyg-lib/pull/84))
- Added `pyg::sampler::hetero_neighbor_sample` implementation ([#90](https://github.com/pyg-team/pyg-lib/pull/90), [#92](https://github.com/pyg-team/pyg-lib/pull/92), [#94](https://github.com/pyg-team/pyg-lib/pull/94))
- Added `pyg::sampler::hetero_neighbor_sample` implementation ([#90](https://github.com/pyg-team/pyg-lib/pull/90), [#92](https://github.com/pyg-team/pyg-lib/pull/92), [#94](https://github.com/pyg-team/pyg-lib/pull/94), [#97](https://github.com/pyg-team/pyg-lib/pull/97))
- Added `pyg::utils::to_vector` implementation ([#88](https://github.com/pyg-team/pyg-lib/pull/88))
- Added support for PyTorch 1.12 ([#57](https://github.com/pyg-team/pyg-lib/pull/57), [#58](https://github.com/pyg-team/pyg-lib/pull/58))
- Added `grouped_matmul` and `segment_matmul` CUDA implementations via `cutlass` ([#51](https://github.com/pyg-team/pyg-lib/pull/51), [#56](https://github.com/pyg-team/pyg-lib/pull/56), [#61](https://github.com/pyg-team/pyg-lib/pull/61), [#64](https://github.com/pyg-team/pyg-lib/pull/64), [#69](https://github.com/pyg-team/pyg-lib/pull/69))
Expand Down
45 changes: 34 additions & 11 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class NeighborSampler {
// Case 1: Sample the full neighborhood:
if (count < 0 || (!replace && count >= population)) {
for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) {
if (time[col_[edge_id]] <= seed_time)
if (time[col_[edge_id]] > seed_time)
continue;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
Expand All @@ -113,7 +113,7 @@ class NeighborSampler {
// `count` many random neighbors, and filter them based on temporal
// constraints afterwards. Ideally, we only sample exactly `count`
// neighbors which fullfill the time constraint.
if (time[col_[edge_id]] <= seed_time)
if (time[col_[edge_id]] > seed_time)
continue;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
Expand All @@ -134,7 +134,7 @@ class NeighborSampler {
// `count` many random neighbors, and filter them based on temporal
// constraints afterwards. Ideally, we only sample exactly `count`
// neighbors which fullfill the time constraint.
if (time[col_[edge_id]] <= seed_time)
if (time[col_[edge_id]] > seed_time)
continue;
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
Expand Down Expand Up @@ -267,11 +267,11 @@ sample(const at::Tensor& rowptr,

out_node_id = pyg::utils::from_vector(sampled_nodes);

TORCH_CHECK(directed, "Undirected subgraphs not yet supported");
if (directed) {
std::tie(out_row, out_col, out_edge_id) = sampler.get_sampled_edges(csc);
} else {
TORCH_CHECK(!disjoint, "Disjoint subgraphs not yet supported");
// TODO
}
});
return std::make_tuple(out_row, out_col, out_node_id, out_edge_id);
Expand Down Expand Up @@ -324,6 +324,7 @@ sample(const std::vector<node_type>& node_types,
phmap::flat_hash_map<node_type, Mapper<node_t, scalar_t>> mapper_dict;
phmap::flat_hash_map<edge_type, NeighborSamplerImpl> sampler_dict;
phmap::flat_hash_map<node_type, std::pair<size_t, size_t>> slice_dict;
std::vector<scalar_t> seed_times;
for (const auto& k : node_types) {
sampled_nodes_dict[k]; // Initialize empty vector;
mapper_dict.insert({k, Mapper<node_t, scalar_t>(num_nodes_dict.at(k))});
Expand All @@ -337,6 +338,7 @@ sample(const std::vector<node_type>& node_types,
col_dict.at(to_rel_type(k)).data_ptr<scalar_t>())});
}

scalar_t i = 0;
for (const auto& kv : seed_dict) {
const at::Tensor& seed = kv.value();
slice_dict[kv.key()] = {0, seed.size(0)};
Expand All @@ -345,12 +347,24 @@ sample(const std::vector<node_type>& node_types,
sampled_nodes_dict[kv.key()] = pyg::utils::to_vector<scalar_t>(seed);
mapper_dict.at(kv.key()).fill(seed);
} else {
auto sampled_nodes = sampled_nodes_dict.at(kv.key());
auto mapper = mapper_dict.at(kv.key());
auto& sampled_nodes = sampled_nodes_dict.at(kv.key());
auto& mapper = mapper_dict.at(kv.key());
const auto seed_data = seed.data_ptr<scalar_t>();
for (size_t i = 0; i < seed.numel(); i++) {
sampled_nodes.push_back({i, seed_data[i]});
mapper.insert({i, seed_data[i]});
if (!time_dict.has_value()) {
for (size_t j = 0; j < seed.numel(); j++) {
sampled_nodes.push_back({i, seed_data[j]});
mapper.insert({i, seed_data[j]});
i++;
}
} else {
const at::Tensor& time = time_dict.value().at(kv.key());
const auto time_data = time.data_ptr<scalar_t>();
for (size_t j = 0; j < seed.numel(); j++) {
sampled_nodes.push_back({i, seed_data[j]});
mapper.insert({i, seed_data[j]});
seed_times.push_back(time_data[j]);
i++;
}
}
}
}
Expand All @@ -367,13 +381,22 @@ sample(const std::vector<node_type>& node_types,
auto& sampler = sampler_dict.at(k);
std::tie(begin, end) = slice_dict.at(src);

if (!time_dict.has_value()) {
if (!time_dict.has_value() || !time_dict.value().contains(dst)) {
for (size_t i = begin; i < end; ++i) {
sampler.uniform_sample(/*global_src_node=*/src_sampled_nodes[i],
/*local_src_node=*/i, count, dst_mapper,
generator, dst_sampled_nodes);
}
} else if constexpr (!std::is_scalar<node_t>::value) { // Temporal:
const at::Tensor& dst_time = time_dict.value().at(dst);
const auto dst_time_data = dst_time.data_ptr<scalar_t>();
for (size_t i = begin; i < end; ++i) {
const auto batch_idx = src_sampled_nodes[i].first;
sampler.temporal_sample(/*global_src_node=*/src_sampled_nodes[i],
/*local_src_node=*/i, count,
seed_times[batch_idx], dst_time_data,
dst_mapper, generator, dst_sampled_nodes);
}
}
}
for (const auto& k : node_types) {
Expand All @@ -387,7 +410,7 @@ sample(const std::vector<node_type>& node_types,
k, pyg::utils::from_vector(sampled_nodes_dict.at(k)));
}

TORCH_CHECK(directed, "Undirected Heterogeneous graphs not yet supported");
TORCH_CHECK(directed, "Undirected heterogeneous graphs not yet supported");
if (directed) {
for (const auto& k : edge_types) {
const auto edges = sampler_dict.at(k).get_sampled_edges(csc);
Expand Down