Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bfloat16 support to OCL bn kernels #3351

Merged
merged 23 commits into from
Nov 2, 2024
Merged

Conversation

BrianHarrisonAMD
Copy link
Contributor

Draft of updating OCL BN to support bf16

buff[gammaindex] = (_FLOAT_PREC)dscale;
buff[betaindex] = (_FLOAT_PREC)dbias;
buff[gammaindex] = FLOAT2FLOATPREC(dscale);
buff[betaindex] = FLOAT2FLOATPREC(dbias);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is odd since buff is _FLOAT*. but then a bunch of the functions around here don't seem right like the ones that only use _FLOAT so they better never be used for bf16.

Copy link
Contributor Author

@BrianHarrisonAMD BrianHarrisonAMD Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea these look incorrect for BF16, and just generally, but I think these don't get called for mixed.
Only issue is it might have compilation errors that were okay before let me check how it behaved for FP16.

delta_scale[xgid] = (_FLOAT_PREC)ds;
delta_bias[xgid] = (_FLOAT_PREC)db;
delta_scale[xgid] = FLOAT2FLOATPREC(ds);
delta_bias[xgid] = FLOAT2FLOATPREC(db);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

src/kernels/MIOpenBatchNormActivFwdTrainSpatial.cl Outdated Show resolved Hide resolved
src/kernels/batchnorm_functions.h Outdated Show resolved Hide resolved
src/kernels/batchnorm_functions.h Show resolved Hide resolved
@BrianHarrisonAMD BrianHarrisonAMD marked this pull request as ready for review November 1, 2024 19:33
BrianHarrisonAMD and others added 3 commits November 2, 2024 19:04
* fix driver for old ocl kernel

* fix template name

* fix review comments

* fix isApplicable for CK and OCL bn spatial

* fix gtest issue in bn for ck and ocl

* fix minor issue

* fix clang format

* hip tidy

* fix gtest sample and fix hiptidy

* * move type check logic to problem description of batch norm
* add type checks to other ocl solvers
* fix other minor issues

* fix network cache

* fix minor testing issue

* disable ck bn for now

---------

Co-authored-by: BrianHarrisonAMD <[email protected]>
@BrianHarrisonAMD BrianHarrisonAMD merged commit 1922f62 into develop Nov 2, 2024
13 of 128 checks passed
@BrianHarrisonAMD BrianHarrisonAMD deleted the bharriso/bn-ocl-bf16 branch November 2, 2024 20:27
Copy link
Contributor

@averinevg averinevg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post-merge review

@BrianHarrisonAMD Please create post-merge fix.

CC @junliume @bpepers-me

Comment on lines +39 to +47
bool is_fp16_or_bfp16(miopenDataType_t type)
{
return ((type == miopenHalf) || (type == miopenBFloat16));
}

bool is_fp32_or_fp64(miopenDataType_t type)
{
return ((type == miopenFloat) || (type == miopenDouble));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't use snake case for function names

Comment on lines +43 to +46
const ExecutionContext&, const miopen::batchnorm::ProblemDescription& bn_problem) const
{
if(problem.GetDirection() != miopen::batchnorm::Direction::ForwardTraining ||
problem.GetMode() != miopenBNSpatial)
if(bn_problem.GetDirection() != miopen::batchnorm::Direction::ForwardTraining ||
bn_problem.GetMode() != miopenBNSpatial)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This renaming was not superfluous. I believe that the name problem in this context is sufficient. Inside all solvers this structure is called exactly this way, and renaming it will only bring chaos.

Comment on lines +51 to +140
bool IsOCLInferTypeValid(const ProblemDescription& bn_problem)
{
// case 1 : mix type
return (
(is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetYDesc().GetType()) &&
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType())) ||
// case 2 : float type
(is_fp32(bn_problem.GetXDesc().GetType()) && is_fp32(bn_problem.GetYDesc().GetType()) &&
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType())));
}

bool IsCKInferTypeValid(const ProblemDescription& bn_problem)
{
// case 1 : mix type
return ((is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetYDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetBnScale().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetBnBias().GetType()) &&
is_fp32(bn_problem.GetBnSMean().GetType()) &&
is_fp32(bn_problem.GetBnSVar().GetType())) ||
// case 2 : fp32 or fp64
(is_fp32_or_fp64(bn_problem.GetXDesc().GetType()) &&
is_fp32_or_fp64(bn_problem.GetYDesc().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnScale().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnBias().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnSMean().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnSVar().GetType())));
}

bool IsOCLFwdTrainTypeValid(const ProblemDescription& bn_problem)
{
// case 1 : mix type
return (
(is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetYDesc().GetType()) &&
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType())) ||
// case 2 : float type
(is_fp32(bn_problem.GetXDesc().GetType()) && is_fp32(bn_problem.GetYDesc().GetType()) &&
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType())));
}

bool IsCKFwdTrainTypeValid(const ProblemDescription& bn_problem)
{
// case 1 : mix type
return ((is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetYDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetBnScale().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetBnBias().GetType()) &&
is_fp32(bn_problem.GetBnSMean().GetType()) &&
is_fp32(bn_problem.GetBnSVar().GetType())) ||
// case 2 : fp32 or fp64
(is_fp32_or_fp64(bn_problem.GetXDesc().GetType()) &&
is_fp32_or_fp64(bn_problem.GetYDesc().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnScale().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnBias().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnSMean().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnSVar().GetType())));
}

bool IsOCLBwdTypeValid(const ProblemDescription& bn_problem)
{
return (
(is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetDXDesc().GetType()) &&
is_fp16_or_bfp16(bn_problem.GetDYDesc().GetType()) &&
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnSMean().GetType()) &&
is_fp32(bn_problem.GetBnSVar().GetType())) ||
// case 1 : fp32
(is_fp32(bn_problem.GetXDesc().GetType()) && is_fp32(bn_problem.GetDXDesc().GetType()) &&
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType()) &&
is_fp32(bn_problem.GetBnSMean().GetType()) && is_fp32(bn_problem.GetBnSVar().GetType())));
}

bool IsCKBwdTypeValid(const ProblemDescription& bn_problem)
{
return ((is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) &&
bn_problem.GetDXDesc().GetType() == miopenFloat &&
is_fp16_or_bfp16(bn_problem.GetBnScale().GetType()) &&
bn_problem.GetDYDesc().GetType() == miopenFloat &&
bn_problem.GetBnSMean().GetType() == miopenFloat &&
bn_problem.GetBnSVar().GetType() == miopenFloat) ||
// case 1 : fp32 or fp64
(is_fp32_or_fp64(bn_problem.GetXDesc().GetType()) &&
is_fp32_or_fp64(bn_problem.GetDXDesc().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnScale().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnBias().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnSMean().GetType()) &&
is_fp32_or_fp64(bn_problem.GetBnSVar().GetType())));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions should not be here. ProblemDescription is a universal structure, not tied to a backend, third-party library, etc. Move them to the appropriate files.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants