Skip to content

Commit

Permalink
Fixed pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
OlhaBabicheva committed Jan 5, 2023
1 parent 6cb6c92 commit 77f55d4
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 70 deletions.
69 changes: 37 additions & 32 deletions benchmark/sampler/hetero_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
argparser.add_argument('--temporal-strategy', choices=['uniform', 'last'],
default='uniform')
argparser.add_argument('--write-csv', action='store_true')
argparser.add_argument('--libraries', nargs="*", type=str, default=['pyg-lib', 'torch-sparse'])
argparser.add_argument('--libraries', nargs="*", type=str,
default=['pyg-lib', 'torch-sparse'])
args = argparser.parse_args()


Expand All @@ -61,9 +62,10 @@ def test_hetero_neighbor(dataset, **kwargs):
node_perm = torch.randperm(num_nodes_dict['paper'])
else:
node_perm = torch.arange(0, num_nodes_dict['paper'])

if args.write_csv:
data = {'num_neighbors': [], 'batch-size': [], 'pyg-lib': [], 'torch-sparse': []}
data = {'num_neighbors': [], 'batch-size': [],
'pyg-lib': [], 'torch-sparse': []}

for num_neighbors in args.num_neighbors:
num_neighbors_dict = {key: num_neighbors for key in colptr_dict.keys()}
Expand All @@ -77,19 +79,19 @@ def test_hetero_neighbor(dataset, **kwargs):
for seed in tqdm(node_perm.split(batch_size)[:20]):
seed_dict = {'paper': seed}
pyg_lib.sampler.hetero_neighbor_sample(
colptr_dict,
row_dict,
seed_dict,
num_neighbors_dict,
node_time_dict,
None, # seed_time_dict
True, # csc
False, # replace
True, # directed
disjoint=args.disjoint,
temporal_strategy=args.temporal_strategy,
return_edge_id=True,
)
colptr_dict,
row_dict,
seed_dict,
num_neighbors_dict,
node_time_dict,
None, # seed_time_dict
True, # csc
False, # replace
True, # directed
disjoint=args.disjoint,
temporal_strategy=args.temporal_strategy,
return_edge_id=True,
)
pyg_lib_duration = time.perf_counter() - t
data['pyg-lib'].append(round(pyg_lib_duration, 3))
print(f' pyg-lib={pyg_lib_duration:.3f} seconds')
Expand All @@ -103,32 +105,35 @@ def test_hetero_neighbor(dataset, **kwargs):
colptr_dict_sparse = remap_keys(colptr_dict)
row_dict_sparse = remap_keys(row_dict)
seed_dict = {'paper': seed}
num_neighbors_dict_sparse = remap_keys(num_neighbors_dict)
num_neighbors_dict_sparse = remap_keys(
num_neighbors_dict)
num_hops = max([len(v) for v in [num_neighbors]])
torch.ops.torch_sparse.hetero_neighbor_sample(
node_types,
edge_types,
colptr_dict_sparse,
row_dict_sparse,
seed_dict,
num_neighbors_dict_sparse,
num_hops,
False, # replace
True, # directed
)
node_types,
edge_types,
colptr_dict_sparse,
row_dict_sparse,
seed_dict,
num_neighbors_dict_sparse,
num_hops,
False, # replace
True, # directed
)
torch_sparse_duration = time.perf_counter() - t
data['torch-sparse'].append(round(torch_sparse_duration, 3))
data['torch-sparse'].append(
round(torch_sparse_duration, 3))
print(f'torch-sparse={torch_sparse_duration:.3f} seconds')

# TODO (kgajdamo): Add dgl hetero sampler.
print()

if args.write_csv:
df = pd.DataFrame()
for key in data.keys():
if len(data[key]) != 0: df[key] = data[key]
for key in data.keys():
if len(data[key]) != 0:
df[key] = data[key]
df.to_csv(f'hetero_neighbor{datetime.now()}.csv', index=False)


if __name__ == '__main__':
test_hetero_neighbor()
test_hetero_neighbor()
81 changes: 43 additions & 38 deletions benchmark/sampler/neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
argparser.add_argument('--temporal-strategy', choices=['uniform', 'last'],
default='uniform')
argparser.add_argument('--write-csv', action='store_true')
argparser.add_argument('--libraries', nargs="*", type=str, default=['pyg-lib', 'torch-sparse', 'dgl'])
argparser.add_argument('--libraries', nargs="*", type=str,
default=['pyg-lib', 'torch-sparse', 'dgl'])
args = argparser.parse_args()


Expand All @@ -46,7 +47,8 @@ def test_neighbor(dataset, **kwargs):

(rowptr, col), num_nodes = dataset, dataset[0].size(0) - 1
if 'dgl' in args.libraries:
dgl_graph = dgl.graph(('csc', (rowptr, col, torch.arange(col.size(0)))))
dgl_graph = dgl.graph(
('csc', (rowptr, col, torch.arange(col.size(0)))))

if args.temporal:
# generate random timestamps
Expand All @@ -58,10 +60,11 @@ def test_neighbor(dataset, **kwargs):
node_perm = torch.randperm(num_nodes)
else:
node_perm = torch.arange(num_nodes)

if args.write_csv:
data = {'num_neighbors': [], 'batch-size': [], 'pyg-lib': [], 'torch-sparse': [], 'dgl': []}

data = {'num_neighbors': [], 'batch-size': [],
'pyg-lib': [], 'torch-sparse': [], 'dgl': []}

for num_neighbors in args.num_neighbors:
for batch_size in args.batch_sizes:
print(f'batch_size={batch_size}, num_neighbors={num_neighbors}):')
Expand All @@ -71,18 +74,18 @@ def test_neighbor(dataset, **kwargs):
t = time.perf_counter()
for seed in tqdm(node_perm.split(batch_size)):
pyg_lib.sampler.neighbor_sample(
rowptr,
col,
seed,
num_neighbors,
time=node_time,
seed_time=None,
replace=args.replace,
directed=args.directed,
disjoint=args.disjoint,
temporal_strategy=args.temporal_strategy,
return_edge_id=True,
)
rowptr,
col,
seed,
num_neighbors,
time=node_time,
seed_time=None,
replace=args.replace,
directed=args.directed,
disjoint=args.disjoint,
temporal_strategy=args.temporal_strategy,
return_edge_id=True,
)
pyg_lib_duration = time.perf_counter() - t
data['pyg-lib'].append(round(pyg_lib_duration, 3))
print(f' pyg-lib={pyg_lib_duration:.3f} seconds')
Expand All @@ -92,42 +95,44 @@ def test_neighbor(dataset, **kwargs):
t = time.perf_counter()
for seed in tqdm(node_perm.split(batch_size)):
torch.ops.torch_sparse.neighbor_sample(
rowptr,
col,
seed,
num_neighbors,
args.replace,
args.directed,
)
rowptr,
col,
seed,
num_neighbors,
args.replace,
args.directed,
)
torch_sparse_duration = time.perf_counter() - t
data['torch-sparse'].append(round(torch_sparse_duration, 3))
data['torch-sparse'].append(
round(torch_sparse_duration, 3))
print(f'torch-sparse={torch_sparse_duration:.3f} seconds')

if 'dgl' in args.libraries:
dgl_sampler = dgl.dataloading.NeighborSampler(
num_neighbors,
replace=args.replace,
)
num_neighbors,
replace=args.replace,
)
dgl_loader = dgl.dataloading.DataLoader(
dgl_graph,
node_perm,
dgl_sampler,
batch_size=batch_size,
)
dgl_graph,
node_perm,
dgl_sampler,
batch_size=batch_size,
)
t = time.perf_counter()
for _ in tqdm(dgl_loader):
pass
dgl_duration = time.perf_counter() - t
data['dgl'].append(round(dgl_duration, 3))
print(f' dgl={dgl_duration:.3f} seconds')
print()

if args.write_csv:
df = pd.DataFrame()
for key in data.keys():
if len(data[key]) != 0: df[key] = data[key]
for key in data.keys():
if len(data[key]) != 0:
df[key] = data[key]
df.to_csv(f'neighbor{datetime.now()}.csv', index=False)


if __name__ == '__main__':
test_neighbor()
test_neighbor()

0 comments on commit 77f55d4

Please sign in to comment.