diff --git a/tfjs-backend-wasm/src/cc/BUILD.bazel b/tfjs-backend-wasm/src/cc/BUILD.bazel index 347a3934fa4..65f5836d083 100644 --- a/tfjs-backend-wasm/src/cc/BUILD.bazel +++ b/tfjs-backend-wasm/src/cc/BUILD.bazel @@ -380,6 +380,7 @@ tfjs_cc_library( ":Dilation2DBackpropInput", ":Elu", ":Equal", + ":Erf", ":Exp", ":FlipLeftRight", ":FloorDiv", @@ -913,6 +914,15 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Erf", + srcs = ["kernels/Erf.cc"], + deps = [ + ":unary", + ":util", + ], +) + tfjs_cc_library( name = "Exp", srcs = ["kernels/Exp.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Erf.cc b/tfjs-backend-wasm/src/cc/kernels/Erf.cc new file mode 100644 index 00000000000..70fe654b08f --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Erf.cc @@ -0,0 +1,53 @@ +/* Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "tfjs-backend-wasm/src/cc/unary.h" +#include "tfjs-backend-wasm/src/cc/util.h" + +namespace tfjs { +namespace wasm { + +namespace { +template +inline T ErfImpl(T n) { + return static_cast(std::erff(static_cast(n))); +} +} // namespace + +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +void Erf(const int x_id, const DType dtype, const int out_id) { + switch (dtype) { + case DType::float32: + unary_f32(x_id, out_id, ErfImpl); + break; + default: + util::warn("Erf for tensor id %d failed. Unsupported dtype %d", x_id, + dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Erf.ts b/tfjs-backend-wasm/src/kernels/Erf.ts new file mode 100644 index 00000000000..af39a05a4a9 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Erf.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Erf, KernelConfig} from '@tensorflow/tfjs-core'; + +import {createUnaryKernelConfig} from './unary_kernel'; + +export const erfConfig: KernelConfig = createUnaryKernelConfig(Erf); diff --git a/tfjs-backend-wasm/src/register_all_kernels.ts b/tfjs-backend-wasm/src/register_all_kernels.ts index 303f315d182..f2cfc346771 100644 --- a/tfjs-backend-wasm/src/register_all_kernels.ts +++ b/tfjs-backend-wasm/src/register_all_kernels.ts @@ -66,6 +66,7 @@ import {dilation2DBackpropInputConfig} from './kernels/Dilation2DBackpropInput'; import {eluConfig} from './kernels/Elu'; import {eluGradConfig} from './kernels/EluGrad'; import {equalConfig} from './kernels/Equal'; +import {erfConfig} from './kernels/Erf'; import {expConfig} from './kernels/Exp'; import {expandDimsConfig} from './kernels/ExpandDims'; import {expm1Config} from './kernels/Expm1'; @@ -221,6 +222,7 @@ const kernelConfigs: KernelConfig[] = [ eluConfig, eluGradConfig, equalConfig, + erfConfig, expConfig, expandDimsConfig, expm1Config, diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 4fdcac5decd..42175b7b643 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -297,6 +297,7 @@ const TEST_FILTERS: TestFilter[] = [ 'string tensor' // String tensors not yet implemented. ] }, + {startsWith: 'erf'}, {startsWith: 'sin '}, {startsWith: 'sinh '}, {