You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to use shard_map to compute independently on each CPU, and then use io_callback() to launch some other computation on GPU, and return the result. I've registered a shard_map rule for io_callback that seems correct in this specific scenario, but I'm getting an XLA Assertion error:
2023-04-25 16:47:53.206001: F external/xla/xla/hlo/ir/hlo_sharding.cc:843] Check failed: !IsManual()
It seems that using a manual sharding strategy is not possible when running callbacks? The jax.debug.debug_callback works, but it can't return anything, so maybe that's why.
Here's the example I'm using (invoke with XLA_FLAGS="--xla_force_host_platform_device_count=8" )
fromfunctoolsimportpartialimportjaximportjax.numpyasjnpimportnumpyasnpfromjax._src.callbackimportio_callback_pfromjax.experimentalimportio_callbackfromjax.experimental.shard_mapimportregister_rule, shard_mapfromjax.shardingimportMeshfromjax.shardingimportPartitionSpecasP@register_rule(io_callback_p)def_io_callback_rule(mesh, *in_rep, **_kwargs):
returnin_repdefprint_args(x):
print("args shapes :", x.shape)
returnxcpu_mesh=Mesh(np.array(jax.devices("cpu")), axis_names=("i",))
@jax.jit@partial(shard_map, mesh=cpu_mesh, in_specs=(P("i"),), out_specs=P("i"))defcompute_in_cpu(x):
x=x+xx=io_callback(print_args, x, x, ordered=False) # works fine if commenting out this linereturnxif__name__=="__main__":
ones=jnp.ones((8000, 3))
print(ones.devices())
out=compute_in_cpu(ones)
print(out.devices())
print("out shape=", out.shape)
(My actual objective is to do some Jax computation on many CPUs, and some Jax computation on a single GPU. My attempts to do this have led me to set XLA_FLAGS="--xla_force_host_platform_device_count=8" to create multiple devices from CPUs, and call io_callback to later do some work on the GPU from them. Maybe I should implement a custom CPU and GPU primitive that communicate under the hood instead?)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I'm trying to use
shard_map
to compute independently on each CPU, and then useio_callback()
to launch some other computation on GPU, and return the result. I've registered ashard_map
rule forio_callback
that seems correct in this specific scenario, but I'm getting an XLA Assertion error:(https://github.com/openxla/xla/blob/main/xla/hlo/ir/hlo_sharding.cc#L843)
It seems that using a manual sharding strategy is not possible when running callbacks? The
jax.debug.debug_callback
works, but it can't return anything, so maybe that's why.Here's the example I'm using (invoke with
XLA_FLAGS="--xla_force_host_platform_device_count=8"
)(My actual objective is to do some Jax computation on many CPUs, and some Jax computation on a single GPU. My attempts to do this have led me to set
XLA_FLAGS="--xla_force_host_platform_device_count=8"
to create multiple devices from CPUs, and callio_callback
to later do some work on the GPU from them. Maybe I should implement a custom CPU and GPU primitive that communicate under the hood instead?)Beta Was this translation helpful? Give feedback.
All reactions