Skip to content

Commit

Permalink
[NFC] ConvolutionContext made accessible in SetNextValue() (#1033)
Browse files Browse the repository at this point in the history
  • Loading branch information
atamazov authored Jul 14, 2021
1 parent d214708 commit 9e6cf55
Show file tree
Hide file tree
Showing 27 changed files with 67 additions and 77 deletions.
4 changes: 2 additions & 2 deletions src/include/miopen/generic_search.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_COMPILE_ONLY)
/// Constructs an instance with invalid value.
/// - (ctor)(bool)
/// Constructs an instance with minimal value.
/// - SetNextValue()
/// - SetNextValue(const Context& c)
/// Advances instance value to the next available value and returns true.
/// If max value reached, returns false.
/// - IsValid(const Context& c) const
Expand All @@ -86,7 +86,7 @@ class ComputedIterator : public std::iterator<std::input_iterator_tag, Performan
{
do
{
if(!v.SetNextValue())
if(!v.SetNextValue(*p))
{ // Wraparound, end reached. Iterator is useless from now.
p = nullptr;
break;
Expand Down
62 changes: 32 additions & 30 deletions src/include/miopen/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ struct PerformanceConfigConvAsm3x3U : Serializable<PerformanceConfigConvAsm3x3U>

void HeuristicInit(const ConvolutionContext& config);
bool IsValidValue() const;
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValid(const ConvolutionContext& config) const;
bool operator==(const PerformanceConfigConvAsm3x3U& other) const;
std::string ToString() const;
Expand Down Expand Up @@ -233,7 +233,7 @@ struct PerformanceConfigConvAsm1x1U : Serializable<PerformanceConfigConvAsm1x1U>

void HeuristicInit(const ConvolutionContext& config);
bool IsValidValue() const;
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValid(const ConvolutionContext& config) const;
bool operator==(const PerformanceConfigConvAsm1x1U& other) const;
std::string ToString() const;
Expand Down Expand Up @@ -329,7 +329,7 @@ struct PerformanceConfigConvAsm1x1UV2 : Serializable<PerformanceConfigConvAsm1x1

void HeuristicInit(const ConvolutionContext& config);
bool IsValidValue() const;
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValid(const ConvolutionContext& config) const;
bool operator==(const PerformanceConfigConvAsm1x1UV2& other) const;
std::string ToString() const;
Expand Down Expand Up @@ -438,7 +438,7 @@ struct PerformanceImplicitGemm : Serializable<PerformanceImplicitGemm>

void HeuristicInit(const ConvolutionContext& config);
bool IsValidValue() const;
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValid(const ConvolutionContext& ctx) const;
bool operator==(const PerformanceImplicitGemm& other) const;
std::string ToString() const;
Expand Down Expand Up @@ -530,7 +530,7 @@ struct PerformanceImplicitGemmV4R4Fwd : Serializable<PerformanceImplicitGemmV4R4
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
std::string ToString() const;
};

Expand Down Expand Up @@ -586,7 +586,7 @@ struct PerformanceImplicitGemmV4R4WrW : Serializable<PerformanceImplicitGemmV4R4
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
std::string ToString() const;
};

Expand Down Expand Up @@ -643,7 +643,7 @@ struct PerformanceImplicitGemmBwdDataV1R1 : Serializable<PerformanceImplicitGemm
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
std::string ToString() const;
};

Expand Down Expand Up @@ -700,7 +700,7 @@ struct PerformanceImplicitGemmBwdDataV4R1 : Serializable<PerformanceImplicitGemm
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
std::string ToString() const;
};

Expand Down Expand Up @@ -757,7 +757,7 @@ struct PerformanceImplicitGemmBwdDataV4R1Xdlops
bool IsReallyValid(const ConvolutionContext& ctx) const;
bool IsFastToBeUsedForTuning(const ConvolutionContext& ctx) const;
void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
std::string ToString() const;
};

Expand Down Expand Up @@ -842,7 +842,7 @@ struct PerformanceImplicitGemmV4R4GenXdlopsFwdFp32

void HeuristicInit(const ConvolutionContext& ctx);
bool IsValidValue() const;
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValid(const ConvolutionContext& ctx) const;
bool operator==(const PerformanceImplicitGemmV4R4GenXdlopsFwdFp32& other) const;
std::string ToString() const;
Expand Down Expand Up @@ -933,7 +933,7 @@ struct PerformanceImplicitGemmXdlops : Serializable<PerformanceImplicitGemmXdlop

void HeuristicInit(const ConvolutionContext& ctx);
bool IsValidValue() const;
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValid(const ConvolutionContext& ctx) const;
bool operator==(const PerformanceImplicitGemmXdlops& other) const;
std::string ToString() const;
Expand Down Expand Up @@ -974,7 +974,7 @@ struct PerformanceImplicitGemmForwardV4R4Xdlops
std::string ToString() const;

void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
bool IsReallyValid(const ConvolutionContext& ctx) const;
Expand Down Expand Up @@ -1032,7 +1032,7 @@ struct PerformanceImplicitGemmForwardV4R5Xdlops
std::string ToString() const;

void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
bool IsReallyValid(const ConvolutionContext& ctx) const;
Expand Down Expand Up @@ -1092,7 +1092,7 @@ struct PerformanceImplicitGemmForwardV4R4Xdlops_Padded_Gemm
std::string ToString() const;

void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
bool IsReallyValid(const ConvolutionContext& ctx) const;
Expand Down Expand Up @@ -1139,7 +1139,7 @@ struct PerformanceImplicitGemmBwdV1R1Xdlops : Serializable<PerformanceImplicitGe
std::string ToString() const;

void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
bool IsReallyValid(const ConvolutionContext& ctx) const;
Expand Down Expand Up @@ -1237,7 +1237,7 @@ struct PerformanceImplicitGemmV4R4GenXdlopsWrWFp32

void HeuristicInit(const ConvolutionContext& ctx);
bool IsValidValue() const;
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValid(const ConvolutionContext& ctx) const;
bool operator==(const PerformanceImplicitGemmV4R4GenXdlopsWrWFp32& other) const;
std::string ToString() const;
Expand Down Expand Up @@ -1468,7 +1468,7 @@ struct PerformanceConfigConvBinWinogradRxSf3x2

void HeuristicInit(const ConvolutionContext& config);
bool IsValidValue() const;
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValid(const ConvolutionContext& config) const;
bool operator==(const PerformanceConfigConvBinWinogradRxSf3x2& other) const;
std::string ToString() const;
Expand Down Expand Up @@ -1511,7 +1511,7 @@ struct PerformanceConfigConvBinWinogradRxSf2x3

void HeuristicInit(const ConvolutionContext& config);
bool IsValidValue() const;
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValid(const ConvolutionContext& config) const;
bool operator==(const PerformanceConfigConvBinWinogradRxSf2x3& other) const;
std::string ToString() const;
Expand Down Expand Up @@ -1756,7 +1756,7 @@ struct PerformanceConfigAsmDirect3x3WrW : Serializable<PerformanceConfigAsmDirec

void HeuristicInit(const ConvolutionContext& config);
bool IsValidValue() const;
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValid(const ConvolutionContext& config) const;
bool operator==(const PerformanceConfigAsmDirect3x3WrW& other) const;
std::string ToString() const;
Expand Down Expand Up @@ -1860,7 +1860,7 @@ struct PerformanceConfigConvAsmBwdWrW1x1 : Serializable<PerformanceConfigConvAsm

void HeuristicInit(const ConvolutionContext& config);
bool IsValidValue() const;
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValid(const ConvolutionContext& config) const;
bool operator==(const PerformanceConfigConvAsmBwdWrW1x1& other) const;
std::string ToString() const;
Expand Down Expand Up @@ -1935,7 +1935,7 @@ struct PerformanceConfigConvOclBwdWrw2

void HeuristicInit(const ConvolutionContext& params);
bool IsValidValue() const;
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValid(const ConvolutionContext& params) const;
bool operator==(const PerformanceConfigConvOclBwdWrw2<N_BATCH_LOOPS>& other) const;
std::string ToString() const;
Expand Down Expand Up @@ -2042,7 +2042,7 @@ struct PerformanceImplicitGemmWrwV4R4Xdlops : Serializable<PerformanceImplicitGe
std::string ToString() const;

void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
bool IsReallyValid(const ConvolutionContext& ctx) const;
Expand Down Expand Up @@ -2117,7 +2117,7 @@ struct PerformanceImplicitGemmWrwV4R4Xdlops_Padded_Gemm
std::string ToString() const;

void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
bool IsReallyValid(const ConvolutionContext& ctx) const;
Expand Down Expand Up @@ -2567,10 +2567,12 @@ struct PerformanceConfigAsmImplicitGemmGTC : Serializable<PerformanceConfigAsmIm
f(self.index, "index");
}

void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
// Chilrden must provide support for ComputedContainer.
void HeuristicInit(const ConvolutionContext&) = delete;
bool SetNextValue(const ConvolutionContext&) = delete;
bool IsValidValue() const = delete;
bool IsValid(const ConvolutionContext&) const = delete;

bool IsDefaultConstructed() const;
bool operator==(const PerformanceConfigAsmImplicitGemmGTC& other) const;
void CopyParameters(const PerformanceConfigAsmImplicitGemmGTC& other);
Expand Down Expand Up @@ -2691,7 +2693,7 @@ struct PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC : PerformanceConfigAsmIm
}

void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
};
Expand Down Expand Up @@ -2823,7 +2825,7 @@ struct PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC : PerformanceConfigAsmIm
{
}
void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
};
Expand Down Expand Up @@ -2956,7 +2958,7 @@ struct PerformanceConfigAsmImplicitGemmGTCWrwXdlopsNHWC : PerformanceConfigAsmIm
}

void HeuristicInit(const ConvolutionContext& ctx);
bool SetNextValue();
bool SetNextValue(const ConvolutionContext& config);
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const;
size_t ComputeKernelOccupancy() const;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_1x1u.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ inline static bool Next_1_4(int& v)
return false;
}

bool PerformanceConfigConvAsm1x1U::SetNextValue()
bool PerformanceConfigConvAsm1x1U::SetNextValue(const ConvolutionContext& /*config*/)
{
// Increment with wrap-around:
do
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_1x1u_stride2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ struct buff_info
}
};

bool PerformanceConfigConvAsm1x1UV2::SetNextValue()
bool PerformanceConfigConvAsm1x1UV2::SetNextValue(const ConvolutionContext& /*config*/)
{
// Increment with wrap-around:
do
Expand Down
5 changes: 4 additions & 1 deletion src/solver/conv_asm_3x3u.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ auto PerfFieldRules()

} // namespace

bool PerformanceConfigConvAsm3x3U::SetNextValue() { return !PerfFieldRules().Next(*this); }
bool PerformanceConfigConvAsm3x3U::SetNextValue(const ConvolutionContext& /*config*/)
{
return !PerfFieldRules().Next(*this);
}

PerformanceConfigConvAsm3x3U::PerformanceConfigConvAsm3x3U(int lwc, int fpw, int olpw)
: limit_wave_cnt(lwc), filters_per_wave(fpw), output_lines_per_wave(olpw)
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_dir_BwdWrW1x1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ inline static bool Inc_1_2_4(int& v)

inline static bool Is_1_2_4(const int& v) { return v == 1 || v == 2 || v == 4; }

bool PerformanceConfigConvAsmBwdWrW1x1::SetNextValue()
bool PerformanceConfigConvAsmBwdWrW1x1::SetNextValue(const ConvolutionContext& /*config*/)
{
// Increment with wrap-around:
// select fast or full method
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_asm_dir_BwdWrW3x3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ inline static bool Inc_1_2_4_8(int& v)

inline static bool Is_1_2_4_8(const int& v) { return v == 1 || v == 2 || v == 4 || v == 8; }

bool PerformanceConfigAsmDirect3x3WrW::SetNextValue()
bool PerformanceConfigAsmDirect3x3WrW::SetNextValue(const ConvolutionContext& /*config*/)
{
// Increment with wrap-around:
do
Expand Down
3 changes: 2 additions & 1 deletion src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ bool PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC::IsValidValue() const
return false;
return *this == config_list[index];
}
bool PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC::SetNextValue()
bool PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC::SetNextValue(
const ConvolutionContext& /*config*/)
{
if(use_spare_set)
{
Expand Down
3 changes: 2 additions & 1 deletion src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ void PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC::HeuristicInit(const Convo
}
}

bool PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC::SetNextValue()
bool PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC::SetNextValue(
const ConvolutionContext& /*config*/)
{
if(use_spare_set)
{
Expand Down
21 changes: 0 additions & 21 deletions src/solver/conv_asm_implicit_gemm_gtc_perf_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,27 +88,6 @@ PerformanceConfigAsmImplicitGemmGTC::PerformanceConfigAsmImplicitGemmGTC(
{
}

void PerformanceConfigAsmImplicitGemmGTC::HeuristicInit(const ConvolutionContext& ctx)
{
// need override in child struct
(void)ctx;
}
bool PerformanceConfigAsmImplicitGemmGTC::SetNextValue()
{
// need override in child struct
return false;
}
bool PerformanceConfigAsmImplicitGemmGTC::IsValidValue() const
{
// need override in child struct
return false;
}
bool PerformanceConfigAsmImplicitGemmGTC::IsValid(const ConvolutionContext& ctx) const
{
// need override in child struct
(void)ctx;
return false;
}
bool PerformanceConfigAsmImplicitGemmGTC::IsDefaultConstructed() const
{
int default_lengths[4] = {1, 1, 1, 1};
Expand Down
3 changes: 2 additions & 1 deletion src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,8 @@ void PerformanceConfigAsmImplicitGemmGTCWrwXdlopsNHWC::HeuristicInit(const Convo
}
}

bool PerformanceConfigAsmImplicitGemmGTCWrwXdlopsNHWC::SetNextValue()
bool PerformanceConfigAsmImplicitGemmGTCWrwXdlopsNHWC::SetNextValue(
const ConvolutionContext& /*config*/)
{
if(use_spare_set)
{
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ void PerformanceImplicitGemmBwdDataV1R1::HeuristicInit(const ConvolutionContext&
MIOPEN_LOG_I(ToString());
}

bool PerformanceImplicitGemmBwdDataV1R1::SetNextValue()
bool PerformanceImplicitGemmBwdDataV1R1::SetNextValue(const ConvolutionContext& /*config*/)
{
// always search full space, no matter if use_spare_set or not
do
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ operator==(const PerformanceImplicitGemmBwdV1R1Xdlops& other) const
// clang-format on
}

bool PerformanceImplicitGemmBwdV1R1Xdlops::SetNextValue()
bool PerformanceImplicitGemmBwdV1R1Xdlops::SetNextValue(const ConvolutionContext& /*config*/)
{
do
{
Expand Down
Loading

0 comments on commit 9e6cf55

Please sign in to comment.