Replies: 1 comment 1 reply
-
Do you currently know the best practices for implementing multi-host data parallelism with JAX? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I want to use JAX in multi-host and multi-process environments. I have found the tutorial at "https://jax.readthedocs.io/en/latest/multi_process.html". However, it doesn't seem to explain how to load data on different hosts. Suppose I have 100 datasets and 10 processes with 10 devices, labeled as "data_001" to "data_100". I want the first process to load "data_001", "data_011", and so on. Then, I would like to train the model based on the dataset in each process and average the gradients at the end.
Beta Was this translation helpful? Give feedback.
All reactions