diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index cb4923f41..84ee59eee 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -67,7 +67,7 @@ def gather(self, embedding_count = indice.shape[0] current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),) output_dtype = ( - force_dtype if force_dtype is not None else self.embedding_tensor.dtype + force_dtype if force_dtype is not None else self.dtype ) output_tensor = torch.empty( [embedding_count, embedding_dim],