From da0a412ab02587d80897f56ece1610b71fcb9ab4 Mon Sep 17 00:00:00 2001 From: linhu-nv <141609318+linhu-nv@users.noreply.github.com> Date: Tue, 28 May 2024 20:35:48 +0800 Subject: [PATCH] a quick fix to wholememory tensor gather default data type (#173) A quick fixes to this issue (https://github.com/rapidsai/wholegraph/issues/168). Set correct default wholememory tensor gather results data type. Authors: - https://github.com/linhu-nv Approvers: - Chuang Zhu (https://github.com/chuangz0) URL: https://github.com/rapidsai/wholegraph/pull/173 --- python/pylibwholegraph/pylibwholegraph/torch/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index cb4923f41..84ee59eee 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -67,7 +67,7 @@ def gather(self, embedding_count = indice.shape[0] current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),) output_dtype = ( - force_dtype if force_dtype is not None else self.embedding_tensor.dtype + force_dtype if force_dtype is not None else self.dtype ) output_tensor = torch.empty( [embedding_count, embedding_dim],