Skip to content

Commit

Permalink
Merge pull request #5977 from apaszke:xmap-with-control-flow
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 361854852
  • Loading branch information
jax authors committed Mar 9, 2021
2 parents f1ba3bc + ec29275 commit 6515b5f
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 39 deletions.
26 changes: 25 additions & 1 deletion jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,31 @@ def register_name(axis_name):
def subst_axis_names(primitive: Primitive, params: ParamDict, subst: AxisSubst) -> ParamDict:
if primitive in axis_substitution_rules:
return axis_substitution_rules[primitive](params, subst)
return params
# Default implementation: substitute names in all jaxpr parameters
if isinstance(primitive, MapPrimitive):
def shadowed_subst(name):
return (name,) if name == params['axis_name'] else subst(name)
else:
shadowed_subst = subst
jaxpr_params = [(n, v) for n, v in params.items() if isinstance(v, (Jaxpr, ClosedJaxpr))]
if not jaxpr_params:
return params
new_params = dict(params)
for name, jaxpr in jaxpr_params:
new_params[name] = subst_axis_names_jaxpr(jaxpr, shadowed_subst)
return new_params

def subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst):
consts = None
if isinstance(jaxpr, ClosedJaxpr):
consts = jaxpr.consts
jaxpr = jaxpr.jaxpr
eqns = [eqn._replace(params=subst_axis_names(eqn.primitive, eqn.params, subst))
for eqn in jaxpr.eqns]
new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, eqns)
if consts is not None:
return ClosedJaxpr(new_jaxpr, consts)
return new_jaxpr

axis_substitution_rules: Dict[Primitive, Callable[[ParamDict, AxisSubst], ParamDict]] = {}

Expand Down
57 changes: 19 additions & 38 deletions jax/experimental/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
with core.extend_axis_env_nd(axis_sizes.items()):
jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
out_axes = out_axes_thunk()
jaxpr = subst_jaxpr_axis_names(jaxpr, plan.axis_subst)
jaxpr = core.subst_axis_names_jaxpr(jaxpr, plan.axis_subst)

f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts)))
f = hide_mapped_axes(f, tuple(in_axes), tuple(out_axes))
Expand Down Expand Up @@ -563,9 +563,13 @@ def make_xmap_callable(fun: lu.WrappedFun,
class EvaluationPlan(NamedTuple):
"""Encapsulates preprocessing common to top-level xmap invocations and its translation rule."""
physical_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]
axis_subst: Dict[AxisName, Tuple[ResourceAxisName, ...]]
axis_subst_dict: Dict[AxisName, Tuple[ResourceAxisName, ...]]
axis_vmap_size: Dict[AxisName, Optional[int]]

@property
def axis_subst(self) -> core.AxisSubst:
return lambda name: self.axis_subst_dict.get(name, (name,))

@classmethod
def from_axis_resources(cls,
axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]],
Expand All @@ -574,7 +578,7 @@ def from_axis_resources(cls,
# TODO: Support sequential resources
physical_axis_resources = axis_resources # NB: We only support physical resources at the moment
resource_shape = resource_env.shape
axis_subst = dict(axis_resources)
axis_subst_dict = dict(axis_resources)
axis_vmap_size: Dict[AxisName, Optional[int]] = {}
for naxis, raxes in axis_resources.items():
num_resources = int(np.prod([resource_shape[axes] for axes in raxes], dtype=np.int64))
Expand All @@ -587,13 +591,13 @@ def from_axis_resources(cls,
# when every resource gets chunks of values.
if not raxes or tile_size > 1:
axis_vmap_size[naxis] = tile_size
axis_subst[naxis] += (fresh_resource_name(naxis),)
axis_subst_dict[naxis] += (fresh_resource_name(naxis),)
else:
axis_vmap_size[naxis] = None
return cls(physical_axis_resources, axis_subst, axis_vmap_size)
return cls(physical_axis_resources, axis_subst_dict, axis_vmap_size)

def vectorize(self, f: lu.WrappedFun, in_axes, out_axes):
for naxis, raxes in self.axis_subst.items():
for naxis, raxes in self.axis_subst_dict.items():
tile_size = self.axis_vmap_size[naxis]
if tile_size is None:
continue
Expand Down Expand Up @@ -643,6 +647,13 @@ def _process_xmap_default(self, call_primitive, f, tracers, params):
raise NotImplementedError(f"{type(self)} must override process_xmap to handle xmap")
core.Trace.process_xmap = _process_xmap_default # type: ignore

def _xmap_axis_subst(params, subst):
def shadowed_subst(name):
return (name,) if name in params['axis_sizes'] else subst(name)
new_jaxpr = core.subst_axis_names_jaxpr(params['call_jaxpr'], shadowed_subst)
return dict(params, call_jaxpr=new_jaxpr)
core.axis_substitution_rules[xmap_p] = _xmap_axis_subst


# This is DynamicJaxprTrace.process_map with some very minor modifications
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
Expand Down Expand Up @@ -747,7 +758,7 @@ def _xmap_translation_rule_replica(c, axis_env,
in zip(call_jaxpr.invars, in_axes, mesh_in_axes)]
# We have to substitute before tracing, because we want the vectorized
# axes to be used in the jaxpr.
resource_call_jaxpr = subst_jaxpr_axis_names(call_jaxpr, plan.axis_subst)
resource_call_jaxpr = core.subst_axis_names_jaxpr(call_jaxpr, plan.axis_subst)
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ())))
f = hide_mapped_axes(f, tuple(in_axes), tuple(out_axes))
f = plan.vectorize(f, in_axes, out_axes)
Expand Down Expand Up @@ -922,7 +933,7 @@ def _unsqueeze_mapped_axes(out, axes: AxisNamePos):
yield map(_unsqueeze_mapped_axes, flat_outputs, flat_out_axes)


def _jaxpr_resources(jaxpr, resource_env) -> Set[ResourceAxisName]:
def _jaxpr_resources(jaxpr: core.Jaxpr, resource_env) -> Set[ResourceAxisName]:
used_resources = set()
for eqn in jaxpr.eqns:
if eqn.primitive is xmap_p:
Expand All @@ -936,36 +947,6 @@ def _jaxpr_resources(jaxpr, resource_env) -> Set[ResourceAxisName]:
used_resources |= update
return used_resources

def subst_jaxpr_axis_names(jaxpr, axis_subst: Dict[AxisName, Tuple[AxisName]]):
eqns = [subst_eqn_axis_names(eqn, axis_subst) for eqn in jaxpr.eqns]
return core.Jaxpr(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, eqns)

def subst_eqn_axis_names(eqn, axis_subst: Dict[AxisName, Tuple[AxisName]]):
# TODO: Support custom_vjp, custom_jvp
if eqn.primitive is xmap_p:
shadowed_axes = set(eqn.params['axis_sizes']) & set(axis_subst)
if shadowed_axes:
shadowed_subst = dict(axis_subst)
for saxis in shadowed_axes:
del shadowed_subst[saxis]
else:
shadowed_subst = axis_subst
new_call_jaxpr = subst_jaxpr_axis_names(eqn.params['call_jaxpr'], shadowed_subst)
return eqn._replace(params=dict(eqn.params, call_jaxpr=new_call_jaxpr))
if isinstance(eqn.primitive, (core.CallPrimitive, core.MapPrimitive)):
bound_name = eqn.params.get('axis_name', None)
if bound_name in axis_subst: # Check for shadowing
sub_subst = dict(axis_subst)
del sub_subst[bound_name]
else:
sub_subst = axis_subst
new_call_jaxpr = subst_jaxpr_axis_names(eqn.params['call_jaxpr'], sub_subst)
return eqn._replace(params=dict(eqn.params, call_jaxpr=new_call_jaxpr))
new_params = core.subst_axis_names(eqn.primitive, eqn.params,
lambda name: axis_subst.get(name, (name,)))
return eqn if new_params is eqn.params else eqn._replace(params=new_params)


# -------- soft_pmap --------

def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
Expand Down
6 changes: 6 additions & 0 deletions tests/xmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,12 @@ def testBufferDonation(self, mesh, axis_resources):
self.assertNotDeleted(y)
self.assertDeleted(x)

@ignore_xmap_warning()
def testControlFlow(self):
x = jnp.arange(5)
xmap(lambda x: lax.fori_loop(0, 10, lambda _, x: lax.psum(x, 'i'), x),
in_axes=['i', ...], out_axes=['i', ...])(x)

@with_and_without_mesh
@ignore_xmap_warning()
def testAxisSizes(self, mesh, axis_resources):
Expand Down

0 comments on commit 6515b5f

Please sign in to comment.