Skip to content

Commit

Permalink
Remove implements decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
rajasekharporeddy committed Jun 25, 2024
1 parent 6067820 commit 0eb8a85
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def force(x):
"""


@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
Expand All @@ -223,7 +224,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
initial=initial, where_=where, parallel_reduce=lax.psum,
promote_integers=promote_integers)

@implements(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)

def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None, promote_integers: bool = True) -> Array:
Expand Down

0 comments on commit 0eb8a85

Please sign in to comment.