From bee127a950f75f33cdcaf7ce79e8b8fd018ba208 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade <36705640+mdoijade@users.noreply.github.com> Date: Fri, 2 Dec 2022 21:21:51 +0530 Subject: [PATCH] switch mma instruction shape to 1684 from current 1688 for 3xTF32 L2/cosine kernel (#1057) -- switch mma instruction shape to 1684 from current 1688 as it is always faster for all the inputs tried from DISTANCE_BENCH for L2 and cosine distances. -- the speedup in best case is 1.37x, and at minimum it is 1.05x faster. Authors: - Mahesh Doijade (https://github.com/mdoijade) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/1057 --- cpp/include/raft/distance/detail/pairwise_distance_gemm.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h index ea9ed77fb5..8dcccfc14f 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h @@ -66,7 +66,7 @@ struct PairwiseDistanceGemm { /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op using InstructionShape = - cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 + cutlass::gemm::GemmShape<16, 8, 4>; // <- MMA Op tile M = 16, N = 8, K = 4 /// Operation performed by GEMM using Operator = cutlass::arch::OpMultiplyAddFastF32;