Skip to content

Commit

Permalink
Remove the shim of functions in sharding_utils from pxla.py and use t…
Browse files Browse the repository at this point in the history
…hose functions directly from sharding_utils in JAX

PiperOrigin-RevId: 522319332
  • Loading branch information
yashk2810 authored and jax authors committed Apr 6, 2023
1 parent 95525e7 commit b926e04
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 26 deletions.
11 changes: 4 additions & 7 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,6 @@ class WeakRefList(list):

PartitionSpec = sharding_impls.PartitionSpec

get_num_ways_dim_sharded = sutils.get_num_ways_dim_sharded
is_op_sharding_replicated = sutils.is_op_sharding_replicated


def sharding_spec_mesh_shape(self):
sharded_axis_sizes = []
Expand Down Expand Up @@ -219,13 +216,13 @@ def _op_sharding_to_numpy_indices(
# num_devices is required as an argument when op_sharding is
# REPLICATED. `jax.device_count()` cannot be used because you can create
# an opsharding with less number of devices than `jax.device_count()`.
if is_op_sharding_replicated(op_sharding):
if sutils.is_op_sharding_replicated(op_sharding):
indices.fill((slice(None),) * len(shape))
return indices

assert num_devices == len(op_sharding.tile_assignment_devices)

partitions, num_replicas = get_num_ways_dim_sharded(op_sharding)
partitions, num_replicas = sutils.get_num_ways_dim_sharded(op_sharding)
assert len(partitions) == len(shape), (len(partitions), len(shape))

axis_indices: List[Sequence[Index]] = []
Expand Down Expand Up @@ -2781,7 +2778,7 @@ def _get_input_indices(
# represent index for each device in the global mesh. But here we want
# indices for the local devices of the global mesh.
proto = sharding._to_xla_op_sharding(aval.ndim)
if is_op_sharding_replicated(proto):
if sutils.is_op_sharding_replicated(proto):
index = tuple(
(slice(None),) * aval.ndim
for _ in range(len(sharding.addressable_devices))) # type: ignore
Expand Down Expand Up @@ -3296,7 +3293,7 @@ def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
def are_op_shardings_equal(op1: xc.OpSharding, op2: xc.OpSharding) -> bool:
if id(op1) == id(op2):
return True
if is_op_sharding_replicated(op1) and is_op_sharding_replicated(op2):
if sutils.is_op_sharding_replicated(op1) and sutils.is_op_sharding_replicated(op2):
return True
return xc.HloSharding.from_proto(op1) == xc.HloSharding.from_proto(op2)

Expand Down
5 changes: 3 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jax._src import mesh as mesh_lib
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import sharding_utils as sutils
from jax._src import traceback_util
from jax._src import util
from jax._src import xla_bridge as xb
Expand Down Expand Up @@ -961,7 +962,7 @@ def pjit_check_aval_sharding(
# XLACompatibleSharding.
op_sharding = s._to_xla_op_sharding(len(shape))
assert op_sharding is not None
num_ways_dim_sharded, _ = pxla.get_num_ways_dim_sharded(
num_ways_dim_sharded, _ = sutils.get_num_ways_dim_sharded(
cast(xc.OpSharding, op_sharding))
for i, size in enumerate(num_ways_dim_sharded):
if not allow_uneven_sharding and shape[i] % size != 0:
Expand Down Expand Up @@ -1200,7 +1201,7 @@ def _resolve_in_shardings(
raise NotImplementedError('Having uncommitted Array sharded on '
'multiple devices is not supported.')
else:
if isinstance(arg, np.ndarray) and not pxla.is_op_sharding_replicated(
if isinstance(arg, np.ndarray) and not sutils.is_op_sharding_replicated(
pjit_in_s._to_xla_op_sharding(arg.ndim)) and xb.process_count() > 1: # type: ignore
raise ValueError(
'Passing non-trivial shardings for numpy '
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jax._src import core
from jax._src import mesh as mesh_lib
from jax._src import sharding
from jax._src import sharding_utils as sutils
from jax._src import xla_bridge
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -73,9 +74,9 @@ def _addressable_device_assignment(self) -> XLADeviceAssignment:
@functools.lru_cache(maxsize=4096)
def shard_shape(self, global_shape: Shape) -> Shape:
op_sharding = cast(xc.OpSharding, self._to_xla_op_sharding(len(global_shape)))
if pxla.is_op_sharding_replicated(op_sharding):
if sutils.is_op_sharding_replicated(op_sharding):
return global_shape
partitions, _ = pxla.get_num_ways_dim_sharded(op_sharding)
partitions, _ = sutils.get_num_ways_dim_sharded(op_sharding)
assert len(partitions) == len(global_shape), (len(partitions), len(global_shape))
out = []
for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)):
Expand Down Expand Up @@ -622,7 +623,7 @@ def __repr__(self):
return f'GSPMDSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})'

def is_compatible_aval(self, aval_shape: Shape):
num_ways_dim_sharded, _ = pxla.get_num_ways_dim_sharded(self._op_sharding)
num_ways_dim_sharded, _ = sutils.get_num_ways_dim_sharded(self._op_sharding)
if len(aval_shape) < len(num_ways_dim_sharded):
raise ValueError(
f"Sharding {self} is only valid for values of rank at least "
Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import prng
from jax._src import sharding_utils as sutils
from jax._src import random as random_internal
from jax._src import source_info_util
from jax._src import util
Expand Down Expand Up @@ -3142,7 +3143,7 @@ def _shard_value(val: TfVal,
sharding_proto: xla_client.OpSharding = cast(
xla_client.OpSharding, sd._to_xla_op_sharding(aval.ndim))

if skip_replicated_sharding and pxla.is_op_sharding_replicated(sharding_proto):
if skip_replicated_sharding and sutils.is_op_sharding_replicated(sharding_proto):
return val

# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
Expand Down
27 changes: 14 additions & 13 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src import array
from jax._src.sharding import Sharding
from jax._src import sharding_utils as sutils
from jax._src.sharding_impls import NamedSharding, GSPMDSharding
import jax._src.pjit as pjit_lib
from jax._src.pjit import (pjit, pjit_p, AUTO)
Expand Down Expand Up @@ -664,7 +665,7 @@ def testVMapShardingConstraint(self):
self.assertEqual(op.type, xc.OpSharding.Type.OTHER)
self.assertListEqual(op.tile_assignment_dimensions, [1, 2])
self.assertListEqual(op.tile_assignment_devices, [0, 1])
self.assertFalse(pxla.is_op_sharding_replicated(op))
self.assertFalse(sutils.is_op_sharding_replicated(op))

@jtu.with_mesh([('x', 2)])
def testVMapShardingConstraintWithSpmdAxis(self):
Expand All @@ -684,7 +685,7 @@ def testVMapShardingConstraintWithSpmdAxis(self):
self.assertEqual(op.type, xc.OpSharding.Type.OTHER)
self.assertListEqual(op.tile_assignment_dimensions, [2, 1])
self.assertListEqual(op.tile_assignment_devices, [0, 1])
self.assertFalse(pxla.is_op_sharding_replicated(op))
self.assertFalse(sutils.is_op_sharding_replicated(op))

@jtu.with_mesh([('x', 2), ('y', 1)])
def testShardingInXMap(self):
Expand All @@ -701,7 +702,7 @@ def _test_rule(*args, **kwargs):
self.assertLen(in_shardings, 1)
self.assertListEqual(in_shardings[0]._op_sharding.tile_assignment_dimensions,
[1, 1, 2])
self.assertFalse(pxla.is_op_sharding_replicated(in_shardings[0]._op_sharding))
self.assertFalse(sutils.is_op_sharding_replicated(in_shardings[0]._op_sharding))

return rule(*args, **kwargs)
try:
Expand Down Expand Up @@ -2562,9 +2563,9 @@ def test_pmap_sharding_input_to_pjit_multi_device(self):
self.assertArraysEqual(out2, inp2 * 2)
self.assertLen(out1.devices(), 4)
self.assertLen(out2.devices(), 4)
self.assertTrue(pxla.is_op_sharding_replicated(
self.assertTrue(sutils.is_op_sharding_replicated(
out1.sharding._to_xla_op_sharding(pmap_out.ndim)))
self.assertTrue(pxla.is_op_sharding_replicated(
self.assertTrue(sutils.is_op_sharding_replicated(
out2.sharding._to_xla_op_sharding(inp2.ndim)))

def test_pmap_sharding_input_pjit_in_axis_resources(self):
Expand Down Expand Up @@ -2770,7 +2771,7 @@ def f(inp):

with mesh:
out = jax.vmap(jax.jit(f), spmd_axis_name='mdl')(x)
ns, _ = pxla.get_num_ways_dim_sharded(
ns, _ = sutils.get_num_ways_dim_sharded(
out.sharding._to_xla_op_sharding(out.ndim))
self.assertListEqual(ns, [2, 2, 1, 1])

Expand All @@ -2780,7 +2781,7 @@ def apply_with_scan(x):

with mesh:
out2 = jax.vmap(apply_with_scan, spmd_axis_name='mdl')(x)
ns2, _ = pxla.get_num_ways_dim_sharded(
ns2, _ = sutils.get_num_ways_dim_sharded(
out2.sharding._to_xla_op_sharding(out2.ndim))
self.assertListEqual(ns2, [2, 2, 1, 1])

Expand Down Expand Up @@ -3464,10 +3465,10 @@ def test_op_sharding_semantically_replicated(self):
op4.tile_assignment_dimensions = [1]
op4.tile_assignment_devices = [0]

self.assertTrue(pxla.is_op_sharding_replicated(op1))
self.assertTrue(pxla.is_op_sharding_replicated(op2))
self.assertTrue(pxla.is_op_sharding_replicated(op3))
self.assertTrue(pxla.is_op_sharding_replicated(op4))
self.assertTrue(sutils.is_op_sharding_replicated(op1))
self.assertTrue(sutils.is_op_sharding_replicated(op2))
self.assertTrue(sutils.is_op_sharding_replicated(op3))
self.assertTrue(sutils.is_op_sharding_replicated(op4))
self.assertTrue(pxla.are_op_shardings_equal(op1, op2))
self.assertTrue(pxla.are_op_shardings_equal(op2, op3))
self.assertTrue(pxla.are_op_shardings_equal(op3, op4))
Expand All @@ -3488,8 +3489,8 @@ def test_op_sharding_manual_replicated(self):
op3 = xc.OpSharding()
op3.type = xc.OpSharding.Type.REPLICATED

self.assertTrue(pxla.is_op_sharding_replicated(op1))
self.assertTrue(pxla.is_op_sharding_replicated(op2))
self.assertTrue(sutils.is_op_sharding_replicated(op1))
self.assertTrue(sutils.is_op_sharding_replicated(op2))
self.assertTrue(pxla.are_op_shardings_equal(op1, op2))
self.assertTrue(pxla.are_op_shardings_equal(op1, op3))

Expand Down

0 comments on commit b926e04

Please sign in to comment.