Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved env_vars_connector._defaults_from_env_vars to utilities.argsparse._defaults_from_env_vars #10501

Merged
merged 10 commits into from
Nov 22, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))


- Moved `trainer.connectors.env_vars_connector._defaults_from_env_vars` to `utilities.argsparse._defaults_from_env_vars` ([#10501](https://github.com/PyTorchLightning/pytorch-lightning/pull/10501))


- Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426))


Expand Down
40 changes: 0 additions & 40 deletions pytorch_lightning/trainer/connectors/env_vars_connector.py

This file was deleted.

2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
Expand All @@ -75,6 +74,7 @@
rank_zero_warn,
)
from pytorch_lightning.utilities.argparse import (
_defaults_from_env_vars,
add_argparse_args,
from_argparse_args,
parse_argparser,
Expand Down
20 changes: 20 additions & 0 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from abc import ABC
from argparse import _ArgumentGroup, ArgumentParser, Namespace
from contextlib import suppress
from functools import wraps
from typing import Any, Callable, Dict, List, Tuple, Type, Union

import pytorch_lightning as pl
Expand Down Expand Up @@ -312,3 +313,22 @@ def _precision_allowed_type(x: Union[int, str]) -> Union[int, str]:
return int(x)
except ValueError:
return x


def _defaults_from_env_vars(fn: Callable) -> Callable:
@wraps(fn)
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any:
cls = self.__class__ # get the class
if args: # in case any args passed move them to kwargs
# parse only the argument names
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
# convert args to kwargs
kwargs.update(dict(zip(cls_arg_names, args)))
env_variables = vars(parse_env_variables(cls))
# update the kwargs by env variables
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))

# all args were already moved to kwargs
return fn(self, **kwargs)

return insert_env_defaults