Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transaction method for multi-stage updates #123

Merged
merged 4 commits into from
Jan 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

name: Build Package and Test Source Code [Python 3.6, 3.7, 3.8]

on:
push:
branches:
- master
pull_request: {}

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]

steps:
- name: Checkout
uses: actions/checkout@master
with:
persist-credentials: false

- name: Setup Miniconda using Python ${{ matrix.python-version }}
uses: conda-incubator/setup-miniconda@v2
with:
activate-environment: paramtools-dev
environment-file: environment.yml
python-version: ${{ matrix.python-version }}
auto-activate-base: false

- name: Build
shell: bash -l {0}
run: |
pip install -e .
- name: Test
shell: bash -l {0}
working-directory: ./
run: |
pytest paramtools -v -s
2 changes: 1 addition & 1 deletion docs/api/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Parameters
.. currentmodule:: paramtools.parameters

.. autoclass:: Parameters
:members: adjust, read_params, set_state, view_state, clear_state, specification, extend, extend_func, to_array, from_array, parse_labels, sort_values
:members: adjust, read_params, set_state, view_state, clear_state, specification, extend, extend_func, to_array, from_array, parse_labels, sort_values, validate, transaction

Values
------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions paramtools/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,12 @@ class InconsistentLabelsException(ParamToolsError):
"_validator_schema",
"_defaults_schema",
"_select",
"_defer_validation",
"operators",
"adjust",
"delete",
"validate",
"transaction",
"array_first",
"clear_state",
"defaults",
Expand Down
116 changes: 105 additions & 11 deletions paramtools/parameters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import copy
import itertools
from collections import OrderedDict, defaultdict
from functools import partial, reduce
from contextlib import contextmanager
import functools
from typing import Optional, Dict, List, Any, Union, Mapping
import warnings

Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
self._validator_schema.context["spec"] = self
self._warnings = {}
self._errors = {}
self._defer_validation = False
self._state = self.parse_labels(**(initial_state or {}))
self.index_rates = index_rates or self.index_rates
self.sel = ParameterSlice(self)
Expand Down Expand Up @@ -180,7 +182,7 @@ def view_state(self):
"""
Access the label state of the ``Parameters`` instance.
"""
return self._state
return {label: value for label, value in self._state.items()}

def read_params(
self,
Expand Down Expand Up @@ -266,18 +268,19 @@ def _adjust(
ignore_warnings=False,
raise_errors=True,
extend_adj=True,
is_deserialized=False,
deserialized=False,
validate=True,
clobber=True,
):
"""
Internal method for performing adjustments.
"""
# Validate user adjustments.
if is_deserialized:
if deserialized:
parsed_params = {}
try:
parsed_params = self._validator_schema.load(
params_or_path, ignore_warnings, is_deserialized=True
params_or_path, ignore_warnings, deserialized=True
)
except MarshmallowValidationError as ve:
self._parse_validation_messages(ve.messages, params_or_path)
Expand Down Expand Up @@ -340,7 +343,6 @@ def _adjust(
raise_errors=True,
ignore_warnings=ignore_warnings,
)

# set user adjustments.
self._adjust(
parsed_params,
Expand Down Expand Up @@ -377,6 +379,98 @@ def _adjust(

return parsed_params

@contextmanager
def transaction(
self, defer_validation=True, raise_errors=False, ignore_warnings=False
):
"""
Rollback any changes to parameter state after the context block closes.

.. code-block:: Python

import paramtools

class Params(paramtools.Parameters):
defaults = {
"min_param": {
"title": "Min param",
"description": "Must be less than 'max_param'",
"type": "int",
"value": 2,
"validators": {
"range": {"max": "max_param"}
}
},
"max_param": {
"title": "Max param",
"type": "int",
"value": 3
}
}

params = Params()
with params.transaction():
params.adjust({"min_param": 4})
params.adjust({"max_param": 5})


**Parameters:**
- `defer_validation`: Defer schema-level validation until the end of the block.
- `ignore_warnings`: Whether to raise an error on warnings or ignore them.
- `raise_errors`: Either raise errors or simply store the error messages.
"""
_data = copy.deepcopy(self._data)
_ops = dict(self.operators)
_state = dict(self.view_state())

try:
self._defer_validation = defer_validation
yield self
except Exception as e:
self._data = _data
raise e
finally:
self._state = _state
self._ops = _ops
self._defer_validation = False
if defer_validation:
self.validate(
self.specification(use_state=False, meta_data=False),
ignore_warnings=ignore_warnings,
raise_errors=raise_errors,
)

def validate(self, params, raise_errors=True, ignore_warnings=False):
"""
Validate parameter adjustment without modifying existing values.

For example, validate the current parameter values:

.. code-block:: Python

params.validate(
params.specification(use_state=False)
)

**Parameters:**
- `params`: Parameters to validate.
- `ignore_warnings`: Whether to raise an error on warnings or ignore them.
- `raise_errors`: Either raise errors or simply store the error messages.
"""
try:
self._validator_schema.load(
params, ignore_warnings, deserialized=True
)
except MarshmallowValidationError as ve:
self._parse_validation_messages(ve.messages, params)

has_errors = bool(self._errors.get("messages"))
has_warnings = bool(self._warnings.get("messages"))
if (raise_errors and has_errors) or (
not ignore_warnings and has_warnings
):
raise self.validation_error

def delete(
self,
params_or_path,
Expand Down Expand Up @@ -633,7 +727,7 @@ def to_array(self, param, **labels):
return value
else:
return data_type(value)
exp_full_shape = reduce(lambda x, y: x * y, shape)
exp_full_shape = functools.reduce(lambda x, y: x * y, shape)
act_full_shape = len(value_items)
if act_full_shape != exp_full_shape:
# maintains label value order over value objects.
Expand Down Expand Up @@ -866,14 +960,14 @@ def extend(
extended[val].append(ext)
skl.add(val)
adjustment[param].append(OrderedDict(ext, _auto=True))
# Ensure that the adjust method of paramtools.Parameter is used
# Ensure that the adjust method of paramtools.Parameters is used
# in case the child class also implements adjust.
self._adjust(
return self._adjust(
adjustment,
extend_adj=False,
ignore_warnings=ignore_warnings,
raise_errors=raise_errors,
is_deserialized=True,
deserialized=True,
)

def extend_func(
Expand Down Expand Up @@ -1280,7 +1374,7 @@ def keyfunc(vo, label, label_values):
for param in data:
for label in reversed(label_grid):
label_values = label_grid[label]
pfunc = partial(
pfunc = functools.partial(
keyfunc, label=label, label_values=label_values
)
if has_meta_data:
Expand Down
Loading