Skip to content

Commit

Permalink
Change implementation of numpy.min() of torch backend (#19066)
Browse files Browse the repository at this point in the history
* Change implementation of numpy.min() of torch backend

The implementation of min() with torch backend is not working properly. For eg if list of axis passed the reduction is not happening. Changed the implementation like torch.numpy.max function which is working fine.

Attached https://colab.research.google.com/gist/SuryanarayanaY/9b28fd5fa5837d11d660550caa13eea0/19064.ipynb for demo of same.

* Added test case for  numpy.min function

Added test case for  numpy.min function to test when axis is list or axis is empty.This is to verify the same behaviour of torch wrt other backends also.

* Rectify lint errors
  • Loading branch information
SuryanarayanaY authored Jan 20, 2024
1 parent 2c4decf commit 07500c3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
4 changes: 1 addition & 3 deletions keras/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,9 +892,7 @@ def min(x, axis=None, keepdims=False, initial=None):
if axis is None:
result = torch.min(x)
else:
if isinstance(axis, list):
axis = axis[-1]
result = torch.min(x, dim=axis, keepdim=keepdims)
result = amin(x, axis=axis, keepdims=keepdims)

if isinstance(getattr(result, "values", None), torch.Tensor):
result = result.values
Expand Down
6 changes: 6 additions & 0 deletions keras/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3429,6 +3429,12 @@ def test_min(self):
self.assertAllClose(knp.min(x), np.min(x))
self.assertAllClose(knp.Min()(x), np.min(x))

self.assertAllClose(knp.min(x, axis=(0, 1)), np.min(x, (0, 1)))
self.assertAllClose(knp.Min((0, 1))(x), np.min(x, (0, 1)))

self.assertAllClose(knp.min(x, axis=()), np.min(x, axis=()))
self.assertAllClose(knp.Min(())(x), np.min(x, axis=()))

self.assertAllClose(knp.min(x, 0), np.min(x, 0))
self.assertAllClose(knp.Min(0)(x), np.min(x, 0))

Expand Down

0 comments on commit 07500c3

Please sign in to comment.