Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bg/fix bn solvers is applicable #3355

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions src/batch_norm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ namespace miopen {
namespace debug {

void LogCmdBNorm(const miopenTensorDescriptor_t xDesc,
const miopenTensorDescriptor_t sMeanDesc,
const miopenTensorDescriptor_t yDesc,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t saveMeanDesc,
miopenBatchNormMode_t bn_mode,
const void* resultRunningMean,
const void* resultRunningVariance,
Expand All @@ -61,7 +64,10 @@ void LogCmdBNorm(const miopenTensorDescriptor_t xDesc,
if(miopen::IsLoggingCmd())
{
const std::string& str = BnormArgsForMIOpenDriver(xDesc,
sMeanDesc,
yDesc,
scaleDesc,
biasDesc,
saveMeanDesc,
bn_mode,
resultRunningMean,
resultRunningVariance,
Expand Down Expand Up @@ -206,7 +212,7 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
const miopenTensorDescriptor_t yDesc,
void* y,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t BiasDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t estMeanDesc,
const miopenTensorDescriptor_t estVarianceDesc,
void* bnScale,
Expand All @@ -222,7 +228,7 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
yDesc,
y,
scaleDesc,
BiasDesc,
biasDesc,
estMeanDesc,
estVarianceDesc,
bnScale,
Expand All @@ -232,12 +238,15 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
epsilon);

miopen::debug::LogCmdBNorm(xDesc,
yDesc,
scaleDesc,
biasDesc,
estMeanDesc,
bn_mode,
estimatedMean,
estimatedVariance,
nullptr,
nullptr,
estMeanDesc,
estimatedVariance,
miopen::debug::BatchNormDirection_t::ForwardInference);

// In case of NxCxDxHxW
Expand All @@ -256,7 +265,7 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
: miopen::deref(yDesc),
DataCast(y),
miopen::deref(scaleDesc),
miopen::deref(BiasDesc),
miopen::deref(biasDesc),
miopen::deref(estMeanDesc),
miopen::deref(estVarianceDesc),
DataCast(bnScale),
Expand All @@ -277,7 +286,7 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
const miopenTensorDescriptor_t yDesc,
void* y,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t BiasDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t savedMeanDesc,
const miopenTensorDescriptor_t savedVarianceDesc,
void* bnScale,
Expand All @@ -296,7 +305,7 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
yDesc,
y,
scaleDesc,
BiasDesc,
biasDesc,
savedMeanDesc,
savedVarianceDesc,
bnScale,
Expand All @@ -309,6 +318,9 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
resultSaveInvVariance);

miopen::debug::LogCmdBNorm(xDesc,
yDesc,
scaleDesc,
biasDesc,
savedMeanDesc,
bn_mode,
resultRunningMean,
Expand All @@ -332,7 +344,7 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
: miopen::deref(yDesc),
DataCast(y),
miopen::deref(scaleDesc),
miopen::deref(BiasDesc),
miopen::deref(biasDesc),
miopen::deref(savedMeanDesc),
miopen::deref(savedVarianceDesc),
DataCast(bnScale),
Expand Down Expand Up @@ -360,7 +372,7 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
const miopenTensorDescriptor_t dxDesc,
void* dx,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t BiasDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t savedMeanDesc,
const miopenTensorDescriptor_t savedVarianceDesc,
const void* bnScale,
Expand All @@ -379,7 +391,7 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
dxDesc,
dx,
scaleDesc,
BiasDesc,
biasDesc,
savedMeanDesc,
savedVarianceDesc,
bnScale,
Expand All @@ -389,6 +401,9 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
savedMean,
savedInvVariance);
miopen::debug::LogCmdBNorm(xDesc,
dyDesc,
scaleDesc,
biasDesc,
savedMeanDesc,
bn_mode,
nullptr,
Expand Down Expand Up @@ -417,7 +432,7 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
: miopen::deref(dxDesc),
DataCast(dx),
miopen::deref(scaleDesc),
miopen::deref(BiasDesc),
miopen::deref(biasDesc),
miopen::deref(savedMeanDesc),
miopen::deref(savedVarianceDesc),
DataCast(bnScale),
Expand Down
125 changes: 117 additions & 8 deletions src/batchnorm/problem_description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,109 @@ namespace miopen {

namespace batchnorm {

bool is_fp16_or_bfp16(miopenDataType_t type)
{
return ((type == miopenHalf) || (type == miopenBFloat16));
}

bool is_fp32_or_fp64(miopenDataType_t type)
{
return ((type == miopenFloat) || (type == miopenDouble));
}

bool is_fp32(miopenDataType_t type) { return (type == miopenFloat); }

bool IsOCLInferTypeValid(const ProblemDescription& bn_problem)
{
// case 1 : mix type
return (
(is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetYDesc().GetType()) &&
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType())) ||
// case 2 : float type
(is_fp32(bn_problem.GetXDesc().GetType()) && is_fp32(bn_problem.GetYDesc().GetType()) &&
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType())));
}

bool IsCKInferTypeValid(const ProblemDescription& bn_problem)
{
// case 1 : mix type
return ((is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetYDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetBnScale().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetBnBias().GetType()) &&
is_fp32(bn_problem.GetBnSMean().GetType()) &&
is_fp32(bn_problem.GetBnSVar().GetType())) ||
// case 2 : fp32 or fp64
(is_fp32_or_fp64(bn_problem.GetXDesc().GetType()) &&
is_fp32_or_fp64(bn_problem.GetYDesc().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnScale().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnBias().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnSMean().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnSVar().GetType())));
}

bool IsOCLFwdTrainTypeValid(const ProblemDescription& bn_problem)
{
// case 1 : mix type
return (
(is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetYDesc().GetType()) &&
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType())) ||
// case 2 : float type
(is_fp32(bn_problem.GetXDesc().GetType()) && is_fp32(bn_problem.GetYDesc().GetType()) &&
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType())));
}

bool IsCKFwdTrainTypeValid(const ProblemDescription& bn_problem)
{
// case 1 : mix type
return ((is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetYDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetBnScale().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetBnBias().GetType()) &&
is_fp32(bn_problem.GetBnSMean().GetType()) &&
is_fp32(bn_problem.GetBnSVar().GetType())) ||
// case 2 : fp32 or fp64
(is_fp32_or_fp64(bn_problem.GetXDesc().GetType()) &&
is_fp32_or_fp64(bn_problem.GetYDesc().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnScale().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnBias().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnSMean().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnSVar().GetType())));
}

bool IsOCLBwdTypeValid(const ProblemDescription& bn_problem)
{
return (
(is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetDXDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetDYDesc().GetType()) &&
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnSMean().GetType()) &&
is_fp32(bn_problem.GetBnSVar().GetType())) ||
// case 1 : fp32
(is_fp32(bn_problem.GetXDesc().GetType()) && is_fp32(bn_problem.GetDXDesc().GetType()) &&
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType()) &&
is_fp32(bn_problem.GetBnSMean().GetType()) && is_fp32(bn_problem.GetBnSVar().GetType())));
}

bool IsCKBwdTypeValid(const ProblemDescription& bn_problem)
{
return ((is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) &&
bn_problem.GetDXDesc().GetType() == miopenFloat &&
is_fp16_or_bfp16(bn_problem.GetBnScale().GetType()) &&
bn_problem.GetDYDesc().GetType() == miopenFloat &&
bn_problem.GetBnSMean().GetType() == miopenFloat &&
bn_problem.GetBnSVar().GetType() == miopenFloat) ||
// case 1 : fp32 or fp64
(is_fp32_or_fp64(bn_problem.GetXDesc().GetType()) &&
is_fp32_or_fp64(bn_problem.GetDXDesc().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnScale().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnBias().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnSMean().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnSVar().GetType())));
}

NetworkConfig ProblemDescription::MakeNetworkConfig() const
{
switch(direction)
Expand Down Expand Up @@ -67,7 +170,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
size_t ygridsize = 1;

bool bfpmixparm = false;
if(xDesc.GetType() == miopenHalf && GetBnScaleBiasMeanVarDesc().GetType() == miopenFloat)
if(IsMix())
{
bfpmixparm = true;
}
Expand Down Expand Up @@ -137,7 +240,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp16" << static_cast<int>(IsFp16());
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fbf16" << static_cast<int>(IsBFp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "c" << c;
}
Expand All @@ -154,7 +257,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp16" << static_cast<int>(IsFp16());
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fbf16" << static_cast<int>(IsBFp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "single" << static_cast<int>(single);
ss << "n" << n;
Expand All @@ -173,7 +276,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp16" << static_cast<int>(IsFp16());
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fbf16" << static_cast<int>(IsBFp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "gx" << xgridsize;
ss << "gy" << ygridsize;
Expand All @@ -187,6 +290,8 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "hw" << in_cstride;
}
ss << "layout" << in_layout;
ss << "scaleType" << static_cast<int>(IsScaleFp16());
ss << "scaleType" << static_cast<int>(IsScaleFp32());

return NetworkConfig{ss.str()};
}
Expand All @@ -203,12 +308,14 @@ NetworkConfig ProblemDescription::MakeForwardInferenceNetworkConfig() const
ss << "fp16" << static_cast<int>(IsFp16());
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fbf16" << static_cast<int>(IsBFp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "mode" << bn_mode;
ss << "HWdims" << in_cstride;
ss << "C" << c;
ss << "layout" << in_layout;
ss << "scaleType" << static_cast<int>(IsScaleFp16());
ss << "scaleType" << static_cast<int>(IsScaleFp32());

return NetworkConfig{ss.str()};
}
Expand All @@ -218,7 +325,7 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const
std::ostringstream ss;

bool bfpmixparm = false;
if(xDesc.GetType() == miopenHalf && GetScaleBiasDiffDesc().GetType() == miopenFloat)
if(xDesc.GetType() == miopenHalf && GetBnScale().GetType() == miopenFloat)
{
bfpmixparm = true;
}
Expand Down Expand Up @@ -311,7 +418,7 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const
ss << "fp16" << static_cast<int>(IsFp16());
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fbf16" << static_cast<int>(IsBFp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "single" << static_cast<int>(single);
ss << "gcn" << ldsgcn;
Expand All @@ -334,11 +441,13 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const
ss << "fp16" << static_cast<int>(IsFp16());
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fbf16" << static_cast<int>(IsBFp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "nhw" << in_nhw;
}
ss << "layout" << in_layout;
ss << "scaleType" << static_cast<int>(IsScaleFp16());
ss << "scaleType" << static_cast<int>(IsScaleFp32());

return NetworkConfig{ss.str()};
}
Expand Down
Loading