diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index eb5235b796b..bea4d55e174 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -55,6 +55,11 @@ _LOGGER = logging.getLogger(__name__) +# Map defined with option names to flag names for boolean options +# that have a destination(dest) in parser.add_argument() different +# from the flag name and whose default value is `None`. +_FLAG_THAT_SETS_FALSE_VALUE = {'use_public_ips': 'no_use_public_ips'} + def _static_value_provider_of(value_type): """"Helper function to plug a ValueProvider into argparse. @@ -180,7 +185,15 @@ def __init__(self, flags=None, **kwargs): flags: An iterable of command line arguments to be used. If not specified then sys.argv will be used as input for parsing arguments. - **kwargs: Add overrides for arguments passed in flags. + **kwargs: Add overrides for arguments passed in flags. For overrides + of arguments, please pass the `option names` instead of + flag names. + Option names: These are defined as dest in the + parser.add_argument() for each flag. Passing flags + like {no_use_public_ips: True}, for which the dest is + defined to a different flag name in the parser, + would be discarded. Instead, pass the dest of + the flag (dest of no_use_public_ips is use_public_ips). """ # Initializing logging configuration in case the user did not set it up. logging.basicConfig() @@ -237,9 +250,22 @@ def from_dictionary(cls, options): """ flags = [] for k, v in options.items(): + # Note: If a boolean flag is True in the dictionary, + # implicitly the method assumes the boolean flag is + # specified as a command line argument. If the + # boolean flag is False, this method simply discards them. + # Eg: {no_auth: True} is similar to python your_file.py --no_auth + # {no_auth: False} is similar to python your_file.py. if isinstance(v, bool): if v: flags.append('--%s' % k) + elif k in _FLAG_THAT_SETS_FALSE_VALUE: + # Capture overriding flags, which have a different dest + # from the flag name defined in the parser.add_argument + # Eg: no_use_public_ips, which has the dest=use_public_ips + # different from flag name + flag_that_disables_the_option = (_FLAG_THAT_SETS_FALSE_VALUE[k]) + flags.append('--%s' % flag_that_disables_the_option) elif isinstance(v, list): for i in v: flags.append('--%s=%s' % (k, i)) @@ -353,6 +379,7 @@ def view_as(self, cls): """ view = cls(self._flags) + for option_name in view._visible_option_list(): # Initialize values of keys defined by a cls. # diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py index 9392055799c..3195f3594fe 100644 --- a/sdks/python/apache_beam/options/pipeline_options_test.py +++ b/sdks/python/apache_beam/options/pipeline_options_test.py @@ -31,11 +31,14 @@ from apache_beam.options.pipeline_options import ProfilingOptions from apache_beam.options.pipeline_options import TypeOptions from apache_beam.options.pipeline_options import WorkerOptions +from apache_beam.options.pipeline_options import _BeamArgumentParser from apache_beam.options.value_provider import RuntimeValueProvider from apache_beam.options.value_provider import StaticValueProvider from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher +_LOGGER = logging.getLogger(__name__) + class PipelineOptionsTest(unittest.TestCase): def setUp(self): @@ -647,6 +650,53 @@ def test_dataflow_service_options(self): self.assertEqual( options.get_all_options()['dataflow_service_options'], None) + def test_options_store_false_with_different_dest(self): + parser = _BeamArgumentParser() + for cls in PipelineOptions.__subclasses__(): + cls._add_argparse_args(parser) + + actions = parser._actions.copy() + options_to_flags = {} + options_diff_dest_store_true = {} + + for i in range(len(actions)): + flag_names = actions[i].option_strings + option_name = actions[i].dest + + if isinstance(actions[i].const, bool): + for flag_name in flag_names: + flag_name = flag_name.strip('-') + if flag_name != option_name: + # Capture flags which has store_action=True and has a + # different dest. This behavior would be confusing. + if actions[i].const: + options_diff_dest_store_true[flag_name] = option_name + continue + # check the flags like no_use_public_ips + # default is None, action is {True, False} + if actions[i].default is None: + options_to_flags[option_name] = flag_name + + self.assertEqual( + len(options_diff_dest_store_true), + 0, + _LOGGER.error( + "There should be no flags that have a dest " + "different from flag name and action as " + "store_true. It would be confusing " + "to the user. Please specify the dest as the " + "flag_name instead.")) + from apache_beam.options.pipeline_options import ( + _FLAG_THAT_SETS_FALSE_VALUE) + + self.assertDictEqual( + _FLAG_THAT_SETS_FALSE_VALUE, + options_to_flags, + "If you are adding a new boolean flag with default=None," + " with different dest/option_name from the flag name, please add " + "the dest and the flag name to the map " + "_FLAG_THAT_SETS_FALSE_VALUE in PipelineOptions.py") + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)