Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bg/fix bn solvers is applicable #3355

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions src/batch_norm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ namespace miopen {
namespace debug {

void LogCmdBNorm(const miopenTensorDescriptor_t xDesc,
const miopenTensorDescriptor_t sMeanDesc,
const miopenTensorDescriptor_t yDesc,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t saveMeanDesc,
miopenBatchNormMode_t bn_mode,
const void* resultRunningMean,
const void* resultRunningVariance,
Expand All @@ -61,7 +64,10 @@ void LogCmdBNorm(const miopenTensorDescriptor_t xDesc,
if(miopen::IsLoggingCmd())
{
const std::string& str = BnormArgsForMIOpenDriver(xDesc,
sMeanDesc,
yDesc,
scaleDesc,
biasDesc,
saveMeanDesc,
bn_mode,
resultRunningMean,
resultRunningVariance,
Expand Down Expand Up @@ -206,7 +212,7 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
const miopenTensorDescriptor_t yDesc,
void* y,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t BiasDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t estMeanDesc,
const miopenTensorDescriptor_t estVarianceDesc,
void* bnScale,
Expand All @@ -222,7 +228,7 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
yDesc,
y,
scaleDesc,
BiasDesc,
biasDesc,
estMeanDesc,
estVarianceDesc,
bnScale,
Expand All @@ -232,12 +238,15 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
epsilon);

miopen::debug::LogCmdBNorm(xDesc,
yDesc,
scaleDesc,
biasDesc,
estMeanDesc,
bn_mode,
estimatedMean,
estimatedVariance,
nullptr,
nullptr,
estMeanDesc,
estimatedVariance,
miopen::debug::BatchNormDirection_t::ForwardInference);

// In case of NxCxDxHxW
Expand All @@ -256,7 +265,7 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
: miopen::deref(yDesc),
DataCast(y),
miopen::deref(scaleDesc),
miopen::deref(BiasDesc),
miopen::deref(biasDesc),
miopen::deref(estMeanDesc),
miopen::deref(estVarianceDesc),
DataCast(bnScale),
Expand All @@ -277,7 +286,7 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
const miopenTensorDescriptor_t yDesc,
void* y,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t BiasDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t savedMeanDesc,
const miopenTensorDescriptor_t savedVarianceDesc,
void* bnScale,
Expand All @@ -296,7 +305,7 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
yDesc,
y,
scaleDesc,
BiasDesc,
biasDesc,
savedMeanDesc,
savedVarianceDesc,
bnScale,
Expand All @@ -309,6 +318,9 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
resultSaveInvVariance);

miopen::debug::LogCmdBNorm(xDesc,
yDesc,
scaleDesc,
biasDesc,
savedMeanDesc,
bn_mode,
resultRunningMean,
Expand All @@ -332,7 +344,7 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
: miopen::deref(yDesc),
DataCast(y),
miopen::deref(scaleDesc),
miopen::deref(BiasDesc),
miopen::deref(biasDesc),
miopen::deref(savedMeanDesc),
miopen::deref(savedVarianceDesc),
DataCast(bnScale),
Expand Down Expand Up @@ -360,7 +372,7 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
const miopenTensorDescriptor_t dxDesc,
void* dx,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t BiasDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t savedMeanDesc,
const miopenTensorDescriptor_t savedVarianceDesc,
const void* bnScale,
Expand All @@ -379,7 +391,7 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
dxDesc,
dx,
scaleDesc,
BiasDesc,
biasDesc,
savedMeanDesc,
savedVarianceDesc,
bnScale,
Expand All @@ -389,6 +401,9 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
savedMean,
savedInvVariance);
miopen::debug::LogCmdBNorm(xDesc,
dyDesc,
scaleDesc,
biasDesc,
savedMeanDesc,
bn_mode,
nullptr,
Expand Down Expand Up @@ -417,7 +432,7 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
: miopen::deref(dxDesc),
DataCast(dx),
miopen::deref(scaleDesc),
miopen::deref(BiasDesc),
miopen::deref(biasDesc),
miopen::deref(savedMeanDesc),
miopen::deref(savedVarianceDesc),
DataCast(bnScale),
Expand Down
16 changes: 8 additions & 8 deletions src/batchnorm/problem_description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
size_t ygridsize = 1;

bool bfpmixparm = false;
if(xDesc.GetType() == miopenHalf && GetBnScaleBiasMeanVarDesc().GetType() == miopenFloat)
if(IsMix())
{
bfpmixparm = true;
}
Expand Down Expand Up @@ -137,7 +137,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp16" << static_cast<int>(IsFp16());
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fbf16" << static_cast<int>(IsBFp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "c" << c;
}
Expand All @@ -154,7 +154,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp16" << static_cast<int>(IsFp16());
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fbf16" << static_cast<int>(IsBFp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "single" << static_cast<int>(single);
ss << "n" << n;
Expand All @@ -173,7 +173,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp16" << static_cast<int>(IsFp16());
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fbf16" << static_cast<int>(IsBFp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "gx" << xgridsize;
ss << "gy" << ygridsize;
Expand Down Expand Up @@ -203,7 +203,7 @@ NetworkConfig ProblemDescription::MakeForwardInferenceNetworkConfig() const
ss << "fp16" << static_cast<int>(IsFp16());
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fbf16" << static_cast<int>(IsBFp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "mode" << bn_mode;
ss << "HWdims" << in_cstride;
Expand All @@ -218,7 +218,7 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const
std::ostringstream ss;

bool bfpmixparm = false;
if(xDesc.GetType() == miopenHalf && GetScaleBiasDiffDesc().GetType() == miopenFloat)
if(xDesc.GetType() == miopenHalf && GetBnScale().GetType() == miopenFloat)
{
bfpmixparm = true;
}
Expand Down Expand Up @@ -311,7 +311,7 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const
ss << "fp16" << static_cast<int>(IsFp16());
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fbf16" << static_cast<int>(IsBFp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "single" << static_cast<int>(single);
ss << "gcn" << ldsgcn;
Expand All @@ -334,7 +334,7 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const
ss << "fp16" << static_cast<int>(IsFp16());
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fbf16" << static_cast<int>(IsBFp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "nhw" << in_nhw;
}
Expand Down
119 changes: 103 additions & 16 deletions src/driver_arguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,105 @@ void ConvDataType(std::stringstream& ss, const miopen::TensorDescriptor& desc)
// We choose scaleMean because its a accumulator type.
void BnDataType(std::stringstream& ss,
const miopen::TensorDescriptor& xDesc,
const miopen::TensorDescriptor& sMeanDesc)
const miopen::TensorDescriptor& yDesc,
const miopen::TensorDescriptor& scaleDesc,
const miopen::TensorDescriptor& biasDesc,
const miopen::TensorDescriptor& sMeanDesc,
const BatchNormDirection_t bn_mode)
{
if(xDesc.GetType() == miopenHalf && sMeanDesc.GetType() == miopenHalf)
if(bn_mode == BatchNormDirection_t::ForwardInference ||
bn_mode == BatchNormDirection_t::ForwardTraining)
{
ss << "bnormfp16";
}
else if(xDesc.GetType() == miopenBFloat16 && sMeanDesc.GetType() == miopenBFloat16)
{
ss << "bnormbfp16";
}
else if(xDesc.GetType() == miopenHalf && sMeanDesc.GetType() == miopenFloat)
{
ss << "bnormfp16fp32";
if(xDesc.GetType() == miopenHalf && yDesc.GetType() == miopenHalf &&
scaleDesc.GetType() == miopenFloat && biasDesc.GetType() == miopenFloat &&
sMeanDesc.GetType() == miopenFloat)
{
ss << "bnormfp16";
}
else if(xDesc.GetType() == miopenBFloat16 && yDesc.GetType() == miopenBFloat16 &&
scaleDesc.GetType() == miopenFloat && biasDesc.GetType() == miopenFloat &&
sMeanDesc.GetType() == miopenFloat)
{
ss << "bnormbfp16";
}
else if(xDesc.GetType() == miopenHalf && yDesc.GetType() == miopenHalf &&
scaleDesc.GetType() == miopenHalf && biasDesc.GetType() == miopenHalf &&
sMeanDesc.GetType() == miopenFloat)
{
ss << "bnormfp16fp32";
}
else if(xDesc.GetType() == miopenBFloat16 && yDesc.GetType() == miopenBFloat16 &&
scaleDesc.GetType() == miopenBFloat16 && biasDesc.GetType() == miopenBFloat16 &&
sMeanDesc.GetType() == miopenFloat)
{
ss << "bnormbfp16fp32";
}
else
{
ss << "bnorm";
}
}
else if(xDesc.GetType() == miopenBFloat16 && sMeanDesc.GetType() == miopenFloat)
else if(bn_mode == BatchNormDirection_t::ForwardTraining)
{
ss << "bnormbfp16fp32";
if(xDesc.GetType() == miopenHalf && yDesc.GetType() == miopenHalf &&
scaleDesc.GetType() == miopenFloat && biasDesc.GetType() == miopenFloat &&
sMeanDesc.GetType() == miopenFloat)
{
ss << "bnormfp16";
}
else if(xDesc.GetType() == miopenBFloat16 && yDesc.GetType() == miopenBFloat16 &&
scaleDesc.GetType() == miopenFloat && biasDesc.GetType() == miopenFloat &&
sMeanDesc.GetType() == miopenFloat)
{
ss << "bnormbfp16";
}
else if(xDesc.GetType() == miopenHalf && yDesc.GetType() == miopenHalf &&
scaleDesc.GetType() == miopenHalf && biasDesc.GetType() == miopenHalf &&
sMeanDesc.GetType() == miopenFloat)
{
ss << "bnormfp16fp32";
}
else if(xDesc.GetType() == miopenBFloat16 && yDesc.GetType() == miopenBFloat16 &&
scaleDesc.GetType() == miopenBFloat16 && biasDesc.GetType() == miopenBFloat16 &&
sMeanDesc.GetType() == miopenFloat)
{
ss << "bnormbfp16fp32";
}
else
{
ss << "bnorm";
}
}
else
else if(bn_mode == BatchNormDirection_t::Backward)
{
ss << "bnorm";
if(xDesc.GetType() == miopenHalf && yDesc.GetType() == miopenHalf &&
scaleDesc.GetType() == miopenFloat && biasDesc.GetType() == miopenFloat &&
sMeanDesc.GetType() == miopenFloat)
{
ss << "bnormfp16";
}
else if(xDesc.GetType() == miopenBFloat16 && yDesc.GetType() == miopenBFloat16 &&
scaleDesc.GetType() == miopenFloat && biasDesc.GetType() == miopenFloat &&
sMeanDesc.GetType() == miopenFloat)
{
ss << "bnormbfp16";
}
else if(xDesc.GetType() == miopenHalf && yDesc.GetType() == miopenFloat &&
scaleDesc.GetType() == miopenHalf && biasDesc.GetType() == miopenFloat &&
sMeanDesc.GetType() == miopenFloat)
{
ss << "bnormfp16fp32";
}
else if(xDesc.GetType() == miopenBFloat16 && yDesc.GetType() == miopenFloat &&
scaleDesc.GetType() == miopenBFloat16 && biasDesc.GetType() == miopenFloat &&
sMeanDesc.GetType() == miopenFloat)
{
ss << "bnormbfp16fp32";
}
else
{
ss << "bnorm";
}
}
}

Expand Down Expand Up @@ -228,6 +306,9 @@ std::string ConvArgsForMIOpenDriver(const miopen::TensorDescriptor& xDesc,
}

std::string BnormArgsForMIOpenDriver(const miopenTensorDescriptor_t xDesc,
const miopenTensorDescriptor_t yDesc,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t sMeanDesc,
miopenBatchNormMode_t bn_mode,
const void* resultRunningMean,
Expand All @@ -241,7 +322,13 @@ std::string BnormArgsForMIOpenDriver(const miopenTensorDescriptor_t xDesc,
miopenGetTensorDescriptorSize(xDesc, &size);
std::stringstream ss;
if(print_for_bn_driver)
BnDataType(ss, miopen::deref(xDesc), miopen::deref(sMeanDesc));
BnDataType(ss,
miopen::deref(xDesc),
miopen::deref(yDesc),
miopen::deref(scaleDesc),
miopen::deref(biasDesc),
miopen::deref(sMeanDesc),
dir);

ss << " -n " << miopen::deref(xDesc).GetLengths()[0] // clang-format off
<< " -c " << miopen::deref(xDesc).GetLengths()[1];
Expand Down
21 changes: 12 additions & 9 deletions src/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,15 +389,18 @@ std::string LogCmdBnormFusion(const miopenFusionPlanDescriptor_t fusePlanDesc, i

if(bn_op != nullptr)
{
str += BnormArgsForMIOpenDriver(&bn_op->input_desc,
&bn_op->base_desc,
bn_op->mode,
nullptr,
nullptr,
nullptr,
nullptr,
miopen::debug::BatchNormDirection_t::ForwardInference,
false);
// str += BnormArgsForMIOpenDriver(&bn_op->input_desc,
// &bn_op->base_desc,
// nullptr,
// nullptr,
// nullptr,
// bn_op->mode,
// nullptr,
// nullptr,
// nullptr,
// nullptr,
// miopen::debug::BatchNormDirection_t::ForwardInference,
// false);
bghimireamd marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
Expand Down
Loading