From 0be30fbf9649a41a77e678fea7c35d8201f36925 Mon Sep 17 00:00:00 2001 From: Qiao Zhang Date: Mon, 25 Oct 2021 15:34:57 -0700 Subject: [PATCH] Add jax.distributed.initialize for multi-host GPU. --- CHANGELOG.md | 2 ++ jax/__init__.py | 1 + jax/_src/distributed.py | 59 ++++++++++++++++++++++++++++++++++++++ jax/_src/lib/xla_bridge.py | 12 ++++---- jax/distributed.py | 16 +++++++++++ 5 files changed, 85 insertions(+), 5 deletions(-) create mode 100644 jax/_src/distributed.py create mode 100644 jax/distributed.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d265cccf3778..114a6352ecc2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.24...main). +* New features: + * (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend. * Breaking changes * Moved `jax.experimental.stax` to `jax.example_libraries.stax` * Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers` diff --git a/jax/__init__.py b/jax/__init__.py index 85fda8b44203..a9d3da6b3e97 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -120,6 +120,7 @@ # jax and rely on the names imported above. from . import abstract_arrays as abstract_arrays from . import api_util as api_util +from . import distributed as distributed from . import dtypes as dtypes from . import errors as errors from . import image as image diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py new file mode 100644 index 000000000000..5487f4f80fae --- /dev/null +++ b/jax/_src/distributed.py @@ -0,0 +1,59 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools + +from absl import logging +from jax._src.lib import xla_bridge +from jax._src.lib import xla_client +from jax._src.lib import xla_extension + +_service = None +def initialize(coordinator_address: str, num_processes: int, process_id: int): + """Initialize distributed system for topology discovery. + + Currently, calling ``initialize`` sets up the multi-host GPU backend, and + is not required for CPU or TPU backends. + + Args: + coordinator_address: IP address of the coordinator. + num_processes: Number of processes. + process_id: Id of the current processe. + + Example: + + Suppose there are two GPU hosts, and host 0 is the designated coordinator + with address '10.0.0.1:1234', to initialize the GPU cluster, run the + following commands before anything else. + + On host 0 + >>> jax.distributed.initialize('10.0.0.1:1234', 2, 0) # doctest: +SKIP + + On host 1 + >>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP + """ + if process_id == 0: + global _service + assert _service is None, 'initialize should be called once only' + logging.info('Starting JAX distributed service on %s', coordinator_address) + _service = xla_extension.get_distributed_runtime_service(coordinator_address, + num_processes) + + client = xla_extension.get_distributed_runtime_client(coordinator_address, + process_id) + logging.info('Connecting to JAX distributed service on %s', coordinator_address) + client.connect() + + factory = functools.partial(xla_client.make_gpu_client, client, process_id) + xla_bridge.register_backend_factory('gpu', factory, priority=300) diff --git a/jax/_src/lib/xla_bridge.py b/jax/_src/lib/xla_bridge.py index 93b372305a58..afdc26391e43 100644 --- a/jax/_src/lib/xla_bridge.py +++ b/jax/_src/lib/xla_bridge.py @@ -170,8 +170,15 @@ def _log_warning(): # example, there could be multiple backends that provide the same kind of # device. _backend_factories = {} +_default_backend = None +_backends : Dict[str, Any] = {} +_backends_errors : Dict[str, str] = {} +_backend_lock = threading.Lock() def register_backend_factory(name, factory, *, priority=0): + with _backend_lock: + if name in _backends: + raise RuntimeError(f"Backend {name} already initialized") _backend_factories[name] = (factory, priority) @@ -187,11 +194,6 @@ def register_backend_factory(name, factory, *, priority=0): register_backend_factory( 'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300) -_default_backend = None -_backends : Dict[str, Any] = {} -_backends_errors : Dict[str, str] = {} -_backend_lock = threading.Lock() - def backends(): global _backends diff --git a/jax/distributed.py b/jax/distributed.py new file mode 100644 index 000000000000..1f2a0224966b --- /dev/null +++ b/jax/distributed.py @@ -0,0 +1,16 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa: F401 +from jax._src.distributed import initialize