From ecaa215043fea78a79216df6075e9c6dd49fc31d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 15 Dec 2022 19:25:42 -0800 Subject: [PATCH] [Rollback 2] Add `keep_unused` to `pjit`'s API as a step to merge `jit` and `pjit` frontend API. PiperOrigin-RevId: 495756613 --- jax/_src/checkify.py | 3 +- jax/_src/stages.py | 4 +-- jax/experimental/jax2tf/jax2tf.py | 1 - jax/experimental/pjit.py | 50 +++++++++----------------- jax/interpreters/pxla.py | 17 ++++----- tests/pjit_test.py | 58 +++++-------------------------- 6 files changed, 38 insertions(+), 95 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index ae1782940acc..f851f74cdc90 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -1115,7 +1115,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, in_positional_semantics, out_positional_semantics, - keep_unused, inline): + inline): checked_jaxpr, (out_tree, effects) = checkify_jaxpr(jaxpr, error, enabled_errors) out_error = error._add_placeholder_effects(effects) @@ -1155,7 +1155,6 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, name=name, in_positional_semantics=new_positional_sems_in, out_positional_semantics=new_positional_sems_out, - keep_unused=keep_unused, inline=inline) err, *out = tree_unflatten(out_tree, err_and_out) return out, err diff --git a/jax/_src/stages.py b/jax/_src/stages.py index ff74b79badda..137efd4f5a5c 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -497,13 +497,13 @@ def call(*args, **kwargs): else: raise outs = tree_util.tree_unflatten(params.out_tree, out_flat) - return outs, out_flat, args_flat + return outs, out_flat def __call__(self, *args, **kwargs): if self._cpp_call is not None: return self._cpp_call(*args, **kwargs) - outs, _, _ = Compiled.call(self._params, *args, **kwargs) + outs, _ = Compiled.call(self._params, *args, **kwargs) return outs diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index db6af4ed6f3f..1323ad568481 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3118,7 +3118,6 @@ def _pjit(*args: TfVal, name: str, in_positional_semantics, out_positional_semantics, - keep_unused: bool, inline: bool, _in_avals: Sequence[core.ShapedArray], _out_aval: Sequence[core.ShapedArray]) -> TfVal: diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 139b21ace239..260d45ac7f6c 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -112,7 +112,7 @@ def _python_pjit_helper(infer_params, *args, **kwargs): _check_arg(arg) out_flat = pjit_p.bind(*args_flat, **params) outs = tree_unflatten(out_tree, out_flat) - return outs, out_flat, out_tree, args_flat + return outs, out_flat, out_tree def _python_pjit(fun: Callable, infer_params): @@ -133,8 +133,7 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums): def cache_miss(*args, **kwargs): global _most_recent_pjit_call_executable - outs, out_flat, out_tree, args_flat = _python_pjit_helper( - infer_params, *args, **kwargs) + outs, out_flat, out_tree = _python_pjit_helper(infer_params, *args, **kwargs) executable = _most_recent_pjit_call_executable.value _most_recent_pjit_call_executable.value = None @@ -151,11 +150,9 @@ def cache_miss(*args, **kwargs): if use_fastpath: out_avals = [o.aval for o in out_flat] out_committed = [o._committed for o in out_flat] - kept_var_bitvec = [i in executable._kept_var_idx - for i in range(len(args_flat))] fastpath_data = pxla._MeshExecutableFastpathData( executable.xla_executable, out_tree, executable._in_shardings, - executable._out_shardings, out_avals, out_committed, kept_var_bitvec) + executable._out_shardings, out_avals, out_committed) else: fastpath_data = None @@ -182,7 +179,6 @@ def pjit( static_argnums: Union[int, Sequence[int], None] = None, static_argnames: Union[str, Iterable[str], None] = None, donate_argnums: Union[int, Sequence[int]] = (), - keep_unused: bool = False, device: Optional[xc.Device] = None, backend: Optional[str] = None, inline: bool = False, @@ -288,10 +284,6 @@ def pjit( should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation). - keep_unused: If `False` (the default), arguments that JAX determines to be - unused by `fun` *may* be dropped from resulting compiled XLA executables. - Such arguments will not be transferred to the device nor provided to the - underlying executable. If `True`, unused arguments will not be pruned. device: This argument is deprecated. Please put your arguments on the device you want before passing them to jit. Optional, the Device the jitted function will run on. (Available devices @@ -304,7 +296,7 @@ def pjit( ``'tpu'``. Returns: A wrapped version of ``fun``, set up for just-in-time compilation and - automatically partitioned by the mesh available at each call site. + automaticly partitioned by the mesh available at each call site. For example, a convolution operator can be automatically partitioned over an arbitrary set of devices by a single :func:`~pjit` application: @@ -327,7 +319,7 @@ def pjit( if not config.jax_array and (_is_unspecified(in_axis_resources) or _is_unspecified(out_axis_resources)): raise ValueError( - "in_axis_resources and out_axis_resources should not " + "in_axis_resources and out_axis_resouces should not " "be the unspecified singleton value. Please enable `jax.Array` to use " "this feature. You can use jax.config.update('jax_array', True) or " "set the environment variable JAX_ARRAY=1 , or set the `jax_array` " @@ -430,7 +422,7 @@ def infer_params(*args, _global_avals=False, **kwargs): out_shardings = tree_map( lambda x: x if _is_unspecified(x) else _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), out_axis_resources) - # This check fails extremely rarely and has a huge cost in the dispatch + # This check fails extrememly rarely and has a huge cost in the dispatch # path. So hide it behind the jax_enable_checks flag. if config.jax_enable_checks: _maybe_check_pjit_gda_mesh(args_flat, pjit_mesh) @@ -474,13 +466,12 @@ def infer_params(*args, _global_avals=False, **kwargs): name=getattr(flat_fun, '__name__', ''), in_positional_semantics=in_positional_semantics, out_positional_semantics=out_positional_semantics, - keep_unused=keep_unused, inline=inline, ) return (args_flat, local_in_avals, params, in_tree, out_tree(), donate_argnums) - if FLAGS.experimental_cpp_pjit and xc._version >= 111: + if FLAGS.experimental_cpp_pjit and xc._version >= 96: wrapped = _cpp_pjit(fun, infer_params, static_argnums) else: wrapped = _python_pjit(fun, infer_params) @@ -499,7 +490,7 @@ def lower(*args, _global_avals=False, **kwargs): lowering = _pjit_lower( params['jaxpr'], in_shardings, params['out_shardings'], params['resource_env'], params['donated_invars'], params['name'], - in_is_global, params['keep_unused'], always_lower=True) + in_is_global, always_lower=True) if kwargs: args_kwargs_in_tree = in_tree @@ -1009,7 +1000,7 @@ def _pjit_call_impl(*args, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, in_positional_semantics, out_positional_semantics, - keep_unused, inline): + inline): global _most_recent_pjit_call_executable @@ -1024,8 +1015,7 @@ def _pjit_call_impl(*args, jaxpr, _allow_propagation_to_outputs = False compiled = _pjit_lower( jaxpr, in_shardings, out_shardings, resource_env, - donated_invars, name, in_is_global, keep_unused, - always_lower=False).compile( + donated_invars, name, in_is_global, always_lower=False).compile( _allow_propagation_to_outputs=_allow_propagation_to_outputs) _most_recent_pjit_call_executable.value = compiled # This check is expensive so only do it if enable_checks is on. @@ -1094,7 +1084,6 @@ def _pjit_lower_cached( donated_invars, name: str, in_is_global: Sequence[bool], - keep_unused: bool, always_lower: bool): in_shardings: Tuple[PjitShardingMinusUnspecified, ...] = cast( Tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings) @@ -1141,7 +1130,7 @@ def _pjit_lower_cached( # the arguments just like dispatch.py in `sharded_lowering`. return pxla.lower_sharding_computation( fun, 'pjit', name, in_shardings, out_shardings, donated_invars, - jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused, + jaxpr.in_avals, in_is_global=in_is_global, keep_unused=True, always_lower=always_lower, devices_from_context=(None if mesh.empty else list(mesh.devices.flat))) @@ -1170,7 +1159,7 @@ def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, in_positional_semantics, out_positional_semantics, - keep_unused, inline): + inline): if not isinstance(ctx.module_context.axis_context, (mlir.SPMDAxisContext, mlir.ShardingContext)): raise RuntimeError("Nesting pjit() inside jit() is not allowed.") @@ -1206,7 +1195,7 @@ def _pjit_batcher(insert_axis, spmd_axis_name, vals_in, dims_in, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, in_positional_semantics, - out_positional_semantics, keep_unused, inline): + out_positional_semantics, inline): # batch_jaxpr expects all batching dimensions to be equal to 0 vals_in = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(vals_in, dims_in)] @@ -1235,7 +1224,6 @@ def _pjit_batcher(insert_axis, spmd_axis_name, name=name, in_positional_semantics=in_positional_semantics, out_positional_semantics=out_positional_semantics, - keep_unused=keep_unused, inline=inline) dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out] return vals_out, dims_out @@ -1266,7 +1254,7 @@ def _pjit_batcher_for_sharding( def _pjit_jvp(primals_in, tangents_in, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, in_positional_semantics, - out_positional_semantics, keep_unused, inline): + out_positional_semantics, inline): is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in] jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr( jaxpr, is_nz_tangents_in, instantiate=False) @@ -1285,7 +1273,6 @@ def _filter_zeros(is_nz_l, l): name=wrap_name(name, 'jvp'), in_positional_semantics=(*in_positional_semantics, *_filter_zeros_in(in_positional_semantics)), out_positional_semantics=out_positional_semantics, - keep_unused=keep_unused, inline=inline) primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)]) @@ -1299,7 +1286,7 @@ def _filter_zeros(is_nz_l, l): def _pjit_partial_eval(trace, *in_tracers, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, in_positional_semantics, - out_positional_semantics, keep_unused, inline): + out_positional_semantics, inline): in_pvals = [t.pval for t in in_tracers] known_ins = tuple(pv.is_known() for pv in in_pvals) @@ -1329,7 +1316,6 @@ def keep_where(l, should_keep): name=name, in_positional_semantics=keep_where(in_positional_semantics, known_ins), out_positional_semantics=out_positional_semantics, - keep_unused=keep_unused, inline=inline) if num_residuals: @@ -1339,7 +1325,7 @@ def keep_where(l, should_keep): known_params["jaxpr"], known_params["in_shardings"], known_params["out_shardings"], known_params["resource_env"], known_params["donated_invars"], known_params["name"], - in_is_global, known_params['keep_unused'], always_lower=False).compile( + in_is_global, always_lower=False).compile( _allow_propagation_to_outputs=True, _allow_compile_replicated=False) da = compiled._device_assignment @@ -1387,7 +1373,6 @@ def keep_where(l, should_keep): in_positional_semantics=(keep_where( in_positional_semantics, unknown_ins) + (out_positional_semantics,) * num_residuals), out_positional_semantics=out_positional_semantics, - keep_unused=keep_unused, inline=inline) unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()] unknown_tracers_out = [ @@ -1411,7 +1396,7 @@ def keep_where(l, should_keep): def _pjit_transpose(reduce_axes, cts_in, *primals_in, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, in_positional_semantics, - out_positional_semantics, keep_unused, inline): + out_positional_semantics, inline): def prune_type(ty, xs, maybe_zeros): return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) @@ -1456,7 +1441,6 @@ def prune_type(ty, xs, maybe_zeros): name=name, in_positional_semantics=transpose_in_positional_semantics, out_positional_semantics=out_positional_semantics, - keep_unused=keep_unused, inline=inline) return tree_unflatten(cts_out_treedef, nz_cts_out) ad.reducing_transposes[pjit_p] = _pjit_transpose diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 28656e95aafd..4fbab1e37d5f 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -3498,7 +3498,6 @@ class _MeshExecutableFastpathData(NamedTuple): out_shardings: Sequence[Any] out_avals: Sequence[Any] out_committed: Sequence[bool] - kept_var_bitvec: Iterable[bool] class MeshExecutable(stages.XlaExecutable): @@ -3578,24 +3577,26 @@ def create_cpp_call(self, no_kwargs, in_tree, out_tree): not self.unsafe_call.has_host_callbacks): return None - if not flags.FLAGS.experimental_cpp_pjit or xc._version < 111: + if not flags.FLAGS.experimental_cpp_pjit or xc._version < 96: return None def aot_cache_miss(*args, **kwargs): params = stages.CompiledCallParams(self, no_kwargs, in_tree, out_tree) - outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs) + outs, out_flat = stages.Compiled.call(params, *args, **kwargs) + use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat)) if use_fastpath: out_avals = [o.aval for o in out_flat] out_committed = [o._committed for o in out_flat] - kept_var_bitvec = [i in self._kept_var_idx - for i in range(len(args_flat))] - fastpath_data = _MeshExecutableFastpathData( - self.xla_executable, out_tree, self._in_shardings, - self._out_shardings, out_avals, out_committed, kept_var_bitvec) + fastpath_data = _MeshExecutableFastpathData(self.xla_executable, + out_tree, + self._in_shardings, + self._out_shardings, + out_avals, out_committed) else: fastpath_data = None + return outs, fastpath_data if xc._version < 108: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 35bdf5ccd644..b9363693f82a 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -728,7 +728,7 @@ def f_for_jit(x): # execution of the compiled function is blocking, so transferring data # to infeed before executing ensures that the execution does not deadlock # waiting for the infeed data. - logging.info('Transferring to infeed for the jit call') + logging.info('Transfering to infeed for the jit call') d = devices[0] d.transfer_to_infeed((y,)) d.transfer_to_infeed((z,)) @@ -759,7 +759,7 @@ def f_for_pjit(x): partitions=(P(1, nr_devices),)) return x + y + z + w - logging.info('Transferring to infeed for the pjit call') + logging.info('Transfering to infeed for the pjit call') for didx, d in enumerate(devices): # Transfer the whole array to all devices for replicated. d.transfer_to_infeed((y,)) @@ -801,7 +801,7 @@ def check_outfeed(d, x): xc.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent()) self.assertAllClose(x, y, check_dtypes=True) - logging.info('Transferring from outfeed for the pjit call') + logging.info('Transfering from outfeed for the pjit call') for didx, d in enumerate(devices): # Transfer the whole array from all devices for replicated. check_outfeed(d, x) @@ -2548,9 +2548,9 @@ def test_multi_device_pjit_mul(self): self.assertEqual(out2.shape, (8, 2)) @jax_array(True) - def test_single_device_pjit_cpp_dispatch(self): - if xla_extension_version < 111: - self.skipTest('Does not work for xla_extension_version < 111') + def test_single_device_pjit_perf(self): + if xla_extension_version < 103: + self.skipTest('Does not work for xla_extension_version < 103') shape = (8, 2) mesh = jtu.create_global_mesh((1,), ('x',)) @@ -2579,8 +2579,8 @@ def pjit_lower_and_count(*args, **kwargs): @jax_array(True) def test_single_device_add_single_compile(self): - if xla_extension_version < 111: - self.skipTest('Does not work for xla_extension_version < 111') + if xla_extension_version < 103: + self.skipTest('Does not work for xla_extension_version < 103') f1 = pjit(lambda x, y: x + y) a = jax.device_put(jnp.array([1, 2, 3], dtype=jnp.float32), @@ -2635,7 +2635,7 @@ def test_unspecified_error_without_jax_array(self): with self.assertRaisesRegex( ValueError, - ("in_axis_resources and out_axis_resources should not " + ("in_axis_resources and out_axis_resouces should not " "be the unspecified singleton value. Please enable `jax.Array` to use " "this feature.")): pjit(lambda x: x) @@ -2824,46 +2824,6 @@ def test_pjit_kwargs_axis_resources_error(self): "pjit does not support kwargs when in_axis_resources is specified."): pjit(lambda x: x, in_axis_resources=None)(x=jnp.arange(8.)) - def test_pjit_keep_unused_true(self): - @partial(pjit, keep_unused=True) - def f(x, y, z, a, b, c): # pylint: disable=unused-argument - return c @ c.T - - inp = jnp.arange(4) - unused_inp = jnp.arange(8) - - out = f(unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp) - # Run it again to take the C++ dispatch. - out_again = f(unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp) - - self.assertArraysEqual(out, inp @ inp.T) - self.assertArraysEqual(out_again, inp @ inp.T) - - compiled = f.lower( - unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp).compile() - self.assertEqual(compiled._executable._kept_var_idx, {0, 1, 2, 3, 4, 5}) - self.assertLen(compiled._executable.in_avals, 6) - - def test_pjit_keep_unused_default_false(self): - @pjit - def f(x, y, z, a, b, c): # pylint: disable=unused-argument - return c @ c.T - - inp = jax.device_put(jnp.arange(4), jax.devices()[0]) - unused_inp = jax.device_put(jnp.arange(8), jax.devices()[0]) - - out = f(unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp) - # Run it again to take the C++ dispatch. - out_again = f(unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp) - - self.assertArraysEqual(out, inp @ inp.T) - self.assertArraysEqual(out_again, inp @ inp.T) - - compiled = f.lower( - unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp).compile() - self.assertEqual(compiled._executable._kept_var_idx, {5}) - self.assertLen(compiled._executable.in_avals, 1) - def test_pjit_with_device_arg(self): def mul(x): return x @ x.T