Skip to content

Commit

Permalink
Moved env_vars_connector._defaults_from_env_vars to `utilities.args…
Browse files Browse the repository at this point in the history
…parse._defaults_from_env_vars` (#10501)

Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
kaushikb11 and carmocca authored Nov 22, 2021
1 parent 8ea39d2 commit ce0a977
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 41 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))


- Moved `trainer.connectors.env_vars_connector._defaults_from_env_vars` to `utilities.argsparse._defaults_from_env_vars` ([#10501](https://github.com/PyTorchLightning/pytorch-lightning/pull/10501))


- Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426))


Expand Down
40 changes: 0 additions & 40 deletions pytorch_lightning/trainer/connectors/env_vars_connector.py

This file was deleted.

2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
Expand All @@ -75,6 +74,7 @@
rank_zero_warn,
)
from pytorch_lightning.utilities.argparse import (
_defaults_from_env_vars,
add_argparse_args,
from_argparse_args,
parse_argparser,
Expand Down
20 changes: 20 additions & 0 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from abc import ABC
from argparse import _ArgumentGroup, ArgumentParser, Namespace
from contextlib import suppress
from functools import wraps
from typing import Any, Callable, Dict, List, Tuple, Type, Union

import pytorch_lightning as pl
Expand Down Expand Up @@ -312,3 +313,22 @@ def _precision_allowed_type(x: Union[int, str]) -> Union[int, str]:
return int(x)
except ValueError:
return x


def _defaults_from_env_vars(fn: Callable) -> Callable:
@wraps(fn)
def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any:
cls = self.__class__ # get the class
if args: # in case any args passed move them to kwargs
# parse only the argument names
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
# convert args to kwargs
kwargs.update(dict(zip(cls_arg_names, args)))
env_variables = vars(parse_env_variables(cls))
# update the kwargs by env variables
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))

# all args were already moved to kwargs
return fn(self, **kwargs)

return insert_env_defaults

0 comments on commit ce0a977

Please sign in to comment.