Skip to content

Commit

Permalink
Remove jax.spmd_mode from pmap wrapper around shard_map
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653403904
  • Loading branch information
yashk2810 authored and jax authors committed Jul 17, 2024
1 parent 378a830 commit 174429d
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,15 +1716,13 @@ def infer_params(*args, **kwargs):
def wrapped(*args, **kwargs):
(jitted_f, flat_global_args, out_tree, mesh,
out_specs) = infer_params(*args, **kwargs)
with jax.spmd_mode('allow_all'):
outs = jitted_f(*flat_global_args)
outs = global_array_to_host_local_array(outs, mesh, out_specs())
outs = jitted_f(*flat_global_args)
outs = global_array_to_host_local_array(outs, mesh, out_specs())
return tree_unflatten(out_tree(), outs)

def lower(*args, **kwargs):
jitted_f, _, _, _, _ = infer_params(*args, **kwargs)
with jax.spmd_mode('allow_all'):
return jitted_f.lower(*args, **kwargs)
return jitted_f.lower(*args, **kwargs)
wrapped.lower = lower

return wrapped
Expand Down

0 comments on commit 174429d

Please sign in to comment.