-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Benchmark scripts and
pyg::sampler::Mapper
(#45)
* added benchmark * update * update * update * update * changelog * update * todo Co-authored-by: Zeyuan Tan <[email protected]>
Showing
5 changed files
with
154 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import time | ||
|
||
import torch | ||
|
||
import pyg_lib | ||
from pyg_lib.testing import to_edge_index, withDataset, withSeed | ||
|
||
|
||
@withSeed | ||
@withDataset('DIMACS10', 'citationCiteseer') | ||
def test_subgraph(dataset, **kwargs): | ||
(rowptr, col), num_nodes = dataset, dataset[0].size(0) - 1 | ||
perm = torch.randperm(num_nodes, dtype=rowptr.dtype, device=rowptr.device) | ||
nodes = perm[:num_nodes // 100] | ||
|
||
t = time.perf_counter() | ||
for _ in range(10): | ||
pyg_lib.sampler.subgraph(rowptr, col, nodes) | ||
print(time.perf_counter() - t) | ||
|
||
edge_index = to_edge_index(rowptr, col) | ||
from torch_geometric.utils import subgraph | ||
|
||
t = time.perf_counter() | ||
for _ in range(10): | ||
subgraph(nodes, edge_index, num_nodes=num_nodes, relabel_nodes=True) | ||
print(time.perf_counter() - t) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_subgraph() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#pragma once | ||
|
||
#include <ATen/ATen.h> | ||
|
||
namespace pyg { | ||
namespace sampler { | ||
|
||
// TODO Implement `Mapper` as an interface/abstract class to allow for other | ||
// implementations as well. | ||
template <typename scalar_t> | ||
class Mapper { | ||
public: | ||
Mapper(scalar_t num_nodes, scalar_t num_entries) | ||
: num_nodes(num_nodes), num_entries(num_entries) { | ||
// Use a some simple heuristic to determine whether we can use a std::vector | ||
// to perform the mapping instead of relying on the more memory-friendly, | ||
// but slower std::unordered_map implementation: | ||
use_vec = (num_nodes < 1000000) || (num_entries > num_nodes / 10); | ||
|
||
if (use_vec) | ||
to_local_vec = std::vector<scalar_t>(num_nodes, -1); | ||
} | ||
|
||
void fill(const scalar_t* nodes_data, const scalar_t size) { | ||
if (use_vec) { | ||
for (scalar_t i = 0; i < size; ++i) | ||
to_local_vec[nodes_data[i]] = i; | ||
} else { | ||
for (scalar_t i = 0; i < size; ++i) | ||
to_local_map.insert({nodes_data[i], i}); | ||
} | ||
} | ||
|
||
void fill(const at::Tensor& nodes) { | ||
fill(nodes.data_ptr<scalar_t>(), nodes.numel()); | ||
} | ||
|
||
bool exists(const scalar_t& node) { | ||
if (use_vec) | ||
return to_local_vec[node] >= 0; | ||
else | ||
return to_local_map.count(node) > 0; | ||
} | ||
|
||
scalar_t map(const scalar_t& node) { | ||
if (use_vec) | ||
return to_local_vec[node]; | ||
else { | ||
const auto search = to_local_map.find(node); | ||
return search != to_local_map.end() ? search->second : -1; | ||
} | ||
} | ||
|
||
private: | ||
scalar_t num_nodes, num_entries; | ||
|
||
bool use_vec; | ||
std::vector<scalar_t> to_local_vec; | ||
std::unordered_map<scalar_t, scalar_t> to_local_map; | ||
}; | ||
|
||
} // namespace sampler | ||
} // namespace pyg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters