diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index fb4fb8d34c..ef51a54622 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -107,7 +107,9 @@ void fusedL2NN(OutT* min, bool is_skinny = k < 32; size_t bytes = sizeof(DataT) * k; - if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) { + auto px = reinterpret_cast(x); + auto py = reinterpret_cast(y); + if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { if (is_skinny) { detail::fusedL2NNImpl( min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); } - } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0) { + } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { if (is_skinny) { detail::fusedL2NNImpl