diff --git a/tfjs-backend-webgl/src/kernel_utils/shared.ts b/tfjs-backend-webgl/src/kernel_utils/shared.ts index b9a7a5a59f2..327a932a897 100644 --- a/tfjs-backend-webgl/src/kernel_utils/shared.ts +++ b/tfjs-backend-webgl/src/kernel_utils/shared.ts @@ -31,6 +31,7 @@ const { addImpl: addImplCPU, bincountImpl: bincountImplCPU, bincountReduceImpl: bincountReduceImplCPU, + bitwiseAndImpl: bitwiseAndImplCPU, castImpl: castImplCPU, ceilImpl: ceilImplCPU, concatImpl: concatImplCPU, @@ -82,6 +83,7 @@ export { addImplCPU, bincountImplCPU, bincountReduceImplCPU, + bitwiseAndImplCPU, castImplCPU, ceilImplCPU, concatImplCPU, diff --git a/tfjs-backend-webgl/src/kernels/BitwiseAnd.ts b/tfjs-backend-webgl/src/kernels/BitwiseAnd.ts index 91e14f022b0..22f8a93557d 100644 --- a/tfjs-backend-webgl/src/kernels/BitwiseAnd.ts +++ b/tfjs-backend-webgl/src/kernels/BitwiseAnd.ts @@ -15,22 +15,64 @@ * ============================================================================= */ -import {BitwiseAnd, KernelConfig} from '@tensorflow/tfjs-core'; +import {BitwiseAnd, BitwiseAndInputs, env, KernelConfig, KernelFunc, TensorInfo, TypedArray} from '@tensorflow/tfjs-core'; +import {MathBackendWebGL} from '../backend_webgl'; +import {BinaryOpProgram} from '../binaryop_gpu'; +import {BinaryOpPackedProgram} from '../binaryop_packed_gpu'; +import {bitwiseAndImplCPU as cpuBitWiseAnd} from '../kernel_utils/shared'; -import {binaryKernelFunc} from '../kernel_utils/kernel_funcs_utils'; -import {bitwiseAndImplCPU as cpuBitwiseAnd} from '../kernel_utils/shared'; +export const BITWISEAND = ` + int r = int(a.r) & int(b.r); + int g = int(a.g) & int(b.g); + int rb = int(a.b) & int(b.b); + int ra = int(a.a) & int(b.a); + return vec4(r, g, rb, ra); +`; -const BITWISEAND = 'return a & b;'; +export const BITWISEAND_UNPACKED = ` + return float(int(a.r) & int(b.r)); +`; -export const addKernelFunc = binaryKernelFunc({ - opSnippet: BITWISEAND, - packedOpSnippet: BITWISEAND, - supportsComplex: true, - cpuKernelImpl: cpuBitwiseAnd -}); +export function bitwiseAnd(args: { + inputs: BitwiseAndInputs, + backend: MathBackendWebGL, +}): TensorInfo { + const {inputs, backend} = args; + const {a, b} = inputs; + const webglBackend = backend as MathBackendWebGL; + const shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + const versionNumber = env().getNumber('WEBGL_VERSION'); + if (versionNumber !== 2) { + throw new Error( + `Unsupported webgl version. Current webgl version: ${versionNumber}`); + } -export const addConfig: KernelConfig = { + // The type of a and b are ensured to be `int` in core, therefore no need to + // consider other type situations. + if ((webglBackend.shouldExecuteOnCPU([a, b])) && cpuBitWiseAnd != null) { + const aVals = webglBackend.texData.get(a.dataId).values as TypedArray; + const bVals = webglBackend.texData.get(b.dataId).values as TypedArray; + const [outValues, outShape] = + cpuBitWiseAnd(a.shape, b.shape, aVals, bVals, a.dtype); + + const out = webglBackend.makeTensorInfo(outShape, a.dtype); + const outData = webglBackend.texData.get(out.dataId); + outData.values = outValues; + return out; + } + + let program: BinaryOpProgram|BinaryOpPackedProgram; + if (shouldUsePackedProgram) { + program = new BinaryOpPackedProgram(BITWISEAND, a.shape, b.shape, false); + } else { + program = new BinaryOpProgram(BITWISEAND_UNPACKED, a.shape, b.shape); + } + + return webglBackend.runWebGLProgram(program, [a, b], a.dtype); +} + +export const bitwiseAndConfig: KernelConfig = { kernelName: BitwiseAnd, backendName: 'webgl', - kernelFunc: addKernelFunc + kernelFunc: bitwiseAnd as unknown as KernelFunc }; diff --git a/tfjs-backend-webgl/src/register_all_kernels.ts b/tfjs-backend-webgl/src/register_all_kernels.ts index 9f8e8f4ff39..42b1e0e9ae5 100644 --- a/tfjs-backend-webgl/src/register_all_kernels.ts +++ b/tfjs-backend-webgl/src/register_all_kernels.ts @@ -39,6 +39,7 @@ import {batchMatMulConfig} from './kernels/BatchMatMul'; import {batchNormConfig} from './kernels/BatchNorm'; import {batchToSpaceNDConfig} from './kernels/BatchToSpaceND'; import {bincountConfig} from './kernels/Bincount'; +import {bitwiseAndConfig} from './kernels/BitwiseAnd'; import {broadcastArgsConfig} from './kernels/BroadcastArgs'; import {castConfig} from './kernels/Cast'; import {ceilConfig} from './kernels/Ceil'; @@ -211,6 +212,7 @@ const kernelConfigs: KernelConfig[] = [ batchNormConfig, batchToSpaceNDConfig, bincountConfig, + bitwiseAndConfig, broadcastArgsConfig, castConfig, ceilConfig, diff --git a/tfjs-backend-webgl/src/setup_test.ts b/tfjs-backend-webgl/src/setup_test.ts index a22acd98022..343a8e5e9af 100644 --- a/tfjs-backend-webgl/src/setup_test.ts +++ b/tfjs-backend-webgl/src/setup_test.ts @@ -36,7 +36,7 @@ const customInclude = (testName: string) => { 'isBrowser: false', 'dilation gradient', 'throws when index is out of bound', // otsu tests for threshold op is failing on windows - 'method otsu', 'bitwiseAnd' + 'method otsu' ]; for (const subStr of toExclude) { if (testName.includes(subStr)) {