From 1324c628cab91bb79bcd64a56a79188b087131c7 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Thu, 30 May 2024 22:02:11 -0400 Subject: [PATCH] [Kernel] Marlin_24: Ensure the mma.sp instruction is using the ::ordered_metadata modifier (introduced with PTX 8.5) (#5136) --- csrc/quantization/marlin/sparse/common/mma.h | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/marlin/sparse/common/mma.h b/csrc/quantization/marlin/sparse/common/mma.h index 45ab67a78a1de..fd3dbda5b9c93 100644 --- a/csrc/quantization/marlin/sparse/common/mma.h +++ b/csrc/quantization/marlin/sparse/common/mma.h @@ -32,7 +32,8 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, float* c = reinterpret_cast(&frag_c); if (psel == 0) { asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16." + "f32 " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x0;\n" : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) @@ -40,7 +41,8 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), "r"(e[0])); asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16." + "f32 " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x0;\n" : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) @@ -49,7 +51,8 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, "r"(e[0])); } else { asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16." + "f32 " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x1;\n" : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) @@ -57,7 +60,8 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), "r"(e[0])); asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16." + "f32 " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x1;\n" : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])