Skip to content

Commit

Permalink
bugfix : fix N
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Sundar Rabindranath committed Jul 26, 2024
1 parent b8111a5 commit 163e091
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_dispatch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct sm89_config_default {
using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;

uint32_t const n = a.size(1);
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);

if (np2 <= 4096) {
Expand Down Expand Up @@ -91,7 +91,7 @@ struct sm89_config_M256 {
using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;

uint32_t const n = a.size(1);
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);

if (np2 <= 4096) {
Expand Down Expand Up @@ -131,7 +131,7 @@ struct sm89_config_M128 {
using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;

uint32_t const n = a.size(1);
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);

if (np2 <= 8192) {
Expand Down Expand Up @@ -178,7 +178,7 @@ struct sm89_config_M64 {
using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;

uint32_t const n = a.size(1);
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);

if (np2 <= 8196) {
Expand Down Expand Up @@ -231,7 +231,7 @@ struct sm89_config_M32 {
using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;

uint32_t const n = a.size(1);
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);

if (np2 <= 8192) {
Expand Down Expand Up @@ -283,7 +283,7 @@ struct sm89_config_M16 {
using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;

uint32_t const n = a.size(1);
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);

if (np2 <= 8192) {
Expand Down

0 comments on commit 163e091

Please sign in to comment.