Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark scripts and pyg::sampler::Mapper #45

Merged
merged 9 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
- Added `pyg::sampler::Mapper` utility for mapping global to local node indices ([#45](https://github.com/pyg-team/pyg-lib/pull/45)
- Added benchmark script ([#45](https://github.com/pyg-team/pyg-lib/pull/45)
- Added download script for benchmark data ([#44](https://github.com/pyg-team/pyg-lib/pull/44)
- Added `biased sampling` utils ([#38](https://github.com/pyg-team/pyg-lib/pull/38))
- Added `CHANGELOG.md` ([#39](https://github.com/pyg-team/pyg-lib/pull/39))
Expand Down
31 changes: 31 additions & 0 deletions benchmark/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import time

import torch

import pyg_lib
from pyg_lib.testing import to_edge_index, withDataset, withSeed


@withSeed
@withDataset('DIMACS10', 'citationCiteseer')
def test_subgraph(dataset, **kwargs):
(rowptr, col), num_nodes = dataset, dataset[0].size(0) - 1
perm = torch.randperm(num_nodes, dtype=rowptr.dtype, device=rowptr.device)
nodes = perm[:num_nodes // 100]

t = time.perf_counter()
for _ in range(10):
pyg_lib.sampler.subgraph(rowptr, col, nodes)
print(time.perf_counter() - t)

edge_index = to_edge_index(rowptr, col)
from torch_geometric.utils import subgraph

t = time.perf_counter()
for _ in range(10):
subgraph(nodes, edge_index, num_nodes=num_nodes, relabel_nodes=True)
print(time.perf_counter() - t)


if __name__ == '__main__':
test_subgraph()
61 changes: 61 additions & 0 deletions pyg_lib/csrc/sampler/cpu/mapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#pragma once

#include <ATen/ATen.h>

namespace pyg {
namespace sampler {

template <typename scalar_t>
class Mapper {
public:
Mapper(scalar_t num_nodes, scalar_t num_entries)
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
: num_nodes(num_nodes), num_entries(num_entries) {
// Use a some simple heuristic to determine whether we can use a std::vector
// to perform the mapping instead of relying on the more memory-friendly,
// but slower std::unordered_map implementation:
use_vec = (num_nodes < 1000000) || (num_entries > num_nodes / 10);

if (use_vec)
to_local_vec = std::vector<scalar_t>(num_nodes, -1);
}

void fill(const scalar_t* nodes_data, const scalar_t size) {
if (use_vec) {
for (scalar_t i = 0; i < size; ++i)
to_local_vec[nodes_data[i]] = i;
} else {
for (scalar_t i = 0; i < size; ++i)
to_local_map.insert({nodes_data[i], i});
}
}

void fill(const at::Tensor& nodes) {
fill(nodes.data_ptr<scalar_t>(), nodes.numel());
}

bool exists(const scalar_t& node) {
if (use_vec)
return to_local_vec[node] >= 0;
else
return to_local_map.count(node) > 0;
}

scalar_t map(const scalar_t& node) {
if (use_vec)
return to_local_vec[node];
else {
const auto search = to_local_map.find(node);
return search != to_local_map.end() ? search->second : -1;
}
}

private:
scalar_t num_nodes, num_entries;

bool use_vec;
std::vector<scalar_t> to_local_vec;
std::unordered_map<scalar_t, scalar_t> to_local_map;
};

} // namespace sampler
} // namespace pyg
24 changes: 12 additions & 12 deletions pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/Parallel.h>
#include <torch/library.h>

#include "pyg_lib/csrc/sampler/cpu/mapper.h"
#include "pyg_lib/csrc/utils/cpu/convert.h"

namespace pyg {
Expand All @@ -18,31 +19,31 @@ std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph_kernel(
TORCH_CHECK(col.is_cpu(), "'col' must be a CPU tensor");
TORCH_CHECK(nodes.is_cpu(), "'nodes' must be a CPU tensor");

const auto deg = rowptr.new_empty({nodes.size(0)});
const auto num_nodes = rowptr.size(0) - 1;
const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1});
at::Tensor out_col;
c10::optional<at::Tensor> out_edge_id = c10::nullopt;

AT_DISPATCH_INTEGRAL_TYPES(nodes.scalar_type(), "subgraph_kernel", [&] {
auto mapper = pyg::sampler::Mapper<scalar_t>(num_nodes, nodes.size(0));
mapper.fill(nodes);

const auto rowptr_data = rowptr.data_ptr<scalar_t>();
const auto col_data = col.data_ptr<scalar_t>();
const auto nodes_data = nodes.data_ptr<scalar_t>();

std::unordered_map<scalar_t, scalar_t> to_local_node;
for (scalar_t i = 0; i < nodes.size(0); ++i) // TODO parallelize
to_local_node.insert({nodes_data[i], i});

// We first iterate over all nodes and collect information about the number
// of edges in the induced subgraph.
const auto deg = rowptr.new_empty({nodes.size(0)});
auto deg_data = deg.data_ptr<scalar_t>();
auto grain_size = at::internal::GRAIN_SIZE;
at::parallel_for(0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) {
for (scalar_t i = _s; i < _e; ++i) {
for (size_t i = _s; i < _e; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like these nit.

const auto v = nodes_data[i];
// Iterate over all neighbors and check if they are part of `nodes`:
scalar_t d = 0;
for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) {
if (to_local_node.count(col_data[j]) > 0)
for (size_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) {
if (mapper.exists(col_data[j]))
d++;
}
deg_data[i] = d;
Expand Down Expand Up @@ -73,10 +74,9 @@ std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph_kernel(
// Iterate over all neighbors and check if they are part of `nodes`:
scalar_t offset = out_rowptr_data[i];
for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) {
const auto w = col_data[j];
const auto search = to_local_node.find(w);
if (search != to_local_node.end()) {
out_col_data[offset] = search->second;
const auto w = mapper.map(col_data[j]);
if (w >= 0) {
out_col_data[offset] = w;
if (return_edge_id)
out_edge_id_data[offset] = j;
offset++;
Expand Down
47 changes: 46 additions & 1 deletion pyg_lib/testing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,51 @@
import os
import os.path as osp
from typing import Optional, Tuple
from typing import Callable, Optional, Tuple

import torch
from torch import Tensor

from pyg_lib import get_home_dir

# Decorators ##################################################################


def withSeed(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
torch.manual_seed(12345)
func(*args, **kwargs)

return wrapper


def withCUDA(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
func(*args, device=torch.device('cpu'), **kwargs)
if torch.cuda.is_available():
func(*args, device=torch.device('cuda:0'), **kwargs)

return wrapper


def withDataset(group: str, name: str) -> Callable:
def decorator(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
dataset = get_sparse_matrix(
group,
name,
dtype=kwargs.get('dtype', torch.long),
device=kwargs.get('device', None),
)

func(*args, dataset=dataset, **kwargs)

return wrapper

return decorator


# Helper functions ############################################################


def get_sparse_matrix(
group: str,
Expand Down Expand Up @@ -48,3 +87,9 @@ def get_sparse_matrix(
col = torch.from_numpy(mat.indices).to(device, dtype)

return rowptr, col


def to_edge_index(rowptr: Tensor, col: Tensor) -> Tensor:
row = torch.arange(rowptr.size(0) - 1, dtype=col.dtype, device=col.device)
row = row.repeat_interleave(rowptr[1:] - rowptr[:-1])
return torch.stack([row, col], dim=0)