diff --git a/python/perf-kernels/streamk/tune_streamk.py b/python/perf-kernels/streamk/tune_streamk.py index 8bbe245eed85..0be199127df7 100755 --- a/python/perf-kernels/streamk/tune_streamk.py +++ b/python/perf-kernels/streamk/tune_streamk.py @@ -389,15 +389,15 @@ def matmul(a, b, c, bias, P, locks, num_sms, block_m, block_n, block_k, group_m, EVEN_K = K % block_k == 0 m_tiles = triton.cdiv(M, block_m) n_tiles = triton.cdiv(N, block_n) - streamk_tiles= m_tiles*n_tiles % num_sms + streamk_tiles = m_tiles * n_tiles % num_sms # change num_xcds = 1 if using MI250 num_xcds = 8 streamk_gemm[ grid, ](a, b, c, bias, P, locks, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), stride_bias=stride_bias, BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, - NUM_SMS=num_sms, STREAMK_TILES=streamk_tiles, NUM_XCDS=num_xcds, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K) + NUM_SMS=num_sms, STREAMK_TILES=streamk_tiles, NUM_XCDS=num_xcds, num_warps=num_warps, num_stages=num_stages, + waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K) return c