From 174429d7cf1f8b04fc5a390ceb2198d6e3a779bd Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 17 Jul 2024 16:50:30 -0700 Subject: [PATCH] Remove `jax.spmd_mode` from pmap wrapper around shard_map PiperOrigin-RevId: 653403904 --- jax/experimental/shard_map.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index ad3d6bf46ece..efd109804147 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1716,15 +1716,13 @@ def infer_params(*args, **kwargs): def wrapped(*args, **kwargs): (jitted_f, flat_global_args, out_tree, mesh, out_specs) = infer_params(*args, **kwargs) - with jax.spmd_mode('allow_all'): - outs = jitted_f(*flat_global_args) - outs = global_array_to_host_local_array(outs, mesh, out_specs()) + outs = jitted_f(*flat_global_args) + outs = global_array_to_host_local_array(outs, mesh, out_specs()) return tree_unflatten(out_tree(), outs) def lower(*args, **kwargs): jitted_f, _, _, _, _ = infer_params(*args, **kwargs) - with jax.spmd_mode('allow_all'): - return jitted_f.lower(*args, **kwargs) + return jitted_f.lower(*args, **kwargs) wrapped.lower = lower return wrapped