diff --git a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h index ea9ed77fb5..8dcccfc14f 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h @@ -66,7 +66,7 @@ struct PairwiseDistanceGemm { /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op using InstructionShape = - cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 + cutlass::gemm::GemmShape<16, 8, 4>; // <- MMA Op tile M = 16, N = 8, K = 4 /// Operation performed by GEMM using Operator = cutlass::arch::OpMultiplyAddFastF32;