Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convenience functions for Matrix, Vector and NormalizationParameters #161

Merged
merged 9 commits into from
Jan 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions ompy/matrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import logging
import warnings
import copy
import numpy as np
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -201,30 +202,50 @@ def load(self, path: Union[str, Path],
self.verify_integrity()

def save(self, path: Union[str, Path], filetype: Optional[str] = None,
**kwargs):
which: Optional[str] = 'values', **kwargs):
"""Save matrix to file

Args:
path (str or Path): path to file to save
filetype (str, optional): Filetype to save. Has an
auto-recognition. Options: ["numpy", "tar", "mama", "txt"]
which (str, optional): Which attribute to save. Default is
'values'. Options: ["values", "std"]
**kwargs: additional keyword arguments

Raises:
ValueError: If filetype is unknown
RuntimeError: If `std` attribute not set.
NotImplementedError: If which is unknown
"""
path = Path(path) if isinstance(path, str) else path
if filetype is None:
filetype = filetype_from_suffix(path)
filetype = filetype.lower()

values = None
if which.lower() == 'values':
values = self.values
if self.std is not None:
warnings.warn(UserWarning("The std attribute of Matrix class has to be saved to file 'manually'. Call with which='std'.")) # noqa
elif which.lower() == 'std':
if self.std is None:
raise RuntimeError(f"Attribute `std` not set.")
values = self.std
else:
raise NotImplementedError(
f"{which} is unsupported: Use 'values' or 'std'")

if filetype == 'numpy':
save_numpy_2D(self.values, self.Eg, self.Ex, path)
save_numpy_2D(values, self.Eg, self.Ex, path)
elif filetype == 'txt':
save_txt_2D(self.values, self.Eg, self.Ex, path, **kwargs)
save_txt_2D(values, self.Eg, self.Ex, path, **kwargs)
elif filetype == 'tar':
save_tar([self.values, self.Eg, self.Ex], path)
save_tar([values, self.Eg, self.Ex], path)
elif filetype == 'mama':
if which.lower() == 'std':
warnings.warn(UserWarning(
"Cannot write std attrbute to MaMa format."))

mama_write(self, path, comment="Made by OMpy",
**kwargs)
else:
Expand Down Expand Up @@ -307,7 +328,7 @@ def plot(self, *, ax: Any = None,
ax.tick_params(axis='x', rotation=40)
ax.yaxis.set_major_locator(MeshLocator(self.Ex))
# ax.xaxis.set_major_locator(ticker.FixedLocator(self.Eg, nbins=10))
#fix_pcolormesh_ticks(ax, xvalues=self.Eg, yvalues=self.Ex)
# fix_pcolormesh_ticks(ax, xvalues=self.Eg, yvalues=self.Ex)

ax.set_title(title if title is not None else self.state)
ax.set_xlabel(r"$\gamma$-ray energy $E_{\gamma}$")
Expand Down Expand Up @@ -612,7 +633,8 @@ def line_mask(self, E1: Iterable[float],
def trapezoid(self, Ex_min: float, Ex_max: float,
Eg_min: float, Eg_max: Optional[float] = None,
inplace: bool = True) -> Optional[Matrix]:
"""Create a trapezoidal cut or mask delimited by the diagonal of the matrix
"""Create a trapezoidal cut or mask delimited by the diagonal of the
matrix

Args:
Ex_min: The bottom edge of the trapezoid
Expand Down Expand Up @@ -729,8 +751,8 @@ def diagonal_elements(self) -> Iterator[Tuple[int, int]]:
entries with `Eg > Ex + dE`.
Args:
mat: The matrix to iterate over
Iterator[Tuple[int, int]]: Indicies (i, j) over the last non-zero (=diagonal)
elements.
Iterator[Tuple[int, int]]: Indicies (i, j) over the last
non-zero(=diagonal) elements.
"""
return diagonal_elements(self.values)

Expand Down Expand Up @@ -863,6 +885,7 @@ def __matmul__(self, other: Matrix) -> Matrix:
result.values = [email protected]
return result


class MeshLocator(ticker.Locator):
def __init__(self, locs, nbins=10):
'place ticks on the i-th data points where (i-offset)%base==0'
Expand Down
46 changes: 45 additions & 1 deletion ompy/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import re
import pickle
import warnings
from pathlib import Path
from dataclasses import dataclass, field, fields, asdict
from typing import Optional, Union, Tuple, Any, Dict, Callable, List
Expand Down Expand Up @@ -391,6 +392,12 @@ class NormalizationParameters(Model):
"""Storage for normalization parameters + some convenience functions
"""

#: Element number of the nucleus
Z: Optional[int] = field(default=None,
metadata="Element number of the nucleus") # noqa
#: Mass number of the nucleus
_A: Optional[int] = field(default=None,
metadata="Mass number of the nucleus") # noqa
#: Average s-wave resonance spacing D0 [eV]
D0: Optional[Tuple[float, float]] = field(default=None,
metadata='Average s-wave resonance spacing D0 [eV]') # noqa
Expand All @@ -416,7 +423,7 @@ class NormalizationParameters(Model):
spincutModel: str = field(default=None,
metadata='Spincut model') # noqa
#: Parameters necessary for the spin cut model
spincutPars: Dict[str, Any] = field(default=None,
_spincutPars: Dict[str, Any] = field(default=None,
metadata='parameters necessary for the spin cut model') # noqa

def E_grid(self,
Expand All @@ -433,6 +440,43 @@ def E_grid(self,
return np.linspace(self.Emin, self.Sn[0], num=self.steps,
retstep=retstep)

@property
def spinMass(self) -> Union[int, None]:
try:
mass = self._spincutPars['mass']
return mass
except: # noqa
return None

@property
def spincutPars(self) -> Dict[str, Any]:
return self._spincutPars

@spincutPars.setter
def spincutPars(self, value: Dict[str, Any]):
try:
mass = value['mass']
if self._A is not None:
if mass != self._A:
warnings.warn(UserWarning("mass number set in `spincutPars` does not match `A`.")) # noqa
except KeyError:
pass
self._spincutPars = value

@property
def A(self) -> Union[int, None]:
if self._A is None:
return self.spinMass
return self._A

@A.setter
def A(self, value: int) -> None:
self._A = value
if self.spinMass is not None:
if self.spinMass != value:
warnings.warn(UserWarning("mass number set in `spincutPars` does not match `A`.")) # noqa
self._A = value

@property
def Emax(self) -> float:
""" Max energy to integrate <Γγ> to """
Expand Down
3 changes: 3 additions & 0 deletions ompy/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def __init__(self, values: Optional[Iterable[float]] = None,
self.load(path)
self.verify_integrity()

def __len__(self):
return len(self.values)

def verify_integrity(self, check_equidistant: bool = False):
""" Verify the internal consistency of the vector

Expand Down
42 changes: 41 additions & 1 deletion tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ def ones(shape: Tuple[int, int]) -> om.Matrix:
mat = np.tril(mat)
return om.Matrix(values=mat)


@pytest.fixture()
def Si28():
return om.example_raw('Si28')


@pytest.mark.parametrize(
"axis,Emin,Emax,shape",
[('Ex', None, None, (10, 10)),
Expand Down Expand Up @@ -78,7 +80,7 @@ def test_index(E, index):
assert mat.index_Ex(E) == index


@pytest.mark.filterwarnings('ignore:divide by zero encountered in true_divide:RuntimeWarning')
@pytest.mark.filterwarnings('ignore:divide by zero encountered in true_divide:RuntimeWarning') # noqa
def test_numericals():
E = np.array([0, 1, 2])
values1 = np.array([[0, 1, 2.], [-2, 1, 2.], [2, 3, -10.]])
Expand All @@ -98,6 +100,7 @@ def test_numericals():
assert_equal((matrix2@matrix1).values, values2@values1)
assert_equal((matrix1@matrix2).values, values1@values2)


@pytest.mark.parametrize(
"Ex,Eg",
[(np.linspace(0, 10., num=10), np.linspace(10, 20., num=15)),
Expand All @@ -120,6 +123,43 @@ def test_bin_shift(Ex, Eg):
assert_almost_equal(Ex, mat.Ex)
assert_almost_equal(Eg, mat.Eg)


@pytest.mark.parametrize(
"Ex,Eg",
[(np.linspace(0, 10., num=10), np.linspace(10, 20., num=15)),
([0, 1, 2, 3, 7, 10.], [0, 1, 2, 3, 80, 90.])
])
def test_save_warning(Ex, Eg):
values = np.ones((len(Ex), len(Eg)), dtype="float")
mat = om.Matrix(values=values, Ex=Ex, Eg=Eg, std=0.5*values)
with pytest.warns(UserWarning):
mat.save("/tmp/mat.npy")


@pytest.mark.parametrize(
"Ex,Eg",
[(np.linspace(0, 10., num=10), np.linspace(10, 20., num=15)),
([0, 1, 2, 3, 7, 10.], [0, 1, 2, 3, 80, 90.])
])
def test_save_std_exception(Ex, Eg):
values = np.ones((len(Ex), len(Eg)), dtype="float")
mat = om.Matrix(values=values, Ex=Ex, Eg=Eg)
with pytest.raises(RuntimeError):
mat.save("/tmp/mat.npy", which='std')


@pytest.mark.parametrize(
"Ex,Eg",
[(np.linspace(0, 10., num=10), np.linspace(10, 20., num=15)),
([0, 1, 2, 3, 7, 10.], [0, 1, 2, 3, 80, 90.])
])
def test_save_which_error(Ex, Eg):
values = np.ones((len(Ex), len(Eg)), dtype="float")
mat = om.Matrix(values=values, Ex=Ex, Eg=Eg, std=0.5*values)
with pytest.raises(NotImplementedError):
mat.save("/tmp/mat.npy", which='Im not real')


# This does not work as of now...
# def test_mutable():
# E = np.array([0, 1, 2])
Expand Down
40 changes: 40 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
import ompy as om
import numpy as np
from numpy.testing import assert_equal, assert_almost_equal
from typing import Tuple


def test_set_mass():

normpar = om.NormalizationParameters(name="Test")

normpar.A = 16
normpar.Z = 8

normpar.spincutModel = 'EB09_CT'
normpar.spincutPars = {'mass': 16}

with pytest.warns(UserWarning):
normpar.A = 17

with pytest.warns(UserWarning):
normpar.spincutPars = {'mass': 18}

# We expect this to NOT trigger a warning.
with pytest.warns(None) as record:
A = normpar.A
assert len(record) == 0

# We do not expect this to trigger a warning.
with pytest.warns(None) as record:
spinpar = normpar.spincutPars
assert len(record) == 0

# Neither should this
with pytest.warns(None) as record:
normpar.spincutPars = {'sigma': 2.9}
assert len(record) == 0

if __name__ == "__main__":
test_set_A()
11 changes: 10 additions & 1 deletion tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ def test_init():
om.Vector(vals, [1, 2, 3, 4, 5])


def test_len():
N = 100
E = np.linspace(0, 1, N)
vals = np.linspace(2, 3.4, N)
vec = om.Vector(values=vals, E=E)
assert_equal(len(vec), N)


def test_save_load_no_std():
E = np.linspace(0, 1, 100)
vals = np.linspace(2, 3.4, 100)
Expand Down Expand Up @@ -184,7 +192,8 @@ def test_cut():
assert_equal(vector.E, Ecut)
assert_equal(vector.values, valcut)

@pytest.mark.filterwarnings('ignore:divide by zero encountered in true_divide:RuntimeWarning')

@pytest.mark.filterwarnings('ignore:divide by zero encountered in true_divide:RuntimeWarning') # noqa
def test_numericals():
E = np.array([0, 1, 2])
values1 = np.array([0, 1, -2.])
Expand Down