From a85eb73252dda52b0aa4d9a122e2836cc961c512 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Tue, 20 Jun 2023 14:46:57 -0700 Subject: [PATCH] Add keras_core.operations.convert_to_numpy (#378) For downstream use cases, nice to be able to have this in the ops layer, instead of being forced to reach into backend. --- keras_core/operations/__init__.py | 1 - keras_core/operations/core.py | 18 ++++++++++++++++++ keras_core/operations/core_test.py | 10 ++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/keras_core/operations/__init__.py b/keras_core/operations/__init__.py index 5777ea5766d7..2b9988d23b21 100644 --- a/keras_core/operations/__init__.py +++ b/keras_core/operations/__init__.py @@ -4,7 +4,6 @@ from keras_core.backend import cast from keras_core.backend import cond -from keras_core.backend import convert_to_tensor from keras_core.backend import is_tensor from keras_core.backend import name_scope from keras_core.backend import random diff --git a/keras_core/operations/core.py b/keras_core/operations/core.py index 1f613a2add7c..e8d8efdd12db 100644 --- a/keras_core/operations/core.py +++ b/keras_core/operations/core.py @@ -279,3 +279,21 @@ def cast(x, dtype): if any_symbolic_tensors((x,)): return backend.KerasTensor(shape=x.shape, dtype=dtype) return backend.core.cast(x, dtype) + + +@keras_core_export("keras_core.operations.convert_to_tensor") +def convert_to_tensor(x, dtype=None): + """Convert a NumPy array to a tensor.""" + return backend.convert_to_tensor(x, dtype=dtype) + + +@keras_core_export("keras_core.operations.convert_to_numpy") +def convert_to_numpy(x): + """Convert a tensor to a NumPy array.""" + if any_symbolic_tensors((x,)): + raise ValueError( + "A symbolic tensor (usually the result of applying layers or " + "operations to a `keras.Input`), cannot be converted to a numpy " + "array. There is no concrete value for the input." + ) + return backend.convert_to_numpy(x) diff --git a/keras_core/operations/core_test.py b/keras_core/operations/core_test.py index 3fed1bd2e5dd..d4a2f98b6555 100644 --- a/keras_core/operations/core_test.py +++ b/keras_core/operations/core_test.py @@ -231,3 +231,13 @@ def test_shape(self): x = KerasTensor((None, 3, None, 1)) self.assertAllEqual(core.shape(x), (None, 3, None, 1)) + + def test_convert_to_tensor(self): + x = np.ones((2,)) + x = ops.convert_to_tensor(x) + x = ops.convert_to_numpy(x) + self.assertAllEqual(x, (1, 1)) + self.assertIsInstance(x, np.ndarray) + + with self.assertRaises(ValueError): + ops.convert_to_numpy(KerasTensor((2,)))