diff --git a/docs/requirements.txt b/docs/requirements.txt index 8e1ad88807b2cf..273c4f0b18fd67 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,3 +4,5 @@ docutils==0.16 sphinxcontrib.katex matplotlib tensorboard +# required to build torch.distributed.elastic.rendezvous.etcd* docs +python-etcd>=0.4.5 diff --git a/docs/source/distributed.elastic.rst b/docs/source/distributed.elastic.rst new file mode 100644 index 00000000000000..0a23912bfa850b --- /dev/null +++ b/docs/source/distributed.elastic.rst @@ -0,0 +1,42 @@ +Torch Distributed Elastic +============================ + +Makes distributed PyTorch fault-tolerant and elastic. + +Get Started +--------------- +.. toctree:: + :maxdepth: 1 + :caption: Usage + + elastic/quickstart + elastic/train_script + elastic/examples + +Documentation +--------------- + +.. toctree:: + :maxdepth: 1 + :caption: API + + elastic/run + elastic/agent + elastic/multiprocessing + elastic/errors + elastic/rendezvous + elastic/timer + elastic/metrics + elastic/events + +.. toctree:: + :maxdepth: 1 + :caption: Advanced + + elastic/customization + +.. toctree:: + :maxdepth: 1 + :caption: Plugins + + elastic/kubernetes diff --git a/docs/source/elastic/agent.rst b/docs/source/elastic/agent.rst new file mode 100644 index 00000000000000..4cf92557e291f8 --- /dev/null +++ b/docs/source/elastic/agent.rst @@ -0,0 +1,61 @@ +Elastic Agent +============== + +.. automodule:: torch.distributed.elastic.agent +.. currentmodule:: torch.distributed.elastic.agent + +Server +-------- + +.. automodule:: torch.distributed.elastic.agent.server + +Below is a diagram of an agent that manages a local group of workers. + +.. image:: agent_diagram.jpg + +Concepts +-------- + +This section describes the high-level classes and concepts that +are relevant to understanding the role of the ``agent`` in torchelastic. + +.. currentmodule:: torch.distributed.elastic.agent.server + +.. autoclass:: ElasticAgent + :members: + +.. autoclass:: WorkerSpec + :members: + +.. autoclass:: WorkerState + :members: + +.. autoclass:: Worker + :members: + +.. autoclass:: WorkerGroup + :members: + +Implementations +------------------- + +Below are the agent implementations provided by torchelastic. + +.. currentmodule:: torch.distributed.elastic.agent.server.local_elastic_agent +.. autoclass:: LocalElasticAgent + + +Extending the Agent +--------------------- + +To extend the agent you can implement ```ElasticAgent`` directly, however +we recommend you extend ``SimpleElasticAgent`` instead, which provides +most of the scaffolding and leaves you with a few specific abstract methods +to implement. + +.. currentmodule:: torch.distributed.elastic.agent.server +.. autoclass:: SimpleElasticAgent + :members: + :private-members: + +.. autoclass:: torch.distributed.elastic.agent.server.api.RunResult diff --git a/docs/source/elastic/agent_diagram.jpg b/docs/source/elastic/agent_diagram.jpg new file mode 100644 index 00000000000000..79fad343280362 Binary files /dev/null and b/docs/source/elastic/agent_diagram.jpg differ diff --git a/docs/source/elastic/customization.rst b/docs/source/elastic/customization.rst new file mode 100644 index 00000000000000..f7975c9b86c649 --- /dev/null +++ b/docs/source/elastic/customization.rst @@ -0,0 +1,118 @@ +Customization +============= + +This section describes how to customize TorchElastic to fit your needs. + +Launcher +------------------------ + +The launcher program that ships with TorchElastic +should be sufficient for most use-cases (see :ref:`launcher-api`). +You can implement a custom launcher by +programmatically creating an agent and passing it specs for your workers as +shown below. + +.. code-block:: python + + # my_launcher.py + + if __name__ == "__main__": + args = parse_args(sys.argv[1:]) + rdzv_handler = RendezvousHandler(...) + spec = WorkerSpec( + local_world_size=args.nproc_per_node, + fn=trainer_entrypoint_fn, + args=(trainer_entrypoint_fn args.fn_args,...), + rdzv_handler=rdzv_handler, + max_restarts=args.max_restarts, + monitor_interval=args.monitor_interval, + ) + + agent = LocalElasticAgent(spec, start_method="spawn") + try: + run_result = agent.run() + if run_result.is_failed(): + print(f"worker 0 failed with: run_result.failures[0]") + else: + print(f"worker 0 return value is: run_result.return_values[0]") + except Exception ex: + # handle exception + + +Rendezvous Handler +------------------------ + +To implement your own rendezvous, extend ``torch.distributed.elastic.rendezvous.RendezvousHandler`` +and implement its methods. + +.. warning:: Rendezvous handlers are tricky to implement. Before you begin + make sure you completely understand the properties of rendezvous. + Please refer to :ref:`rendezvous-api` for more information. + +Once implemented you can pass your custom rendezvous handler to the worker +spec when creating the agent. + +.. code-block:: python + + spec = WorkerSpec( + rdzv_handler=MyRendezvousHandler(params), + ... + ) + elastic_agent = LocalElasticAgent(spec, start_method=start_method) + elastic_agent.run(spec.role) + + +Metric Handler +----------------------------- + +TorchElastic emits platform level metrics (see :ref:`metrics-api`). +By default metrics are emitted to `/dev/null` so you will not see them. +To have the metrics pushed to a metric handling service in your infrastructure, +implement a `torch.distributed.elastic.metrics.MetricHandler` and `configure` it in your +custom launcher. + +.. code-block:: python + + # my_launcher.py + + import torch.distributed.elastic.metrics as metrics + + class MyMetricHandler(metrics.MetricHandler): + def emit(self, metric_data: metrics.MetricData): + # push metric_data to your metric sink + + def main(): + metrics.configure(MyMetricHandler()) + + spec = WorkerSpec(...) + agent = LocalElasticAgent(spec) + agent.run() + +Events Handler +----------------------------- + +TorchElastic supports events recording (see :ref:`events-api`). +The events module defines API that allows you to record events and +implement custom EventHandler. EventHandler is used for publishing events +produced during torchelastic execution to different sources, e.g. AWS CloudWatch. +By default it uses `torch.distributed.elastic.events.NullEventHandler` that ignores +events. To configure custom events handler you need to implement +`torch.distributed.elastic.events.EventHandler` interface and `configure` it +in your custom launcher. + +.. code-block:: python + + # my_launcher.py + + import torch.distributed.elastic.events as events + + class MyEventHandler(events.EventHandler): + def record(self, event: events.Event): + # process event + + def main(): + events.configure(MyEventHandler()) + + spec = WorkerSpec(...) + agent = LocalElasticAgent(spec) + agent.run() diff --git a/docs/source/elastic/errors.rst b/docs/source/elastic/errors.rst new file mode 100644 index 00000000000000..1105d1b253e8d9 --- /dev/null +++ b/docs/source/elastic/errors.rst @@ -0,0 +1,17 @@ +Error Propagation +================== + +.. automodule:: torch.distributed.elastic.multiprocessing.errors + +Methods and Classes +--------------------- + +.. currentmodule:: torch.distributed.elastic.multiprocessing.errors + +.. autofunction:: torch.distributed.elastic.multiprocessing.errors.record + +.. autoclass:: ChildFailedError + +.. autoclass:: ErrorHandler + +.. autoclass:: ProcessFailure diff --git a/docs/source/elastic/etcd_rdzv_diagram.png b/docs/source/elastic/etcd_rdzv_diagram.png new file mode 100644 index 00000000000000..c15b8160391caa Binary files /dev/null and b/docs/source/elastic/etcd_rdzv_diagram.png differ diff --git a/docs/source/elastic/events.rst b/docs/source/elastic/events.rst new file mode 100644 index 00000000000000..86d0be8dad52bb --- /dev/null +++ b/docs/source/elastic/events.rst @@ -0,0 +1,24 @@ +.. _events-api: + +Events +============================ + +.. automodule:: torch.distributed.elastic.events + +API Methods +------------ + +.. autofunction:: torch.distributed.elastic.events.record + +.. autofunction:: torch.distributed.elastic.events.get_logging_handler + +Event Objects +----------------- + +.. currentmodule:: torch.distributed.elastic.events.api + +.. autoclass:: torch.distributed.elastic.events.api.Event + +.. autoclass:: torch.distributed.elastic.events.api.EventSource + +.. autoclass:: torch.distributed.elastic.events.api.EventMetadataValue diff --git a/docs/source/elastic/examples.rst b/docs/source/elastic/examples.rst new file mode 100644 index 00000000000000..e12c2264696396 --- /dev/null +++ b/docs/source/elastic/examples.rst @@ -0,0 +1,4 @@ +Examples +========================== + +Please refer to the `elastic/examples README `_. diff --git a/docs/source/elastic/kubernetes.rst b/docs/source/elastic/kubernetes.rst new file mode 100644 index 00000000000000..55a051b5a76f5c --- /dev/null +++ b/docs/source/elastic/kubernetes.rst @@ -0,0 +1,5 @@ +TorchElastic Kubernetes +========================== + +Please refer to our github's `Kubernetes README `_ +for more information on Elastic Job Controller and custom resource definition. diff --git a/docs/source/elastic/metrics.rst b/docs/source/elastic/metrics.rst new file mode 100644 index 00000000000000..ca31ff83b86eea --- /dev/null +++ b/docs/source/elastic/metrics.rst @@ -0,0 +1,31 @@ +.. _metrics-api: + +Metrics +========= + +.. automodule:: torch.distributed.elastic.metrics + + +Metric Handlers +----------------- + +.. currentmodule:: torch.distributed.elastic.metrics.api + +Below are the metric handlers that come included with torchelastic. + +.. autoclass:: MetricHandler + +.. autoclass:: ConsoleMetricHandler + +.. autoclass:: NullMetricHandler + + + +Methods +------------ + +.. autofunction:: torch.distributed.elastic.metrics.configure + +.. autofunction:: torch.distributed.elastic.metrics.prof + +.. autofunction:: torch.distributed.elastic.metrics.put_metric diff --git a/docs/source/elastic/multiprocessing.rst b/docs/source/elastic/multiprocessing.rst new file mode 100644 index 00000000000000..fc5866c01e7c75 --- /dev/null +++ b/docs/source/elastic/multiprocessing.rst @@ -0,0 +1,24 @@ +:github_url: https://github.com/pytorch/elastic + +Multiprocessing +================ + +.. automodule:: torch.distributed.elastic.multiprocessing + +Starting Multiple Workers +--------------------------- + +.. autofunction:: torch.distributed.elastic.multiprocessing.start_processes + +Process Context +---------------- + +.. currentmodule:: torch.distributed.elastic.multiprocessing.api + +.. autoclass:: PContext + +.. autoclass:: MultiprocessContext + +.. autoclass:: SubprocessContext + +.. autoclass:: RunProcsResult diff --git a/docs/source/elastic/quickstart.rst b/docs/source/elastic/quickstart.rst new file mode 100644 index 00000000000000..4b7788b02cadad --- /dev/null +++ b/docs/source/elastic/quickstart.rst @@ -0,0 +1,50 @@ +Quickstart +=========== + +.. code-block:: bash + + pip install torch + + # start a single-node etcd server on ONE host + etcd --enable-v2 + --listen-client-urls http://0.0.0.0:2379,http://127.0.0.1:4001 + --advertise-client-urls PUBLIC_HOSTNAME:2379 + +To launch a **fault-tolerant** job, run the following on all nodes. + +.. code-block:: bash + + python -m torch.distributed.run + --nnodes=NUM_NODES + --nproc_per_node=TRAINERS_PER_NODE + --rdzv_id=JOB_ID + --rdzv_backend=etcd + --rdzv_endpoint=ETCD_HOST:ETCD_PORT + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + + +To launch an **elastic** job, run the following on at least ``MIN_SIZE`` nodes +and at most ``MAX_SIZE`` nodes. + +.. code-block:: bash + + python -m torch.distributed.run + --nnodes=MIN_SIZE:MAX_SIZE + --nproc_per_node=TRAINERS_PER_NODE + --rdzv_id=JOB_ID + --rdzv_backend=etcd + --rdzv_endpoint=ETCD_HOST:ETCD_PORT + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + + +.. note:: The `--standalone` option can be passed to launch a single node job with + a sidecar rendezvous server. You don’t have to pass —rdzv_id, —rdzv_endpoint, + and —rdzv_backend when the —standalone option is used + + +.. note:: Learn more about writing your distributed training script + `here `_. + +If ``torch.distributed.run`` does not meet your requirements +you may use our APIs directly for more powerful customization. Start by +taking a look at the `elastic agent `_ API). diff --git a/docs/source/elastic/rendezvous.rst b/docs/source/elastic/rendezvous.rst new file mode 100644 index 00000000000000..259e5d99b0ffdc --- /dev/null +++ b/docs/source/elastic/rendezvous.rst @@ -0,0 +1,65 @@ +.. _rendezvous-api: + +Rendezvous +========== + +.. automodule:: torch.distributed.elastic.rendezvous + +Below is a state diagram describing how rendezvous works. + +.. image:: etcd_rdzv_diagram.png + +Registry +-------------------- + +.. autoclass:: RendezvousParameters + +.. automodule:: torch.distributed.elastic.rendezvous.registry + +Handler +-------------------- + +.. currentmodule:: torch.distributed.elastic.rendezvous + +.. autoclass:: RendezvousHandler + :members: + +Exceptions +------------- +.. autoclass:: RendezvousError +.. autoclass:: RendezvousClosedError +.. autoclass:: RendezvousTimeoutError +.. autoclass:: RendezvousConnectionError +.. autoclass:: RendezvousStateError + +Implmentations +---------------- + +Etcd Rendezvous +**************** + +.. currentmodule:: torch.distributed.elastic.rendezvous.etcd_rendezvous + +.. autoclass:: EtcdRendezvousHandler + +.. autoclass:: EtcdRendezvous + :members: + +.. autoclass:: EtcdStore + :members: + +Etcd Server +************* + +The ``EtcdServer`` is a convenience class that makes it easy for you to +start and stop an etcd server on a subprocess. This is useful for testing +or single-node (multi-worker) deployments where manually setting up an +etcd server on the side is cumbersome. + +.. warning:: For production and multi-node deployments please consider + properly deploying a highly available etcd server as this is + the single point of failure for your distributed jobs. + +.. currentmodule:: torch.distributed.elastic.rendezvous.etcd_server + +.. autoclass:: EtcdServer diff --git a/docs/source/elastic/run.rst b/docs/source/elastic/run.rst new file mode 100644 index 00000000000000..6e1eac8055ded9 --- /dev/null +++ b/docs/source/elastic/run.rst @@ -0,0 +1,9 @@ +.. _launcher-api: + +Elastic Launch +============================ + +torch.distributed.run +---------------------- + +.. automodule:: torch.distributed.run diff --git a/docs/source/elastic/timer.rst b/docs/source/elastic/timer.rst new file mode 100644 index 00000000000000..e9d4228ee7a6a4 --- /dev/null +++ b/docs/source/elastic/timer.rst @@ -0,0 +1,41 @@ +Expiration Timers +================== + +.. automodule:: torch.distributed.elastic.timer +.. currentmodule:: torch.distributed.elastic.timer + +Client Methods +--------------- +.. autofunction:: torch.distributed.elastic.timer.configure + +.. autofunction:: torch.distributed.elastic.timer.expires + +Server/Client Implementations +------------------------------ +Below are the timer server and client pairs that are provided by torchelastic. + +.. note:: Timer server and clients always have to be implemented and used + in pairs since there is a messaging protocol between the server + and client. + +.. autoclass:: LocalTimerServer + +.. autoclass:: LocalTimerClient + +Writing a custom timer server/client +-------------------------------------- + +To write your own timer server and client extend the +``torch.distributed.elastic.timer.TimerServer`` for the server and +``torch.distributed.elastic.timer.TimerClient`` for the client. The +``TimerRequest`` object is used to pass messages between +the server and client. + +.. autoclass:: TimerRequest + :members: + +.. autoclass:: TimerServer + :members: + +.. autoclass:: TimerClient + :members: diff --git a/docs/source/elastic/train_script.rst b/docs/source/elastic/train_script.rst new file mode 100644 index 00000000000000..261ebd474c9ee1 --- /dev/null +++ b/docs/source/elastic/train_script.rst @@ -0,0 +1,46 @@ +Train script +------------- + +If your train script works with ``torch.distributed.launch`` it will continue +working with ``torch.distributed.run`` with these differences: + +1. No need to manually pass ``RANK``, ``WORLD_SIZE``, + ``MASTER_ADDR``, and ``MASTER_PORT``. + +2. ``rdzv_backend`` and ``rdzv_endpoint`` must be provided. For most users + this will be set to ``etcd`` (see `rendezvous `_). + +3. Make sure you have a ``load_checkpoint(path)`` and + ``save_checkpoint(path)`` logic in your script. When workers fail + we restart all the workers with the same program arguments so you will + lose progress up to the most recent checkpoint + (see `elastic launch `_). + +4. ``use_env`` flag has been removed. If you were parsing local rank by parsing + the ``--local_rank`` option, you need to get the local rank from the + environment variable ``LOCAL_RANK`` (e.g. ``os.environ["LOCAL_RANK"]``). + +Below is an expository example of a training script that checkpoints on each +epoch, hence the worst-case progress lost on failure is one full epoch worth +of training. + +.. code-block:: python + + def main(): + args = parse_args(sys.argv[1:]) + state = load_checkpoint(args.checkpoint_path) + initialize(state) + + # torch.distributed.run ensure that this will work + # by exporting all the env vars needed to initialize the process group + torch.distributed.init_process_group(backend=args.backend) + + for i in range(state.epoch, state.total_num_epochs) + for batch in iter(state.dataset) + train(batch, state.model) + + state.epoch += 1 + save_checkpoint(state) + +For concrete examples of torchelastic-compliant train scripts, visit +our `examples `_ page. diff --git a/docs/source/index.rst b/docs/source/index.rst index 2ced217819f444..da434fb0ef272c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -58,6 +58,7 @@ Features described in this documentation are classified by release status: torch.cuda.amp torch.backends torch.distributed + torch.distributed.elastic torch.distributed.optim torch.distributions torch.fft @@ -102,7 +103,6 @@ Features described in this documentation are classified by release status: torchaudio torchtext torchvision - TorchElastic TorchServe PyTorch on XLA Devices diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index cd47d5bcc00469..55b5b10e35537f 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -38,7 +38,7 @@ def wait_event(self, event): r"""Makes all future work submitted to the stream wait for an event. Args: - event (Event): an event to wait for. + event (torch.cuda.Event): an event to wait for. .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see `CUDA Stream documentation`_ for more info. @@ -69,7 +69,7 @@ def record_event(self, event=None): r"""Records an event. Args: - event (Event, optional): event to record. If not given, a new one + event (torch.cuda.Event, optional): event to record. If not given, a new one will be allocated. Returns: diff --git a/torch/distributed/elastic/rendezvous/__init__.py b/torch/distributed/elastic/rendezvous/__init__.py index db44bb7faf189b..ef6433d9ecd3e8 100644 --- a/torch/distributed/elastic/rendezvous/__init__.py +++ b/torch/distributed/elastic/rendezvous/__init__.py @@ -98,7 +98,7 @@ can participate in *next* rendezvous. 2. Setting the rendezvous *closed* to signal all workers not - to participate in next rendezvous. + to participate in next rendezvous """ from .api import * # noqa: F403 diff --git a/torch/distributed/elastic/rendezvous/registry.py b/torch/distributed/elastic/rendezvous/registry.py index fcdeb04bae4a46..cdfe88ed0a6396 100644 --- a/torch/distributed/elastic/rendezvous/registry.py +++ b/torch/distributed/elastic/rendezvous/registry.py @@ -47,6 +47,21 @@ def _register_default_handlers() -> None: handler_registry.register("static", _create_static_handler) -# The legacy function kept for backwards compatibility. def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler: + """ + This method is used to obtain a reference to a :py:class`RendezvousHandler`. + Custom rendezvous handlers can be registered by + + :: + + from torch.distributed.elastid.rendezvous import rendezvous_handler_registry + from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler + + def create_my_rdzv(params: RendezvousParameters): + return MyCustomRdzv(params) + + rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv) + + my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters) + """ return handler_registry.create_handler(params)