Skip to content

Commit

Permalink
Merge branch 'master' into feat/cli-shorthand-model
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Sep 22, 2021
2 parents 848d0eb + e64f358 commit e9ce94f
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 35 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `add_argparse_args` raising `TypeError` when args are typed as `typing.Generic` in Python 3.6 ([#9554](https://github.com/PyTorchLightning/pytorch-lightning/pull/9554))


- Fixed gradient accumulation for `DDPShardedPlugin` ([#9122](https://github.com/PyTorchLightning/pytorch-lightning/pull/9122))


## [1.4.7] - 2021-09-14

- Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364))
Expand Down
4 changes: 2 additions & 2 deletions pl_examples/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Examples

Our most robust examples showing all sorts of implementations
can be found in our sister library [lightning-bolts](https://lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2).
can be found in our sister library [lightning-bolts](https://pytorch-lightning.readthedocs.io/en/latest/ecosystem/bolts.html).

______________________________________________________________________

Expand All @@ -18,5 +18,5 @@ ______________________________________________________________________
## Domain examples

This folder contains older examples. You should instead use the examples
in [lightning-bolts](https://lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2)
in [lightning-bolts](https://pytorch-lightning.readthedocs.io/en/latest/ecosystem/bolts.html)
for advanced use cases.
2 changes: 1 addition & 1 deletion pl_examples/domain_templates/reinforce_learn_Qnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O
Training loss and log metrics
"""
device = self.get_device(batch)
epsilon = max(self.eps_end, self.eps_start - self.global_step + 1 / self.eps_last_frame)
epsilon = max(self.eps_end, self.eps_start - (self.global_step + 1) / self.eps_last_frame)

# step through environment with agent
reward, done = self.agent.play_step(self.net, epsilon, device)
Expand Down
16 changes: 15 additions & 1 deletion pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
from contextlib import contextmanager
from typing import Dict, Generator, Optional

import torch

Expand Down Expand Up @@ -100,6 +101,19 @@ def lightning_module(self) -> "pl.LightningModule":
def pre_backward(self, closure_loss: torch.Tensor) -> None:
pass

@contextmanager
def block_backward_sync(self) -> Generator:
"""Blocks syncing gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, ShardedDataParallel):
with self.model.no_sync():
yield None
else:
yield None

def post_training_step(self):
pass

Expand Down
16 changes: 15 additions & 1 deletion pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
from contextlib import contextmanager
from typing import Dict, Generator, Optional

import torch

Expand Down Expand Up @@ -63,6 +64,19 @@ def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

@contextmanager
def block_backward_sync(self) -> Generator:
"""Blocks syncing gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, ShardedDataParallel):
with self.model.no_sync():
yield None
else:
yield None

@rank_zero_only
def _optim_state_dict(self, optimizer):
"""
Expand Down
27 changes: 0 additions & 27 deletions pytorch_lightning/trainer/deprecated_api.py

This file was deleted.

9 changes: 7 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import DeprecatedTrainerAttributes
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
Expand Down Expand Up @@ -114,7 +113,6 @@ class Trainer(
TrainerModelHooksMixin,
TrainerOptimizersMixin,
TrainerDataLoadingMixin,
DeprecatedTrainerAttributes,
):
# Needed because of LightningOptimizer
_lightning_optimizers = None
Expand Down Expand Up @@ -1957,6 +1955,13 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop
if self.predicting:
return self.predict_loop

@property
def train_loop(self) -> FitLoop:
rank_zero_deprecation(
"`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6."
)
return self.fit_loop

@property
def _ckpt_path(self) -> Optional[str]:
if self.state.fn == TrainerFn.VALIDATING:
Expand Down
4 changes: 3 additions & 1 deletion requirements/adjust_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
import sys
from typing import Dict, Optional

# IMPORTANT: this list needs to be sorted in reverse
VERSIONS = [
dict(torch="1.10.0", torchvision="0.11.*", torchtext=""), # nightly
dict(torch="1.9.1", torchvision="0.10.1", torchtext="0.10.1"),
dict(torch="1.9.0", torchvision="0.10.0", torchtext="0.10.0"),
dict(torch="1.8.2", torchvision="0.9.1", torchtext="0.9.1"),
dict(torch="1.8.1", torchvision="0.9.1", torchtext="0.9.1"),
dict(torch="1.8.0", torchvision="0.9.0", torchtext="0.9.0"),
dict(torch="1.7.1", torchvision="0.8.2", torchtext="0.8.1"),
dict(torch="1.7.0", torchvision="0.8.1", torchtext="0.8.0"),
dict(torch="1.6.0", torchvision="0.7.0", torchtext="0.7"),
]
VERSIONS.sort(key=lambda v: v["torch"], reverse=True)


def find_latest(ver: str) -> Dict[str, str]:
Expand Down
10 changes: 10 additions & 0 deletions tests/plugins/test_sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,13 @@ def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffe
assert kwargs["reduce_buffer_size"] == DDPShardedPlugin._REDUCE_BUFFER_SIZE_DEFAULT
else:
assert kwargs["reduce_buffer_size"] == expected_buffer_size


@RunIf(skip_windows=True, fairscale=True)
def test_block_backward_sync(tmpdir):
plugin = DDPShardedPlugin()
model = mock.MagicMock(spec=ShardedDataParallel)
with mock.patch.object(plugin, "_model", model):
with plugin.block_backward_sync():
pass
model.no_sync.assert_called_once()

0 comments on commit e9ce94f

Please sign in to comment.