Skip to content

Commit

Permalink
extend MBSK WS to 40MB
Browse files Browse the repository at this point in the history
  • Loading branch information
aazz44ss committed Nov 27, 2024
1 parent d001ff3 commit 9dbcfd3
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions library/src/amd_detail/hipblaslt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ try
// TODO: Synchronizer size pass into predicate SynchronizerSizeCheck
// 1K just for small size now, need to cal corner case if support all situations
void* d_Synchronizer = nullptr;
CHECK_HIP_ERROR(hipMalloc(&d_Synchronizer, 16 * 40960 * sizeof(int)));
CHECK_HIP_ERROR(hipMemset(d_Synchronizer, 0, sizeof(int) * 16 * 40960));
CHECK_HIP_ERROR(hipMalloc(&d_Synchronizer, 16 * 409600 * sizeof(int)));
CHECK_HIP_ERROR(hipMemset(d_Synchronizer, 0, sizeof(int) * 16 * 409600));

err = hipGetDevice(&deviceId);
if(err == hipSuccess)
Expand Down
2 changes: 1 addition & 1 deletion tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,7 +1733,7 @@ def calculateWG():
module.add(SLShiftLeftB32(dst=sgpr(tmpSgpr0), src=sgpr(tmpSgpr0), shiftHex=(2)))
module.add(SAddU32(dst=sgpr("AddressTD"), src0=sgpr("AddressTD"), src1=sgpr(tmpSgpr0)))
module.add(SAddCU32(dst=sgpr("AddressTD+1"), src0=sgpr("AddressTD+1"), src1=hex(0)))
module.add(SAddU32(dst=sgpr("Synchronizer"), src0=sgpr("Synchronizer"), src1=hex(163840)))
module.add(SAddU32(dst=sgpr("Synchronizer"), src0=sgpr("Synchronizer"), src1=hex(1638400)))
module.add(SAddCU32(dst=sgpr("Synchronizer+1"), src0=sgpr("Synchronizer+1"), src1=hex(0)))
module.add(extReadEpilogueLabeltmp)
module.add(SAddU32(dst=sgpr(tmpSgprAddrM), src0=sgpr(tmpSgprAddrM), src1=sgpr(tmpSgprArgOffsett)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ namespace TensileLite
else
{
rv.back().setSynchronizer(
m_constantTypes[ContractionProblemGemm::CONST::ALPHA], 40960);
m_constantTypes[ContractionProblemGemm::CONST::ALPHA], 409600);
}
if(j < m_activationEnumArg.size())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ namespace TensileLite
bool ret = (std::ceil(static_cast<float>(problem.freeSizeA(0)) / value[0])
* std::ceil(static_cast<float>(problem.freeSizeB(0)) / value[1]))
* (value[2]) * (value[4] / 64) * value[3]
<= 40960;
<= 409600;
if(problem.groupedGemm())
ret = ret && (problem.groupedGemmCount() <= 16);

Expand Down

0 comments on commit 9dbcfd3

Please sign in to comment.