Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bfloat16 support to OCL bn kernels #3351

Merged
merged 23 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e44dd04
Add bfloat16 support to OCL bn kernels
BrianHarrisonAMD Nov 1, 2024
c016604
Fix issue
BrianHarrisonAMD Nov 1, 2024
a9e0ecb
Fix spacing
BrianHarrisonAMD Nov 1, 2024
43aa479
Add bf16 support for bwd OCL spatial
BrianHarrisonAMD Nov 1, 2024
ce33a3e
Further changes of OCL kernels for bfp16
BradPepersAMD Nov 1, 2024
09de609
Add bf16 support for remaining OCL solvers
BrianHarrisonAMD Nov 1, 2024
f0a080c
Fix compilation error
BrianHarrisonAMD Nov 1, 2024
8acd52d
Fix compilation errors
BrianHarrisonAMD Nov 1, 2024
24b0a11
Fix formatting
BrianHarrisonAMD Nov 1, 2024
c7bc908
Fix extra bracket being added
BrianHarrisonAMD Nov 1, 2024
8ff2396
Fix kernel compilation and issue with variant selection for BF16 mixed
BrianHarrisonAMD Nov 1, 2024
99326bf
Swap solver order to prefer OCL over CK by default
BrianHarrisonAMD Nov 1, 2024
58bd202
Update batchnorm bfp16mix to match fp16mix declaration
BrianHarrisonAMD Nov 1, 2024
90b0695
Fix improper cast
BrianHarrisonAMD Nov 1, 2024
59ad519
Fix formatting
BrianHarrisonAMD Nov 1, 2024
90237e0
Merge remote-tracking branch 'origin/develop' into bharriso/bn-ocl-bf16
BrianHarrisonAMD Nov 1, 2024
fffc1d1
Merge develop
BrianHarrisonAMD Nov 2, 2024
d811667
Merge remote-tracking branch 'origin/develop' into bharriso/bn-ocl-bf16
BrianHarrisonAMD Nov 2, 2024
6e85cb9
Merge branch 'develop' into bharriso/bn-ocl-bf16
BrianHarrisonAMD Nov 2, 2024
a95531e
Fix kernel typo for activ
BrianHarrisonAMD Nov 2, 2024
a792d50
Merge branch 'bharriso/bn-ocl-bf16' of github.com:ROCm/MIOpen into bh…
BrianHarrisonAMD Nov 2, 2024
75f15a6
Bg/fix bn solvers is applicable (#3355)
bghimireamd Nov 2, 2024
3d5cc27
Fix formatting
BrianHarrisonAMD Nov 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/kernels/MIOpenBatchNormActivBwdPerAct.cl
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,16 @@ MIOpenBatchNormActivBwdPerActivation(const __global _FLOAT* __restrict x_in,
{
// per (x-dims) channel load a block of data into LDS
index = MIO_BN_CHW * n + adjIndex;
xhat = ((_FLOAT_PREC)(*(x_in + index)) - mean) * invVar;
act_dyin = *(dy_in + index);
act_out = *(y_in + index);
xhat = (FLOAT2FLOATPREC(*(x_in + index)) - mean) * invVar;
act_dyin = FLOAT2FLOATPREC(*(dy_in + index));
act_out = FLOAT2FLOATPREC(*(y_in + index));
bn_out = mad(xhat, pvt_scale, pvt_bias);
ActivationFunction_Diff(
1, &bn_dyin, &act_dyin, &bn_out, &act_out, diff_scale, gamma, beta, alpha);
1, &bn_dyin, &act_dyin, &bn_out, &act_out, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));
#if MIO_BN_CBA_WRITE_INTERMEDIATE
// for debugging
bn_out_dev[index] = bn_out;
bn_dyin_dev[index] = bn_dyin;
bn_out_dev[index] = FLOATPREC2FLOAT(bn_out);
bn_dyin_dev[index] = FLOATPREC2FLOAT(bn_dyin);
#endif
dyelem = bn_dyin;
pvt_dbias += dyelem;
Expand All @@ -128,16 +128,16 @@ MIOpenBatchNormActivBwdPerActivation(const __global _FLOAT* __restrict x_in,
for(int n = 0; n < MIO_BN_N; n++)
{
index = MIO_BN_CHW * n + adjIndex;
xhat = ((_FLOAT_PREC)(*(x_in + index)) - mean) * invVar;
xhat = (FLOAT2FLOATPREC(*(x_in + index)) - mean) * invVar;
tmp1 = mad(xhat, dxhathat, dxhat);
bn_out = mad(xhat, pvt_scale, pvt_bias);
act_dyin = *(dy_in + index);
act_out = *(y_in + index);
act_dyin = FLOAT2FLOAPREC(*(dy_in + index));
act_out = FLOAT2FLOATPREC(*(y_in + index));
ActivationFunction_Diff(
1, &bn_dyin, &act_dyin, &bn_out, &act_out, diff_scale, gamma, beta, alpha);
1, &bn_dyin, &act_dyin, &bn_out, &act_out, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));
tmp2 = mad((_FLOAT_PREC)MIO_BN_N, bn_dyin * pvt_scale, -tmp1);
tmp3 = invVar / ((_FLOAT_PREC)MIO_BN_N);
dx_out[index] = (_FLOAT)(tmp3 * tmp2);
dx_out[index] = FLOATPREC2FLOAT(tmp3 * tmp2);
}
// Write out data
dbias[adjIndex] = pvt_dbias;
Expand Down
96 changes: 48 additions & 48 deletions src/kernels/MIOpenBatchNormActivBwdSpatial.cl
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
{
nid = n * MIO_BN_SEGIHW + lidihw;
index = nid * MIO_BN_CHW + chwid;
_FLOAT_PREC xhat = (((_FLOAT_PREC)(*(x_in + index)) - mean) * invVariance);
_FLOAT_PREC xhat = (FLOAT2FLOATPREC(*(x_in + index)) - mean) * invVariance;
_FLOAT_PREC bn_out = mad(xhat, lbns, lbnb);
_FLOAT_PREC bn_dyin;
_FLOAT_PREC act_dyin = *(dy_in + index);
_FLOAT_PREC act_out = *(y_in + index);
_FLOAT_PREC act_dyin = FLOAT2FLOATPREC(*(dy_in + index));
_FLOAT_PREC act_out = FLOAT2FLOATPREC(*(y_in + index));
ActivationFunction_Diff(
1, &bn_dyin, &act_dyin, &bn_out, &act_out, diff_scale, gamma, beta, alpha);
1, &bn_dyin, &act_dyin, &bn_out, &act_out, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));
dyvalues[n] = bn_dyin;
db += dyvalues[n];
batchvalues[n] = xhat;
Expand All @@ -139,13 +139,13 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
index = nid * MIO_BN_CHW + chwid;
if(index < MIO_BN_NCHW)
{
_FLOAT_PREC xhat = (((_FLOAT_PREC)(*(x_in + index)) - mean) * invVariance);
_FLOAT_PREC xhat = (FLOAT2FLOATPREC(*(x_in + index)) - mean) * invVariance;
_FLOAT_PREC bn_out = mad(xhat, lbns, lbnb);
_FLOAT_PREC bn_dyin;
_FLOAT_PREC act_dyin = (_FLOAT_PREC)(*(dy_in + index));
_FLOAT_PREC act_out = (_FLOAT_PREC)(*(y_in + index));
_FLOAT_PREC act_dyin = FLOAT2FLOATPREC(*(dy_in + index));
_FLOAT_PREC act_out = FLOAT2FLOATPREC(*(y_in + index));
ActivationFunction_Diff(
1, &bn_dyin, &act_dyin, &bn_out, &act_out, diff_scale, gamma, beta, alpha);
1, &bn_dyin, &act_dyin, &bn_out, &act_out, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));
dyvalues[MIO_BN_NLOOPM] = bn_dyin;

#if MIO_BN_CBA_WRITE_INTERMEDIATE
Expand All @@ -161,7 +161,7 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
db += dyvalues[MIO_BN_NLOOPM];

batchvalues[MIO_BN_NLOOPM] = (index < MIO_BN_NCHW)
? (((_FLOAT_PREC)(*(x_in + index)) - mean) * invVariance)
? (FLOAT2FLOATPREC(*(x_in + index)) - mean) * invVariance
: (_FLOAT_PREC)0.;

// batchvalues is now xhat
Expand Down Expand Up @@ -191,7 +191,7 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
tmp1 = mad(NHW, dyvalues[n], -db);
tmp2 = -batchvalues[n] * ds;
tmp3 = (pscale * invVariance) * INHW;
dx_out[index] = (_FLOAT)(tmp3 * (tmp2 + tmp1));
dx_out[index] = FLOATPREC2FLOAT(tmp3 * (tmp2 + tmp1));
} // end for
nid = MIO_BN_SNHW + lidihw;
index = nid * MIO_BN_CHW + chwid;
Expand All @@ -200,7 +200,7 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
tmp1 = mad(NHW, dyvalues[MIO_BN_NLOOPM], -db);
tmp2 = -batchvalues[MIO_BN_NLOOPM] * ds;
tmp3 = (pscale * invVariance) * INHW;
dx_out[index] = (_FLOAT)(tmp3 * (tmp2 + tmp1));
dx_out[index] = FLOATPREC2FLOAT(tmp3 * (tmp2 + tmp1));
}
}
if(lid == 0)
Expand Down Expand Up @@ -293,10 +293,10 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
xread4 = *((const global _FLOAT4*)(x_in + index));
act_dyin4 = *((const global _FLOAT4*)(dy_in + index));
act_out4 = *((const global _FLOAT4*)(y_in + index));
xhat4.x = ((_FLOAT_PREC)xread4.x - mean) * invVariance;
xhat4.y = ((_FLOAT_PREC)xread4.y - mean) * invVariance;
xhat4.z = ((_FLOAT_PREC)xread4.z - mean) * invVariance;
xhat4.w = ((_FLOAT_PREC)xread4.w - mean) * invVariance;
xhat4.x = (FLOAT2FLOATPREC(xread4.x) - mean) * invVariance;
xhat4.y = (FLOAT2FLOATPREC(xread4.y) - mean) * invVariance;
xhat4.z = (FLOAT2FLOATPREC(xread4.z) - mean) * invVariance;
xhat4.w = (FLOAT2FLOATPREC(xread4.w) - mean) * invVariance;

bn_out4.x = mad(xhat4.x, lcl_scale, lcl_bias);
bn_out4.y = mad(xhat4.y, lcl_scale, lcl_bias);
Expand All @@ -308,30 +308,30 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
_FLOAT_PREC pbnout = bn_out4.x;
_FLOAT_PREC pactout = act_out4.x;
ActivationFunction_Diff(
1, &pbndyin, &pactdyin, &pbnout, &pactout, diff_scale, gamma, beta, alpha);
1, &pbndyin, &pactdyin, &pbnout, &pactout, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));

db += pbndyin;
ds = mad(xhat4.x, pbndyin, ds);
pactdyin = act_dyin4.y;
pbnout = bn_out4.y;
pactout = act_out4.y;
ActivationFunction_Diff(
1, &pbndyin, &pactdyin, &pbnout, &pactout, diff_scale, gamma, beta, alpha);
1, &pbndyin, &pactdyin, &pbnout, &pactout, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));

db += pbndyin;
ds = mad(xhat4.y, pbndyin, ds);
pactdyin = act_dyin4.z;
pbnout = bn_out4.z;
pactout = act_out4.z;
ActivationFunction_Diff(
1, &pbndyin, &pactdyin, &pbnout, &pactout, diff_scale, gamma, beta, alpha);
1, &pbndyin, &pactdyin, &pbnout, &pactout, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));
db += pbndyin;
ds = mad(xhat4.z, pbndyin, ds);
pactdyin = act_dyin4.w;
pbnout = bn_out4.w;
pactout = act_out4.w;
ActivationFunction_Diff(
1, &pbndyin, &pactdyin, &pbnout, &pactout, diff_scale, gamma, beta, alpha);
1, &pbndyin, &pactdyin, &pbnout, &pactout, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));
db += pbndyin;
ds = mad(xhat4.w, pbndyin, ds);

Expand Down Expand Up @@ -359,10 +359,10 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
xread4 = *((const global _FLOAT4*)(x_in + index));
act_dyin4 = *((const global _FLOAT4*)(dy_in + index));
act_out4 = *((const global _FLOAT4*)(y_in + index));
xhat4.x = ((_FLOAT_PREC)xread4.x - mean) * invVariance;
xhat4.y = ((_FLOAT_PREC)xread4.y - mean) * invVariance;
xhat4.z = ((_FLOAT_PREC)xread4.z - mean) * invVariance;
xhat4.w = ((_FLOAT_PREC)xread4.w - mean) * invVariance;
xhat4.x = (FLOAT2FLOATPREC(xread4.x) - mean) * invVariance;
xhat4.y = (FLOAT2FLOATPREC(xread4.y) - mean) * invVariance;
xhat4.z = (FLOAT2FLOATPREC(xread4.z) - mean) * invVariance;
xhat4.w = (FLOAT2FLOATPREC(xread4.w) - mean) * invVariance;

bn_out4.x = mad(xhat4.x, lcl_scale, lcl_bias);
bn_out4.y = mad(xhat4.y, lcl_scale, lcl_bias);
Expand All @@ -374,30 +374,30 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
_FLOAT_PREC pbnout = bn_out4.x;
_FLOAT_PREC pactout = act_out4.x;
ActivationFunction_Diff(
1, &pbndyin, &pactdyin, &pbnout, &pactout, diff_scale, gamma, beta, alpha);
1, &pbndyin, &pactdyin, &pbnout, &pactout, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));

db += pbndyin;
ds = mad(xhat4.x, pbndyin, ds);
pactdyin = act_dyin4.y;
pbnout = bn_out4.y;
pactout = act_out4.y;
ActivationFunction_Diff(
1, &pbndyin, &pactdyin, &pbnout, &pactout, diff_scale, gamma, beta, alpha);
1, &pbndyin, &pactdyin, &pbnout, &pactout, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));

db += pbndyin;
ds = mad(xhat4.y, pbndyin, ds);
pactdyin = act_dyin4.z;
pbnout = bn_out4.z;
pactout = act_out4.z;
ActivationFunction_Diff(
1, &pbndyin, &pactdyin, &pbnout, &pactout, diff_scale, gamma, beta, alpha);
1, &pbndyin, &pactdyin, &pbnout, &pactout, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));
db += pbndyin;
ds = mad(xhat4.z, pbndyin, ds);
pactdyin = act_dyin4.w;
pbnout = bn_out4.w;
pactout = act_out4.w;
ActivationFunction_Diff(
1, &pbndyin, &pactdyin, &pbnout, &pactout, diff_scale, gamma, beta, alpha);
1, &pbndyin, &pactdyin, &pbnout, &pactout, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));
db += pbndyin;
ds = mad(xhat4.w, pbndyin, ds);

Expand Down Expand Up @@ -448,12 +448,12 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
hwidx = l - (nidx * MIO_BN_HW);
index = nidx * MIO_BN_CHW + chwid + hwidx;
_FLOAT_PREC bn_dyin;
_FLOAT_PREC act_dyin = (_FLOAT_PREC) * (dy_in + index);
_FLOAT_PREC act_out = (_FLOAT_PREC) * (y_in + index);
xhat = ((_FLOAT_PREC)(*(x_in + index)) - mean) * invVariance;
_FLOAT_PREC act_dyin = FLOAT2FLOATPREC(*(dy_in + index));
_FLOAT_PREC act_out = FLOAT2FLOATPREC(*(y_in + index));
xhat = (FLOAT2FLOATPREC(*(x_in + index)) - mean) * invVariance;
_FLOAT_PREC bn_out = mad(xhat, lcl_scale, lcl_bias);
ActivationFunction_Diff(
1, &bn_dyin, &act_dyin, &bn_out, &act_out, diff_scale, gamma, beta, alpha);
1, &bn_dyin, &act_dyin, &bn_out, &act_out, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));
tmp1 = mad(NHW, bn_dyin, -db);
tmp2 = -xhat * ds;
vals[j] = tmp3 * (tmp2 + tmp1);
Expand All @@ -465,7 +465,7 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
nidx = l / MIO_BN_HW;
hwidx = l - (nidx * MIO_BN_HW);
index = nidx * MIO_BN_CHW + chwid + hwidx;
*(dx_out + index) = (_FLOAT)vals[j];
*(dx_out + index) = FLOATPREC2FLOAT(vals[j]);
}
}

Expand All @@ -480,12 +480,12 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
if(index < MIO_BN_NCHW)
{
_FLOAT_PREC bn_dyin;
_FLOAT_PREC act_dyin = (_FLOAT_PREC) * (dy_in + index);
_FLOAT_PREC act_out = (_FLOAT_PREC) * (y_in + index);
_FLOAT_PREC act_dyin = FLOAT2FLOATPREC(*(dy_in + index));
_FLOAT_PREC act_out = FLOAT2FLOATPREC(*(y_in + index));
xhat = (*(x_in + index) - mean) * invVariance;
_FLOAT_PREC bn_out = mad(xhat, lcl_scale, lcl_bias);
ActivationFunction_Diff(
1, &bn_dyin, &act_dyin, &bn_out, &act_out, diff_scale, gamma, beta, alpha);
1, &bn_dyin, &act_dyin, &bn_out, &act_out, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));

tmp1 = mad(NHW, bn_dyin, -db);
tmp2 = -xhat * ds;
Expand All @@ -501,7 +501,7 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
index = nidx * MIO_BN_CHW + chwid + hwidx;
if(index < MIO_BN_NCHW)
{
*(dx_out + index) = (_FLOAT)vals[j];
*(dx_out + index) = FLOATPREC2FLOAT(vals[j]);
}
}
#endif
Expand Down Expand Up @@ -575,20 +575,20 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
for(unsigned n = 0; n < MIO_BN_N; n++)
{
index = n * MIO_BN_CHW + cidx + lid;
_FLOAT_PREC xhat = ((_FLOAT_PREC) * (x_in + index) - mean) * invVariance;
_FLOAT_PREC xhat = (FLOAT2FLOATPREC(*(x_in + index)) - mean) * invVariance;
_FLOAT_PREC bn_out = mad(xhat, lcl_scale, lcl_bias);
_FLOAT_PREC bn_dyin;
_FLOAT_PREC act_dyin = (_FLOAT_PREC) * (dy_in + index);
_FLOAT_PREC act_out = (_FLOAT_PREC) * (y_in + index);
_FLOAT_PREC act_dyin = FLOAT2FLOATPREC(*(dy_in + index));
_FLOAT_PREC act_out = FLOAT2FLOATPREC(*(y_in + index));
ActivationFunction_Diff(1,
&bn_dyin,
&act_dyin,
&bn_out,
&act_out,
(_FLOAT_PREC)diff_scale,
gamma,
beta,
alpha);
FLOAT2FLOATPREC(diff_scale),
FLOAT2FLOATPREC(gamma),
FLOAT2FLOATPREC(beta),
FLOAT2FLOATPREC(alpha));

#if MIO_BN_CBA_WRITE_INTERMEDIATE
// for debugging
Expand Down Expand Up @@ -638,19 +638,19 @@ MIOpenBatchNormActivBwdSpatial(const __global _FLOAT* __restrict x_in,
tmp1 = mad(NHW, dyvalues[n], -db);
tmp2 = -(batchvalues[n] * ds);
#else
_FLOAT_PREC act_dyin = (_FLOAT_PREC) * (dy_in + index);
_FLOAT_PREC act_out = (_FLOAT_PREC) * (y_in + index);
_FLOAT_PREC xhat = ((_FLOAT_PREC) * (x_in + index) - mean) * invVariance;
_FLOAT_PREC act_dyin = FLOAT2FLOATPREC(*(dy_in + index));
_FLOAT_PREC act_out = FLOAT2FLOATPREC(*(y_in + index));
_FLOAT_PREC xhat = (FLOAT2FLOATPREC(*(x_in + index)) - mean) * invVariance;
_FLOAT_PREC bn_out = mad(xhat, lcl_scale, lcl_bias);
_FLOAT_PREC bn_dyin;
ActivationFunction_Diff(
1, &bn_dyin, &act_dyin, &bn_out, &act_out, diff_scale, gamma, beta, alpha);
1, &bn_dyin, &act_dyin, &bn_out, &act_out, FLOAT2FLOATPREC(diff_scale), FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));

tmp1 = mad(NHW, bn_dyin, -db);
tmp2 = -(xhat)*ds;
#endif
tmp3 = (pscale * invVariance) * INHW;
dx_out[index] = (_FLOAT)(tmp3 * (tmp2 + tmp1));
dx_out[index] = FLOATPREC2FLOAT(tmp3 * (tmp2 + tmp1));
}
}
if(lid == 0)
Expand Down
8 changes: 4 additions & 4 deletions src/kernels/MIOpenBatchNormActivFwdTrainPerAct.cl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ __kernel void MIOpenBatchNormActivFwdTrainPerActivation(
for(unsigned int n = 0; n < MIO_BN_N; n++)
{
index = MIO_BN_CHW * n + adjIndex;
_FLOAT_PREC xin = (_FLOAT_PREC)(*(in + index));
_FLOAT_PREC xin = FLOAT2FLOATPREC(*(in + index));
mean += xin;
variance = mad(xin, xin, variance);
} // end for(n)
Expand All @@ -115,10 +115,10 @@ __kernel void MIOpenBatchNormActivFwdTrainPerActivation(
for(unsigned int n = 0; n < MIO_BN_N; n++)
{ // per (x-dims) channel load a block of data unsigned into LDS
index = MIO_BN_CHW * n + adjIndex;
inhat = ((_FLOAT_PREC)(*(in + index)) - mean) * invVariance;
inhat = (FLOAT2FLOATPREC(*(in + index)) - mean) * invVariance;
bn_out = mad(pvt_scale, inhat, pvt_bias);
ActivationFunction(1, &act_out, &bn_out, gamma, beta, alpha);
out[index] = (_FLOAT)act_out;
ActivationFunction(1, &act_out, &bn_out, FLOAT2FLOATPREC(gamma), FLOAT2FLOATPREC(beta), FLOAT2FLOATPREC(alpha));
out[index] = FLOATPREC2FLOAT(act_out);
} // end for(n)
} // end if(inImgIndex)
} // end for(img_offset) //image mini_batch is processed
Expand Down
Loading