Skip to content

Commit

Permalink
Add keras_core.operations.convert_to_numpy (keras-team#378)
Browse files Browse the repository at this point in the history
For downstream use cases, nice to be able to have this in the ops layer,
instead of being forced to reach into backend.
  • Loading branch information
mattdangerw authored Jun 20, 2023
1 parent 6f61f78 commit a85eb73
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
1 change: 0 additions & 1 deletion keras_core/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from keras_core.backend import cast
from keras_core.backend import cond
from keras_core.backend import convert_to_tensor
from keras_core.backend import is_tensor
from keras_core.backend import name_scope
from keras_core.backend import random
Expand Down
18 changes: 18 additions & 0 deletions keras_core/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,21 @@ def cast(x, dtype):
if any_symbolic_tensors((x,)):
return backend.KerasTensor(shape=x.shape, dtype=dtype)
return backend.core.cast(x, dtype)


@keras_core_export("keras_core.operations.convert_to_tensor")
def convert_to_tensor(x, dtype=None):
"""Convert a NumPy array to a tensor."""
return backend.convert_to_tensor(x, dtype=dtype)


@keras_core_export("keras_core.operations.convert_to_numpy")
def convert_to_numpy(x):
"""Convert a tensor to a NumPy array."""
if any_symbolic_tensors((x,)):
raise ValueError(
"A symbolic tensor (usually the result of applying layers or "
"operations to a `keras.Input`), cannot be converted to a numpy "
"array. There is no concrete value for the input."
)
return backend.convert_to_numpy(x)
10 changes: 10 additions & 0 deletions keras_core/operations/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,13 @@ def test_shape(self):

x = KerasTensor((None, 3, None, 1))
self.assertAllEqual(core.shape(x), (None, 3, None, 1))

def test_convert_to_tensor(self):
x = np.ones((2,))
x = ops.convert_to_tensor(x)
x = ops.convert_to_numpy(x)
self.assertAllEqual(x, (1, 1))
self.assertIsInstance(x, np.ndarray)

with self.assertRaises(ValueError):
ops.convert_to_numpy(KerasTensor((2,)))

0 comments on commit a85eb73

Please sign in to comment.