Skip to content

Commit

Permalink
[Feature] Add gather/scatter support 1D tensor (#74)
Browse files Browse the repository at this point in the history
Migrated from:  rapidsai/wholegraph#229 



This PR is to add gather/scatter support 1D tensor on python level, as WholeGraph should support basic indexing operations for both 1D (array) and 2D (matrix) wholememory tensors.   Without this PR, if with 1D wholememory tensor, gather/scatter op does not work, e.g., https://github.com/rapidsai/wholegraph/blob/0efba33835d6e4e104b5d7101a91e0ea55a6ca53/python/pylibwholegraph/pylibwholegraph/torch/tensor.py#L89



To test, run 
```
pytest --cache-clear  --import-mode=append  tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py -s
```

**Remaining issue:**

On my local test with single GPU, the test can pass.   
For multiGPU setup, gather op works fine, but 1D scatter seems not working as it would crash at:
https://github.com/rapidsai/wholegraph/blob/2e963b98aa6027c300d60e839010d3dd8ca422eb/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_gather_scatter.py#L108 with incorrect scatter outputs: `Indices where allclose fails:  tensor([0., 0., 0.,  ..., 0., 0., 0.]) tensor([  1435.,   1439.,   1443.,  ..., 257703., 257707., 257711.]) `

This would work if this bugfix is merged: #73

cc. @linhu-nv

Authors:
  - Chang Liu (https://github.com/chang-l)

Approvers:
  - https://github.com/linhu-nv
  - Alex Barghi (https://github.com/alexbarghi-nv)

URL: #74
  • Loading branch information
chang-l authored Nov 22, 2024
1 parent aa099e4 commit 2776772
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@


def gen_int_embedding(indice_tensor, embedding_dim, output_type):
if embedding_dim == 0:
embedding_dim = 1 # unsqueeze 2D for input (2D is required for scatter op)
indice_count = indice_tensor.shape[0]
indice_part = (
indice_tensor.type(torch.int).reshape(indice_count, 1).repeat(1, embedding_dim)
Expand Down Expand Up @@ -57,9 +59,14 @@ def scatter_gather_test_cast(
f"embedding_dim={embedding_dim}, "
f"indice_count={indice_count}, dt={dt}, mt={mt}, ml={ml}"
)
wm_embedding = wmb.create_wholememory_matrix(
dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition
)
if embedding_dim == 0:
wm_embedding = wmb.create_wholememory_array(
dt, embedding_count, wm_comm, mt, ml, entry_partition
)
else:
wm_embedding = wmb.create_wholememory_matrix(
dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition
)

scatter_indice = torch.arange(
world_rank, embedding_count, world_size, dtype=torch.int64
Expand Down Expand Up @@ -93,9 +100,13 @@ def scatter_gather_test_cast(
local_ref_start = wm_embedding.get_local_entry_start()
local_ref_count = wm_embedding.get_local_entry_count()
assert local_start == local_ref_start
assert local_tensor_cuda.dim() == 2
assert local_tensor_cuda.dim() == 2 if embedding_dim > 0 else 1
assert local_tensor_cuda.shape[0] == local_ref_count
assert local_tensor_cuda.shape[1] == embedding_dim
if local_tensor_cuda.dim() == 2:
assert local_tensor_cuda.shape[1] == embedding_dim
else:
# unsqueeze to 2D for comparison
local_tensor_cuda = local_tensor_cuda.unsqueeze(1)

local_tensor = local_tensor_cuda.cpu()
local_indices = torch.arange(
Expand All @@ -118,6 +129,9 @@ def scatter_gather_test_cast(
)
embedding_after_gather = embedding_after_gather_cuda.cpu()
ref_embedding_gather = gen_int_embedding(gather_indice, embedding_dim, torch.float)
if embedding_after_gather.dim() == 1:
# unsqueeze to 2D for comparison
embedding_after_gather = embedding_after_gather.unsqueeze(1)
# print('\ngather_indice=%s\nembedding_after_gather=%s\nref_embedding_gather=%s' % (
# gather_indice, embedding_after_gather, ref_embedding_gather))
assert torch.allclose(embedding_after_gather, ref_embedding_gather)
Expand All @@ -138,7 +152,6 @@ def routine_func(world_rank: int, world_size: int):
wm_comm = wm_comm.wmb_comm

embedding_count = 1024 * 256 * world_size + 3
embedding_dim = 256
indice_count = 100001
dt = wmb.WholeMemoryDataType.DtFloat
entry_partition = random_partition(embedding_count, world_size)
Expand All @@ -154,18 +167,19 @@ def routine_func(world_rank: int, world_size: int):
wmb.WholeMemoryMemoryLocation.MlHost,
wmb.WholeMemoryMemoryLocation.MlDevice,
]:
if wm_comm.support_type_location(mt, ml):
scatter_gather_test_cast(
wm_comm,
dt,
mt,
ml,
embedding_count,
embedding_dim,
indice_count,
True,
entry_partition,
)
for embedding_dim in [0, 256]: # 0 is for 1D tensor
if wm_comm.support_type_location(mt, ml):
scatter_gather_test_cast(
wm_comm,
dt,
mt,
ml,
embedding_count,
embedding_dim,
indice_count,
True,
entry_partition,
)
wmb.finalize()


Expand Down
12 changes: 8 additions & 4 deletions python/pylibwholegraph/pylibwholegraph/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def gather(
self, indice: torch.Tensor, *, force_dtype: Union[torch.dtype, None] = None
):
assert indice.dim() == 1
embedding_dim = self.shape[1]
embedding_dim = self.shape[1] if self.dim() == 2 else 1
embedding_count = indice.shape[0]
current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),)
output_dtype = force_dtype if force_dtype is not None else self.dtype
Expand All @@ -79,13 +79,17 @@ def gather(
get_wholegraph_env_fns(),
get_stream(),
)
return output_tensor
return output_tensor.view(-1) if self.dim() == 1 else output_tensor

def scatter(self, input_tensor: torch.Tensor, indice: torch.Tensor):
assert indice.dim() == 1
assert input_tensor.dim() == 2
assert input_tensor.dim() == self.dim()
assert indice.shape[0] == input_tensor.shape[0]
assert input_tensor.shape[1] == self.shape[1]
if self.dim() == 2:
assert input_tensor.shape[1] == self.shape[1]
else:
# unsqueeze to 2D tensor because wmb_tensor is unsqueezed within scatter_op
input_tensor = input_tensor.unsqueeze(1)
wmb.wholememory_scatter_op(
wrap_torch_tensor(input_tensor),
wrap_torch_tensor(indice),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ def wholememory_gather_forward_functor(
assert indices_tensor.dtype == torch.int32 or indices_tensor.dtype == torch.int64
if torch_output_dtype is None:
torch_output_dtype = wholememory_dtype_to_torch_dtype(wholememory_tensor.dtype)

embedding_dim = wholememory_tensor.shape[1] if wholememory_tensor.dim() == 2 else 1
output_tensor = torch.empty(
[indices_tensor.shape[0], wholememory_tensor.shape[1]],
[indices_tensor.shape[0], embedding_dim],
device="cuda",
dtype=torch_output_dtype,
requires_grad=requires_grad,
Expand All @@ -52,7 +54,7 @@ def wholememory_gather_forward_functor(
get_wholegraph_env_fns(),
get_stream(),
)
return output_tensor
return output_tensor.view(-1) if wholememory_tensor.dim() == 1 else output_tensor


def wholememory_scatter_functor(
Expand Down

0 comments on commit 2776772

Please sign in to comment.