Skip to content

Commit

Permalink
Update webGL kernels for bitwise AND
Browse files Browse the repository at this point in the history
  • Loading branch information
fengwuyao committed May 4, 2023
1 parent 1a8a96f commit 27f359e
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 13 deletions.
2 changes: 2 additions & 0 deletions tfjs-backend-webgl/src/kernel_utils/shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const {
addImpl: addImplCPU,
bincountImpl: bincountImplCPU,
bincountReduceImpl: bincountReduceImplCPU,
bitwiseAndImpl: bitwiseAndImplCPU,
castImpl: castImplCPU,
ceilImpl: ceilImplCPU,
concatImpl: concatImplCPU,
Expand Down Expand Up @@ -82,6 +83,7 @@ export {
addImplCPU,
bincountImplCPU,
bincountReduceImplCPU,
bitwiseAndImplCPU,
castImplCPU,
ceilImplCPU,
concatImplCPU,
Expand Down
66 changes: 54 additions & 12 deletions tfjs-backend-webgl/src/kernels/BitwiseAnd.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,64 @@
* =============================================================================
*/

import {BitwiseAnd, KernelConfig} from '@tensorflow/tfjs-core';
import {BitwiseAnd, BitwiseAndInputs, env, KernelConfig, KernelFunc, TensorInfo, TypedArray} from '@tensorflow/tfjs-core';
import {MathBackendWebGL} from '../backend_webgl';
import {BinaryOpProgram} from '../binaryop_gpu';
import {BinaryOpPackedProgram} from '../binaryop_packed_gpu';
import {bitwiseAndImplCPU as cpuBitWiseAnd} from '../kernel_utils/shared';

import {binaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
import {bitwiseAndImplCPU as cpuBitwiseAnd} from '../kernel_utils/shared';
export const BITWISEAND = `
int r = int(a.r) & int(b.r);
int g = int(a.g) & int(b.g);
int rb = int(a.b) & int(b.b);
int ra = int(a.a) & int(b.a);
return vec4(r, g, rb, ra);
`;

const BITWISEAND = 'return a & b;';
export const BITWISEAND_UNPACKED = `
return float(int(a.r) & int(b.r));
`;

export const addKernelFunc = binaryKernelFunc({
opSnippet: BITWISEAND,
packedOpSnippet: BITWISEAND,
supportsComplex: true,
cpuKernelImpl: cpuBitwiseAnd
});
export function bitwiseAnd(args: {
inputs: BitwiseAndInputs,
backend: MathBackendWebGL,
}): TensorInfo {
const {inputs, backend} = args;
const {a, b} = inputs;
const webglBackend = backend as MathBackendWebGL;
const shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS');
const versionNumber = env().getNumber('WEBGL_VERSION');
if (versionNumber !== 2) {
throw new Error(
`Unsupported webgl version. Current webgl version: ${versionNumber}`);
}

export const addConfig: KernelConfig = {
// The type of a and b are ensured to be `int` in core, therefore no need to
// consider other type situations.
if ((webglBackend.shouldExecuteOnCPU([a, b])) && cpuBitWiseAnd != null) {
const aVals = webglBackend.texData.get(a.dataId).values as TypedArray;
const bVals = webglBackend.texData.get(b.dataId).values as TypedArray;
const [outValues, outShape] =
cpuBitWiseAnd(a.shape, b.shape, aVals, bVals, a.dtype);

const out = webglBackend.makeTensorInfo(outShape, a.dtype);
const outData = webglBackend.texData.get(out.dataId);
outData.values = outValues;
return out;
}

let program: BinaryOpProgram|BinaryOpPackedProgram;
if (shouldUsePackedProgram) {
program = new BinaryOpPackedProgram(BITWISEAND, a.shape, b.shape, false);
} else {
program = new BinaryOpProgram(BITWISEAND_UNPACKED, a.shape, b.shape);
}

return webglBackend.runWebGLProgram(program, [a, b], a.dtype);
}

export const bitwiseAndConfig: KernelConfig = {
kernelName: BitwiseAnd,
backendName: 'webgl',
kernelFunc: addKernelFunc
kernelFunc: bitwiseAnd as unknown as KernelFunc
};
2 changes: 2 additions & 0 deletions tfjs-backend-webgl/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import {batchMatMulConfig} from './kernels/BatchMatMul';
import {batchNormConfig} from './kernels/BatchNorm';
import {batchToSpaceNDConfig} from './kernels/BatchToSpaceND';
import {bincountConfig} from './kernels/Bincount';
import {bitwiseAndConfig} from './kernels/BitwiseAnd';
import {broadcastArgsConfig} from './kernels/BroadcastArgs';
import {castConfig} from './kernels/Cast';
import {ceilConfig} from './kernels/Ceil';
Expand Down Expand Up @@ -211,6 +212,7 @@ const kernelConfigs: KernelConfig[] = [
batchNormConfig,
batchToSpaceNDConfig,
bincountConfig,
bitwiseAndConfig,
broadcastArgsConfig,
castConfig,
ceilConfig,
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-webgl/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const customInclude = (testName: string) => {
'isBrowser: false', 'dilation gradient',
'throws when index is out of bound',
// otsu tests for threshold op is failing on windows
'method otsu', 'bitwiseAnd'
'method otsu'
];
for (const subStr of toExclude) {
if (testName.includes(subStr)) {
Expand Down

0 comments on commit 27f359e

Please sign in to comment.