From 79b0849f15b15ead9d4e22088262f58b8d90faa2 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 8 Sep 2022 13:06:39 +0000 Subject: [PATCH 1/4] update --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 31 +++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 4c26f0407..9b849433f 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -322,6 +322,7 @@ sample(const std::vector& node_types, phmap::flat_hash_map> mapper_dict; phmap::flat_hash_map sampler_dict; phmap::flat_hash_map> slice_dict; + std::vector seed_times; for (const auto& k : node_types) { sampled_nodes_dict[k]; // Initialize empty vector; mapper_dict.insert({k, Mapper(num_nodes_dict.at(k))}); @@ -335,6 +336,7 @@ sample(const std::vector& node_types, col_dict.at(to_rel_type(k)).data_ptr())}); } + scalar_t i = 0; for (const auto& kv : seed_dict) { const at::Tensor& seed = kv.value(); slice_dict[kv.key()] = {0, seed.size(0)}; @@ -346,9 +348,21 @@ sample(const std::vector& node_types, auto sampled_nodes = sampled_nodes_dict.at(kv.key()); auto mapper = mapper_dict.at(kv.key()); const auto seed_data = seed.data_ptr(); - for (size_t i = 0; i < seed.numel(); i++) { - sampled_nodes.push_back({i, seed_data[i]}); - mapper.insert({i, seed_data[i]}); + if (!time_dict.has_value()) { + for (size_t j = 0; j < seed.numel(); j++) { + sampled_nodes.push_back({i, seed_data[j]}); + mapper.insert({i, seed_data[j]}); + i++; + } + } else { + const at::Tensor& time = time_dict.value().at(kv.key()); + const auto time_data = time.data_ptr(); + for (size_t j = 0; j < seed.numel(); j++) { + sampled_nodes.push_back({i, seed_data[j]}); + mapper.insert({i, seed_data[j]}); + seed_times.push_back(time_data[j]); + i++; + } } } } @@ -365,13 +379,22 @@ sample(const std::vector& node_types, auto& sampler = sampler_dict.at(k); std::tie(begin, end) = slice_dict.at(src); - if (!time_dict.has_value()) { + if (!time_dict.has_value() || !time_dict.value().contains(dst)) { for (size_t i = begin; i < end; ++i) { sampler.uniform_sample(/*global_src_node=*/src_sampled_nodes[i], /*local_src_node=*/i, count, dst_mapper, generator, dst_sampled_nodes); } } else if constexpr (!std::is_scalar::value) { // Temporal: + const at::Tensor& dst_time = time_dict.value().at(dst); + const auto dst_time_data = dst_time.data_ptr(); + for (size_t i = begin; i < end; ++i) { + const auto batch_idx = src_sampled_nodes[i].first; + sampler.temporal_sample(/*global_src_node=*/src_sampled_nodes[i], + /*local_src_node=*/i, count, + seed_times[batch_idx], dst_time_data, + dst_mapper, generator, dst_sampled_nodes); + } } } for (const auto& k : node_types) { From 3fe4cd01edd8eba20352d8dd6cc60e00af3b7c7d Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 8 Sep 2022 14:02:00 +0000 Subject: [PATCH 2/4] update --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 9b849433f..84ba7f738 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -98,8 +98,11 @@ class NeighborSampler { // Case 1: Sample the full neighborhood: if (count < 0 || (!replace && count >= population)) { for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) { - if (time[col_[edge_id]] <= seed_time) + std::cout << "seed_time " << seed_time << " to " << time[col_[edge_id]] + << std::endl; + if (time[col_[edge_id]] > seed_time) continue; + std::cout << "add" << std::endl; add(edge_id, global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); } @@ -113,7 +116,7 @@ class NeighborSampler { // `count` many random neighbors, and filter them based on temporal // constraints afterwards. Ideally, we only sample exactly `count` // neighbors which fullfill the time constraint. - if (time[col_[edge_id]] <= seed_time) + if (time[col_[edge_id]] > seed_time) continue; add(edge_id, global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); @@ -134,7 +137,7 @@ class NeighborSampler { // `count` many random neighbors, and filter them based on temporal // constraints afterwards. Ideally, we only sample exactly `count` // neighbors which fullfill the time constraint. - if (time[col_[edge_id]] <= seed_time) + if (time[col_[edge_id]] > seed_time) continue; add(edge_id, global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); @@ -268,6 +271,7 @@ sample(const at::Tensor& rowptr, } else if (directed && csc) { std::tie(out_col, out_row, out_edge_id) = sampler.get_sampled_edges(); } else { + TORCH_CHECK(directed, "Undirected subgraphs not yet supported"); TORCH_CHECK(!disjoint, "Disjoint subgraphs not yet supported"); // TODO } @@ -345,8 +349,8 @@ sample(const std::vector& node_types, sampled_nodes_dict[kv.key()] = pyg::utils::to_vector(seed); mapper_dict.at(kv.key()).fill(seed); } else { - auto sampled_nodes = sampled_nodes_dict.at(kv.key()); - auto mapper = mapper_dict.at(kv.key()); + auto& sampled_nodes = sampled_nodes_dict.at(kv.key()); + auto& mapper = mapper_dict.at(kv.key()); const auto seed_data = seed.data_ptr(); if (!time_dict.has_value()) { for (size_t j = 0; j < seed.numel(); j++) { @@ -389,7 +393,9 @@ sample(const std::vector& node_types, const at::Tensor& dst_time = time_dict.value().at(dst); const auto dst_time_data = dst_time.data_ptr(); for (size_t i = begin; i < end; ++i) { + std::cout << src << " " << dst << std::endl; const auto batch_idx = src_sampled_nodes[i].first; + std::cout << "batch " << batch_idx << std::endl; sampler.temporal_sample(/*global_src_node=*/src_sampled_nodes[i], /*local_src_node=*/i, count, seed_times[batch_idx], dst_time_data, From ffffb9886bf516241bd3b47772a4bd87eef72ac9 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 8 Sep 2022 14:03:55 +0000 Subject: [PATCH 3/4] update --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 84ba7f738..31dec6021 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -98,11 +98,8 @@ class NeighborSampler { // Case 1: Sample the full neighborhood: if (count < 0 || (!replace && count >= population)) { for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) { - std::cout << "seed_time " << seed_time << " to " << time[col_[edge_id]] - << std::endl; if (time[col_[edge_id]] > seed_time) continue; - std::cout << "add" << std::endl; add(edge_id, global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); } @@ -266,14 +263,13 @@ sample(const at::Tensor& rowptr, out_node_id = pyg::utils::from_vector(sampled_nodes); + TORCH_CHECK(directed, "Undirected subgraphs not yet supported"); if (directed && !csc) { std::tie(out_row, out_col, out_edge_id) = sampler.get_sampled_edges(); } else if (directed && csc) { std::tie(out_col, out_row, out_edge_id) = sampler.get_sampled_edges(); } else { - TORCH_CHECK(directed, "Undirected subgraphs not yet supported"); TORCH_CHECK(!disjoint, "Disjoint subgraphs not yet supported"); - // TODO } }); return std::make_tuple(out_row, out_col, out_node_id, out_edge_id); @@ -393,9 +389,7 @@ sample(const std::vector& node_types, const at::Tensor& dst_time = time_dict.value().at(dst); const auto dst_time_data = dst_time.data_ptr(); for (size_t i = begin; i < end; ++i) { - std::cout << src << " " << dst << std::endl; const auto batch_idx = src_sampled_nodes[i].first; - std::cout << "batch " << batch_idx << std::endl; sampler.temporal_sample(/*global_src_node=*/src_sampled_nodes[i], /*local_src_node=*/i, count, seed_times[batch_idx], dst_time_data, @@ -414,7 +408,7 @@ sample(const std::vector& node_types, k, pyg::utils::from_vector(sampled_nodes_dict.at(k))); } - TORCH_CHECK(directed, "Undirected Heterogeneous graphs not yet supported"); + TORCH_CHECK(directed, "Undirected heterogeneous graphs not yet supported"); if (directed) { for (const auto& k : edge_types) { const auto edges = sampler_dict.at(k).get_sampled_edges(); From 4f2e2353b2c178e143efe7f0abcbf028e19f3982 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 8 Sep 2022 14:04:28 +0000 Subject: [PATCH 4/4] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c04511f38..8fe0309b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added CSC mode to `pyg::sampler::neighbor_sample` and `pyg::sampler::hetero_neighbor_sample` ([#95](https://github.com/pyg-team/pyg-lib/pull/95)) - Speed up `pyg::sampler::neighbor_sample` via `IndexTracker` implementation ([#84](https://github.com/pyg-team/pyg-lib/pull/84)) -- Added `pyg::sampler::hetero_neighbor_sample` implementation ([#90](https://github.com/pyg-team/pyg-lib/pull/90), [#92](https://github.com/pyg-team/pyg-lib/pull/92), [#94](https://github.com/pyg-team/pyg-lib/pull/94)) +- Added `pyg::sampler::hetero_neighbor_sample` implementation ([#90](https://github.com/pyg-team/pyg-lib/pull/90), [#92](https://github.com/pyg-team/pyg-lib/pull/92), [#94](https://github.com/pyg-team/pyg-lib/pull/94), [#97](https://github.com/pyg-team/pyg-lib/pull/97)) - Added `pyg::utils::to_vector` implementation ([#88](https://github.com/pyg-team/pyg-lib/pull/88)) - Added support for PyTorch 1.12 ([#57](https://github.com/pyg-team/pyg-lib/pull/57), [#58](https://github.com/pyg-team/pyg-lib/pull/58)) - Added `grouped_matmul` and `segment_matmul` CUDA implementations via `cutlass` ([#51](https://github.com/pyg-team/pyg-lib/pull/51), [#56](https://github.com/pyg-team/pyg-lib/pull/56), [#61](https://github.com/pyg-team/pyg-lib/pull/61), [#64](https://github.com/pyg-team/pyg-lib/pull/64), [#69](https://github.com/pyg-team/pyg-lib/pull/69))