Skip to content

Commit

Permalink
document deepspeed.initialize() (deepspeedai#644)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
stas00 and jeffra authored Jan 8, 2021
1 parent 4e2dc4e commit 828d75b
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions docs/_tutorials/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@ construct and manage the training optimizer, data loader, and the learning rate
scheduler based on the parameters passed to `deepspeed.initialize` and the
DeepSpeed [configuration file](#deepspeed-configuration).

If you already have a distributed environment setup, you'd need to replace:

```python
torch.distributed.init_process_group(...)
```

with:

```python
deepspeed.init_distributed()
```

The default is to use the NCCL backend, which DeepSpeed has been thoroughly tested with, but you can also [override the default](https://deepspeed.readthedocs.io/en/latest/initialize.html#distributed-initialization).

But if you don't need the distributed environment setup until after `deepspeed.initialize()` you don't have to use this function, as DeepSpeed will automatically initialize the distributed environment during its `initialize`. Regardless, you will need to remove `torch.distributed.init_process_group` if you already had it in place.


### Training

Expand Down

0 comments on commit 828d75b

Please sign in to comment.