From a32133e37c6489f443b8f018b6eae8e70c091573 Mon Sep 17 00:00:00 2001 From: Chang Liu Date: Wed, 16 Oct 2024 23:36:39 -0700 Subject: [PATCH] Add gather/scatter support 1D tensor --- .../ops/test_wholegraph_gather_scatter.py | 36 +++++++++++++------ .../pylibwholegraph/torch/tensor.py | 12 ++++--- .../pylibwholegraph/torch/wholememory_ops.py | 6 ++-- 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py index a1fbad89e..c7d8b7762 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py @@ -24,6 +24,8 @@ def gen_int_embedding(indice_tensor, embedding_dim, output_type): + if embedding_dim == 0: + embedding_dim = 1 # unsqueeze to 2D tensor for input embeddings (2D is required for scatter op) indice_count = indice_tensor.shape[0] indice_part = ( indice_tensor.type(torch.int).reshape(indice_count, 1).repeat(1, embedding_dim) @@ -54,9 +56,14 @@ def scatter_gather_test_cast( "Rank=%d testing scatter gather with embedding_count=%d, embedding_dim=%d, indice_count=%d, dt=%s, mt=%s, ml=%s" % (world_rank, embedding_count, embedding_dim, indice_count, dt, mt, ml) ) - wm_embedding = wmb.create_wholememory_matrix( - dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition - ) + if embedding_dim == 0: + wm_embedding = wmb.create_wholememory_array( + dt, embedding_count, wm_comm, mt, ml, entry_partition + ) + else: + wm_embedding = wmb.create_wholememory_matrix( + dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition + ) scatter_indice = torch.arange( world_rank, embedding_count, world_size, dtype=torch.int64 @@ -91,9 +98,13 @@ def scatter_gather_test_cast( local_ref_start = wm_embedding.get_local_entry_start() local_ref_count = wm_embedding.get_local_entry_count() assert local_start == local_ref_start - assert local_tensor_cuda.dim() == 2 + assert local_tensor_cuda.dim() == 2 if embedding_dim > 0 else 1 assert local_tensor_cuda.shape[0] == local_ref_count - assert local_tensor_cuda.shape[1] == embedding_dim + if local_tensor_cuda.dim() == 2: + assert local_tensor_cuda.shape[1] == embedding_dim + else: + # unsqueeze to 2D for comparison + local_tensor_cuda = local_tensor_cuda.unsqueeze(1) local_tensor = local_tensor_cuda.cpu() local_indices = torch.arange(local_ref_start, local_ref_start + local_ref_count, dtype=torch.int64) @@ -114,6 +125,9 @@ def scatter_gather_test_cast( ) embedding_after_gather = embedding_after_gather_cuda.cpu() ref_embedding_gather = gen_int_embedding(gather_indice, embedding_dim, torch.float) + if embedding_after_gather.dim() == 1: + # unsqueeze to 2D for comparison + embedding_after_gather = embedding_after_gather.unsqueeze(1) # print('\ngather_indice=%s\nembedding_after_gather=%s\nref_embedding_gather=%s' % ( # gather_indice, embedding_after_gather, ref_embedding_gather)) assert torch.allclose(embedding_after_gather, ref_embedding_gather) @@ -134,7 +148,6 @@ def routine_func(world_rank: int, world_size: int): wm_comm = wm_comm.wmb_comm embedding_count = 1024 * 256 * world_size + 3 - embedding_dim = 256 indice_count = 100001 dt = wmb.WholeMemoryDataType.DtFloat entry_partition = random_partition(embedding_count, world_size) @@ -150,11 +163,12 @@ def routine_func(world_rank: int, world_size: int): wmb.WholeMemoryMemoryLocation.MlHost, wmb.WholeMemoryMemoryLocation.MlDevice, ]: - if wm_comm.support_type_location(mt, ml): - scatter_gather_test_cast( - wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, True, entry_partition - ) - # scatter_gather_test_cast(wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, False) + for embedding_dim in [0, 256]: # 0 is for 1D tensor + if wm_comm.support_type_location(mt, ml): + scatter_gather_test_cast( + wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, True, entry_partition + ) + # scatter_gather_test_cast(wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, False) wmb.finalize() diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index be0b1bfff..40af38aaa 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -63,7 +63,7 @@ def gather(self, *, force_dtype: Union[torch.dtype, None] = None): assert indice.dim() == 1 - embedding_dim = self.shape[1] + embedding_dim = self.shape[1] if self.dim() == 2 else 1 embedding_count = indice.shape[0] current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),) output_dtype = ( @@ -80,15 +80,19 @@ def gather(self, wrap_torch_tensor(output_tensor), get_wholegraph_env_fns(), get_stream()) - return output_tensor + return output_tensor.view(-1) if self.dim() == 1 else output_tensor def scatter(self, input_tensor: torch.Tensor, indice: torch.Tensor): assert indice.dim() == 1 - assert input_tensor.dim() == 2 + assert input_tensor.dim() == self.dim() assert indice.shape[0] == input_tensor.shape[0] - assert input_tensor.shape[1] == self.shape[1] + if self.dim() == 2: + assert input_tensor.shape[1] == self.shape[1] + else: + # unsqueeze input to 2D tensor here because wmb_tensor is unsqueezed within scatter_op + input_tensor = input_tensor.unsqueeze(1) wmb.wholememory_scatter_op(wrap_torch_tensor(input_tensor), wrap_torch_tensor(indice), self.wmb_tensor, diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py b/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py index 5bc25d4ca..af4f56135 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py @@ -39,8 +39,10 @@ def wholememory_gather_forward_functor( assert indices_tensor.dtype == torch.int32 or indices_tensor.dtype == torch.int64 if torch_output_dtype is None: torch_output_dtype = wholememory_dtype_to_torch_dtype(wholememory_tensor.dtype) + + embedding_dim = wholememory_tensor.shape[1] if wholememory_tensor.dim() == 2 else 1 output_tensor = torch.empty( - [indices_tensor.shape[0], wholememory_tensor.shape[1]], + [indices_tensor.shape[0], embedding_dim], device="cuda", dtype=torch_output_dtype, requires_grad=requires_grad, @@ -52,7 +54,7 @@ def wholememory_gather_forward_functor( get_wholegraph_env_fns(), get_stream(), ) - return output_tensor + return output_tensor.view(-1) if wholememory_tensor.dim() == 1 else output_tensor def wholememory_scatter_functor(