Skip to content

Commit

Permalink
Merge pull request #24167 from rajasekharporeddy:testbranch1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684085868
  • Loading branch information
Google-ML-Automation committed Oct 9, 2024
2 parents c2deae8 + ed028be commit db71965
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
55 changes: 53 additions & 2 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,14 +1229,65 @@ def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
"""
return lax.bitwise_xor(*promote_args("bitwise_xor", x, y))

@implements(np.left_shift, module='numpy')

@partial(jit, inline=True)
def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array:
r"""Shift bits of ``x`` to left by the amount specified in ``y``, element-wise.
JAX implementation of :obj:`numpy.left_shift`.
Args:
x: Input array, must be integer-typed.
y: The amount of bits to shift each element in ``x`` to the left, only accepts
integer subtypes. ``x`` and ``y`` must either have same shape or be broadcast
compatible.
Returns:
An array containing the left shifted elements of ``x`` by the amount specified
in ``y``, with the same shape as the broadcasted shape of ``x`` and ``y``.
Note:
Left shifting ``x`` by ``y`` is equivalent to ``x * (2**y)`` within the
bounds of the dtypes involved.
See also:
- :func:`jax.numpy.right_shift`: and :func:`jax.numpy.bitwise_right_shift`:
Shifts the bits of ``x1`` to right by the amount specified in ``x2``,
element-wise.
- :func:`jax.numpy.bitwise_left_shift`: Alias of :func:`jax.left_shift`.
Examples:
>>> def print_binary(x):
... return [bin(int(val)) for val in x]
>>> x1 = jnp.arange(5)
>>> x1
Array([0, 1, 2, 3, 4], dtype=int32)
>>> print_binary(x1)
['0b0', '0b1', '0b10', '0b11', '0b100']
>>> x2 = 1
>>> result = jnp.left_shift(x1, x2)
>>> result
Array([0, 2, 4, 6, 8], dtype=int32)
>>> print_binary(result)
['0b0', '0b10', '0b100', '0b110', '0b1000']
>>> x3 = 4
>>> print_binary([x3])
['0b100']
>>> x4 = jnp.array([1, 2, 3, 4])
>>> result1 = jnp.left_shift(x3, x4)
>>> result1
Array([ 8, 16, 32, 64], dtype=int32)
>>> print_binary(result1)
['0b1000', '0b10000', '0b100000', '0b1000000']
"""
return lax.shift_left(*promote_args_numeric("left_shift", x, y))

@implements(getattr(np, "bitwise_left_shift", np.left_shift), module='numpy')

@partial(jit, inline=True)
def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.left_shift`."""
return lax.shift_left(*promote_args_numeric("bitwise_left_shift", x, y))

@implements(np.equal, module='numpy')
Expand Down
4 changes: 2 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6349,8 +6349,8 @@ def test_lax_numpy_docstrings(self):

unimplemented = ['fromfile', 'fromiter']
aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2',
'amax', 'amin', 'around', 'bitwise_right_shift', 'conj', 'degrees',
'divide', 'mod', 'pow', 'radians', 'round_']
'amax', 'amin', 'around', 'bitwise_left_shift', 'bitwise_right_shift',
'conj', 'degrees', 'divide', 'mod', 'pow', 'radians', 'round_']
skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split']

for name in dir(jnp):
Expand Down

0 comments on commit db71965

Please sign in to comment.