diff --git a/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py b/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py index 5120a23180d..af1d0efa76c 100644 --- a/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py +++ b/python/cugraph-equivariant/cugraph_equivariant/nn/tensor_product_conv.py @@ -251,7 +251,10 @@ def forward( if edge_envelope is not None: out = out * edge_envelope.view(-1, 1) - out = scatter_reduce(out, dst, dim=0, dim_size=num_dst_nodes, reduce=reduce) + dtype = out.dtype + out = scatter_reduce( + out.float(), dst, dim=0, dim_size=num_dst_nodes, reduce=reduce + ).to(dtype) if self.batch_norm: out = self.batch_norm(out)