-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add bfloat16 support to OCL bn kernels #3351
Conversation
buff[gammaindex] = (_FLOAT_PREC)dscale; | ||
buff[betaindex] = (_FLOAT_PREC)dbias; | ||
buff[gammaindex] = FLOAT2FLOATPREC(dscale); | ||
buff[betaindex] = FLOAT2FLOATPREC(dbias); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is odd since buff is _FLOAT*. but then a bunch of the functions around here don't seem right like the ones that only use _FLOAT so they better never be used for bf16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea these look incorrect for BF16, and just generally, but I think these don't get called for mixed.
Only issue is it might have compilation errors that were okay before let me check how it behaved for FP16.
delta_scale[xgid] = (_FLOAT_PREC)ds; | ||
delta_bias[xgid] = (_FLOAT_PREC)db; | ||
delta_scale[xgid] = FLOAT2FLOATPREC(ds); | ||
delta_bias[xgid] = FLOAT2FLOATPREC(db); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
* fix driver for old ocl kernel * fix template name * fix review comments * fix isApplicable for CK and OCL bn spatial * fix gtest issue in bn for ck and ocl * fix minor issue * fix clang format * hip tidy * fix gtest sample and fix hiptidy * * move type check logic to problem description of batch norm * add type checks to other ocl solvers * fix other minor issues * fix network cache * fix minor testing issue * disable ck bn for now --------- Co-authored-by: BrianHarrisonAMD <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool is_fp16_or_bfp16(miopenDataType_t type) | ||
{ | ||
return ((type == miopenHalf) || (type == miopenBFloat16)); | ||
} | ||
|
||
bool is_fp32_or_fp64(miopenDataType_t type) | ||
{ | ||
return ((type == miopenFloat) || (type == miopenDouble)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't use snake case for function names
const ExecutionContext&, const miopen::batchnorm::ProblemDescription& bn_problem) const | ||
{ | ||
if(problem.GetDirection() != miopen::batchnorm::Direction::ForwardTraining || | ||
problem.GetMode() != miopenBNSpatial) | ||
if(bn_problem.GetDirection() != miopen::batchnorm::Direction::ForwardTraining || | ||
bn_problem.GetMode() != miopenBNSpatial) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This renaming was not superfluous. I believe that the name problem
in this context is sufficient. Inside all solvers this structure is called exactly this way, and renaming it will only bring chaos.
bool IsOCLInferTypeValid(const ProblemDescription& bn_problem) | ||
{ | ||
// case 1 : mix type | ||
return ( | ||
(is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) && | ||
is_fp16_or_bfp16(bn_problem.GetYDesc().GetType()) && | ||
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType())) || | ||
// case 2 : float type | ||
(is_fp32(bn_problem.GetXDesc().GetType()) && is_fp32(bn_problem.GetYDesc().GetType()) && | ||
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType()))); | ||
} | ||
|
||
bool IsCKInferTypeValid(const ProblemDescription& bn_problem) | ||
{ | ||
// case 1 : mix type | ||
return ((is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) && | ||
is_fp16_or_bfp16(bn_problem.GetYDesc().GetType()) && | ||
is_fp16_or_bfp16(bn_problem.GetBnScale().GetType()) && | ||
is_fp16_or_bfp16(bn_problem.GetBnBias().GetType()) && | ||
is_fp32(bn_problem.GetBnSMean().GetType()) && | ||
is_fp32(bn_problem.GetBnSVar().GetType())) || | ||
// case 2 : fp32 or fp64 | ||
(is_fp32_or_fp64(bn_problem.GetXDesc().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetYDesc().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetBnScale().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetBnBias().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetBnSMean().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetBnSVar().GetType()))); | ||
} | ||
|
||
bool IsOCLFwdTrainTypeValid(const ProblemDescription& bn_problem) | ||
{ | ||
// case 1 : mix type | ||
return ( | ||
(is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) && | ||
is_fp16_or_bfp16(bn_problem.GetYDesc().GetType()) && | ||
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType())) || | ||
// case 2 : float type | ||
(is_fp32(bn_problem.GetXDesc().GetType()) && is_fp32(bn_problem.GetYDesc().GetType()) && | ||
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType()))); | ||
} | ||
|
||
bool IsCKFwdTrainTypeValid(const ProblemDescription& bn_problem) | ||
{ | ||
// case 1 : mix type | ||
return ((is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) && | ||
is_fp16_or_bfp16(bn_problem.GetYDesc().GetType()) && | ||
is_fp16_or_bfp16(bn_problem.GetBnScale().GetType()) && | ||
is_fp16_or_bfp16(bn_problem.GetBnBias().GetType()) && | ||
is_fp32(bn_problem.GetBnSMean().GetType()) && | ||
is_fp32(bn_problem.GetBnSVar().GetType())) || | ||
// case 2 : fp32 or fp64 | ||
(is_fp32_or_fp64(bn_problem.GetXDesc().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetYDesc().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetBnScale().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetBnBias().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetBnSMean().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetBnSVar().GetType()))); | ||
} | ||
|
||
bool IsOCLBwdTypeValid(const ProblemDescription& bn_problem) | ||
{ | ||
return ( | ||
(is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) && | ||
is_fp16_or_bfp16(bn_problem.GetDXDesc().GetType()) && | ||
is_fp16_or_bfp16(bn_problem.GetDYDesc().GetType()) && | ||
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnSMean().GetType()) && | ||
is_fp32(bn_problem.GetBnSVar().GetType())) || | ||
// case 1 : fp32 | ||
(is_fp32(bn_problem.GetXDesc().GetType()) && is_fp32(bn_problem.GetDXDesc().GetType()) && | ||
is_fp32(bn_problem.GetBnScale().GetType()) && is_fp32(bn_problem.GetBnBias().GetType()) && | ||
is_fp32(bn_problem.GetBnSMean().GetType()) && is_fp32(bn_problem.GetBnSVar().GetType()))); | ||
} | ||
|
||
bool IsCKBwdTypeValid(const ProblemDescription& bn_problem) | ||
{ | ||
return ((is_fp16_or_bfp16(bn_problem.GetXDesc().GetType()) && | ||
bn_problem.GetDXDesc().GetType() == miopenFloat && | ||
is_fp16_or_bfp16(bn_problem.GetBnScale().GetType()) && | ||
bn_problem.GetDYDesc().GetType() == miopenFloat && | ||
bn_problem.GetBnSMean().GetType() == miopenFloat && | ||
bn_problem.GetBnSVar().GetType() == miopenFloat) || | ||
// case 1 : fp32 or fp64 | ||
(is_fp32_or_fp64(bn_problem.GetXDesc().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetDXDesc().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetBnScale().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetBnBias().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetBnSMean().GetType()) && | ||
is_fp32_or_fp64(bn_problem.GetBnSVar().GetType()))); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These functions should not be here. ProblemDescription is a universal structure, not tied to a backend, third-party library, etc. Move them to the appropriate files.
Draft of updating OCL BN to support bf16