From dc14c613957848585900db217098f379813ff0d5 Mon Sep 17 00:00:00 2001 From: hdoupe Date: Tue, 12 Jan 2021 11:32:34 -0500 Subject: [PATCH 1/4] Add transaction function for multi-stage updates - Rollback to previous state after errors/exceptions - Optionally defer schema-level validation until the update is complete - Change `is_deserialized` arg to `deserialized` --- paramtools/__init__.py | 3 +- paramtools/exceptions.py | 2 + paramtools/parameters.py | 92 +++++++++++++++++++++++++---- paramtools/schema.py | 72 +++++++++++++--------- paramtools/tests/test_parameters.py | 24 +++++++- 5 files changed, 150 insertions(+), 43 deletions(-) diff --git a/paramtools/__init__.py b/paramtools/__init__.py index e90ee2c..e167a36 100644 --- a/paramtools/__init__.py +++ b/paramtools/__init__.py @@ -9,7 +9,7 @@ ParameterNameCollisionException, UnknownTypeException, ) -from paramtools.parameters import Parameters +from paramtools.parameters import Parameters, transaction from paramtools.schema import ( RangeSchema, ChoiceSchema, @@ -66,6 +66,7 @@ "ParameterNameCollisionException", "UnknownTypeException", "Parameters", + "transaction", "RangeSchema", "ChoiceSchema", "ValueValidatorSchema", diff --git a/paramtools/exceptions.py b/paramtools/exceptions.py index d6fb33a..860708f 100644 --- a/paramtools/exceptions.py +++ b/paramtools/exceptions.py @@ -58,9 +58,11 @@ class InconsistentLabelsException(ParamToolsError): "_validator_schema", "_defaults_schema", "_select", + "_defer_validation", "operators", "adjust", "delete", + "validate", "array_first", "clear_state", "defaults", diff --git a/paramtools/parameters.py b/paramtools/parameters.py index a7b20c5..dc85772 100644 --- a/paramtools/parameters.py +++ b/paramtools/parameters.py @@ -1,7 +1,8 @@ import copy import itertools from collections import OrderedDict, defaultdict -from functools import partial, reduce +from contextlib import contextmanager +import functools from typing import Optional, Dict, List, Any, Union, Mapping import warnings @@ -103,6 +104,7 @@ def __init__( self._validator_schema.context["spec"] = self self._warnings = {} self._errors = {} + self._defer_validation = False self._state = self.parse_labels(**(initial_state or {})) self.index_rates = index_rates or self.index_rates self.sel = ParameterSlice(self) @@ -180,7 +182,7 @@ def view_state(self): """ Access the label state of the ``Parameters`` instance. """ - return self._state + return {label: value for label, value in self._state.items()} def read_params( self, @@ -266,18 +268,19 @@ def _adjust( ignore_warnings=False, raise_errors=True, extend_adj=True, - is_deserialized=False, + deserialized=False, + validate=True, clobber=True, ): """ Internal method for performing adjustments. """ # Validate user adjustments. - if is_deserialized: + if deserialized: parsed_params = {} try: parsed_params = self._validator_schema.load( - params_or_path, ignore_warnings, is_deserialized=True + params_or_path, ignore_warnings, deserialized=True ) except MarshmallowValidationError as ve: self._parse_validation_messages(ve.messages, params_or_path) @@ -340,7 +343,6 @@ def _adjust( raise_errors=True, ignore_warnings=ignore_warnings, ) - # set user adjustments. self._adjust( parsed_params, @@ -377,6 +379,37 @@ def _adjust( return parsed_params + def validate(self, params, raise_errors=True, ignore_warnings=False): + """ + Validate parameter adjustment without modifying existing values. + + For example, validate the current parameter values: + + ``` + params.validate( + params.specification(use_state=False) + ) + ``` + + Parameters: + - `params`: Parameters to validate. + - `ignore_warnings`: Whether to raise an error on warnings or ignore them. + - `raise_errors`: Either raise errors or simply store the error messages. + """ + try: + self._validator_schema.load( + params, ignore_warnings, deserialized=True + ) + except MarshmallowValidationError as ve: + self._parse_validation_messages(ve.messages, params) + + has_errors = bool(self._errors.get("messages")) + has_warnings = bool(self._warnings.get("messages")) + if (raise_errors and has_errors) or ( + not ignore_warnings and has_warnings + ): + raise self.validation_error + def delete( self, params_or_path, @@ -633,7 +666,7 @@ def to_array(self, param, **labels): return value else: return data_type(value) - exp_full_shape = reduce(lambda x, y: x * y, shape) + exp_full_shape = functools.reduce(lambda x, y: x * y, shape) act_full_shape = len(value_items) if act_full_shape != exp_full_shape: # maintains label value order over value objects. @@ -866,14 +899,14 @@ def extend( extended[val].append(ext) skl.add(val) adjustment[param].append(OrderedDict(ext, _auto=True)) - # Ensure that the adjust method of paramtools.Parameter is used + # Ensure that the adjust method of paramtools.Parameters is used # in case the child class also implements adjust. - self._adjust( + return self._adjust( adjustment, extend_adj=False, ignore_warnings=ignore_warnings, raise_errors=raise_errors, - is_deserialized=True, + deserialized=True, ) def extend_func( @@ -1280,7 +1313,7 @@ def keyfunc(vo, label, label_values): for param in data: for label in reversed(label_grid): label_values = label_grid[label] - pfunc = partial( + pfunc = functools.partial( keyfunc, label=label, label_values=label_values ) if has_meta_data: @@ -1309,3 +1342,40 @@ def keyfunc(vo, label, label_values): )[param] setattr(self, param, sorted_values) return data + + +@contextmanager +def transaction( + params: Parameters, + defer_validation: bool = False, + ignore_warnings: bool = False, + raise_errors: bool = True, +): + """ + Rollback any changes to parameter state after the context block closes. + + Parameters: + - `defer_validation`: Defer schema-level validation until the end of the block. + - `ignore_warnings`: Whether to raise an error on warnings or ignore them. + - `raise_errors`: Either raise errors or simply store the error messages. + """ + _data = copy.deepcopy(params._data) + _ops = dict(params.operators) + _state = dict(params.view_state()) + + try: + params._defer_validation = defer_validation + yield params + except Exception as e: + params._data = _data + raise e + finally: + params._state = _state + params._ops = _ops + params._defer_validation = False + if defer_validation: + params.validate( + params.specification(use_state=False, meta_data=False), + ignore_warnings=ignore_warnings, + raise_errors=raise_errors, + ) diff --git a/paramtools/schema.py b/paramtools/schema.py index 444e465..5818c1c 100644 --- a/paramtools/schema.py +++ b/paramtools/schema.py @@ -233,10 +233,10 @@ def validate_only(self, data): raise exc return data - def load(self, data, ignore_warnings, is_deserialized=False): + def load(self, data, ignore_warnings, deserialized=False): self.ignore_warnings = ignore_warnings try: - if is_deserialized: + if deserialized: return self.validate_only(data) else: return super().load(data) @@ -270,7 +270,12 @@ def validate_param(self, param_name, param_spec, raw_data): """ Do range validation for a parameter. """ - validators = self.validators(param_name, param_spec, raw_data) + validate_schema = not getattr( + self.context["spec"], "_defer_validation", False + ) + validators = self.validators( + param_name, param_spec, raw_data, validate_schema=validate_schema + ) warnings = [] errors = [] for validator in validators: @@ -296,7 +301,9 @@ def field(self, param_name): data = self.context["spec"]._data[param_name] return get_type(data, self.validators(param_name)) - def validators(self, param_name, param_spec=None, raw_data=None): + def validators( + self, param_name, param_spec=None, raw_data=None, validate_schema=True + ): if param_spec is None: param_spec = {} if raw_data is None: @@ -310,7 +317,12 @@ def validators(self, param_name, param_spec=None, raw_data=None): if vname == "range" and param_info.get("type", None) in ("date",): vname = "date_range" validator = getattr(self, self.WRAPPER_MAP[vname])( - vname, vdata, param_name, param_spec, raw_data + vname, + vdata, + param_name, + param_spec, + raw_data, + validate_schema=validate_schema, ) validators.append(validator) return validators @@ -323,7 +335,10 @@ def _get_when_validator( param_spec, raw_data, ndim_restriction=False, + validate_schema=True, ): + if not validate_schema: + return when_param = when_dict["param"] if ( @@ -334,7 +349,7 @@ def _get_when_validator( f"'{when_param}' is not a specified parameter." ) - oth_param, when_vos = self._resolve_op_value( + oth_param, when_vos = self._get_related_value( when_param, param_name, param_spec, raw_data ) then_validators = [] @@ -393,6 +408,7 @@ def _get_range_validator( param_spec, raw_data, ndim_restriction=False, + validate_schema=True, ): if vname == "range": range_class = contrib.validate.Range @@ -403,20 +419,26 @@ def _get_range_validator( f"{vname} is not an allowed validator." ) min_value = range_dict.get("min", None) - if min_value is not None: - min_oth_param, min_vos = self._resolve_op_value( + is_related_param = min_value == "default" or min_value in self.fields + if min_value is None or (is_related_param and not validate_schema): + min_oth_param, min_vos = None, [] + elif is_related_param and validate_schema: + min_oth_param, min_vos = self._get_related_value( min_value, param_name, param_spec, raw_data ) else: - min_oth_param, min_vos = None, [] + min_oth_param, min_vos = None, [{"value": min_value}] max_value = range_dict.get("max", None) - if max_value is not None: - max_oth_param, max_vos = self._resolve_op_value( + is_related_param = max_value == "default" or max_value in self.fields + if max_value is None or (is_related_param and not validate_schema): + max_oth_param, max_vos = None, [] + elif is_related_param and validate_schema: + max_oth_param, max_vos = self._get_related_value( max_value, param_name, param_spec, raw_data ) else: - max_oth_param, max_vos = None, [] + max_oth_param, max_vos = None, [{"value": max_value}] self._check_ndim_restriction( param_name, min_oth_param, @@ -425,8 +447,14 @@ def _get_range_validator( ) min_vos = self._sort_by_label_to_extend(min_vos) max_vos = self._sort_by_label_to_extend(max_vos) - error_min = f"{param_name}{{labels}} {{input}} < min {{min}} {min_oth_param}{{oth_labels}}" - error_max = f"{param_name}{{labels}} {{input}} > max {{max}} {max_oth_param}{{oth_labels}}" + error_min = ( + f"{param_name}{{labels}} {{input}} < min {{min}} " + f"{min_oth_param or ''}{{oth_labels}}" + ).strip() + error_max = ( + f"{param_name}{{labels}} {{input}} > max {{max}} " + f"{max_oth_param or ''}{{oth_labels}}" + ).strip() return range_class( min_vo=min_vos, max_vo=max_vos, @@ -460,6 +488,7 @@ def _get_choice_validator( param_spec, raw_data, ndim_restriction=False, + validate_schema=True, ): choices = choice_dict["choices"] labels = utils.make_label_str(param_spec) @@ -482,20 +511,7 @@ def _get_choice_validator( choices, error=error, level=choice_dict.get("level") ) - def _resolve_op_value(self, op_value, param_name, param_spec, raw_data): - """ - Operator values (`op_value`) are the values pointed to by the "min" - and "max" keys. These can be values to compare against, another - variable to compare against, or the default value of the adjusted - variable. - """ - if op_value in self.fields or op_value == "default": - return self._get_comparable_value( - op_value, param_name, param_spec, raw_data - ) - return "", [{"value": op_value}] - - def _get_comparable_value( + def _get_related_value( self, oth_param_name, param_name, param_spec, raw_data ): """ diff --git a/paramtools/tests/test_parameters.py b/paramtools/tests/test_parameters.py index 3441040..5980f2d 100644 --- a/paramtools/tests/test_parameters.py +++ b/paramtools/tests/test_parameters.py @@ -20,6 +20,7 @@ Parameters, Values, Slice, + transaction, ) from paramtools.contrib import Bool_ @@ -656,6 +657,23 @@ def test_simultaneous_adjust(self, TestParams): assert params.min_int_param == adjustment["min_int_param"] assert params.max_int_param == adjustment["max_int_param"] + def test_transaction(self, TestParams): + """ + Use transaction manager to defer schema level validation until all adjustments + are complete. + """ + params = TestParams() + params.set_state(label0="zero", label1=1) + adjustment = { + "min_int_param": [{"label0": "zero", "label1": 1, "value": 4}], + "max_int_param": [{"label0": "zero", "label1": 1, "value": 5}], + } + with transaction(params, defer_validation=True): + params.adjust({"min_int_param": adjustment["min_int_param"]}) + params.adjust({"max_int_param": adjustment["max_int_param"]}) + assert params.min_int_param == adjustment["min_int_param"] + assert params.max_int_param == adjustment["max_int_param"] + def test_adjust_many_labels(self, TestParams): """ Adjust min_int_param above original max_int_param value at same time as @@ -1220,9 +1238,9 @@ class Params(Parameters): with pytest.raises(ValidationError): params.adjust({"param": params.when_param - 1}) - def test_is_deserialized(self, TestParams): + def test_deserialized(self, TestParams): params = TestParams() - params._adjust({"min_int_param": [{"value": 1}]}, is_deserialized=True) + params._adjust({"min_int_param": [{"value": 1}]}, deserialized=True) assert params.min_int_param == [ {"label0": "zero", "label1": 1, "value": 1}, {"label0": "one", "label1": 2, "value": 1}, @@ -1231,7 +1249,7 @@ def test_is_deserialized(self, TestParams): params._adjust( {"min_int_param": [{"value": -1}]}, raise_errors=False, - is_deserialized=True, + deserialized=True, ) assert params.errors["min_int_param"] == ["min_int_param -1 < min 0 "] From 699ef0805f265d5040a4809f845e45f0102189a8 Mon Sep 17 00:00:00 2001 From: hdoupe Date: Tue, 12 Jan 2021 12:04:12 -0500 Subject: [PATCH 2/4] Refactor transaction function to be a method on the Parameters class - Update docs strings --- docs/api/reference.rst | 2 +- paramtools/__init__.py | 3 +- paramtools/exceptions.py | 1 + paramtools/parameters.py | 116 +++++++++++++++++----------- paramtools/tests/test_parameters.py | 3 +- 5 files changed, 74 insertions(+), 51 deletions(-) diff --git a/docs/api/reference.rst b/docs/api/reference.rst index ae95b96..b9b9f93 100644 --- a/docs/api/reference.rst +++ b/docs/api/reference.rst @@ -11,7 +11,7 @@ Parameters .. currentmodule:: paramtools.parameters .. autoclass:: Parameters - :members: adjust, read_params, set_state, view_state, clear_state, specification, extend, extend_func, to_array, from_array, parse_labels, sort_values + :members: adjust, read_params, set_state, view_state, clear_state, specification, extend, extend_func, to_array, from_array, parse_labels, sort_values, validate, transaction Values ------------------------------------------ diff --git a/paramtools/__init__.py b/paramtools/__init__.py index e167a36..e90ee2c 100644 --- a/paramtools/__init__.py +++ b/paramtools/__init__.py @@ -9,7 +9,7 @@ ParameterNameCollisionException, UnknownTypeException, ) -from paramtools.parameters import Parameters, transaction +from paramtools.parameters import Parameters from paramtools.schema import ( RangeSchema, ChoiceSchema, @@ -66,7 +66,6 @@ "ParameterNameCollisionException", "UnknownTypeException", "Parameters", - "transaction", "RangeSchema", "ChoiceSchema", "ValueValidatorSchema", diff --git a/paramtools/exceptions.py b/paramtools/exceptions.py index 860708f..fe8b9ef 100644 --- a/paramtools/exceptions.py +++ b/paramtools/exceptions.py @@ -63,6 +63,7 @@ class InconsistentLabelsException(ParamToolsError): "adjust", "delete", "validate", + "transaction", "array_first", "clear_state", "defaults", diff --git a/paramtools/parameters.py b/paramtools/parameters.py index dc85772..eff211e 100644 --- a/paramtools/parameters.py +++ b/paramtools/parameters.py @@ -379,22 +379,83 @@ def _adjust( return parsed_params + @contextmanager + def transaction( + self, defer_validation=True, raise_errors=False, ignore_warnings=False + ): + """ + Rollback any changes to parameter state after the context block closes. + + .. code-block:: Python + + import paramtools + + class Params(paramtools.Parameters): + defaults = { + "min_param": { + "title": "Min param", + "description": "Must be less than 'max_param'", + "type": "int", + "value": 2, + "validators": { + "range": {"max": "max_param"} + } + }, + "max_param": { + "title": "Max param", + "type": "int", + "value": 3 + } + } + + params = Params() + with params.transaction(): + params.adjust({"min_param": 4}) + params.adjust({"max_param": 5}) + + + **Parameters:** + - `defer_validation`: Defer schema-level validation until the end of the block. + - `ignore_warnings`: Whether to raise an error on warnings or ignore them. + - `raise_errors`: Either raise errors or simply store the error messages. + """ + _data = copy.deepcopy(self._data) + _ops = dict(self.operators) + _state = dict(self.view_state()) + + try: + self._defer_validation = defer_validation + yield self + except Exception as e: + self._data = _data + raise e + finally: + self._state = _state + self._ops = _ops + self._defer_validation = False + if defer_validation: + self.validate( + self.specification(use_state=False, meta_data=False), + ignore_warnings=ignore_warnings, + raise_errors=raise_errors, + ) + def validate(self, params, raise_errors=True, ignore_warnings=False): """ Validate parameter adjustment without modifying existing values. For example, validate the current parameter values: - ``` - params.validate( - params.specification(use_state=False) - ) - ``` + .. code-block:: Python - Parameters: - - `params`: Parameters to validate. - - `ignore_warnings`: Whether to raise an error on warnings or ignore them. - - `raise_errors`: Either raise errors or simply store the error messages. + params.validate( + params.specification(use_state=False) + ) + + **Parameters:** + - `params`: Parameters to validate. + - `ignore_warnings`: Whether to raise an error on warnings or ignore them. + - `raise_errors`: Either raise errors or simply store the error messages. """ try: self._validator_schema.load( @@ -1342,40 +1403,3 @@ def keyfunc(vo, label, label_values): )[param] setattr(self, param, sorted_values) return data - - -@contextmanager -def transaction( - params: Parameters, - defer_validation: bool = False, - ignore_warnings: bool = False, - raise_errors: bool = True, -): - """ - Rollback any changes to parameter state after the context block closes. - - Parameters: - - `defer_validation`: Defer schema-level validation until the end of the block. - - `ignore_warnings`: Whether to raise an error on warnings or ignore them. - - `raise_errors`: Either raise errors or simply store the error messages. - """ - _data = copy.deepcopy(params._data) - _ops = dict(params.operators) - _state = dict(params.view_state()) - - try: - params._defer_validation = defer_validation - yield params - except Exception as e: - params._data = _data - raise e - finally: - params._state = _state - params._ops = _ops - params._defer_validation = False - if defer_validation: - params.validate( - params.specification(use_state=False, meta_data=False), - ignore_warnings=ignore_warnings, - raise_errors=raise_errors, - ) diff --git a/paramtools/tests/test_parameters.py b/paramtools/tests/test_parameters.py index 5980f2d..017fddc 100644 --- a/paramtools/tests/test_parameters.py +++ b/paramtools/tests/test_parameters.py @@ -20,7 +20,6 @@ Parameters, Values, Slice, - transaction, ) from paramtools.contrib import Bool_ @@ -668,7 +667,7 @@ def test_transaction(self, TestParams): "min_int_param": [{"label0": "zero", "label1": 1, "value": 4}], "max_int_param": [{"label0": "zero", "label1": 1, "value": 5}], } - with transaction(params, defer_validation=True): + with params.transaction(defer_validation=True): params.adjust({"min_int_param": adjustment["min_int_param"]}) params.adjust({"max_int_param": adjustment["max_int_param"]}) assert params.min_int_param == adjustment["min_int_param"] From f16ba5c5cfaaf00b9276c427afe377fd4f423145 Mon Sep 17 00:00:00 2001 From: hdoupe Date: Tue, 12 Jan 2021 12:08:20 -0500 Subject: [PATCH 3/4] Use github actions for testing --- .github/workflows/test.yml | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..1a010ad --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,35 @@ + +name: Build Package and Test Source Code [Python 3.6, 3.7, 3.8] + +on: [push, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.6, 3.7, 3.8, 3.9] + + steps: + - name: Checkout + uses: actions/checkout@master + with: + persist-credentials: false + + - name: Setup Miniconda using Python ${{ matrix.python-version }} + uses: conda-incubator/setup-miniconda@v2 + with: + activate-environment: paramtools-dev + environment-file: environment.yml + python-version: ${{ matrix.python-version }} + auto-activate-base: false + + - name: Build + shell: bash -l {0} + run: | + pip install -e . + - name: Test + shell: bash -l {0} + working-directory: ./ + run: | + pytest paramtools -v -s \ No newline at end of file From 4168c627b4f3fff4cdfc78a77514460973c66bd9 Mon Sep 17 00:00:00 2001 From: hdoupe Date: Tue, 12 Jan 2021 12:10:28 -0500 Subject: [PATCH 4/4] Only test on push to master --- .github/workflows/test.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1a010ad..6a6bfa1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,7 +1,11 @@ name: Build Package and Test Source Code [Python 3.6, 3.7, 3.8] -on: [push, pull_request] +on: + push: + branches: + - master + pull_request: {} jobs: build: