From 90db89a2372598d070e005ffd74327fd2ff20731 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Tue, 30 Jul 2024 10:24:59 -0700 Subject: [PATCH] use the correct wg communicator --- python/cugraph-pyg/cugraph_pyg/data/feature_store.py | 2 +- python/cugraph-pyg/cugraph_pyg/data/graph_store.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/data/feature_store.py b/python/cugraph-pyg/cugraph_pyg/data/feature_store.py index a3715d3..b6450e7 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/feature_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/feature_store.py @@ -169,7 +169,7 @@ def __init__(self, memory_type="distributed", location="cpu"): self.__features = {} - self.__wg_comm = wgth.get_local_node_communicator() + self.__wg_comm = wgth.get_global_communicator() self.__wg_type = memory_type self.__wg_location = location diff --git a/python/cugraph-pyg/cugraph_pyg/data/graph_store.py b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py index 622b68d..e086bf0 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/graph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py @@ -271,7 +271,7 @@ def __get_edgelist(self): torch.tensor( [self.__edge_indices[et].shape[1] for et in sorted_keys], device="cuda", - dtype=torch.int32, + dtype=torch.int64, ) )