You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Very excited to see llama2 built with JAX!
I am trying to run llama2-7b-hf with the provided framework on A800 cluster. But when I follow the readme docs and run python train.py, it reports jaxlib.xla_extension.XlaRuntimeError. Could you please provide some insights about how to fix that? Thanks!
Error Msg:
Traceback (most recent call last):
File "/root/llama-2-jax/train.py", line 146, in <module>
main()
File "/root/llama-2-jax/train.py", line 89, in main
jax.distributed.initialize(coordinator_address="localhost", num_processes=8, process_id=0)
File "/root/miniconda3/envs/jax/lib/python3.11/site-packages/jax/_src/distributed.py", line 180, in initialize
global_state.initialize(coordinator_address, num_processes, process_id,
File "/root/miniconda3/envs/jax/lib/python3.11/site-packages/jax/_src/distributed.py", line 95, in initialize
self.client.connect()
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: Barrier timed out. Barrier_id: PjRT_Client_Connect. Timed out task names:
/job:jax_worker/replica:0/task:1
/job:jax_worker/replica:0/task:6
/job:jax_worker/replica:0/task:3
/job:jax_worker/replica:0/task:2
/job:jax_worker/replica:0/task:5
/job:jax_worker/replica:0/task:7
/job:jax_worker/replica:0/task:4
Additional GRPC error information from remote target unknown_target_for_coordination_leader while calling /tensorflow.CoordinationService/Barrier:
:{"created":"@1700792772.792587791","description":"Error received from peer ipv4:127.0.0.1:443","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Barrier timed out. Barrier_id: PjRT_Client_Connect. Timed out task names:\n/job:jax_worker/replica:0/task:1\n/job:jax_worker/replica:0/task:6\n/job:jax_worker/replica:0/task:3\n/job:jax_worker/replica:0/task:2\n/job:jax_worker/replica:0/task:5\n/job:jax_worker/replica:0/task:7\n/job:jax_worker/replica:0/task:4\n","grpc_status":4}
2023-11-24 02:26:13.167769: E external/tsl/tsl/distributed_runtime/coordination/coordination_service_agent.cc:494] Failed to disconnect from coordination service with status: UNAVAILABLE: failed to connect to all addresses
Additional GRPC error information from remote target unknown_target_for_coordination_leader while calling /tensorflow.CoordinationService/ShutdownTask:
:{"created":"@1700792773.167722780","description":"Failed to pick subchannel","file":"external/com_github_grpc_grpc/src/core/ext/filters/client_channel/client_channel.cc","file_line":3940,"referenced_errors":[{"created":"@1700792773.167720845","description":"failed to connect to all addresses","file":"external/com_github_grpc_grpc/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc","file_line":392,"grpc_status":14}]}
Proceeding with agent shutdown anyway. This is usually caused by an earlier error during execution. Check the logs (this task or the leader) for an earlier error to debug further.
Hi Ayaka,
Very excited to see llama2 built with JAX!
I am trying to run llama2-7b-hf with the provided framework on A800 cluster. But when I follow the readme docs and run
python train.py
, it reportsjaxlib.xla_extension.XlaRuntimeError
. Could you please provide some insights about how to fix that? Thanks!Error Msg:
Software details:
Hardware details:
The text was updated successfully, but these errors were encountered: