From f8165e242561ca6cb17ff1a7e66b117d77b11986 Mon Sep 17 00:00:00 2001 From: ahmedsabie Date: Mon, 12 Sep 2022 16:59:44 -0400 Subject: [PATCH] Fix clamp bug when min = max (#6825) Co-authored-by: Ping Yu <4018+pyu10055@users.noreply.github.com> --- tfjs-core/src/ops/clip_by_value.ts | 5 +++++ tfjs-core/src/ops/clip_by_value_test.ts | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/tfjs-core/src/ops/clip_by_value.ts b/tfjs-core/src/ops/clip_by_value.ts index e597ff0c663..40b0bc0905f 100644 --- a/tfjs-core/src/ops/clip_by_value.ts +++ b/tfjs-core/src/ops/clip_by_value.ts @@ -22,6 +22,7 @@ import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; +import {fill} from './fill'; import {op} from './operation'; @@ -47,6 +48,10 @@ function clipByValue_( () => `Error in clip: min (${clipValueMin}) must be ` + `less than or equal to max (${clipValueMax}).`); + if (clipValueMin === clipValueMax) { + return fill($x.shape, clipValueMin, $x.dtype) as T; + } + const inputs: ClipByValueInputs = {x: $x}; const attrs: ClipByValueAttrs = {clipValueMin, clipValueMax}; diff --git a/tfjs-core/src/ops/clip_by_value_test.ts b/tfjs-core/src/ops/clip_by_value_test.ts index 2bfb067bf85..359672f6abc 100644 --- a/tfjs-core/src/ops/clip_by_value_test.ts +++ b/tfjs-core/src/ops/clip_by_value_test.ts @@ -141,6 +141,14 @@ describeWithFlags('clipByValue', ALL_ENVS, () => { expect(res[1]).toBeCloseTo(max); }); + it('clip min = max', async () => { + const min = 2; + const max = 2; + const tensor = tf.tensor([1, 2, 3, 4, 5], [5], 'float32'); + const result = tf.clipByValue(tensor, min, max); + expectArraysClose(await result.data(), [2, 2, 2, 2, 2]); + }); + it('throws for string tensor', () => { expect(() => tf.clipByValue('q', 0, 1)) .toThrowError(/Argument 'x' passed to 'clipByValue' must be numeric/);