Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
chang-l committed Nov 21, 2024
1 parent 0efba33 commit a12a0a2
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@


def gen_int_embedding(indice_tensor, embedding_dim, output_type):
if embedding_dim == 0:
embedding_dim = 1 # unsqueeze to 2D tensor for input embeddings (2D is required for scatter op)
indice_count = indice_tensor.shape[0]
indice_part = (
indice_tensor.type(torch.int).reshape(indice_count, 1).repeat(1, embedding_dim)
Expand Down Expand Up @@ -54,9 +56,14 @@ def scatter_gather_test_cast(
"Rank=%d testing scatter gather with embedding_count=%d, embedding_dim=%d, indice_count=%d, dt=%s, mt=%s, ml=%s"
% (world_rank, embedding_count, embedding_dim, indice_count, dt, mt, ml)
)
wm_embedding = wmb.create_wholememory_matrix(
dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition
)
if embedding_dim == 0:
wm_embedding = wmb.create_wholememory_array(
dt, embedding_count, wm_comm, mt, ml, entry_partition
)
else:
wm_embedding = wmb.create_wholememory_matrix(
dt, embedding_count, embedding_dim, -1, wm_comm, mt, ml, entry_partition
)

scatter_indice = torch.arange(
world_rank, embedding_count, world_size, dtype=torch.int64
Expand Down Expand Up @@ -91,9 +98,13 @@ def scatter_gather_test_cast(
local_ref_start = wm_embedding.get_local_entry_start()
local_ref_count = wm_embedding.get_local_entry_count()
assert local_start == local_ref_start
assert local_tensor_cuda.dim() == 2
assert local_tensor_cuda.dim() == 2 if embedding_dim > 0 else 1
assert local_tensor_cuda.shape[0] == local_ref_count
assert local_tensor_cuda.shape[1] == embedding_dim
if local_tensor_cuda.dim() == 2:
assert local_tensor_cuda.shape[1] == embedding_dim
else:
# unsqueeze to 2D for comparison
local_tensor_cuda = local_tensor_cuda.unsqueeze(1)

local_tensor = local_tensor_cuda.cpu()
local_indices = torch.arange(local_ref_start, local_ref_start + local_ref_count, dtype=torch.int64)
Expand All @@ -114,6 +125,9 @@ def scatter_gather_test_cast(
)
embedding_after_gather = embedding_after_gather_cuda.cpu()
ref_embedding_gather = gen_int_embedding(gather_indice, embedding_dim, torch.float)
if embedding_after_gather.dim() == 1:
# unsqueeze to 2D for comparison
embedding_after_gather = embedding_after_gather.unsqueeze(1)
# print('\ngather_indice=%s\nembedding_after_gather=%s\nref_embedding_gather=%s' % (
# gather_indice, embedding_after_gather, ref_embedding_gather))
assert torch.allclose(embedding_after_gather, ref_embedding_gather)
Expand All @@ -134,7 +148,6 @@ def routine_func(world_rank: int, world_size: int):
wm_comm = wm_comm.wmb_comm

embedding_count = 1024 * 256 * world_size + 3
embedding_dim = 256
indice_count = 100001
dt = wmb.WholeMemoryDataType.DtFloat
entry_partition = random_partition(embedding_count, world_size)
Expand All @@ -150,11 +163,12 @@ def routine_func(world_rank: int, world_size: int):
wmb.WholeMemoryMemoryLocation.MlHost,
wmb.WholeMemoryMemoryLocation.MlDevice,
]:
if wm_comm.support_type_location(mt, ml):
scatter_gather_test_cast(
wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, True, entry_partition
)
# scatter_gather_test_cast(wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, False)
for embedding_dim in [0, 256]: # 0 is for 1D tensor
if wm_comm.support_type_location(mt, ml):
scatter_gather_test_cast(
wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, True, entry_partition
)
# scatter_gather_test_cast(wm_comm, dt, mt, ml, embedding_count, embedding_dim, indice_count, False)
wmb.finalize()


Expand Down
12 changes: 8 additions & 4 deletions python/pylibwholegraph/pylibwholegraph/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def gather(self,
*,
force_dtype: Union[torch.dtype, None] = None):
assert indice.dim() == 1
embedding_dim = self.shape[1]
embedding_dim = self.shape[1] if self.dim() == 2 else 1
embedding_count = indice.shape[0]
current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),)
output_dtype = (
Expand All @@ -80,15 +80,19 @@ def gather(self,
wrap_torch_tensor(output_tensor),
get_wholegraph_env_fns(),
get_stream())
return output_tensor
return output_tensor.view(-1) if self.dim() == 1 else output_tensor

def scatter(self,
input_tensor: torch.Tensor,
indice: torch.Tensor):
assert indice.dim() == 1
assert input_tensor.dim() == 2
assert input_tensor.dim() == self.dim()
assert indice.shape[0] == input_tensor.shape[0]
assert input_tensor.shape[1] == self.shape[1]
if self.dim() == 2:
assert input_tensor.shape[1] == self.shape[1]
else:
# unsqueeze input to 2D tensor here because wmb_tensor is unsqueezed within scatter_op
input_tensor = input_tensor.unsqueeze(1)
wmb.wholememory_scatter_op(wrap_torch_tensor(input_tensor),
wrap_torch_tensor(indice),
self.wmb_tensor,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -39,8 +39,10 @@ def wholememory_gather_forward_functor(
assert indices_tensor.dtype == torch.int32 or indices_tensor.dtype == torch.int64
if torch_output_dtype is None:
torch_output_dtype = wholememory_dtype_to_torch_dtype(wholememory_tensor.dtype)

embedding_dim = wholememory_tensor.shape[1] if wholememory_tensor.dim() == 2 else 1
output_tensor = torch.empty(
[indices_tensor.shape[0], wholememory_tensor.shape[1]],
[indices_tensor.shape[0], embedding_dim],
device="cuda",
dtype=torch_output_dtype,
requires_grad=requires_grad,
Expand All @@ -52,7 +54,7 @@ def wholememory_gather_forward_functor(
get_wholegraph_env_fns(),
get_stream(),
)
return output_tensor
return output_tensor.view(-1) if wholememory_tensor.dim() == 1 else output_tensor


def wholememory_scatter_functor(
Expand Down

0 comments on commit a12a0a2

Please sign in to comment.