Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved env_vars_connector._defaults_from_env_vars to utilities.argsparse._defaults_from_env_vars #10501

Merged
merged 10 commits into from
Nov 22, 2021
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))


-
- Deprecated `trainer.connectors.env_vars_connector._defaults_from_env_vars` in favor of `utilities.argsparse._defaults_from_env_vars` ([#10501](https://github.com/PyTorchLightning/pytorch-lightning/pull/10501))


-
Expand Down
31 changes: 7 additions & 24 deletions pytorch_lightning/trainer/connectors/env_vars_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import wraps
from typing import Callable
from pytorch_lightning.utilities import rank_zero_deprecation

from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables
rank_zero_deprecation(
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
"Using `pytorch_lightning.trainer.connectors.env_vars_connector._defaults_from_env_vars` is "
"deprecated in v1.6, and will be removed in v1.8. It has been replaced with "
"`pytorch_lightning.utilities.argsparse._defaults_from_env_vars`"
)


def _defaults_from_env_vars(fn: Callable) -> Callable:
"""Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should
be moved automatically to the correct device."""

@wraps(fn)
def insert_env_defaults(self, *args, **kwargs):
cls = self.__class__ # get the class
if args: # inace any args passed move them to kwargs
# parse only the argument names
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
# convert args to kwargs
kwargs.update(dict(zip(cls_arg_names, args)))
env_variables = vars(parse_env_variables(cls))
# update the kwargs by env variables
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))

# all args were already moved to kwargs
return fn(self, **kwargs)

return insert_env_defaults
from pytorch_lightning.utilities.argparse import _defaults_from_env_vars # noqa: E402, F401
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
Expand All @@ -75,6 +74,7 @@
rank_zero_warn,
)
from pytorch_lightning.utilities.argparse import (
_defaults_from_env_vars,
add_argparse_args,
from_argparse_args,
parse_argparser,
Expand Down
20 changes: 20 additions & 0 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from abc import ABC
from argparse import _ArgumentGroup, ArgumentParser, Namespace
from contextlib import suppress
from functools import wraps
from typing import Any, Callable, Dict, List, Tuple, Type, Union

import pytorch_lightning as pl
Expand Down Expand Up @@ -312,3 +313,22 @@ def _precision_allowed_type(x: Union[int, str]) -> Union[int, str]:
return int(x)
except ValueError:
return x


def _defaults_from_env_vars(fn: Callable) -> Callable:
@wraps(fn)
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
def insert_env_defaults(self, *args, **kwargs) -> Any:
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
cls = self.__class__ # get the class
if args: # in case any args passed move them to kwargs
# parse only the argument names
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
# convert args to kwargs
kwargs.update(dict(zip(cls_arg_names, args)))
env_variables = vars(parse_env_variables(cls))
# update the kwargs by env variables
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))

# all args were already moved to kwargs
return fn(self, **kwargs)

return insert_env_defaults
23 changes: 23 additions & 0 deletions tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.8.0."""
import pytest

from tests.deprecated_api import _soft_unimport_module


def test_v1_8_0_deprecated_env_vars_connector_defaults_from_env_vars():
_soft_unimport_module("pytorch_lightning.trainer.connectors.env_vars_connector._defaults_from_env_vars")
with pytest.deprecated_call(match="deprecated in v1.6, and will be removed in v1.8."):
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars # noqa: F401