Skip to content

Commit

Permalink
[Rollback 2] Add keep_unused to pjit's API as a step to merge `ji…
Browse files Browse the repository at this point in the history
…t` and `pjit` frontend API.

PiperOrigin-RevId: 495756613
  • Loading branch information
yashk2810 authored and jax authors committed Dec 16, 2022
1 parent 1598c52 commit ecaa215
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 95 deletions.
3 changes: 1 addition & 2 deletions jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
in_positional_semantics, out_positional_semantics,
keep_unused, inline):
inline):
checked_jaxpr, (out_tree, effects) = checkify_jaxpr(jaxpr, error,
enabled_errors)
out_error = error._add_placeholder_effects(effects)
Expand Down Expand Up @@ -1155,7 +1155,6 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
name=name,
in_positional_semantics=new_positional_sems_in,
out_positional_semantics=new_positional_sems_out,
keep_unused=keep_unused,
inline=inline)
err, *out = tree_unflatten(out_tree, err_and_out)
return out, err
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,13 +497,13 @@ def call(*args, **kwargs):
else:
raise
outs = tree_util.tree_unflatten(params.out_tree, out_flat)
return outs, out_flat, args_flat
return outs, out_flat

def __call__(self, *args, **kwargs):
if self._cpp_call is not None:
return self._cpp_call(*args, **kwargs)

outs, _, _ = Compiled.call(self._params, *args, **kwargs)
outs, _ = Compiled.call(self._params, *args, **kwargs)
return outs


Expand Down
1 change: 0 additions & 1 deletion jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3118,7 +3118,6 @@ def _pjit(*args: TfVal,
name: str,
in_positional_semantics,
out_positional_semantics,
keep_unused: bool,
inline: bool,
_in_avals: Sequence[core.ShapedArray],
_out_aval: Sequence[core.ShapedArray]) -> TfVal:
Expand Down
50 changes: 17 additions & 33 deletions jax/experimental/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _python_pjit_helper(infer_params, *args, **kwargs):
_check_arg(arg)
out_flat = pjit_p.bind(*args_flat, **params)
outs = tree_unflatten(out_tree, out_flat)
return outs, out_flat, out_tree, args_flat
return outs, out_flat, out_tree

def _python_pjit(fun: Callable, infer_params):

Expand All @@ -133,8 +133,7 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums):
def cache_miss(*args, **kwargs):
global _most_recent_pjit_call_executable

outs, out_flat, out_tree, args_flat = _python_pjit_helper(
infer_params, *args, **kwargs)
outs, out_flat, out_tree = _python_pjit_helper(infer_params, *args, **kwargs)

executable = _most_recent_pjit_call_executable.value
_most_recent_pjit_call_executable.value = None
Expand All @@ -151,11 +150,9 @@ def cache_miss(*args, **kwargs):
if use_fastpath:
out_avals = [o.aval for o in out_flat]
out_committed = [o._committed for o in out_flat]
kept_var_bitvec = [i in executable._kept_var_idx
for i in range(len(args_flat))]
fastpath_data = pxla._MeshExecutableFastpathData(
executable.xla_executable, out_tree, executable._in_shardings,
executable._out_shardings, out_avals, out_committed, kept_var_bitvec)
executable._out_shardings, out_avals, out_committed)
else:
fastpath_data = None

Expand All @@ -182,7 +179,6 @@ def pjit(
static_argnums: Union[int, Sequence[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
donate_argnums: Union[int, Sequence[int]] = (),
keep_unused: bool = False,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
inline: bool = False,
Expand Down Expand Up @@ -288,10 +284,6 @@ def pjit(
should not reuse buffers that you donate to a computation, JAX will raise
an error if you try to.
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
keep_unused: If `False` (the default), arguments that JAX determines to be
unused by `fun` *may* be dropped from resulting compiled XLA executables.
Such arguments will not be transferred to the device nor provided to the
underlying executable. If `True`, unused arguments will not be pruned.
device: This argument is deprecated. Please put your arguments on the
device you want before passing them to jit.
Optional, the Device the jitted function will run on. (Available devices
Expand All @@ -304,7 +296,7 @@ def pjit(
``'tpu'``.
Returns:
A wrapped version of ``fun``, set up for just-in-time compilation and
automatically partitioned by the mesh available at each call site.
automaticly partitioned by the mesh available at each call site.
For example, a convolution operator can be automatically partitioned over
an arbitrary set of devices by a single :func:`~pjit` application:
Expand All @@ -327,7 +319,7 @@ def pjit(
if not config.jax_array and (_is_unspecified(in_axis_resources) or
_is_unspecified(out_axis_resources)):
raise ValueError(
"in_axis_resources and out_axis_resources should not "
"in_axis_resources and out_axis_resouces should not "
"be the unspecified singleton value. Please enable `jax.Array` to use "
"this feature. You can use jax.config.update('jax_array', True) or "
"set the environment variable JAX_ARRAY=1 , or set the `jax_array` "
Expand Down Expand Up @@ -430,7 +422,7 @@ def infer_params(*args, _global_avals=False, **kwargs):
out_shardings = tree_map(
lambda x: x if _is_unspecified(x) else
_create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), out_axis_resources)
# This check fails extremely rarely and has a huge cost in the dispatch
# This check fails extrememly rarely and has a huge cost in the dispatch
# path. So hide it behind the jax_enable_checks flag.
if config.jax_enable_checks:
_maybe_check_pjit_gda_mesh(args_flat, pjit_mesh)
Expand Down Expand Up @@ -474,13 +466,12 @@ def infer_params(*args, _global_avals=False, **kwargs):
name=getattr(flat_fun, '__name__', '<unnamed function>'),
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused,
inline=inline,
)
return (args_flat, local_in_avals, params, in_tree, out_tree(),
donate_argnums)

if FLAGS.experimental_cpp_pjit and xc._version >= 111:
if FLAGS.experimental_cpp_pjit and xc._version >= 96:
wrapped = _cpp_pjit(fun, infer_params, static_argnums)
else:
wrapped = _python_pjit(fun, infer_params)
Expand All @@ -499,7 +490,7 @@ def lower(*args, _global_avals=False, **kwargs):
lowering = _pjit_lower(
params['jaxpr'], in_shardings, params['out_shardings'],
params['resource_env'], params['donated_invars'], params['name'],
in_is_global, params['keep_unused'], always_lower=True)
in_is_global, always_lower=True)

if kwargs:
args_kwargs_in_tree = in_tree
Expand Down Expand Up @@ -1009,7 +1000,7 @@ def _pjit_call_impl(*args, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
in_positional_semantics, out_positional_semantics,
keep_unused, inline):
inline):

global _most_recent_pjit_call_executable

Expand All @@ -1024,8 +1015,7 @@ def _pjit_call_impl(*args, jaxpr,
_allow_propagation_to_outputs = False
compiled = _pjit_lower(
jaxpr, in_shardings, out_shardings, resource_env,
donated_invars, name, in_is_global, keep_unused,
always_lower=False).compile(
donated_invars, name, in_is_global, always_lower=False).compile(
_allow_propagation_to_outputs=_allow_propagation_to_outputs)
_most_recent_pjit_call_executable.value = compiled
# This check is expensive so only do it if enable_checks is on.
Expand Down Expand Up @@ -1094,7 +1084,6 @@ def _pjit_lower_cached(
donated_invars,
name: str,
in_is_global: Sequence[bool],
keep_unused: bool,
always_lower: bool):
in_shardings: Tuple[PjitShardingMinusUnspecified, ...] = cast(
Tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
Expand Down Expand Up @@ -1141,7 +1130,7 @@ def _pjit_lower_cached(
# the arguments just like dispatch.py in `sharded_lowering`.
return pxla.lower_sharding_computation(
fun, 'pjit', name, in_shardings, out_shardings, donated_invars,
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused,
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=True,
always_lower=always_lower,
devices_from_context=(None if mesh.empty else list(mesh.devices.flat)))

Expand Down Expand Up @@ -1170,7 +1159,7 @@ def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env,
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
out_shardings, resource_env, donated_invars,
in_positional_semantics, out_positional_semantics,
keep_unused, inline):
inline):
if not isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext)):
raise RuntimeError("Nesting pjit() inside jit() is not allowed.")
Expand Down Expand Up @@ -1206,7 +1195,7 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
vals_in, dims_in,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused, inline):
out_positional_semantics, inline):
# batch_jaxpr expects all batching dimensions to be equal to 0
vals_in = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(vals_in, dims_in)]
Expand Down Expand Up @@ -1235,7 +1224,6 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
name=name,
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused,
inline=inline)
dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out]
return vals_out, dims_out
Expand Down Expand Up @@ -1266,7 +1254,7 @@ def _pjit_batcher_for_sharding(
def _pjit_jvp(primals_in, tangents_in,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused, inline):
out_positional_semantics, inline):
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
jaxpr, is_nz_tangents_in, instantiate=False)
Expand All @@ -1285,7 +1273,6 @@ def _filter_zeros(is_nz_l, l):
name=wrap_name(name, 'jvp'),
in_positional_semantics=(*in_positional_semantics, *_filter_zeros_in(in_positional_semantics)),
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused,
inline=inline)

primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)])
Expand All @@ -1299,7 +1286,7 @@ def _filter_zeros(is_nz_l, l):
def _pjit_partial_eval(trace, *in_tracers,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused, inline):
out_positional_semantics, inline):
in_pvals = [t.pval for t in in_tracers]

known_ins = tuple(pv.is_known() for pv in in_pvals)
Expand Down Expand Up @@ -1329,7 +1316,6 @@ def keep_where(l, should_keep):
name=name,
in_positional_semantics=keep_where(in_positional_semantics, known_ins),
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused,
inline=inline)

if num_residuals:
Expand All @@ -1339,7 +1325,7 @@ def keep_where(l, should_keep):
known_params["jaxpr"], known_params["in_shardings"],
known_params["out_shardings"], known_params["resource_env"],
known_params["donated_invars"], known_params["name"],
in_is_global, known_params['keep_unused'], always_lower=False).compile(
in_is_global, always_lower=False).compile(
_allow_propagation_to_outputs=True,
_allow_compile_replicated=False)
da = compiled._device_assignment
Expand Down Expand Up @@ -1387,7 +1373,6 @@ def keep_where(l, should_keep):
in_positional_semantics=(keep_where(
in_positional_semantics, unknown_ins) + (out_positional_semantics,) * num_residuals),
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused,
inline=inline)
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
unknown_tracers_out = [
Expand All @@ -1411,7 +1396,7 @@ def keep_where(l, should_keep):
def _pjit_transpose(reduce_axes, cts_in, *primals_in,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused, inline):
out_positional_semantics, inline):
def prune_type(ty, xs, maybe_zeros):
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)

Expand Down Expand Up @@ -1456,7 +1441,6 @@ def prune_type(ty, xs, maybe_zeros):
name=name,
in_positional_semantics=transpose_in_positional_semantics,
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused,
inline=inline)
return tree_unflatten(cts_out_treedef, nz_cts_out)
ad.reducing_transposes[pjit_p] = _pjit_transpose
Expand Down
17 changes: 9 additions & 8 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -3498,7 +3498,6 @@ class _MeshExecutableFastpathData(NamedTuple):
out_shardings: Sequence[Any]
out_avals: Sequence[Any]
out_committed: Sequence[bool]
kept_var_bitvec: Iterable[bool]


class MeshExecutable(stages.XlaExecutable):
Expand Down Expand Up @@ -3578,24 +3577,26 @@ def create_cpp_call(self, no_kwargs, in_tree, out_tree):
not self.unsafe_call.has_host_callbacks):
return None

if not flags.FLAGS.experimental_cpp_pjit or xc._version < 111:
if not flags.FLAGS.experimental_cpp_pjit or xc._version < 96:
return None

def aot_cache_miss(*args, **kwargs):
params = stages.CompiledCallParams(self, no_kwargs, in_tree, out_tree)
outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs)
outs, out_flat = stages.Compiled.call(params, *args, **kwargs)

use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat))

if use_fastpath:
out_avals = [o.aval for o in out_flat]
out_committed = [o._committed for o in out_flat]
kept_var_bitvec = [i in self._kept_var_idx
for i in range(len(args_flat))]
fastpath_data = _MeshExecutableFastpathData(
self.xla_executable, out_tree, self._in_shardings,
self._out_shardings, out_avals, out_committed, kept_var_bitvec)
fastpath_data = _MeshExecutableFastpathData(self.xla_executable,
out_tree,
self._in_shardings,
self._out_shardings,
out_avals, out_committed)
else:
fastpath_data = None

return outs, fastpath_data

if xc._version < 108:
Expand Down
Loading

0 comments on commit ecaa215

Please sign in to comment.